You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/oauth/token.go

295 lines
9.1 KiB
Go

3 months ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-07-24 15:27:31
// Distributed under terms of the MIT license.
//
package oauth
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/veypi/OneAuth/cfg"
3 months ago
"github.com/vyes-ai/vigo"
3 months ago
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"time"
)
// TokenRequest 令牌请求参数
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
RedirectURI string `form:"redirect_uri"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
RefreshToken string `form:"refresh_token"`
CodeVerifier string `form:"code_verifier"`
Username string `form:"username"` // for password grant
Password string `form:"password"` // for password grant
Scope string `form:"scope"` // for password grant
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// handleToken 处理OAuth令牌请求
func handleToken(x *vigo.X) error {
args := &TokenRequest{}
if err := x.Parse(args); err != nil {
return vigo.NewError("参数解析失败").WithError(err).WithCode(400)
}
db := cfg.DB()
switch args.GrantType {
case GrantTypeAuthorizationCode:
return handleAuthorizationCodeGrant(db, x, args)
case GrantTypeRefreshToken:
return handleRefreshTokenGrant(db, x, args)
case GrantTypePassword:
return handlePasswordGrant(db, x, args)
default:
return vigo.NewError("不支持的授权类型").WithCode(400)
}
}
// handleAuthorizationCodeGrant 处理授权码授权类型
func handleAuthorizationCodeGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证授权码
var authCode OAuthAuthorizationCode
if err := db.Where("code = ? AND used = ?", args.Code, false).First(&authCode).Error; err != nil {
return vigo.NewError("无效的授权码").WithCode(400)
}
// 2. 检查授权码是否过期
if authCode.IsExpired() {
return vigo.NewError("授权码已过期").WithCode(400)
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", authCode.ClientID, args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 4. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return vigo.NewError("无效的客户端凭据").WithCode(400)
}
// 5. 验证重定向URI
if authCode.RedirectURI != args.RedirectURI {
return vigo.NewError("重定向URI不匹配").WithCode(400)
}
// 6. 验证PKCE如果使用
if authCode.CodeChallenge != "" {
if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, args.CodeVerifier); err != nil {
return vigo.NewError("PKCE验证失败").WithError(err).WithCode(400)
}
}
// 7. 标记授权码为已使用
if err := db.Model(&authCode).Update("used", true).Error; err != nil {
return vigo.NewError("授权码更新失败").WithError(err).WithCode(500)
}
// 8. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, authCode.UserID, authCode.Scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 9. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, authCode.UserID, authCode.Scope)
if err != nil {
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: authCode.Scope,
})
}
// handleRefreshTokenGrant 处理刷新令牌授权类型
func handleRefreshTokenGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证刷新令牌
var refreshToken OAuthRefreshToken
if err := db.Where("token = ? AND revoked = ?", args.RefreshToken, false).First(&refreshToken).Error; err != nil {
return vigo.NewError("无效的刷新令牌").WithCode(400)
}
// 2. 检查刷新令牌是否过期
if refreshToken.IsExpired() {
return vigo.NewError("刷新令牌已过期").WithCode(400)
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", refreshToken.ClientID, args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 4. 撤销旧的访问令牌
if err := db.Model(&OAuthAccessToken{}).Where("id = ?", refreshToken.AccessTokenID).Update("revoked", true).Error; err != nil {
return vigo.NewError("旧令牌撤销失败").WithError(err).WithCode(500)
}
// 5. 生成新的访问令牌
accessToken, err := generateAccessToken(db, &client, refreshToken.UserID, refreshToken.Scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 6. 更新刷新令牌关联
if err := db.Model(&refreshToken).Update("access_token_id", accessToken.ID).Error; err != nil {
return vigo.NewError("刷新令牌更新失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: refreshToken.Scope,
})
}
// handlePasswordGrant 处理密码授权类型
func handlePasswordGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证必要参数
if args.Username == "" || args.Password == "" {
return vigo.NewError("用户名和密码不能为空").WithCode(400)
}
// 2. 验证客户端
var client OAuthClient
if err := db.Where("client_id = ?", args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 3. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return vigo.NewError("无效的客户端凭据").WithCode(400)
}
// 4. 验证用户凭据
var user User
if err := db.Where("username = ?", args.Username).First(&user).Error; err != nil {
return vigo.NewError("用户名或密码错误").WithCode(400)
}
// 5. 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(args.Password)); err != nil {
return vigo.NewError("用户名或密码错误").WithCode(400)
}
// 6. 处理权限范围
scope := args.Scope
if scope == "" {
// scope = DefaultScope // 默认权限范围
}
// 7. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, user.ID, scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 8. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, user.ID, scope)
if err != nil {
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: scope,
})
}
// 辅助函数
func generateAccessToken(db *gorm.DB, client *OAuthClient, userID, scope string) (*OAuthAccessToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
accessToken := &OAuthAccessToken{
Token: token,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultAccessTokenExpiry),
Revoked: false,
}
if err := db.Create(accessToken).Error; err != nil {
return nil, err
}
return accessToken, nil
}
func generateRefreshToken(db *gorm.DB, accessToken *OAuthAccessToken, client *OAuthClient, userID, scope string) (*OAuthRefreshToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
refreshToken := &OAuthRefreshToken{
Token: token,
AccessTokenID: accessToken.ID,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultRefreshTokenExpiry),
Revoked: false,
}
if err := db.Create(refreshToken).Error; err != nil {
return nil, err
}
return refreshToken, nil
}
func validatePKCE(codeChallenge, method, codeVerifier string) error {
if codeVerifier == "" {
return fmt.Errorf("code verifier required")
}
switch method {
case CodeChallengeMethodPlain:
if codeChallenge != codeVerifier {
return fmt.Errorf("invalid code verifier")
}
case CodeChallengeMethodS256:
h := sha256.Sum256([]byte(codeVerifier))
expected := base64.RawURLEncoding.EncodeToString(h[:])
if codeChallenge != expected {
return fmt.Errorf("invalid code verifier")
}
default:
return fmt.Errorf("unsupported code challenge method")
}
return nil
}