mirror of https://github.com/veypi/OneAuth.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
3.1 KiB
Go
142 lines
3.1 KiB
Go
|
2 weeks ago
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"github.com/veypi/vbase/internal/cache"
|
||
|
|
"github.com/veypi/vbase/internal/pkg/jwt"
|
||
|
|
"github.com/veypi/vigo"
|
||
|
|
)
|
||
|
|
|
||
|
|
const (
|
||
|
|
ContextKeyUser = "current_user"
|
||
|
|
ContextKeyClaims = "jwt_claims"
|
||
|
|
ContextKeyOrgID = "org_id"
|
||
|
|
ContextKeyIsAdmin = "is_admin"
|
||
|
|
)
|
||
|
|
|
||
|
|
// AuthRequired 认证中间件
|
||
|
|
func AuthRequired(skips ...string) func(*vigo.X) (any, error) {
|
||
|
|
skipMap := make(map[string]bool)
|
||
|
|
for _, s := range skips {
|
||
|
|
skipMap[s] = true
|
||
|
|
}
|
||
|
|
|
||
|
|
return func(x *vigo.X) (any, error) {
|
||
|
|
// 检查是否跳过
|
||
|
|
if skipMap[x.Request.URL.Path] {
|
||
|
|
x.Next()
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// 提取token
|
||
|
|
tokenString := extractToken(x.Request)
|
||
|
|
if tokenString == "" {
|
||
|
|
return nil, vigo.ErrNotAuthorized.WithString("missing token")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 解析token
|
||
|
|
claims, err := jwt.ParseToken(tokenString)
|
||
|
|
if err != nil {
|
||
|
|
if err == jwt.ErrExpiredToken {
|
||
|
|
return nil, vigo.ErrNotAuthorized.WithString("token expired")
|
||
|
|
}
|
||
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid token")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 必须是access token
|
||
|
|
if !jwt.IsAccessToken(claims) {
|
||
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid token type")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查黑名单
|
||
|
|
if cache.IsEnabled() {
|
||
|
|
isRevoked, err := cache.IsTokenBlacklisted(claims.ID)
|
||
|
|
if err != nil {
|
||
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
||
|
|
}
|
||
|
|
if isRevoked {
|
||
|
|
return nil, vigo.ErrNotAuthorized.WithString("token revoked")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 设置上下文
|
||
|
|
x.Set(ContextKeyClaims, claims)
|
||
|
|
x.Set(ContextKeyUser, claims.UserID)
|
||
|
|
|
||
|
|
x.Next()
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// OptionalAuth 可选认证
|
||
|
|
func OptionalAuth() func(*vigo.X) (any, error) {
|
||
|
|
return func(x *vigo.X) (any, error) {
|
||
|
|
tokenString := extractToken(x.Request)
|
||
|
|
if tokenString != "" {
|
||
|
|
claims, err := jwt.ParseToken(tokenString)
|
||
|
|
if err == nil && jwt.IsAccessToken(claims) {
|
||
|
|
x.Set(ContextKeyClaims, claims)
|
||
|
|
x.Set(ContextKeyUser, claims.UserID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
x.Next()
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// OrgContext 组织上下文中间件
|
||
|
|
func OrgContext() func(*vigo.X) (any, error) {
|
||
|
|
return func(x *vigo.X) (any, error) {
|
||
|
|
orgID := x.Request.Header.Get("X-Org-ID")
|
||
|
|
if orgID == "" {
|
||
|
|
orgID = x.Request.URL.Query().Get("org_id")
|
||
|
|
}
|
||
|
|
if orgID != "" {
|
||
|
|
x.Set(ContextKeyOrgID, orgID)
|
||
|
|
}
|
||
|
|
x.Next()
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// extractToken 从请求中提取token
|
||
|
|
func extractToken(r *http.Request) string {
|
||
|
|
// 从Header获取
|
||
|
|
auth := r.Header.Get("Authorization")
|
||
|
|
if auth != "" {
|
||
|
|
parts := strings.SplitN(auth, " ", 2)
|
||
|
|
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||
|
|
return parts[1]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 从Query获取
|
||
|
|
return r.URL.Query().Get("access_token")
|
||
|
|
}
|
||
|
|
|
||
|
|
// CurrentUser 获取当前用户ID
|
||
|
|
func CurrentUser(x *vigo.X) string {
|
||
|
|
if uid, ok := x.Get(ContextKeyUser).(string); ok {
|
||
|
|
return uid
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
// CurrentClaims 获取当前JWT Claims
|
||
|
|
func CurrentClaims(x *vigo.X) *jwt.Claims {
|
||
|
|
if claims, ok := x.Get(ContextKeyClaims).(*jwt.Claims); ok {
|
||
|
|
return claims
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// CurrentOrgID 获取当前组织ID
|
||
|
|
func CurrentOrgID(x *vigo.X) string {
|
||
|
|
if orgID, ok := x.Get(ContextKeyOrgID).(string); ok {
|
||
|
|
return orgID
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|