package oauth import ( "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "net/http" "net/url" "strings" "time" "github.com/veypi/vbase/internal/model" "github.com/veypi/vigo" ) // AuthorizeRequest 授权请求 type AuthorizeRequest struct { ResponseType string `json:"response_type" src:"query" desc:"响应类型: code/token"` ClientID string `json:"client_id" src:"query" desc:"客户端ID"` RedirectURI string `json:"redirect_uri" src:"query" desc:"回调地址"` Scope string `json:"scope" src:"query" desc:"请求的权限范围"` State string `json:"state" src:"query" desc:"状态值(防CSRF)"` CodeChallenge string `json:"code_challenge" src:"query" desc:"PKCE挑战码"` CodeChallengeMethod string `json:"code_challenge_method" src:"query" desc:"PKCE方法: S256/plain"` } // AuthorizeResponse 授权响应 (用于重定向) type AuthorizeResponse struct { Code string `json:"code"` State string `json:"state"` } // Authorize 授权端点 - 处理授权码请求 // GET /oauth/authorize func Authorize(x *vigo.X, req *AuthorizeRequest) error { // 验证必填参数 if req.ResponseType == "" { return oauthError(x, req.RedirectURI, "invalid_request", "response_type is required", req.State) } if req.ClientID == "" { return oauthError(x, req.RedirectURI, "invalid_request", "client_id is required", req.State) } // 查找客户端 var client model.OAuthClient if err := model.DB.Where("client_id = ? AND status = ?", req.ClientID, 1).First(&client).Error; err != nil { return oauthError(x, req.RedirectURI, "invalid_client", "client not found", req.State) } // 验证response_type if !strings.Contains(client.ResponseTypes, req.ResponseType) { return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State) } // 验证redirect_uri if req.RedirectURI != "" { allowedURIs := strings.Split(client.RedirectURIs, ",") found := false for _, uri := range allowedURIs { if strings.TrimSpace(uri) == req.RedirectURI { found = true break } } if !found { return oauthError(x, "", "invalid_request", "redirect_uri mismatch", "") } } else if client.RedirectURIs != "" { // 使用第一个注册的回调地址 req.RedirectURI = strings.Split(client.RedirectURIs, ",")[0] } else { return oauthError(x, "", "invalid_request", "redirect_uri required", "") } // 验证scope requestedScopes := parseScopes(req.Scope) allowedScopes := parseScopes(client.AllowedScopes) for _, scope := range requestedScopes { if !contains(allowedScopes, scope) { return oauthError(x, req.RedirectURI, "invalid_scope", "scope not allowed: "+scope, req.State) } } // 获取当前用户 var userID string if uid, ok := x.Get("current_user").(string); ok { userID = uid } if userID == "" { // 未登录,需要重定向到登录页面 loginURL := "/login?redirect=" + url.QueryEscape(x.Request.URL.String()) x.ResponseWriter().Header().Set("Location", loginURL) x.ResponseWriter().WriteHeader(http.StatusFound) return nil } // 获取组织ID (从请求头或用户选择) orgID := x.Request.Header.Get("X-Org-ID") if orgID == "" && client.OrgID != "" { orgID = client.OrgID } switch req.ResponseType { case model.ResponseTypeCode: // 生成授权码 code, err := generateAuthorizationCode(&client, userID, orgID, req) if err != nil { return oauthError(x, req.RedirectURI, "server_error", "", req.State) } // 构建重定向URL redirectURL, _ := url.Parse(req.RedirectURI) q := redirectURL.Query() q.Set("code", code) if req.State != "" { q.Set("state", req.State) } redirectURL.RawQuery = q.Encode() x.ResponseWriter().Header().Set("Location", redirectURL.String()) x.ResponseWriter().WriteHeader(http.StatusFound) return nil case model.ResponseTypeToken: // Implicit flow (简化模式) - 直接返回token // 注意: 简化模式安全性较低,建议仅在必要时使用 return handleImplicitGrant(x, &client, userID, orgID, req) default: return oauthError(x, req.RedirectURI, "unsupported_response_type", "", req.State) } } // TokenRequest 令牌请求 type TokenRequest struct { GrantType string `json:"grant_type" src:"form" desc:"授权类型"` Code string `json:"code" src:"form" desc:"授权码"` RedirectURI string `json:"redirect_uri" src:"form" desc:"回调地址"` ClientID string `json:"client_id" src:"form" desc:"客户端ID"` ClientSecret string `json:"client_secret" src:"form" desc:"客户端密钥"` RefreshToken string `json:"refresh_token" src:"form" desc:"刷新令牌"` Scope string `json:"scope" src:"form" desc:"权限范围"` CodeVerifier string `json:"code_verifier" src:"form" desc:"PKCE验证器"` Username string `json:"username" src:"form" desc:"用户名(密码模式)"` Password string `json:"password" src:"form" desc:"密码(密码模式)"` } // TokenResponse 令牌响应 type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` IDToken string `json:"id_token,omitempty"` // OIDC } // Token 令牌端点 - 交换授权码获取访问令牌 // POST /oauth/token func Token(x *vigo.X, req *TokenRequest) (*TokenResponse, error) { // 验证客户端身份 client, err := authenticateClient(x, req) if err != nil { return nil, vigo.ErrNotAuthorized.WithString("invalid_client") } // 验证grant_type if !strings.Contains(client.GrantTypes, req.GrantType) { return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type") } switch req.GrantType { case model.GrantTypeAuthorizationCode: return handleAuthorizationCodeGrant(x, client, req) case model.GrantTypeRefreshToken: return handleRefreshTokenGrant(x, client, req) case model.GrantTypeClientCredentials: return handleClientCredentialsGrant(x, client, req) case model.GrantTypePassword: return handlePasswordGrant(x, client, req) default: return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type") } } // handleAuthorizationCodeGrant 处理授权码模式 func handleAuthorizationCodeGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) { if req.Code == "" { return nil, vigo.ErrArgInvalid.WithString("code is required") } // 查找授权码 var auth model.OAuthAuthorization if err := model.DB.Where("code = ? AND client_id = ? AND used = ?", req.Code, client.ClientID, false).First(&auth).Error; err != nil { return nil, vigo.ErrArgInvalid.WithString("invalid_grant") } // 检查是否过期 if time.Now().After(auth.ExpiresAt) { return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code expired") } // 验证redirect_uri if req.RedirectURI != "" && req.RedirectURI != auth.RedirectURI { return nil, vigo.ErrArgInvalid.WithString("invalid_grant: redirect_uri mismatch") } // PKCE验证 if auth.CodeChallenge != "" { if req.CodeVerifier == "" { return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier required") } if !verifyPKCE(req.CodeVerifier, auth.CodeChallenge, auth.CodeChallengeMethod) { return nil, vigo.ErrArgInvalid.WithString("invalid_grant: code_verifier mismatch") } } // 标记授权码为已使用 now := time.Now() model.DB.Model(&auth).Updates(map[string]interface{}{ "used": true, "used_at": now, }) // 生成访问令牌 tokenResp, err := generateTokenPair(client, auth.UserID, auth.OrgID, auth.Scope, client.TokenExpiry, client.RefreshExpiry) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return tokenResp, nil } // handleRefreshTokenGrant 处理刷新令牌模式 func handleRefreshTokenGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) { if req.RefreshToken == "" { return nil, vigo.ErrArgInvalid.WithString("refresh_token is required") } // 查找刷新令牌 var token model.OAuthToken if err := model.DB.Where("refresh_token = ? AND client_id = ? AND revoked = ?", req.RefreshToken, client.ClientID, false).First(&token).Error; err != nil { return nil, vigo.ErrArgInvalid.WithString("invalid_grant") } // 检查是否过期 if time.Now().After(token.ExpiresAt) { return nil, vigo.ErrArgInvalid.WithString("invalid_grant: token expired") } // 撤销旧的刷新令牌 now := time.Now() model.DB.Model(&token).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }) // 生成新的访问令牌 tokenResp, err := generateTokenPair(client, token.UserID, token.OrgID, token.Scope, client.TokenExpiry, client.RefreshExpiry) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return tokenResp, nil } // handleClientCredentialsGrant 处理客户端凭证模式 func handleClientCredentialsGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) { // 客户端凭证模式没有用户上下文,通常用于服务间调用 // 限制scope,只允许非用户相关的权限 requestedScopes := parseScopes(req.Scope) allowedScopes := []string{} for _, scope := range requestedScopes { if scope == "service" || strings.HasPrefix(scope, "service:") { allowedScopes = append(allowedScopes, scope) } } if len(allowedScopes) == 0 { allowedScopes = []string{"service"} } scopeStr := strings.Join(allowedScopes, " ") // 生成访问令牌 accessToken := generateRandomToken(32) expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second) token := &model.OAuthToken{ UserID: "", // 客户端凭证模式没有用户 ClientID: client.ClientID, OrgID: client.OrgID, AccessToken: accessToken, TokenType: "Bearer", Scope: scopeStr, ExpiresAt: expiresAt, } if err := model.DB.Create(token).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return &TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: client.TokenExpiry, Scope: scopeStr, }, nil } // handlePasswordGrant 处理密码模式 (不推荐,但为兼容性保留) func handlePasswordGrant(x *vigo.X, client *model.OAuthClient, req *TokenRequest) (*TokenResponse, error) { // 密码模式需要验证用户凭据 // 这里简化处理,实际应该调用用户认证逻辑 return nil, vigo.ErrArgInvalid.WithString("unsupported_grant_type: password grant is disabled") } // handleImplicitGrant 处理简化模式 func handleImplicitGrant(x *vigo.X, client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) error { // 生成访问令牌 accessToken := generateRandomToken(32) expiresAt := time.Now().Add(time.Duration(client.TokenExpiry) * time.Second) token := &model.OAuthToken{ UserID: userID, ClientID: client.ClientID, OrgID: orgID, AccessToken: accessToken, TokenType: "Bearer", Scope: req.Scope, ExpiresAt: expiresAt, } if err := model.DB.Create(token).Error; err != nil { return oauthError(x, req.RedirectURI, "server_error", "", req.State) } // 构建fragment URL fragment := url.Values{} fragment.Set("access_token", accessToken) fragment.Set("token_type", "Bearer") fragment.Set("expires_in", string(rune(client.TokenExpiry))) if req.Scope != "" { fragment.Set("scope", req.Scope) } if req.State != "" { fragment.Set("state", req.State) } redirectURL, _ := url.Parse(req.RedirectURI) redirectURL.Fragment = fragment.Encode() x.ResponseWriter().Header().Set("Location", redirectURL.String()) x.ResponseWriter().WriteHeader(http.StatusFound) return nil } // RevokeRequest 撤销请求 type RevokeRequest struct { Token string `json:"token" src:"form" desc:"要撤销的令牌"` TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示: access_token/refresh_token"` } // Revoke 撤销令牌端点 // POST /oauth/revoke func Revoke(x *vigo.X, req *RevokeRequest) error { if req.Token == "" { return nil // 根据RFC 7009,无效的令牌也应返回200 } // 尝试查找access_token var token model.OAuthToken if err := model.DB.Where("access_token = ?", req.Token).First(&token).Error; err == nil { now := time.Now() model.DB.Model(&token).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }) return nil } // 尝试查找refresh_token if err := model.DB.Where("refresh_token = ?", req.Token).First(&token).Error; err == nil { now := time.Now() model.DB.Model(&token).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }) } return nil } // IntrospectRequest 令牌内省请求 type IntrospectRequest struct { Token string `json:"token" src:"form" desc:"要内省的令牌"` TokenTypeHint string `json:"token_type_hint" src:"form" desc:"令牌类型提示"` } // IntrospectResponse 令牌内省响应 type IntrospectResponse struct { Active bool `json:"active"` Scope string `json:"scope,omitempty"` ClientID string `json:"client_id,omitempty"` Username string `json:"username,omitempty"` TokenType string `json:"token_type,omitempty"` Exp int64 `json:"exp,omitempty"` Iat int64 `json:"iat,omitempty"` Sub string `json:"sub,omitempty"` Aud string `json:"aud,omitempty"` Iss string `json:"iss,omitempty"` Jti string `json:"jti,omitempty"` } // Introspect 令牌内省端点 (RFC 7662) // POST /oauth/introspect func Introspect(x *vigo.X, req *IntrospectRequest) (*IntrospectResponse, error) { if req.Token == "" { return &IntrospectResponse{Active: false}, nil } var token model.OAuthToken if err := model.DB.Where("access_token = ? AND revoked = ?", req.Token, false).First(&token).Error; err != nil { return &IntrospectResponse{Active: false}, nil } // 检查是否过期 if time.Now().After(token.ExpiresAt) { return &IntrospectResponse{Active: false}, nil } // 获取用户信息 var user model.User username := "" if err := model.DB.First(&user, "id = ?", token.UserID).Error; err == nil { username = user.Username } return &IntrospectResponse{ Active: true, Scope: token.Scope, ClientID: token.ClientID, Username: username, TokenType: token.TokenType, Exp: token.ExpiresAt.Unix(), Sub: token.UserID, }, nil } // helper functions func generateAuthorizationCode(client *model.OAuthClient, userID, orgID string, req *AuthorizeRequest) (string, error) { code := generateRandomToken(32) expiresAt := time.Now().Add(10 * time.Minute) // 授权码10分钟有效 auth := &model.OAuthAuthorization{ UserID: userID, ClientID: client.ClientID, OrgID: orgID, Code: code, Scope: req.Scope, State: req.State, CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, RedirectURI: req.RedirectURI, ExpiresAt: expiresAt, } if err := model.DB.Create(auth).Error; err != nil { return "", err } return code, nil } func generateTokenPair(client *model.OAuthClient, userID, orgID, scope string, accessExpiry, refreshExpiry int) (*TokenResponse, error) { accessToken := generateRandomToken(32) refreshToken := generateRandomToken(32) expiresAt := time.Now().Add(time.Duration(accessExpiry) * time.Second) token := &model.OAuthToken{ UserID: userID, ClientID: client.ClientID, OrgID: orgID, AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", Scope: scope, ExpiresAt: expiresAt, } if err := model.DB.Create(token).Error; err != nil { return nil, err } return &TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: accessExpiry, RefreshToken: refreshToken, Scope: scope, }, nil } func generateRandomToken(length int) string { b := make([]byte, length) rand.Read(b) return hex.EncodeToString(b) } func authenticateClient(x *vigo.X, req *TokenRequest) (*model.OAuthClient, error) { // 优先从Basic Auth获取 clientID, clientSecret, ok := x.Request.BasicAuth() if ok { req.ClientID = clientID req.ClientSecret = clientSecret } if req.ClientID == "" || req.ClientSecret == "" { return nil, vigo.ErrNotAuthorized } var client model.OAuthClient if err := model.DB.Where("client_id = ? AND client_secret = ? AND status = ?", req.ClientID, req.ClientSecret, 1).First(&client).Error; err != nil { return nil, err } return &client, nil } func verifyPKCE(verifier, challenge, method string) bool { switch method { case "S256": hash := sha256.Sum256([]byte(verifier)) encoded := base64.RawURLEncoding.EncodeToString(hash[:]) return encoded == challenge case "plain": return verifier == challenge default: return false } } func parseScopes(scope string) []string { if scope == "" { return []string{} } return strings.Split(scope, " ") } func contains(arr []string, item string) bool { for _, a := range arr { if a == item { return true } } return false } func oauthError(x *vigo.X, redirectURI, errorCode, errorDescription, state string) error { if redirectURI == "" { return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription) } u, err := url.Parse(redirectURI) if err != nil { return vigo.ErrArgInvalid.WithString(errorCode + ": " + errorDescription) } q := u.Query() q.Set("error", errorCode) if errorDescription != "" { q.Set("error_description", errorDescription) } if state != "" { q.Set("state", state) } u.RawQuery = q.Encode() x.ResponseWriter().Header().Set("Location", u.String()) x.ResponseWriter().WriteHeader(http.StatusFound) return nil }