// // Copyright (C) 2024 veypi // 2025-07-24 15:27:31 // Distributed under terms of the MIT license. // package oauth import ( "crypto/sha256" "encoding/base64" "fmt" "github.com/veypi/OneAuth/cfg" "github.com/vyes-ai/vigo" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "time" ) // TokenRequest 令牌请求参数 type TokenRequest struct { GrantType string `form:"grant_type" binding:"required"` Code string `form:"code"` RedirectURI string `form:"redirect_uri"` ClientID string `form:"client_id"` ClientSecret string `form:"client_secret"` RefreshToken string `form:"refresh_token"` CodeVerifier string `form:"code_verifier"` Username string `form:"username"` // for password grant Password string `form:"password"` // for password grant Scope string `form:"scope"` // for password grant } // TokenResponse 令牌响应 type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` } // handleToken 处理OAuth令牌请求 func handleToken(x *vigo.X) error { args := &TokenRequest{} if err := x.Parse(args); err != nil { return vigo.NewError("参数解析失败").WithError(err).WithCode(400) } db := cfg.DB() switch args.GrantType { case GrantTypeAuthorizationCode: return handleAuthorizationCodeGrant(db, x, args) case GrantTypeRefreshToken: return handleRefreshTokenGrant(db, x, args) case GrantTypePassword: return handlePasswordGrant(db, x, args) default: return vigo.NewError("不支持的授权类型").WithCode(400) } } // handleAuthorizationCodeGrant 处理授权码授权类型 func handleAuthorizationCodeGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error { // 1. 验证授权码 var authCode OAuthAuthorizationCode if err := db.Where("code = ? AND used = ?", args.Code, false).First(&authCode).Error; err != nil { return vigo.NewError("无效的授权码").WithCode(400) } // 2. 检查授权码是否过期 if authCode.IsExpired() { return vigo.NewError("授权码已过期").WithCode(400) } // 3. 验证客户端 var client OAuthClient if err := db.Where("id = ? AND client_id = ?", authCode.ClientID, args.ClientID).First(&client).Error; err != nil { return vigo.NewError("无效的客户端").WithCode(400) } // 4. 验证客户端密钥(对于机密客户端) if !client.IsPublic && client.ClientSecret != args.ClientSecret { return vigo.NewError("无效的客户端凭据").WithCode(400) } // 5. 验证重定向URI if authCode.RedirectURI != args.RedirectURI { return vigo.NewError("重定向URI不匹配").WithCode(400) } // 6. 验证PKCE(如果使用) if authCode.CodeChallenge != "" { if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, args.CodeVerifier); err != nil { return vigo.NewError("PKCE验证失败").WithError(err).WithCode(400) } } // 7. 标记授权码为已使用 if err := db.Model(&authCode).Update("used", true).Error; err != nil { return vigo.NewError("授权码更新失败").WithError(err).WithCode(500) } // 8. 生成访问令牌 accessToken, err := generateAccessToken(db, &client, authCode.UserID, authCode.Scope) if err != nil { return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500) } // 9. 生成刷新令牌 refreshToken, err := generateRefreshToken(db, accessToken, &client, authCode.UserID, authCode.Scope) if err != nil { return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500) } return x.JSON(&TokenResponse{ AccessToken: accessToken.Token, TokenType: TokenTypeBearer, ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()), RefreshToken: refreshToken.Token, Scope: authCode.Scope, }) } // handleRefreshTokenGrant 处理刷新令牌授权类型 func handleRefreshTokenGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error { // 1. 验证刷新令牌 var refreshToken OAuthRefreshToken if err := db.Where("token = ? AND revoked = ?", args.RefreshToken, false).First(&refreshToken).Error; err != nil { return vigo.NewError("无效的刷新令牌").WithCode(400) } // 2. 检查刷新令牌是否过期 if refreshToken.IsExpired() { return vigo.NewError("刷新令牌已过期").WithCode(400) } // 3. 验证客户端 var client OAuthClient if err := db.Where("id = ? AND client_id = ?", refreshToken.ClientID, args.ClientID).First(&client).Error; err != nil { return vigo.NewError("无效的客户端").WithCode(400) } // 4. 撤销旧的访问令牌 if err := db.Model(&OAuthAccessToken{}).Where("id = ?", refreshToken.AccessTokenID).Update("revoked", true).Error; err != nil { return vigo.NewError("旧令牌撤销失败").WithError(err).WithCode(500) } // 5. 生成新的访问令牌 accessToken, err := generateAccessToken(db, &client, refreshToken.UserID, refreshToken.Scope) if err != nil { return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500) } // 6. 更新刷新令牌关联 if err := db.Model(&refreshToken).Update("access_token_id", accessToken.ID).Error; err != nil { return vigo.NewError("刷新令牌更新失败").WithError(err).WithCode(500) } return x.JSON(&TokenResponse{ AccessToken: accessToken.Token, TokenType: TokenTypeBearer, ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()), RefreshToken: refreshToken.Token, Scope: refreshToken.Scope, }) } // handlePasswordGrant 处理密码授权类型 func handlePasswordGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error { // 1. 验证必要参数 if args.Username == "" || args.Password == "" { return vigo.NewError("用户名和密码不能为空").WithCode(400) } // 2. 验证客户端 var client OAuthClient if err := db.Where("client_id = ?", args.ClientID).First(&client).Error; err != nil { return vigo.NewError("无效的客户端").WithCode(400) } // 3. 验证客户端密钥(对于机密客户端) if !client.IsPublic && client.ClientSecret != args.ClientSecret { return vigo.NewError("无效的客户端凭据").WithCode(400) } // 4. 验证用户凭据 var user User if err := db.Where("username = ?", args.Username).First(&user).Error; err != nil { return vigo.NewError("用户名或密码错误").WithCode(400) } // 5. 验证密码 if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(args.Password)); err != nil { return vigo.NewError("用户名或密码错误").WithCode(400) } // 6. 处理权限范围 scope := args.Scope if scope == "" { // scope = DefaultScope // 默认权限范围 } // 7. 生成访问令牌 accessToken, err := generateAccessToken(db, &client, user.ID, scope) if err != nil { return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500) } // 8. 生成刷新令牌 refreshToken, err := generateRefreshToken(db, accessToken, &client, user.ID, scope) if err != nil { return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500) } return x.JSON(&TokenResponse{ AccessToken: accessToken.Token, TokenType: TokenTypeBearer, ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()), RefreshToken: refreshToken.Token, Scope: scope, }) } // 辅助函数 func generateAccessToken(db *gorm.DB, client *OAuthClient, userID, scope string) (*OAuthAccessToken, error) { token, err := generateRandomString(64) if err != nil { return nil, err } accessToken := &OAuthAccessToken{ Token: token, ClientID: client.ID, UserID: userID, Scope: scope, ExpiresAt: time.Now().Add(DefaultAccessTokenExpiry), Revoked: false, } if err := db.Create(accessToken).Error; err != nil { return nil, err } return accessToken, nil } func generateRefreshToken(db *gorm.DB, accessToken *OAuthAccessToken, client *OAuthClient, userID, scope string) (*OAuthRefreshToken, error) { token, err := generateRandomString(64) if err != nil { return nil, err } refreshToken := &OAuthRefreshToken{ Token: token, AccessTokenID: accessToken.ID, ClientID: client.ID, UserID: userID, Scope: scope, ExpiresAt: time.Now().Add(DefaultRefreshTokenExpiry), Revoked: false, } if err := db.Create(refreshToken).Error; err != nil { return nil, err } return refreshToken, nil } func validatePKCE(codeChallenge, method, codeVerifier string) error { if codeVerifier == "" { return fmt.Errorf("code verifier required") } switch method { case CodeChallengeMethodPlain: if codeChallenge != codeVerifier { return fmt.Errorf("invalid code verifier") } case CodeChallengeMethodS256: h := sha256.Sum256([]byte(codeVerifier)) expected := base64.RawURLEncoding.EncodeToString(h[:]) if codeChallenge != expected { return fmt.Errorf("invalid code verifier") } default: return fmt.Errorf("unsupported code challenge method") } return nil }