You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
3.1 KiB
Go

package middleware
import (
"net/http"
"strings"
"github.com/veypi/vbase/internal/cache"
"github.com/veypi/vbase/internal/pkg/jwt"
"github.com/veypi/vigo"
)
const (
ContextKeyUser = "current_user"
ContextKeyClaims = "jwt_claims"
ContextKeyOrgID = "org_id"
ContextKeyIsAdmin = "is_admin"
)
// AuthRequired 认证中间件
func AuthRequired(skips ...string) func(*vigo.X) (any, error) {
skipMap := make(map[string]bool)
for _, s := range skips {
skipMap[s] = true
}
return func(x *vigo.X) (any, error) {
// 检查是否跳过
if skipMap[x.Request.URL.Path] {
x.Next()
return nil, nil
}
// 提取token
tokenString := extractToken(x.Request)
if tokenString == "" {
return nil, vigo.ErrNotAuthorized.WithString("missing token")
}
// 解析token
claims, err := jwt.ParseToken(tokenString)
if err != nil {
if err == jwt.ErrExpiredToken {
return nil, vigo.ErrNotAuthorized.WithString("token expired")
}
return nil, vigo.ErrNotAuthorized.WithString("invalid token")
}
// 必须是access token
if !jwt.IsAccessToken(claims) {
return nil, vigo.ErrNotAuthorized.WithString("invalid token type")
}
// 检查黑名单
if cache.IsEnabled() {
isRevoked, err := cache.IsTokenBlacklisted(claims.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
if isRevoked {
return nil, vigo.ErrNotAuthorized.WithString("token revoked")
}
}
// 设置上下文
x.Set(ContextKeyClaims, claims)
x.Set(ContextKeyUser, claims.UserID)
x.Next()
return nil, nil
}
}
// OptionalAuth 可选认证
func OptionalAuth() func(*vigo.X) (any, error) {
return func(x *vigo.X) (any, error) {
tokenString := extractToken(x.Request)
if tokenString != "" {
claims, err := jwt.ParseToken(tokenString)
if err == nil && jwt.IsAccessToken(claims) {
x.Set(ContextKeyClaims, claims)
x.Set(ContextKeyUser, claims.UserID)
}
}
x.Next()
return nil, nil
}
}
// OrgContext 组织上下文中间件
func OrgContext() func(*vigo.X) (any, error) {
return func(x *vigo.X) (any, error) {
orgID := x.Request.Header.Get("X-Org-ID")
if orgID == "" {
orgID = x.Request.URL.Query().Get("org_id")
}
if orgID != "" {
x.Set(ContextKeyOrgID, orgID)
}
x.Next()
return nil, nil
}
}
// extractToken 从请求中提取token
func extractToken(r *http.Request) string {
// 从Header获取
auth := r.Header.Get("Authorization")
if auth != "" {
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
return parts[1]
}
}
// 从Query获取
return r.URL.Query().Get("access_token")
}
// CurrentUser 获取当前用户ID
func CurrentUser(x *vigo.X) string {
if uid, ok := x.Get(ContextKeyUser).(string); ok {
return uid
}
return ""
}
// CurrentClaims 获取当前JWT Claims
func CurrentClaims(x *vigo.X) *jwt.Claims {
if claims, ok := x.Get(ContextKeyClaims).(*jwt.Claims); ok {
return claims
}
return nil
}
// CurrentOrgID 获取当前组织ID
func CurrentOrgID(x *vigo.X) string {
if orgID, ok := x.Get(ContextKeyOrgID).(string); ok {
return orgID
}
return ""
}