|
|
//
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
// 2025-07-24 15:27:31
|
|
|
// Distributed under terms of the MIT license.
|
|
|
//
|
|
|
|
|
|
package oauth
|
|
|
|
|
|
import (
|
|
|
"crypto/sha256"
|
|
|
"encoding/base64"
|
|
|
"fmt"
|
|
|
"github.com/veypi/OneAuth/cfg"
|
|
|
"github.com/vyes/vigo"
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
"gorm.io/gorm"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
// TokenRequest 令牌请求参数
|
|
|
type TokenRequest struct {
|
|
|
GrantType string `form:"grant_type" binding:"required"`
|
|
|
Code string `form:"code"`
|
|
|
RedirectURI string `form:"redirect_uri"`
|
|
|
ClientID string `form:"client_id"`
|
|
|
ClientSecret string `form:"client_secret"`
|
|
|
RefreshToken string `form:"refresh_token"`
|
|
|
CodeVerifier string `form:"code_verifier"`
|
|
|
Username string `form:"username"` // for password grant
|
|
|
Password string `form:"password"` // for password grant
|
|
|
Scope string `form:"scope"` // for password grant
|
|
|
}
|
|
|
|
|
|
// TokenResponse 令牌响应
|
|
|
type TokenResponse struct {
|
|
|
AccessToken string `json:"access_token"`
|
|
|
TokenType string `json:"token_type"`
|
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
|
Scope string `json:"scope,omitempty"`
|
|
|
}
|
|
|
|
|
|
// handleToken 处理OAuth令牌请求
|
|
|
func handleToken(x *vigo.X) error {
|
|
|
args := &TokenRequest{}
|
|
|
if err := x.Parse(args); err != nil {
|
|
|
return vigo.NewError("参数解析失败").WithError(err).WithCode(400)
|
|
|
}
|
|
|
|
|
|
db := cfg.DB()
|
|
|
|
|
|
switch args.GrantType {
|
|
|
case GrantTypeAuthorizationCode:
|
|
|
return handleAuthorizationCodeGrant(db, x, args)
|
|
|
case GrantTypeRefreshToken:
|
|
|
return handleRefreshTokenGrant(db, x, args)
|
|
|
case GrantTypePassword:
|
|
|
return handlePasswordGrant(db, x, args)
|
|
|
default:
|
|
|
return vigo.NewError("不支持的授权类型").WithCode(400)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// handleAuthorizationCodeGrant 处理授权码授权类型
|
|
|
func handleAuthorizationCodeGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
|
|
|
// 1. 验证授权码
|
|
|
var authCode OAuthAuthorizationCode
|
|
|
if err := db.Where("code = ? AND used = ?", args.Code, false).First(&authCode).Error; err != nil {
|
|
|
return vigo.NewError("无效的授权码").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 2. 检查授权码是否过期
|
|
|
if authCode.IsExpired() {
|
|
|
return vigo.NewError("授权码已过期").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 3. 验证客户端
|
|
|
var client OAuthClient
|
|
|
if err := db.Where("id = ? AND client_id = ?", authCode.ClientID, args.ClientID).First(&client).Error; err != nil {
|
|
|
return vigo.NewError("无效的客户端").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 4. 验证客户端密钥(对于机密客户端)
|
|
|
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
|
|
|
return vigo.NewError("无效的客户端凭据").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 5. 验证重定向URI
|
|
|
if authCode.RedirectURI != args.RedirectURI {
|
|
|
return vigo.NewError("重定向URI不匹配").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 6. 验证PKCE(如果使用)
|
|
|
if authCode.CodeChallenge != "" {
|
|
|
if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, args.CodeVerifier); err != nil {
|
|
|
return vigo.NewError("PKCE验证失败").WithError(err).WithCode(400)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 7. 标记授权码为已使用
|
|
|
if err := db.Model(&authCode).Update("used", true).Error; err != nil {
|
|
|
return vigo.NewError("授权码更新失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
// 8. 生成访问令牌
|
|
|
accessToken, err := generateAccessToken(db, &client, authCode.UserID, authCode.Scope)
|
|
|
if err != nil {
|
|
|
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
// 9. 生成刷新令牌
|
|
|
refreshToken, err := generateRefreshToken(db, accessToken, &client, authCode.UserID, authCode.Scope)
|
|
|
if err != nil {
|
|
|
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
return x.JSON(&TokenResponse{
|
|
|
AccessToken: accessToken.Token,
|
|
|
TokenType: TokenTypeBearer,
|
|
|
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
|
|
|
RefreshToken: refreshToken.Token,
|
|
|
Scope: authCode.Scope,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
// handleRefreshTokenGrant 处理刷新令牌授权类型
|
|
|
func handleRefreshTokenGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
|
|
|
// 1. 验证刷新令牌
|
|
|
var refreshToken OAuthRefreshToken
|
|
|
if err := db.Where("token = ? AND revoked = ?", args.RefreshToken, false).First(&refreshToken).Error; err != nil {
|
|
|
return vigo.NewError("无效的刷新令牌").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 2. 检查刷新令牌是否过期
|
|
|
if refreshToken.IsExpired() {
|
|
|
return vigo.NewError("刷新令牌已过期").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 3. 验证客户端
|
|
|
var client OAuthClient
|
|
|
if err := db.Where("id = ? AND client_id = ?", refreshToken.ClientID, args.ClientID).First(&client).Error; err != nil {
|
|
|
return vigo.NewError("无效的客户端").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 4. 撤销旧的访问令牌
|
|
|
if err := db.Model(&OAuthAccessToken{}).Where("id = ?", refreshToken.AccessTokenID).Update("revoked", true).Error; err != nil {
|
|
|
return vigo.NewError("旧令牌撤销失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
// 5. 生成新的访问令牌
|
|
|
accessToken, err := generateAccessToken(db, &client, refreshToken.UserID, refreshToken.Scope)
|
|
|
if err != nil {
|
|
|
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
// 6. 更新刷新令牌关联
|
|
|
if err := db.Model(&refreshToken).Update("access_token_id", accessToken.ID).Error; err != nil {
|
|
|
return vigo.NewError("刷新令牌更新失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
return x.JSON(&TokenResponse{
|
|
|
AccessToken: accessToken.Token,
|
|
|
TokenType: TokenTypeBearer,
|
|
|
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
|
|
|
RefreshToken: refreshToken.Token,
|
|
|
Scope: refreshToken.Scope,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
// handlePasswordGrant 处理密码授权类型
|
|
|
func handlePasswordGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
|
|
|
// 1. 验证必要参数
|
|
|
if args.Username == "" || args.Password == "" {
|
|
|
return vigo.NewError("用户名和密码不能为空").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 2. 验证客户端
|
|
|
var client OAuthClient
|
|
|
if err := db.Where("client_id = ?", args.ClientID).First(&client).Error; err != nil {
|
|
|
return vigo.NewError("无效的客户端").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 3. 验证客户端密钥(对于机密客户端)
|
|
|
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
|
|
|
return vigo.NewError("无效的客户端凭据").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 4. 验证用户凭据
|
|
|
var user User
|
|
|
if err := db.Where("username = ?", args.Username).First(&user).Error; err != nil {
|
|
|
return vigo.NewError("用户名或密码错误").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 5. 验证密码
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(args.Password)); err != nil {
|
|
|
return vigo.NewError("用户名或密码错误").WithCode(400)
|
|
|
}
|
|
|
|
|
|
// 6. 处理权限范围
|
|
|
scope := args.Scope
|
|
|
if scope == "" {
|
|
|
// scope = DefaultScope // 默认权限范围
|
|
|
}
|
|
|
|
|
|
// 7. 生成访问令牌
|
|
|
accessToken, err := generateAccessToken(db, &client, user.ID, scope)
|
|
|
if err != nil {
|
|
|
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
// 8. 生成刷新令牌
|
|
|
refreshToken, err := generateRefreshToken(db, accessToken, &client, user.ID, scope)
|
|
|
if err != nil {
|
|
|
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
|
|
|
}
|
|
|
|
|
|
return x.JSON(&TokenResponse{
|
|
|
AccessToken: accessToken.Token,
|
|
|
TokenType: TokenTypeBearer,
|
|
|
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
|
|
|
RefreshToken: refreshToken.Token,
|
|
|
Scope: scope,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
// 辅助函数
|
|
|
|
|
|
func generateAccessToken(db *gorm.DB, client *OAuthClient, userID, scope string) (*OAuthAccessToken, error) {
|
|
|
token, err := generateRandomString(64)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
accessToken := &OAuthAccessToken{
|
|
|
Token: token,
|
|
|
ClientID: client.ID,
|
|
|
UserID: userID,
|
|
|
Scope: scope,
|
|
|
ExpiresAt: time.Now().Add(DefaultAccessTokenExpiry),
|
|
|
Revoked: false,
|
|
|
}
|
|
|
|
|
|
if err := db.Create(accessToken).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return accessToken, nil
|
|
|
}
|
|
|
|
|
|
func generateRefreshToken(db *gorm.DB, accessToken *OAuthAccessToken, client *OAuthClient, userID, scope string) (*OAuthRefreshToken, error) {
|
|
|
token, err := generateRandomString(64)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
refreshToken := &OAuthRefreshToken{
|
|
|
Token: token,
|
|
|
AccessTokenID: accessToken.ID,
|
|
|
ClientID: client.ID,
|
|
|
UserID: userID,
|
|
|
Scope: scope,
|
|
|
ExpiresAt: time.Now().Add(DefaultRefreshTokenExpiry),
|
|
|
Revoked: false,
|
|
|
}
|
|
|
|
|
|
if err := db.Create(refreshToken).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return refreshToken, nil
|
|
|
}
|
|
|
|
|
|
func validatePKCE(codeChallenge, method, codeVerifier string) error {
|
|
|
if codeVerifier == "" {
|
|
|
return fmt.Errorf("code verifier required")
|
|
|
}
|
|
|
|
|
|
switch method {
|
|
|
case CodeChallengeMethodPlain:
|
|
|
if codeChallenge != codeVerifier {
|
|
|
return fmt.Errorf("invalid code verifier")
|
|
|
}
|
|
|
case CodeChallengeMethodS256:
|
|
|
h := sha256.Sum256([]byte(codeVerifier))
|
|
|
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
|
|
if codeChallenge != expected {
|
|
|
return fmt.Errorf("invalid code verifier")
|
|
|
}
|
|
|
default:
|
|
|
return fmt.Errorf("unsupported code challenge method")
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|