You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/internal/api/oauth/oauth.go

586 lines
18 KiB
Go

2 weeks ago
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"net/url"
"strings"
"time"
"github.com/veypi/vbase/internal/model"
"github.com/veypi/vigo"
)
// AuthorizeRequest 授权请求
type AuthorizeRequest struct {
ResponseType string `json:"response_type" src:"query" desc:"响应类型: code/token"`
ClientID string `json:"client_id" src:"query" desc:"客户端ID"`
RedirectURI string `json:"redirect_uri" src:"query" desc:"回调地址"`
Scope string `json:"scope" src:"query" desc:"请求的权限范围"`
State string `json:"state" src:"query" desc:"状态值(防CSRF)"`
CodeChallenge string `json:"code_challenge" src:"query" desc:"PKCE挑战码"`
CodeChallengeMethod string `json:"code_challenge_method" src:"query" desc:"PKCE方法: S256/plain"`
}
// AuthorizeResponse 授权响应 (用于重定向)
type AuthorizeResponse struct {
Code string `json:"code"`
State string `json:"state"`
}
// Authorize 授权端点 - 处理授权码请求
// GET /oauth/authorize
func Authorize(x *vigo.X, req *AuthorizeRequest) error {
// 验证必填参数
if req.ResponseType == "" {
return oauthError(x, req.RedirectURI, "invalid_request", "response_type is required", req.State)
}
if req.ClientID == "" {
return oauthError(x, req.RedirectURI, "invalid_request", "client_id is required", req.State)
}
// 查找客户端
var client model.OAuthClient
if err := model.DB.Where("client_id = ? AND status = ?", req.ClientID, 1).First(&client).Error; err != nil {
return oauthError(x, req.RedirectURI, "invalid_client", "client not found", req.State)
}
// 验证response_type
if !strings.Contains(client.ResponseTypes, req.ResponseType) {
return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State)
}
// 验证redirect_uri
if req.RedirectURI != "" {
allowedURIs := strings.Split(client.RedirectURIs, ",")
found := false
for _, uri := range allowedURIs {
if strings.TrimSpace(uri) == req.RedirectURI {
found = true
break
}
}
if !found {
return oauthError(x, "", "invalid_request", "redirect_uri mismatch", "")
}
} else if client.RedirectURIs != "" {
// 使用第一个注册的回调地址
req.RedirectURI = strings.Split(client.RedirectURIs, ",")[0]
} else {
return oauthError(x, "", "invalid_request", "redirect_uri required", "")
}
// 验证scope
requestedScopes := parseScopes(req.Scope)
allowedScopes := parseScopes(client.AllowedScopes)
for _, scope := range requestedScopes {
if !contains(allowedScopes, scope) {
return oauthError(x, req.RedirectURI, "invalid_scope", "scope not allowed: "+scope, req.State)
}
}
// 获取当前用户
var userID string
if uid, ok := x.Get("current_user").(string); ok {
userID = uid
}
if userID == "" {
// 未登录,需要重定向到登录页面
loginURL := "/login?redirect=" + url.QueryEscape(x.Request.URL.String())
x.ResponseWriter().Header().Set("Location", loginURL)
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}
// 获取组织ID (从请求头或用户选择)
orgID := x.Request.Header.Get("X-Org-ID")
if orgID == "" && client.OrgID != "" {
orgID = client.OrgID
}
switch req.ResponseType {
case model.ResponseTypeCode:
// 生成授权码
code, err := generateAuthorizationCode(&client, userID, orgID, req)
if err != nil {
return oauthError(x, req.RedirectURI, "server_error", "", req.State)
}
// 构建重定向URL
redirectURL, _ := url.Parse(req.RedirectURI)
q := redirectURL.Query()
q.Set("code", code)
if req.State != "" {
q.Set("state", req.State)
}
redirectURL.RawQuery = q.Encode()
x.ResponseWriter().Header().Set("Location", redirectURL.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
case model.ResponseTypeToken:
// Implicit flow (简化模式) - 直接返回token
// 注意: 简化模式安全性较低,建议仅在必要时使用
return handleImplicitGrant(x, &client, userID, orgID, req)
default:
return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State)
}
}
// TokenRequest 令牌请求
type TokenRequest struct {
GrantType string `json:"grant_type" src:"form" desc:"授权类型"`
Code string `json:"code" src:"form" desc:"授权码"`
RedirectURI string `json:"redirect_uri" src:"form" desc:"回调地址"`
ClientID string `json:"client_id" src:"form" desc:"客户端ID"`
ClientSecret string `json:"client_secret" src:"form" desc:"客户端密钥"`
RefreshToken string `json:"refresh_token" src:"form" desc:"刷新令牌"`
Scope string `json:"scope" src:"form" desc:"权限范围"`
CodeVerifier string `json:"code_verifier" src:"form" desc:"PKCE验证器"`
Username string `json:"username" src:"form" desc:"用户名(密码模式)"`
Password string `json:"password" src:"form" desc:"密码(密码模式)"`
}
// TokenResponse 令牌响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"` // OIDC
}
// Token 令牌端点 - 交换授权码获取访问令牌
// POST /oauth/token
func Token(x *vigo.X, req *TokenRequest) (*TokenResponse, error) {
// 验证客户端身份
client, err := authenticateClient(x, req)
if err != nil {
return nil, vigo.ErrNotAuthorized.WithString("invalid_client")
}
// 验证grant_type
if !strings.Contains(client.GrantTypes, req.GrantType) {
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type")
}
switch req.GrantType {
case model.GrantTypeAuthorizationCode:
return handleAuthorizationCodeGrant(x, client, req)
case model.GrantTypeRefreshToken:
return handleRefreshTokenGrant(x, client, req)
case model.GrantTypeClientCredentials:
return handleClientCredentialsGrant(x, client, req)
case model.GrantTypePassword:
return handlePasswordGrant(x, client, req)
default:
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type")
}
}
// handleAuthorizationCodeGrant 处理授权码模式
func handleAuthorizationCodeGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
if req.Code == "" {
return nil, vigo.ErrArgInvalid.WithString("code is required")
}
// 查找授权码
var auth model.OAuthAuthorization
if err := model.DB.Where("code = ? AND client_id = ? AND used = ?", req.Code, client.ClientID, false).First(&auth).Error; err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant")
}
// 检查是否过期
if time.Now().After(auth.ExpiresAt) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code expired")
}
// 验证redirect_uri
if req.RedirectURI != "" && req.RedirectURI != auth.RedirectURI {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: redirect_uri mismatch")
}
// PKCE验证
if auth.CodeChallenge != "" {
if req.CodeVerifier == "" {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier required")
}
if !verifyPKCE(req.CodeVerifier, auth.CodeChallenge, auth.CodeChallengeMethod) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier mismatch")
}
}
// 标记授权码为已使用
now := time.Now()
model.DB.Model(&auth).Updates(map[string]interface{}{
"used": true,
"used_at": now,
})
// 生成访问令牌
tokenResp, err := generateTokenPair(client, auth.UserID, auth.OrgID, auth.Scope, client.TokenExpiry, client.RefreshExpiry)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return tokenResp, nil
}
// handleRefreshTokenGrant 处理刷新令牌模式
func handleRefreshTokenGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
if req.RefreshToken == "" {
return nil, vigo.ErrArgInvalid.WithString("refresh_token is required")
}
// 查找刷新令牌
var token model.OAuthToken
if err := model.DB.Where("refresh_token = ? AND client_id = ? AND revoked = ?", req.RefreshToken, client.ClientID, false).First(&token).Error; err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant")
}
// 检查是否过期
if time.Now().After(token.ExpiresAt) {
return nil, vigo.ErrArgInvalid.WithString("invalid_grant: token expired")
}
// 撤销旧的刷新令牌
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
// 生成新的访问令牌
tokenResp, err := generateTokenPair(client, token.UserID, token.OrgID, token.Scope, client.TokenExpiry, client.RefreshExpiry)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return tokenResp, nil
}
// handleClientCredentialsGrant 处理客户端凭证模式
func handleClientCredentialsGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
// 客户端凭证模式没有用户上下文,通常用于服务间调用
// 限制scope只允许非用户相关的权限
requestedScopes := parseScopes(req.Scope)
allowedScopes := []string{}
for _, scope := range requestedScopes {
if scope == "service" || strings.HasPrefix(scope, "service:") {
allowedScopes = append(allowedScopes, scope)
}
}
if len(allowedScopes) == 0 {
allowedScopes = []string{"service"}
}
scopeStr := strings.Join(allowedScopes, " ")
// 生成访问令牌
accessToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second)
token := &model.OAuthToken{
UserID: "", // 客户端凭证模式没有用户
ClientID: client.ClientID,
OrgID: client.OrgID,
AccessToken: accessToken,
TokenType: "Bearer",
Scope: scopeStr,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: client.TokenExpiry,
Scope: scopeStr,
}, nil
}
// handlePasswordGrant 处理密码模式 (不推荐,但为兼容性保留)
func handlePasswordGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) {
// 密码模式需要验证用户凭据
// 这里简化处理,实际应该调用用户认证逻辑
return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type: password grant is disabled")
}
// handleImplicitGrant 处理简化模式
func handleImplicitGrant(x *vigo.X, client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) error {
// 生成访问令牌
accessToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second)
token := &model.OAuthToken{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
AccessToken: accessToken,
TokenType: "Bearer",
Scope: req.Scope,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return oauthError(x, req.RedirectURI, "server_error", "", req.State)
}
// 构建fragment URL
fragment := url.Values{}
fragment.Set("access_token", accessToken)
fragment.Set("token_type", "Bearer")
fragment.Set("expires_in", string(rune(client.TokenExpiry)))
if req.Scope != "" {
fragment.Set("scope", req.Scope)
}
if req.State != "" {
fragment.Set("state", req.State)
}
redirectURL, _ := url.Parse(req.RedirectURI)
redirectURL.Fragment = fragment.Encode()
x.ResponseWriter().Header().Set("Location", redirectURL.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}
// RevokeRequest 撤销请求
type RevokeRequest struct {
Token string `json:"token" src:"form" desc:"要撤销的令牌"`
TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示: access_token/refresh_token"`
}
// Revoke 撤销令牌端点
// POST /oauth/revoke
func Revoke(x *vigo.X, req *RevokeRequest) error {
if req.Token == "" {
return nil // 根据RFC 7009无效的令牌也应返回200
}
// 尝试查找access_token
var token model.OAuthToken
if err := model.DB.Where("access_token = ?", req.Token).First(&token).Error; err == nil {
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
return nil
}
// 尝试查找refresh_token
if err := model.DB.Where("refresh_token = ?", req.Token).First(&token).Error; err == nil {
now := time.Now()
model.DB.Model(&token).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
}
return nil
}
// IntrospectRequest 令牌内省请求
type IntrospectRequest struct {
Token string `json:"token" src:"form" desc:"要内省的令牌"`
TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示"`
}
// IntrospectResponse 令牌内省响应
type IntrospectResponse struct {
Active bool `json:"active"`
Scope string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
}
// Introspect 令牌内省端点 (RFC 7662)
// POST /oauth/introspect
func Introspect(x *vigo.X, req *IntrospectRequest) (*IntrospectResponse, error) {
if req.Token == "" {
return &IntrospectResponse{Active: false}, nil
}
var token model.OAuthToken
if err := model.DB.Where("access_token = ? AND revoked = ?", req.Token, false).First(&token).Error; err != nil {
return &IntrospectResponse{Active: false}, nil
}
// 检查是否过期
if time.Now().After(token.ExpiresAt) {
return &IntrospectResponse{Active: false}, nil
}
// 获取用户信息
var user model.User
username := ""
if err := model.DB.First(&user, "id = ?", token.UserID).Error; err == nil {
username = user.Username
}
return &IntrospectResponse{
Active: true,
Scope: token.Scope,
ClientID: token.ClientID,
Username: username,
TokenType: token.TokenType,
Exp: token.ExpiresAt.Unix(),
Sub: token.UserID,
}, nil
}
// helper functions
func generateAuthorizationCode(client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) (string, error) {
code := generateRandomToken(32)
expiresAt := time.Now().Add(10 * time.Minute) // 授权码10分钟有效
auth := &model.OAuthAuthorization{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
Code: code,
Scope: req.Scope,
State: req.State,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
RedirectURI: req.RedirectURI,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(auth).Error; err != nil {
return "", err
}
return code, nil
}
func generateTokenPair(client *model.OAuthClient, userID, orgID, scope string, accessExpiry, refreshExpiry int) (*TokenResponse, error) {
accessToken := generateRandomToken(32)
refreshToken := generateRandomToken(32)
expiresAt := time.Now().Add(time.Duration(accessExpiry) * time.Second)
token := &model.OAuthToken{
UserID: userID,
ClientID: client.ClientID,
OrgID: orgID,
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
Scope: scope,
ExpiresAt: expiresAt,
}
if err := model.DB.Create(token).Error; err != nil {
return nil, err
}
return &TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: accessExpiry,
RefreshToken: refreshToken,
Scope: scope,
}, nil
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}
func authenticateClient(x *vigo.X, req *TokenRequest) (*model.OAuthClient, error) {
// 优先从Basic Auth获取
clientID, clientSecret, ok := x.Request.BasicAuth()
if ok {
req.ClientID = clientID
req.ClientSecret = clientSecret
}
if req.ClientID == "" || req.ClientSecret == "" {
return nil, vigo.ErrNotAuthorized
}
var client model.OAuthClient
if err := model.DB.Where("client_id = ? AND client_secret = ? AND status = ?", req.ClientID, req.ClientSecret, 1).First(&client).Error; err != nil {
return nil, err
}
return &client, nil
}
func verifyPKCE(verifier, challenge, method string) bool {
switch method {
case "S256":
hash := sha256.Sum256([]byte(verifier))
encoded := base64.RawURLEncoding.EncodeToString(hash[:])
return encoded == challenge
case "plain":
return verifier == challenge
default:
return false
}
}
func parseScopes(scope string) []string {
if scope == "" {
return []string{}
}
return strings.Split(scope, " ")
}
func contains(arr []string, item string) bool {
for _, a := range arr {
if a == item {
return true
}
}
return false
}
func oauthError(x *vigo.X, redirectURI, errorCode, errorDescription, state string) error {
if redirectURI == "" {
return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription)
}
u, err := url.Parse(redirectURI)
if err != nil {
return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription)
}
q := u.Query()
q.Set("error", errorCode)
if errorDescription != "" {
q.Set("error_description", errorDescription)
}
if state != "" {
q.Set("state", state)
}
u.RawQuery = q.Encode()
x.ResponseWriter().Header().Set("Location", u.String())
x.ResponseWriter().WriteHeader(http.StatusFound)
return nil
}