You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/internal/api/oauth/oauth.go

586 lines
18 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"net/url"
"strings"
"time"
"github.com/veypi/vbase/internal/model"
"github.com/veypi/vigo"
)
// AuthorizeRequest 授权请求
type AuthorizeRequest struct {
ResponseType string `json:"response_type" src:"query" desc:"响应类型: code/token"`
ClientID string `json:"client_id" src:"query" desc:"客户端ID"`
RedirectURI string `json:"redirect_uri" src:"query" desc:"回调地址"`
Scope string `json:"scope" src:"query" desc:"请求的权限范围"`
State string `json:"state" src:"query" desc:"状态值(防CSRF)"`
CodeChallenge string `json:"code_challenge" src:"query" desc:"PKCE挑战码"`
CodeChallengeMethod string `json:"code_challenge_method" src:"query" desc:"PKCE方法: S256/plain"`
}
// AuthorizeResponse 授权响应 (用于重定向)
type AuthorizeResponse struct {
Code string `json:"code"`
State string `json:"state"`
}
// Authorize 授权端点 - 处理授权码请求
// GET /oauth/authorize
func Authorize(x *vigo.X, req *AuthorizeRequest) error {
// 验证必填参数
if req.ResponseType == "" {
return oauthError(x, req.RedirectURI, "invalid_request", "response_type is required", req.State)
}
if req.ClientID == "" {
return oauthError(x, req.RedirectURI, "invalid_request", "client_id is required", req.State)
}
// 查找客户端
var client model.OAuthClient
if err := model.DB.Where("client_id = ? AND status = ?", req.ClientID, 1).First(&client).Error; err != nil {
return oauthError(x, req.RedirectURI, "invalid_client", "client not found", req.State)
}
// 验证response_type
if !strings.Contains(client.ResponseTypes, req.ResponseType) {
return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State)
}
// 验证redirect_uri
if req.RedirectURI != "" {
allowedURIs := strings.Split(client.RedirectURIs, ",")
found := false
for _, uri := range allowedURIs {
if strings.TrimSpace(uri) == req.RedirectURI {
found = true
break
}
}
if !found {
return oauthError(x, "", "invalid_request", "redirect_uri mismatch", "")
}
} else if client.RedirectURIs != "" {
// 使用第一个注册的回调地址
req.RedirectURI = strings.Split(client.RedirectURIs, ",")[0]
} else {
return oauthError(x, "", "invalid_request", "redirect_uri required", "")
}
// 验证scope
requestedScopes := parseScopes(req.Scope)
allowedScopes := parseScopes(client.AllowedScopes)
for _, scope := range requestedScopes {
if !contains(allowedScopes, scope) {
return oauthError(x, req.RedirectURI, "invalid_scope", "scope not allowed: "+scope, req.State)
}
}
// 获取当前用户
var userID string
if uid, ok := x.Get("current_user").(string); ok {
userID = uid
}
if userID == "" {
// 未登录,需要重定向到登录页面
loginURL := "/login?redirect=" + url.QueryEscape(x.Request.URL.String())
x.ResponseWriter().Header().Set("Location", loginURL)
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}
// 获取组织ID (从请求头或用户选择)
orgID := x.Request.Header.Get("X-Org-ID")
if orgID == "" && client.OrgID != "" {
orgID = client.OrgID
}
switch req.ResponseType {
case model.ResponseTypeCode:
// 生成授权码
code, err := generateAuthorizationCode(&client, userID, orgID, req)
if err != nil {
return oauthError(x, req.RedirectURI, "server_error", "", req.State)
}
// 构建重定向URL
redirectURL, _ := url.Parse(req.RedirectURI)
q := redirectURL.Query()
q.Set("code", code)
if req.State != "" {
q.Set("state", req.State)
}
redirectURL.RawQuery = q.Encode()
x.ResponseWriter().Header().Set("Location", redirectURL.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
case model.ResponseTypeToken:
// Implicit flow (简化模式) - 直接返回token
// 注意: 简化模式安全性较低,建议仅在必要时使用
return handleImplicitGrant(x, &client, userID, orgID, req)
default:
return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State)
}
}
// TokenRequest 令牌请求
type TokenRequest struct {
GrantType string `json:"grant_type" src:"form" desc:"授权类型"`
Code string `json:"code" src:"form" desc:"授权码"`
RedirectURI string `json:"redirect_uri" src:"form" desc:"回调地址"`
ClientID string `json:"client_id" src:"form" desc:"客户端ID"`
ClientSecret string `json:"client_secret" src:"form" desc:"客户端密钥"`
RefreshToken string `json:"refresh_token" src:"form" desc:"刷新令牌"`
Scope string `json:"scope" src:"form" desc:"权限范围"`
CodeVerifier string `json:"code_verifier" src:"form" desc:"PKCE验证器"`
Username string `json:"username" src:"form" desc:"用户名(密码模式)"`
Password string `json:"password" src:"form" desc:"密码(密码模式)"`
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"` // OIDC
}
// Token 令牌端点 - 交换授权码获取访问令牌
// POST /oauth/token
func Token(x *vigo.X, req *TokenRequest) (*TokenResponse, error) {
// 验证客户端身份
client, err := authenticateClient(x, req)
if err != nil {
return nil, vigo.ErrNotAuthorized.WithString("invalid_client")
}
// 验证grant_type
if !strings.Contains(client.GrantTypes, req.GrantType) {
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type")
}
switch req.GrantType {
case model.GrantTypeAuthorizationCode:
return handleAuthorizationCodeGrant(x, client, req)
case model.GrantTypeRefreshToken:
return handleRefreshTokenGrant(x, client, req)
case model.GrantTypeClientCredentials:
return handleClientCredentialsGrant(x, client, req)
case model.GrantTypePassword:
return handlePasswordGrant(x, client, req)
default:
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type")
}
}
// handleAuthorizationCodeGrant 处理授权码模式
func handleAuthorizationCodeGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
if req.Code == "" {
return nil, vigo.ErrArgInvalid.WithString("code is required")
}
// 查找授权码
var auth model.OAuthAuthorization
if err := model.DB.Where("code = ? AND client_id = ? AND used = ?", req.Code, client.ClientID, false).First(&auth).Error; err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant")
}
// 检查是否过期
if time.Now().After(auth.ExpiresAt) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code expired")
}
// 验证redirect_uri
if req.RedirectURI != "" && req.RedirectURI != auth.RedirectURI {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: redirect_uri mismatch")
}
// PKCE验证
if auth.CodeChallenge != "" {
if req.CodeVerifier == "" {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier required")
}
if !verifyPKCE(req.CodeVerifier, auth.CodeChallenge, auth.CodeChallengeMethod) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier mismatch")
}
}
// 标记授权码为已使用
now := time.Now()
model.DB.Model(&auth).Updates(map[string]interface{}{
"used": true,
"used_at": now,
})
// 生成访问令牌
tokenResp, err := generateTokenPair(client, auth.UserID, auth.OrgID, auth.Scope, client.TokenExpiry, client.RefreshExpiry)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return tokenResp, nil
}
// handleRefreshTokenGrant 处理刷新令牌模式
func handleRefreshTokenGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
if req.RefreshToken == "" {
return nil, vigo.ErrArgInvalid.WithString("refresh_token is required")
}
// 查找刷新令牌
var token model.OAuthToken
if err := model.DB.Where("refresh_token = ? AND client_id = ? AND revoked = ?", req.RefreshToken, client.ClientID, false).First(&token).Error; err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant")
}
// 检查是否过期
if time.Now().After(token.ExpiresAt) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: token expired")
}
// 撤销旧的刷新令牌
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
// 生成新的访问令牌
tokenResp, err := generateTokenPair(client, token.UserID, token.OrgID, token.Scope, client.TokenExpiry, client.RefreshExpiry)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return tokenResp, nil
}
// handleClientCredentialsGrant 处理客户端凭证模式
func handleClientCredentialsGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
// 客户端凭证模式没有用户上下文,通常用于服务间调用
// 限制scope只允许非用户相关的权限
requestedScopes := parseScopes(req.Scope)
allowedScopes := []string{}
for _, scope := range requestedScopes {
if scope == "service" || strings.HasPrefix(scope, "service:") {
allowedScopes = append(allowedScopes, scope)
}
}
if len(allowedScopes) == 0 {
allowedScopes = []string{"service"}
}
scopeStr := strings.Join(allowedScopes, " ")
// 生成访问令牌
accessToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second)
token := &model.OAuthToken{
UserID: "", // 客户端凭证模式没有用户
ClientID: client.ClientID,
OrgID: client.OrgID,
AccessToken: accessToken,
TokenType: "Bearer",
Scope: scopeStr,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: client.TokenExpiry,
Scope: scopeStr,
}, nil
}
// handlePasswordGrant 处理密码模式 (不推荐,但为兼容性保留)
func handlePasswordGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
// 密码模式需要验证用户凭据
// 这里简化处理,实际应该调用用户认证逻辑
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type: password grant is disabled")
}
// handleImplicitGrant 处理简化模式
func handleImplicitGrant(x *vigo.X, client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) error {
// 生成访问令牌
accessToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second)
token := &model.OAuthToken{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
AccessToken: accessToken,
TokenType: "Bearer",
Scope: req.Scope,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return oauthError(x, req.RedirectURI, "server_error", "", req.State)
}
// 构建fragment URL
fragment := url.Values{}
fragment.Set("access_token", accessToken)
fragment.Set("token_type", "Bearer")
fragment.Set("expires_in", string(rune(client.TokenExpiry)))
if req.Scope != "" {
fragment.Set("scope", req.Scope)
}
if req.State != "" {
fragment.Set("state", req.State)
}
redirectURL, _ := url.Parse(req.RedirectURI)
redirectURL.Fragment = fragment.Encode()
x.ResponseWriter().Header().Set("Location", redirectURL.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}
// RevokeRequest 撤销请求
type RevokeRequest struct {
Token string `json:"token" src:"form" desc:"要撤销的令牌"`
TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示: access_token/refresh_token"`
}
// Revoke 撤销令牌端点
// POST /oauth/revoke
func Revoke(x *vigo.X, req *RevokeRequest) error {
if req.Token == "" {
return nil // 根据RFC 7009无效的令牌也应返回200
}
// 尝试查找access_token
var token model.OAuthToken
if err := model.DB.Where("access_token = ?", req.Token).First(&token).Error; err == nil {
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
return nil
}
// 尝试查找refresh_token
if err := model.DB.Where("refresh_token = ?", req.Token).First(&token).Error; err == nil {
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
}
return nil
}
// IntrospectRequest 令牌内省请求
type IntrospectRequest struct {
Token string `json:"token" src:"form" desc:"要内省的令牌"`
TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示"`
}
// IntrospectResponse 令牌内省响应
type IntrospectResponse struct {
Active bool `json:"active"`
Scope string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
}
// Introspect 令牌内省端点 (RFC 7662)
// POST /oauth/introspect
func Introspect(x *vigo.X, req *IntrospectRequest) (*IntrospectResponse, error) {
if req.Token == "" {
return &IntrospectResponse{Active: false}, nil
}
var token model.OAuthToken
if err := model.DB.Where("access_token = ? AND revoked = ?", req.Token, false).First(&token).Error; err != nil {
return &IntrospectResponse{Active: false}, nil
}
// 检查是否过期
if time.Now().After(token.ExpiresAt) {
return &IntrospectResponse{Active: false}, nil
}
// 获取用户信息
var user model.User
username := ""
if err := model.DB.First(&user, "id = ?", token.UserID).Error; err == nil {
username = user.Username
}
return &IntrospectResponse{
Active: true,
Scope: token.Scope,
ClientID: token.ClientID,
Username: username,
TokenType: token.TokenType,
Exp: token.ExpiresAt.Unix(),
Sub: token.UserID,
}, nil
}
// helper functions
func generateAuthorizationCode(client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) (string, error) {
code := generateRandomToken(32)
expiresAt := time.Now().Add(10 * time.Minute) // 授权码10分钟有效
auth := &model.OAuthAuthorization{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
Code: code,
Scope: req.Scope,
State: req.State,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
RedirectURI: req.RedirectURI,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(auth).Error; err != nil {
return "", err
}
return code, nil
}
func generateTokenPair(client *model.OAuthClient, userID, orgID, scope string, accessExpiry, refreshExpiry int) (*TokenResponse, error) {
accessToken := generateRandomToken(32)
refreshToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(accessExpiry) * time.Second)
token := &model.OAuthToken{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
Scope: scope,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return nil, err
}
return &TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: accessExpiry,
RefreshToken: refreshToken,
Scope: scope,
}, nil
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}
func authenticateClient(x *vigo.X, req *TokenRequest) (*model.OAuthClient, error) {
// 优先从Basic Auth获取
clientID, clientSecret, ok := x.Request.BasicAuth()
if ok {
req.ClientID = clientID
req.ClientSecret = clientSecret
}
if req.ClientID == "" || req.ClientSecret == "" {
return nil, vigo.ErrNotAuthorized
}
var client model.OAuthClient
if err := model.DB.Where("client_id = ? AND client_secret = ? AND status = ?", req.ClientID, req.ClientSecret, 1).First(&client).Error; err != nil {
return nil, err
}
return &client, nil
}
func verifyPKCE(verifier, challenge, method string) bool {
switch method {
case "S256":
hash := sha256.Sum256([]byte(verifier))
encoded := base64.RawURLEncoding.EncodeToString(hash[:])
return encoded == challenge
case "plain":
return verifier == challenge
default:
return false
}
}
func parseScopes(scope string) []string {
if scope == "" {
return []string{}
}
return strings.Split(scope, " ")
}
func contains(arr []string, item string) bool {
for _, a := range arr {
if a == item {
return true
}
}
return false
}
func oauthError(x *vigo.X, redirectURI, errorCode, errorDescription, state string) error {
if redirectURI == "" {
return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription)
}
u, err := url.Parse(redirectURI)
if err != nil {
return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription)
}
q := u.Query()
q.Set("error", errorCode)
if errorDescription != "" {
q.Set("error_description", errorDescription)
}
if state != "" {
q.Set("state", state)
}
u.RawQuery = q.Encode()
x.ResponseWriter().Header().Set("Location", u.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}