You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/oauth/token.go

291 lines
9.4 KiB
Go

7 months ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-07-24 15:27:31
// Distributed under terms of the MIT license.
//
package oauth
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"time"
7 months ago
"github.com/veypi/OneAuth/cfg"
7 months ago
"github.com/vyes-ai/vigo"
7 months ago
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// TokenRequest 令牌请求参数
type TokenRequest struct {
GrantType string `json:"grant_type" src:"form" desc:"授权类型"`
Code string `json:"code" src:"form" desc:"授权码"`
RedirectURI string `json:"redirect_uri" src:"form" desc:"重定向URI"`
ClientID string `json:"client_id" src:"form" desc:"客户端ID"`
ClientSecret string `json:"client_secret" src:"form" desc:"客户端密钥"`
RefreshToken string `json:"refresh_token" src:"form" desc:"刷新令牌"`
CodeVerifier string `json:"code_verifier" src:"form" desc:"PKCE验证码"`
Username string `json:"username" src:"form" desc:"用户名"` // for password grant
Password string `json:"password" src:"form" desc:"密码"` // for password grant
Scope string `json:"scope" src:"form" desc:"权限范围"` // for password grant
7 months ago
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// handleToken 处理OAuth令牌请求
func handleToken(x *vigo.X, args *TokenRequest) (*TokenResponse, error) {
7 months ago
db := cfg.DB()
switch args.GrantType {
case GrantTypeAuthorizationCode:
return handleAuthorizationCodeGrant(db, x, args)
case GrantTypeRefreshToken:
return handleRefreshTokenGrant(db, x, args)
case GrantTypePassword:
return handlePasswordGrant(db, x, args)
default:
return nil, vigo.NewError("不支持的授权类型").WithCode(400)
7 months ago
}
}
// handleAuthorizationCodeGrant 处理授权码授权类型
func handleAuthorizationCodeGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) (*TokenResponse, error) {
7 months ago
// 1. 验证授权码
var authCode OAuthAuthorizationCode
if err := db.Where("code = ? AND used = ?", args.Code, false).First(&authCode).Error; err != nil {
return nil, vigo.NewError("无效的授权码").WithCode(400)
7 months ago
}
// 2. 检查授权码是否过期
if authCode.IsExpired() {
return nil, vigo.NewError("授权码已过期").WithCode(400)
7 months ago
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", authCode.ClientID, args.ClientID).First(&client).Error; err != nil {
return nil, vigo.NewError("无效的客户端").WithCode(400)
7 months ago
}
// 4. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return nil, vigo.NewError("无效的客户端凭据").WithCode(400)
7 months ago
}
// 5. 验证重定向URI
if authCode.RedirectURI != args.RedirectURI {
return nil, vigo.NewError("重定向URI不匹配").WithCode(400)
7 months ago
}
// 6. 验证PKCE如果使用
if authCode.CodeChallenge != "" {
if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, args.CodeVerifier); err != nil {
return nil, vigo.NewError("PKCE验证失败").WithError(err).WithCode(400)
7 months ago
}
}
// 7. 标记授权码为已使用
if err := db.Model(&authCode).Update("used", true).Error; err != nil {
return nil, vigo.NewError("授权码更新失败").WithError(err).WithCode(500)
7 months ago
}
// 8. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, authCode.UserID, authCode.Scope)
if err != nil {
return nil, vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
7 months ago
}
// 9. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, authCode.UserID, authCode.Scope)
if err != nil {
return nil, vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
7 months ago
}
return &TokenResponse{
7 months ago
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: authCode.Scope,
}, nil
7 months ago
}
// handleRefreshTokenGrant 处理刷新令牌授权类型
func handleRefreshTokenGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) (*TokenResponse, error) {
7 months ago
// 1. 验证刷新令牌
var refreshToken OAuthRefreshToken
if err := db.Where("token = ? AND revoked = ?", args.RefreshToken, false).First(&refreshToken).Error; err != nil {
return nil, vigo.NewError("无效的刷新令牌").WithCode(400)
7 months ago
}
// 2. 检查刷新令牌是否过期
if refreshToken.IsExpired() {
return nil, vigo.NewError("刷新令牌已过期").WithCode(400)
7 months ago
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", refreshToken.ClientID, args.ClientID).First(&client).Error; err != nil {
return nil, vigo.NewError("无效的客户端").WithCode(400)
7 months ago
}
// 4. 撤销旧的访问令牌
if err := db.Model(&OAuthAccessToken{}).Where("id = ?", refreshToken.AccessTokenID).Update("revoked", true).Error; err != nil {
return nil, vigo.NewError("旧令牌撤销失败").WithError(err).WithCode(500)
7 months ago
}
// 5. 生成新的访问令牌
accessToken, err := generateAccessToken(db, &client, refreshToken.UserID, refreshToken.Scope)
if err != nil {
return nil, vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
7 months ago
}
// 6. 更新刷新令牌关联
if err := db.Model(&refreshToken).Update("access_token_id", accessToken.ID).Error; err != nil {
return nil, vigo.NewError("刷新令牌更新失败").WithError(err).WithCode(500)
7 months ago
}
return &TokenResponse{
7 months ago
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: refreshToken.Scope,
}, nil
7 months ago
}
// handlePasswordGrant 处理密码授权类型
func handlePasswordGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) (*TokenResponse, error) {
7 months ago
// 1. 验证必要参数
if args.Username == "" || args.Password == "" {
return nil, vigo.NewError("用户名和密码不能为空").WithCode(400)
7 months ago
}
// 2. 验证客户端
var client OAuthClient
if err := db.Where("client_id = ?", args.ClientID).First(&client).Error; err != nil {
return nil, vigo.NewError("无效的客户端").WithCode(400)
7 months ago
}
// 3. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return nil, vigo.NewError("无效的客户端凭据").WithCode(400)
7 months ago
}
// 4. 验证用户凭据
var user User
if err := db.Where("username = ?", args.Username).First(&user).Error; err != nil {
return nil, vigo.NewError("用户名或密码错误").WithCode(400)
7 months ago
}
// 5. 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(args.Password)); err != nil {
return nil, vigo.NewError("用户名或密码错误").WithCode(400)
7 months ago
}
// 6. 处理权限范围
scope := args.Scope
if scope == "" {
// scope = DefaultScope // 默认权限范围
}
// 7. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, user.ID, scope)
if err != nil {
return nil, vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
7 months ago
}
// 8. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, user.ID, scope)
if err != nil {
return nil, vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
7 months ago
}
return &TokenResponse{
7 months ago
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: scope,
}, nil
7 months ago
}
// 辅助函数
func generateAccessToken(db *gorm.DB, client *OAuthClient, userID, scope string) (*OAuthAccessToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
accessToken := &OAuthAccessToken{
Token: token,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultAccessTokenExpiry),
Revoked: false,
}
if err := db.Create(accessToken).Error; err != nil {
return nil, err
}
return accessToken, nil
}
func generateRefreshToken(db *gorm.DB, accessToken *OAuthAccessToken, client *OAuthClient, userID, scope string) (*OAuthRefreshToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
refreshToken := &OAuthRefreshToken{
Token: token,
AccessTokenID: accessToken.ID,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultRefreshTokenExpiry),
Revoked: false,
}
if err := db.Create(refreshToken).Error; err != nil {
return nil, err
}
return refreshToken, nil
}
func validatePKCE(codeChallenge, method, codeVerifier string) error {
if codeVerifier == "" {
return fmt.Errorf("code verifier required")
}
switch method {
case CodeChallengeMethodPlain:
if codeChallenge != codeVerifier {
return fmt.Errorf("invalid code verifier")
}
case CodeChallengeMethodS256:
h := sha256.Sum256([]byte(codeVerifier))
expected := base64.RawURLEncoding.EncodeToString(h[:])
if codeChallenge != expected {
return fmt.Errorf("invalid code verifier")
}
default:
return fmt.Errorf("unsupported code challenge method")
}
return nil
}