// // Copyright (C) 2024 veypi // 2025-02-14 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "context" "errors" "fmt" "strconv" "strings" "time" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" pub "github.com/veypi/vigo/contrib/auth" "github.com/veypi/vigo/contrib/event" "github.com/veypi/vigo/logv" "gorm.io/gorm" ) const ( // CtxKeyUserID 用户ID上下文键 CtxKeyUserID = "auth:user_id" // CtxKeySessionID 会话ID上下文键 CtxKeySessionID = "auth:session_id" // 权限等级 LevelNone = 0 LevelCreate = 1 // 001 创建 (检查奇数层) LevelRead = 2 // 010 读取 (检查偶数层) LevelWrite = 4 // 100 写入 (检查偶数层) LevelReadWrite = 6 // 110 读写 (检查偶数层) LevelAdmin = 7 // 111 管理员 (完全控制) ) // ctxKeyTokenParsed 标记token是否已解析(请求级别,避免重复解析) const ctxKeyTokenParsed = "_token_parsed" // PermFunc 权限检查函数类型 type PermFunc = pub.PermFunc // Provider 是 auth.Provider 的别名,用于实现端 type Provider = pub.Provider // Factory 全局 Auth 工厂 var Factory = &authFactory{ apps: make(map[string]Provider), } var _ pub.Provider = &vbaseProvider{} func init() { // 注册权限初始化回调 event.Add("vb.init.auth", Factory.init) } type authFactory struct { apps map[string]Provider } // New 创建权限 Provider 实例 func (f *authFactory) New(scope string) Provider { if p, exists := f.apps[scope]; exists { return p } p := &vbaseProvider{ scope: scope, roleDefs: make(map[string]roleDefinition), } f.apps[scope] = p return p } func (f *authFactory) init() error { for appKey, p := range f.apps { if vp, ok := p.(*vbaseProvider); ok { if err := vp.init(); err != nil { return fmt.Errorf("failed to init auth for %s: %w", appKey, err) } } } return nil } // roleDefinition 角色定义 type roleDefinition struct { code string name string policies []string // 格式: "permissionID:level" } // vbaseProvider 实现 Provider 接口 type vbaseProvider struct { scope string roleDefs map[string]roleDefinition } // ========== Provider 接口实现 ========== func (a *vbaseProvider) UserID(x *vigo.X) string { // 1. 检查是否已解析过(无论成功与否,避免重复解析) if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed { if uid, ok := x.Get(CtxKeyUserID).(string); ok { return uid } return "" } // 2. 惰性解析:从请求中提取 token tokenStr := extractToken(x) if tokenStr == "" { x.Set(ctxKeyTokenParsed, true) return "" } // 3. 解析并验证 token claims, err := jwt.ParseToken(tokenStr) if err != nil { x.Set(ctxKeyTokenParsed, true) return "" } // 确保是 access token if !jwt.IsAccessToken(claims) { x.Set(ctxKeyTokenParsed, true) return "" } // 验证 access token 对应的 session(允许当前版本或上一版本,防止多 tab 并发刷新互踢) if !ValidateAccessSession(claims.SessionID, claims.UserID, claims.Version) { x.Set(ctxKeyTokenParsed, true) return "" } // 5. 设置到上下文中,供后续调用使用 x.Set(CtxKeyUserID, claims.UserID) x.Set(CtxKeySessionID, claims.SessionID) x.Set(ctxKeyTokenParsed, true) return claims.UserID } // Grant 授予权限 func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error { if err := validatePermission(permissionID, level); err != nil { return err } // 检查是否存在 var count int64 cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Count(&count) if count > 0 { // 更新等级 return cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Update("level", level).Error } // 创建 perm := models.Permission{ Scope: a.scope, UserID: &userID, PermissionID: permissionID, Level: level, } return cfg.DB().Create(&perm).Error } // Revoke 撤销权限 func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error { return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Delete(&models.Permission{}).Error } // GrantRole 授予角色 func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } var count int64 cfg.DB().Model(&models.UserRole{}). Where("user_id = ? AND role_id = ?", userID, role.ID). Count(&count) if count > 0 { return nil // 已经有该角色 } userRole := models.UserRole{ UserID: userID, RoleID: role.ID, } return cfg.DB().Create(&userRole).Error } // RevokeRole 撤销角色 func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID). Delete(&models.UserRole{}).Error } // Check 检查权限 func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool { if err := validatePermission(permissionID, level); err != nil { panic(err) } // 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限) perms, err := a.getUserPermissions(userID) if err != nil { return false } // 2. 检查 return checkPermissionLevel(perms, permissionID, level) } // ListResources 查询用户在特定资源类型下的详细权限信息 func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) { perms, err := a.getUserPermissions(userID) if err != nil { return nil, err } result := make(map[string]int) prefix := resourceType + ":" for _, p := range perms { if p.PermissionID == "*" && p.Level == LevelAdmin { continue } if strings.HasPrefix(p.PermissionID, prefix) { suffix := p.PermissionID[len(prefix):] parts := strings.Split(suffix, ":") if len(parts) > 0 { instanceID := parts[0] level := p.Level // If permission is deeper, assume LevelRead for parent unless explicit if len(parts) > 1 { level = LevelRead } if currentLevel, ok := result[instanceID]; !ok || level > currentLevel { result[instanceID] = level } } } } return result, nil } // ListUsers 查询特定资源的所有协作者及其权限 func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) { parents := getAllParents(permissionID) var perms []models.Permission db := cfg.DB().Where("scope = ?", a.scope) conditions := []string{"permission_id = ?"} args := []interface{}{permissionID} if len(parents) > 0 { conditions = append(conditions, "(permission_id IN ? AND level = ?)") args = append(args, parents, LevelAdmin) } conditions = append(conditions, "(permission_id = ? AND level = ?)") args = append(args, "*", LevelAdmin) query := db.Where(strings.Join(conditions, " OR "), args...) if err := query.Find(&perms).Error; err != nil { return nil, err } result := make(map[string]int) for _, p := range perms { if p.UserID != nil { uid := *p.UserID if l, ok := result[uid]; !ok || p.Level > l { result[uid] = p.Level } } if p.RoleID != nil { var userRoles []models.UserRole cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles) for _, ur := range userRoles { if l, ok := result[ur.UserID]; !ok || p.Level > l { result[ur.UserID] = p.Level } } } } return result, nil } func getAllParents(permID string) []string { parts := strings.Split(permID, ":") var parents []string for i := 1; i < len(parts); i++ { parents = append(parents, strings.Join(parts[:i], ":")) } return parents } // AddRole 添加角色定义 func (a *vbaseProvider) AddRole(code, name string, policies ...string) error { a.roleDefs[code] = roleDefinition{ code: code, name: name, policies: policies, } return nil } // init 初始化角色到数据库 func (a *vbaseProvider) init() error { db := cfg.DB() for code, def := range a.roleDefs { // 1. 确保角色存在 var role models.Role err := db.Where("code = ?", code).First(&role).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { role = models.Role{ Code: code, Name: def.name, IsSystem: true, Status: 1, } if err := db.Create(&role).Error; err != nil { return err } } else { return err } } // 2. 同步角色权限 (Diff Sync) // ID格式: scoped:permission:level (用于唯一标识和 Diff Sync) var targetIDs []string // 获取该角色当前scope下的所有权限ID,用于快速比对 var existingIDs []string if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil { return err } existingMap := make(map[string]bool) for _, id := range existingIDs { existingMap[id] = true } for _, policy := range def.policies { // policy 格式: "permissionID:level" parts := strings.Split(policy, ":") if len(parts) < 2 { continue } levelStr := parts[len(parts)-1] permID := strings.Join(parts[:len(parts)-1], ":") var level int fmt.Sscanf(levelStr, "%d", &level) // 生成确定性 ID: scoped:permission:level id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level) targetIDs = append(targetIDs, id) // 检查是否存在 if !existingMap[id] { // 不存在,创建新权限 newPerm := models.Permission{ Scope: a.scope, RoleID: &role.ID, PermissionID: permID, Level: level, } newPerm.ID = id if err := db.Create(&newPerm).Error; err != nil { return err } } } // 3. 清理不再需要的权限 if len(targetIDs) > 0 { if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs). Delete(&models.Permission{}).Error; err != nil { return err } } else { // 如果没有策略,删除所有 if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope). Delete(&models.Permission{}).Error; err != nil { return err } } } return nil } // ========== 内部辅助方法 ========== // getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission) func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) { var perms []models.Permission db := cfg.DB() // 1. 直接权限 if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil { return nil, err } // 2. 角色权限 // 查用户角色 // UserRole 关联的是 RoleID // Role 表已经没有 Scope,所以这里查出用户拥有的所有角色ID var roleIDs []string if err := db.Model(&models.UserRole{}). Where("user_id = ?", userID). Pluck("role_id", &roleIDs).Error; err != nil { return nil, err } if len(roleIDs) > 0 { var rolePerms []models.Permission // 查询这些角色在当前 scope 下拥有的权限 if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil { return nil, err } perms = append(perms, rolePerms...) } return perms, nil } // checkPermissionLevel 核心鉴权逻辑 func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool { for _, p := range perms { // 1. 管理员特权 (Level 7 且是父级或同级) if p.Level == LevelAdmin { // 如果拥有的权限是 target 的前缀,或者是 * if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) { return true } } // 2. 普通权限匹配 if p.Level >= requiredLevel { if p.PermissionID == targetPermID { return true } } } return false } // parsePermissionID 解析动态权限ID func parsePermissionID(x *vigo.X, code string) (string, error) { // 简单实现,支持 {key}, {key@query}, {key@header} start := strings.Index(code, "{") if start == -1 { return code, nil } end := strings.Index(code, "}") if end == -1 || end < start { return "", fmt.Errorf("invalid permission format") } raw := code[start+1 : end] parts := strings.Split(raw, "@") key := parts[0] source := "path" if len(parts) > 1 { source = parts[1] } var val string switch source { case "query": val = x.Request.URL.Query().Get(key) case "header": val = x.Request.Header.Get(key) case "ctx": if v, ok := x.Get(key).(string); ok { val = v } default: // path val = x.PathParams.Get(key) } if val == "" { return "", fmt.Errorf("param %s not found in %s", key, source) } return code[:start] + val + code[end+1:], nil } // extractToken 从请求中提取 token,优先级: Cookie > Authorization Header > Query func extractToken(x *vigo.X) string { // 1. Cookie (HttpOnly,浏览器自动携带) if c, err := x.Request.Cookie(cfg.Global.JWT.CookiePrefix + "access"); err == nil && c.Value != "" { return c.Value } // 2. Authorization Header auth := x.Request.Header.Get("Authorization") if auth != "" && len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { return auth[7:] } // 3. Query 参数 return x.Request.URL.Query().Get("access_token") } // ========== Session 管理 ========== func sessionKey(sid string) string { return fmt.Sprintf("vb:session:%s", sid) } func userSessionsKey(uid string) string { return fmt.Sprintf("vb:user_sessions:%s", uid) } // CreateSession 创建登录会话(DB + Redis) func CreateSession(userID, deviceInfo, ip string, expiresAt time.Time) (*models.Session, error) { session := &models.Session{ UserID: userID, Version: 1, DeviceInfo: deviceInfo, IP: ip, ExpiresAt: expiresAt, } if err := cfg.DB().Create(session).Error; err != nil { return nil, err } // 写 Redis 缓存 fillSessionCache(session) return session, nil } // GetCurrentSessionID 从请求上下文获取当前 session ID func GetCurrentSessionID(x *vigo.X) string { if sid, ok := x.Get(CtxKeySessionID).(string); ok { return sid } return "" } // ValidateAccessSession 验证 access token 对应的 session(允许当前版本或上一版本) func ValidateAccessSession(sessionID, userID string, tokenVer int64) bool { // 先查 Redis rds := cfg.Redis() if rds != nil { revoked, err := rds.HGet(context.Background(), sessionKey(sessionID), "revoked").Result() if err == nil { verStr, _ := rds.HGet(context.Background(), sessionKey(sessionID), "version").Result() suid, _ := rds.HGet(context.Background(), sessionKey(sessionID), "user_id").Result() if revoked == "true" || suid != userID { return false } ver, err := strconv.ParseInt(verStr, 10, 64) if err != nil { return false } return tokenVer == ver || tokenVer == ver-1 } } // Redis 未命中,查 DB var session models.Session if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil { return false } if session.Revoked { return false } // 回填 Redis fillSessionCache(&session) return tokenVer == session.Version || tokenVer == session.Version-1 } // ValidateRefreshSession 验证 refresh token 并旋转版本号,返回新版本 func ValidateRefreshSession(userID, sessionID string, tokenVer int64) (int64, error) { var session models.Session if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil { return 0, fmt.Errorf("session not found") } if session.Revoked { return 0, fmt.Errorf("session revoked") } if tokenVer != session.Version { return 0, fmt.Errorf("refresh token version mismatch: expected %d, got %d", session.Version, tokenVer) } // 版本 +1,同时延长 session 过期时间 newVer := session.Version + 1 newExpiresAt := time.Now().Add(cfg.Global.JWT.RefreshExpiry) if err := cfg.DB().Model(&session).Updates(map[string]interface{}{ "version": newVer, "expires_at": newExpiresAt, }).Error; err != nil { return 0, err } // 更新 Redis rds := cfg.Redis() if rds != nil { ttl := time.Until(newExpiresAt) if ttl > 0 { pipe := rds.Pipeline() pipe.HSet(context.Background(), sessionKey(sessionID), "version", newVer) pipe.Expire(context.Background(), sessionKey(sessionID), ttl) pipe.Expire(context.Background(), userSessionsKey(userID), ttl) pipe.Exec(context.Background()) } } return newVer, nil } // RevokeSession 撤销指定会话 func RevokeSession(userID, sessionID string) error { now := time.Now() res := cfg.DB().Model(&models.Session{}).Where("id = ? AND user_id = ?", sessionID, userID).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }) if res.Error != nil { return res.Error } // 更新 Redis rds := cfg.Redis() if rds != nil { if err := rds.HSet(context.Background(), sessionKey(sessionID), "revoked", "true").Err(); err != nil { logv.Warn().Msgf("RevokeSession: redis HSet failed: %v", err) } if err := rds.SRem(context.Background(), userSessionsKey(userID), sessionID).Err(); err != nil { logv.Warn().Msgf("RevokeSession: redis SRem failed: %v", err) } } return nil } // RevokeOtherSessions 撤销用户除当前外的其他所有会话 func RevokeOtherSessions(userID, currentSessionID string) error { now := time.Now() if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND id != ? AND revoked = ?", userID, currentSessionID, false).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }).Error; err != nil { return err } rds := cfg.Redis() if rds != nil { members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result() if err == nil { for _, sid := range members { if sid != currentSessionID { rds.HSet(context.Background(), sessionKey(sid), "revoked", "true") rds.SRem(context.Background(), userSessionsKey(userID), sid) } } } else { logv.Warn().Msgf("RevokeOtherSessions: redis SMembers failed: %v", err) } } return nil } // RevokeAllSessions 撤销用户所有会话 func RevokeAllSessions(userID string) error { now := time.Now() if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND revoked = ?", userID, false).Updates(map[string]interface{}{ "revoked": true, "revoked_at": now, }).Error; err != nil { return err } // 清理 Redis rds := cfg.Redis() if rds != nil { members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result() if err == nil { for _, sid := range members { rds.Del(context.Background(), sessionKey(sid)) } } else { logv.Warn().Msgf("RevokeAllSessions: redis SMembers failed: %v", err) } if err := rds.Del(context.Background(), userSessionsKey(userID)).Err(); err != nil { logv.Warn().Msgf("RevokeAllSessions: redis Del failed: %v", err) } } return nil } // ListSessions 列出用户所有活跃会话 func ListSessions(userID string) ([]models.Session, error) { var sessions []models.Session if err := cfg.DB().Where("user_id = ? AND revoked = ? AND expires_at > ?", userID, false, time.Now()).Order("created_at DESC").Find(&sessions).Error; err != nil { return nil, err } return sessions, nil } // fillSessionCache 回填 Redis 缓存 func fillSessionCache(session *models.Session) { rds := cfg.Redis() if rds == nil { return } ttl := time.Until(session.ExpiresAt) if ttl <= 0 { return } revoked := "false" if session.Revoked { revoked = "true" } pipe := rds.Pipeline() pipe.HSet(context.Background(), sessionKey(session.ID), "version", session.Version, "revoked", revoked, "user_id", session.UserID, ) pipe.Expire(context.Background(), sessionKey(session.ID), ttl) pipe.SAdd(context.Background(), userSessionsKey(session.UserID), session.ID) pipe.Expire(context.Background(), userSessionsKey(session.UserID), ttl) if _, err := pipe.Exec(context.Background()); err != nil { logv.Warn().Msgf("fillSessionCache: redis pipeline failed: %v", err) } } func validatePermission(code string, level int) error { if code == "*" { if level != LevelAdmin { return fmt.Errorf("wildcard * requires LevelAdmin") } return nil } parts := strings.Split(code, ":") depth := len(parts) if level == LevelCreate { if depth%2 == 0 { return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code) } } else { // Level 2, 4, 6, 7 if depth%2 != 0 { return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code) } } return nil }