You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/oauth/token.go

295 lines
9.1 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-07-24 15:27:31
// Distributed under terms of the MIT license.
//
package oauth
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/veypi/OneAuth/cfg"
"github.com/vyes/vigo"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"time"
)
// TokenRequest 令牌请求参数
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
RedirectURI string `form:"redirect_uri"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
RefreshToken string `form:"refresh_token"`
CodeVerifier string `form:"code_verifier"`
Username string `form:"username"` // for password grant
Password string `form:"password"` // for password grant
Scope string `form:"scope"` // for password grant
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// handleToken 处理OAuth令牌请求
func handleToken(x *vigo.X) error {
args := &TokenRequest{}
if err := x.Parse(args); err != nil {
return vigo.NewError("参数解析失败").WithError(err).WithCode(400)
}
db := cfg.DB()
switch args.GrantType {
case GrantTypeAuthorizationCode:
return handleAuthorizationCodeGrant(db, x, args)
case GrantTypeRefreshToken:
return handleRefreshTokenGrant(db, x, args)
case GrantTypePassword:
return handlePasswordGrant(db, x, args)
default:
return vigo.NewError("不支持的授权类型").WithCode(400)
}
}
// handleAuthorizationCodeGrant 处理授权码授权类型
func handleAuthorizationCodeGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证授权码
var authCode OAuthAuthorizationCode
if err := db.Where("code = ? AND used = ?", args.Code, false).First(&authCode).Error; err != nil {
return vigo.NewError("无效的授权码").WithCode(400)
}
// 2. 检查授权码是否过期
if authCode.IsExpired() {
return vigo.NewError("授权码已过期").WithCode(400)
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", authCode.ClientID, args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 4. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return vigo.NewError("无效的客户端凭据").WithCode(400)
}
// 5. 验证重定向URI
if authCode.RedirectURI != args.RedirectURI {
return vigo.NewError("重定向URI不匹配").WithCode(400)
}
// 6. 验证PKCE如果使用
if authCode.CodeChallenge != "" {
if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, args.CodeVerifier); err != nil {
return vigo.NewError("PKCE验证失败").WithError(err).WithCode(400)
}
}
// 7. 标记授权码为已使用
if err := db.Model(&authCode).Update("used", true).Error; err != nil {
return vigo.NewError("授权码更新失败").WithError(err).WithCode(500)
}
// 8. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, authCode.UserID, authCode.Scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 9. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, authCode.UserID, authCode.Scope)
if err != nil {
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: authCode.Scope,
})
}
// handleRefreshTokenGrant 处理刷新令牌授权类型
func handleRefreshTokenGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证刷新令牌
var refreshToken OAuthRefreshToken
if err := db.Where("token = ? AND revoked = ?", args.RefreshToken, false).First(&refreshToken).Error; err != nil {
return vigo.NewError("无效的刷新令牌").WithCode(400)
}
// 2. 检查刷新令牌是否过期
if refreshToken.IsExpired() {
return vigo.NewError("刷新令牌已过期").WithCode(400)
}
// 3. 验证客户端
var client OAuthClient
if err := db.Where("id = ? AND client_id = ?", refreshToken.ClientID, args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 4. 撤销旧的访问令牌
if err := db.Model(&OAuthAccessToken{}).Where("id = ?", refreshToken.AccessTokenID).Update("revoked", true).Error; err != nil {
return vigo.NewError("旧令牌撤销失败").WithError(err).WithCode(500)
}
// 5. 生成新的访问令牌
accessToken, err := generateAccessToken(db, &client, refreshToken.UserID, refreshToken.Scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 6. 更新刷新令牌关联
if err := db.Model(&refreshToken).Update("access_token_id", accessToken.ID).Error; err != nil {
return vigo.NewError("刷新令牌更新失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: refreshToken.Scope,
})
}
// handlePasswordGrant 处理密码授权类型
func handlePasswordGrant(db *gorm.DB, x *vigo.X, args *TokenRequest) error {
// 1. 验证必要参数
if args.Username == "" || args.Password == "" {
return vigo.NewError("用户名和密码不能为空").WithCode(400)
}
// 2. 验证客户端
var client OAuthClient
if err := db.Where("client_id = ?", args.ClientID).First(&client).Error; err != nil {
return vigo.NewError("无效的客户端").WithCode(400)
}
// 3. 验证客户端密钥(对于机密客户端)
if !client.IsPublic && client.ClientSecret != args.ClientSecret {
return vigo.NewError("无效的客户端凭据").WithCode(400)
}
// 4. 验证用户凭据
var user User
if err := db.Where("username = ?", args.Username).First(&user).Error; err != nil {
return vigo.NewError("用户名或密码错误").WithCode(400)
}
// 5. 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(args.Password)); err != nil {
return vigo.NewError("用户名或密码错误").WithCode(400)
}
// 6. 处理权限范围
scope := args.Scope
if scope == "" {
// scope = DefaultScope // 默认权限范围
}
// 7. 生成访问令牌
accessToken, err := generateAccessToken(db, &client, user.ID, scope)
if err != nil {
return vigo.NewError("访问令牌生成失败").WithError(err).WithCode(500)
}
// 8. 生成刷新令牌
refreshToken, err := generateRefreshToken(db, accessToken, &client, user.ID, scope)
if err != nil {
return vigo.NewError("刷新令牌生成失败").WithError(err).WithCode(500)
}
return x.JSON(&TokenResponse{
AccessToken: accessToken.Token,
TokenType: TokenTypeBearer,
ExpiresIn: int64(DefaultAccessTokenExpiry.Seconds()),
RefreshToken: refreshToken.Token,
Scope: scope,
})
}
// 辅助函数
func generateAccessToken(db *gorm.DB, client *OAuthClient, userID, scope string) (*OAuthAccessToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
accessToken := &OAuthAccessToken{
Token: token,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultAccessTokenExpiry),
Revoked: false,
}
if err := db.Create(accessToken).Error; err != nil {
return nil, err
}
return accessToken, nil
}
func generateRefreshToken(db *gorm.DB, accessToken *OAuthAccessToken, client *OAuthClient, userID, scope string) (*OAuthRefreshToken, error) {
token, err := generateRandomString(64)
if err != nil {
return nil, err
}
refreshToken := &OAuthRefreshToken{
Token: token,
AccessTokenID: accessToken.ID,
ClientID: client.ID,
UserID: userID,
Scope: scope,
ExpiresAt: time.Now().Add(DefaultRefreshTokenExpiry),
Revoked: false,
}
if err := db.Create(refreshToken).Error; err != nil {
return nil, err
}
return refreshToken, nil
}
func validatePKCE(codeChallenge, method, codeVerifier string) error {
if codeVerifier == "" {
return fmt.Errorf("code verifier required")
}
switch method {
case CodeChallengeMethodPlain:
if codeChallenge != codeVerifier {
return fmt.Errorf("invalid code verifier")
}
case CodeChallengeMethodS256:
h := sha256.Sum256([]byte(codeVerifier))
expected := base64.RawURLEncoding.EncodeToString(h[:])
if codeChallenge != expected {
return fmt.Errorf("invalid code verifier")
}
default:
return fmt.Errorf("unsupported code challenge method")
}
return nil
}