package middleware import ( "net/http" "strings" "github.com/veypi/vbase/internal/cache" "github.com/veypi/vbase/internal/pkg/jwt" "github.com/veypi/vigo" ) const ( ContextKeyUser = "current_user" ContextKeyClaims = "jwt_claims" ContextKeyOrgID = "org_id" ContextKeyIsAdmin = "is_admin" ) // AuthRequired 认证中间件 func AuthRequired(skips ...string) func(*vigo.X) (any, error) { skipMap := make(map[string]bool) for _, s := range skips { skipMap[s] = true } return func(x *vigo.X) (any, error) { // 检查是否跳过 if skipMap[x.Request.URL.Path] { x.Next() return nil, nil } // 提取token tokenString := extractToken(x.Request) if tokenString == "" { return nil, vigo.ErrNotAuthorized.WithString("missing token") } // 解析token claims, err := jwt.ParseToken(tokenString) if err != nil { if err == jwt.ErrExpiredToken { return nil, vigo.ErrNotAuthorized.WithString("token expired") } return nil, vigo.ErrNotAuthorized.WithString("invalid token") } // 必须是access token if !jwt.IsAccessToken(claims) { return nil, vigo.ErrNotAuthorized.WithString("invalid token type") } // 检查黑名单 if cache.IsEnabled() { isRevoked, err := cache.IsTokenBlacklisted(claims.ID) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } if isRevoked { return nil, vigo.ErrNotAuthorized.WithString("token revoked") } } // 设置上下文 x.Set(ContextKeyClaims, claims) x.Set(ContextKeyUser, claims.UserID) x.Next() return nil, nil } } // OptionalAuth 可选认证 func OptionalAuth() func(*vigo.X) (any, error) { return func(x *vigo.X) (any, error) { tokenString := extractToken(x.Request) if tokenString != "" { claims, err := jwt.ParseToken(tokenString) if err == nil && jwt.IsAccessToken(claims) { x.Set(ContextKeyClaims, claims) x.Set(ContextKeyUser, claims.UserID) } } x.Next() return nil, nil } } // OrgContext 组织上下文中间件 func OrgContext() func(*vigo.X) (any, error) { return func(x *vigo.X) (any, error) { orgID := x.Request.Header.Get("X-Org-ID") if orgID == "" { orgID = x.Request.URL.Query().Get("org_id") } if orgID != "" { x.Set(ContextKeyOrgID, orgID) } x.Next() return nil, nil } } // extractToken 从请求中提取token func extractToken(r *http.Request) string { // 从Header获取 auth := r.Header.Get("Authorization") if auth != "" { parts := strings.SplitN(auth, " ", 2) if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { return parts[1] } } // 从Query获取 return r.URL.Query().Get("access_token") } // CurrentUser 获取当前用户ID func CurrentUser(x *vigo.X) string { if uid, ok := x.Get(ContextKeyUser).(string); ok { return uid } return "" } // CurrentClaims 获取当前JWT Claims func CurrentClaims(x *vigo.X) *jwt.Claims { if claims, ok := x.Get(ContextKeyClaims).(*jwt.Claims); ok { return claims } return nil } // CurrentOrgID 获取当前组织ID func CurrentOrgID(x *vigo.X) string { if orgID, ok := x.Get(ContextKeyOrgID).(string); ok { return orgID } return "" }