// // Copyright (C) 2024 veypi // 2025-02-14 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "context" "errors" "fmt" "strings" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" "github.com/veypi/vigo/contrib/event" "gorm.io/gorm" ) const ( // CtxKeyUserID 用户ID上下文键 CtxKeyUserID = "auth:user_id" // 权限等级 LevelNone = 0 LevelCreate = 1 // 001 创建 (检查奇数层) LevelRead = 2 // 010 读取 (检查偶数层) LevelWrite = 4 // 100 写入 (检查偶数层) LevelReadWrite = 6 // 110 读写 (检查偶数层) LevelAdmin = 7 // 111 管理员 (完全控制) ) // PermFunc 权限检查函数类型 type PermFunc func(x *vigo.X) error // Auth 权限管理接口 type Auth interface { // ========== 上下文 ========== // UserID 获取当前用户ID UserID(x *vigo.X) string // ========== 登录检查 ========== // Login 检查用户是否登录 Login() PermFunc // ========== 权限检查 ========== // Perm 检查权限 // code: 权限码,支持动态解析 // - 固定写法: "org:orgA" // - 动态解析: "org:{orgID}" 从 path 获取 // "org:{orgID@query}" 从 query 获取 // "org:{orgID@header}" 从 header 获取 // "org:{orgID@ctx}" 从 ctx 获取 // level: 需要的权限等级 Perm(code string, level int) PermFunc // ========== 快捷方法 ========== // PermCreate 检查创建权限 (level 1,检查奇数层) PermCreate(code string) PermFunc // PermRead 检查读取权限 (level 2,检查偶数层) PermRead(code string) PermFunc // PermWrite 检查更新权限 (level 4,检查偶数层) PermWrite(code string) PermFunc // PermAdmin 检查管理员权限 (level 7,检查偶数层) PermAdmin(code string) PermFunc // ========== 权限授予(业务调用) ========== // Grant 授予权限 // 在创建资源、被授权等业务逻辑中调用 // permissionID: 权限码,如 "org:orgA" // level: 权限等级 Grant(ctx context.Context, userID, permissionID string, level int) error // Revoke 撤销权限 Revoke(ctx context.Context, userID, permissionID string) error // ========== 权限查询 ========== // Check 检查权限 不支持动态解析 // permissionID: 完整的权限码,如 "org:orgA" Check(ctx context.Context, userID, permissionID string, level int) bool // ListResources 查询用户在特定资源类型下的详细权限信息 // 用于解决 "查询我有权限的 org 列表" 等场景 // userID: 用户ID // resourceType: 资源类型 (奇数层),如 "org" 或 "org:{orgID}:project" // 返回: map[实例ID]权限等级 (如 {"orgA": 2, "orgB": 7}) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) // ListUsers 查询特定资源的所有协作者及其权限 // 用于解决 "查看这个项目有哪些成员" 等场景 // permissionID: 资源实例权限码,如 "org:orgA" // 返回: map[用户ID]权限等级 (如 {"user1": 2, "user2": 7}) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) // GrantRole 授予角色 GrantRole(ctx context.Context, userID, roleCode string) error // RevokeRole 撤销角色 RevokeRole(ctx context.Context, userID, roleCode string) error // AddRole 添加角色定义 (用于初始化) AddRole(code, name string, policies ...string) error } // Factory 全局 Auth 工厂 var Factory = &authFactory{ apps: make(map[string]*appAuth), } // VBaseAuth vbase 自身的权限管理实例 var VBaseAuth = Factory.New("vb") func init() { // 注册权限初始化回调 event.Add("vb.init.auth", Factory.init) } type authFactory struct { apps map[string]*appAuth } // New 创建权限管理实例 func (f *authFactory) New(scope string) Auth { if auth, exists := f.apps[scope]; exists { return auth } auth := &appAuth{ scope: scope, roleDefs: make(map[string]roleDefinition), } f.apps[scope] = auth return auth } func (f *authFactory) init() error { for appKey, auth := range f.apps { if err := auth.init(); err != nil { return fmt.Errorf("failed to init auth for %s: %w", appKey, err) } } return nil } // roleDefinition 角色定义 type roleDefinition struct { code string name string policies []string // 格式: "permissionID:level" } // appAuth 实现 Auth 接口 type appAuth struct { scope string roleDefs map[string]roleDefinition } // ========== 接口实现 ========== func (a *appAuth) UserID(x *vigo.X) string { if uid, ok := x.Get(CtxKeyUserID).(string); ok { return uid } return "" } // Login 登录检查中间件 func (a *appAuth) Login() PermFunc { return func(x *vigo.X) error { // 1. 提取 token tokenString := extractToken(x) if tokenString == "" { return vigo.ErrUnauthorized.WithString("missing token") } // 2. 解析 token claims, err := jwt.ParseToken(tokenString) if err != nil { if err == jwt.ErrExpiredToken { return vigo.ErrTokenExpired } return vigo.ErrTokenInvalid } // 3. 检查黑名单 if cache.IsEnabled() { blacklisted, _ := cache.IsTokenBlacklisted(claims.ID) if blacklisted { return vigo.ErrUnauthorized.WithString("token has been revoked") } } // 4. 设置 UserID x.Set(CtxKeyUserID, claims.UserID) return nil } } func (a *appAuth) Perm(code string, level int) PermFunc { return func(x *vigo.X) error { userID := a.UserID(x) if userID == "" { // 尝试先运行 Login 逻辑 if err := a.Login()(x); err != nil { return err } userID = a.UserID(x) } // 解析动态参数 permID, err := parsePermissionID(x, code) if err != nil { return vigo.ErrInvalidArg.WithError(err) } // 检查权限 if !a.Check(x.Context(), userID, permID, level) { return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level)) } return nil } } func (a *appAuth) PermCreate(code string) PermFunc { return a.Perm(code, LevelCreate) } func (a *appAuth) PermRead(code string) PermFunc { return a.Perm(code, LevelRead) } func (a *appAuth) PermWrite(code string) PermFunc { return a.Perm(code, LevelWrite) } func (a *appAuth) PermAdmin(code string) PermFunc { return a.Perm(code, LevelAdmin) } // Grant 授予权限 func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error { // 检查是否存在 var count int64 cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Count(&count) if count > 0 { // 更新等级 return cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Update("level", level).Error } // 创建 perm := models.Permission{ Scope: a.scope, UserID: &userID, PermissionID: permissionID, Level: level, } return cfg.DB().Create(&perm).Error } // Revoke 撤销权限 func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error { return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Delete(&models.Permission{}).Error } // GrantRole 授予角色 func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil { return err } var count int64 cfg.DB().Model(&models.UserRole{}). Where("user_id = ? AND role_id = ?", userID, role.ID). Count(&count) if count > 0 { return nil // 已经有该角色 } userRole := models.UserRole{ UserID: userID, RoleID: role.ID, } return cfg.DB().Create(&userRole).Error } // RevokeRole 撤销角色 func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil { return err } return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID). Delete(&models.UserRole{}).Error } // Check 检查权限 func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool { // 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限) perms, err := a.getUserPermissions(userID) if err != nil { return false } // 2. 检查 return checkPermissionLevel(perms, permissionID, level) } // ListResources 查询用户在特定资源类型下的详细权限信息 func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) { perms, err := a.getUserPermissions(userID) if err != nil { return nil, err } result := make(map[string]int) prefix := resourceType + ":" for _, p := range perms { if p.PermissionID == "*" && p.Level == LevelAdmin { continue } if strings.HasPrefix(p.PermissionID, prefix) { suffix := p.PermissionID[len(prefix):] parts := strings.Split(suffix, ":") if len(parts) > 0 { instanceID := parts[0] level := p.Level // If permission is deeper, assume LevelRead for parent unless explicit if len(parts) > 1 { level = LevelRead } if currentLevel, ok := result[instanceID]; !ok || level > currentLevel { result[instanceID] = level } } } } return result, nil } // ListUsers 查询特定资源的所有协作者及其权限 func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) { parents := getAllParents(permissionID) var perms []models.Permission db := cfg.DB().Where("scope = ?", a.scope) query := db.Where("permission_id = ?", permissionID) if len(parents) > 0 { query = query.Or("permission_id IN ? AND level = ?", parents, LevelAdmin) } query = query.Or("permission_id = ? AND level = ?", "*", LevelAdmin) if err := query.Find(&perms).Error; err != nil { return nil, err } result := make(map[string]int) for _, p := range perms { if p.UserID != nil { uid := *p.UserID if l, ok := result[uid]; !ok || p.Level > l { result[uid] = p.Level } } if p.RoleID != nil { var userRoles []models.UserRole cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles) for _, ur := range userRoles { if l, ok := result[ur.UserID]; !ok || p.Level > l { result[ur.UserID] = p.Level } } } } return result, nil } func getAllParents(permID string) []string { parts := strings.Split(permID, ":") var parents []string for i := 1; i < len(parts); i++ { parents = append(parents, strings.Join(parts[:i], ":")) } return parents } // AddRole 添加角色定义 func (a *appAuth) AddRole(code, name string, policies ...string) error { a.roleDefs[code] = roleDefinition{ code: code, name: name, policies: policies, } return nil } // init 初始化角色到数据库 func (a *appAuth) init() error { db := cfg.DB() for code, def := range a.roleDefs { // 1. 确保角色存在 var role models.Role err := db.Where("code = ? AND scope = ?", code, a.scope).First(&role).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { role = models.Role{ Scope: a.scope, Code: code, Name: def.name, IsSystem: true, Status: 1, } if err := db.Create(&role).Error; err != nil { return err } } else { return err } } // 2. 同步角色权限 // 简单起见,先清除旧的,再插入新的(生产环境可能需要更精细的 diff) // 但 Permission 表是 mixed 的,不能随便删。 // 这里我们需要根据 RoleID 删除该角色的所有权限 if err := db.Where("role_id = ?", role.ID).Delete(&models.Permission{}).Error; err != nil { return err } for _, policy := range def.policies { // policy 格式: "permissionID:level" parts := strings.Split(policy, ":") if len(parts) < 2 { continue } // 最后一个部分是 level,前面是 permissionID levelStr := parts[len(parts)-1] permID := strings.Join(parts[:len(parts)-1], ":") var level int fmt.Sscanf(levelStr, "%d", &level) perm := models.Permission{ Scope: a.scope, RoleID: &role.ID, PermissionID: permID, Level: level, } if err := db.Create(&perm).Error; err != nil { return err } } } return nil } // ========== 内部辅助方法 ========== // getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission) func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) { var perms []models.Permission db := cfg.DB() // 1. 直接权限 if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil { return nil, err } // 2. 角色权限 // 查用户角色 // UserRole 关联的是 RoleID,Role 表有 Scope // 我们需要关联查询: UserRole -> Role (where scope=a.scope) var roleIDs []string if err := db.Table("user_roles"). Joins("JOIN roles ON roles.id = user_roles.role_id"). Where("user_roles.user_id = ? AND roles.scope = ?", userID, a.scope). Pluck("user_roles.role_id", &roleIDs).Error; err != nil { return nil, err } if len(roleIDs) > 0 { var rolePerms []models.Permission if err := db.Where("role_id IN ?", roleIDs).Find(&rolePerms).Error; err != nil { return nil, err } perms = append(perms, rolePerms...) } return perms, nil } // checkPermissionLevel 核心鉴权逻辑 func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool { for _, p := range perms { // 1. 管理员特权 (Level 7 且是父级或同级) if p.Level == LevelAdmin { // 如果拥有的权限是 target 的前缀,或者是 * if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) { return true } } // 2. 普通权限匹配 if p.Level >= requiredLevel { if p.PermissionID == targetPermID { return true } } } return false } // parsePermissionID 解析动态权限ID func parsePermissionID(x *vigo.X, code string) (string, error) { // 简单实现,支持 {key}, {key@query}, {key@header} start := strings.Index(code, "{") if start == -1 { return code, nil } end := strings.Index(code, "}") if end == -1 || end < start { return "", fmt.Errorf("invalid permission format") } raw := code[start+1 : end] parts := strings.Split(raw, "@") key := parts[0] source := "path" if len(parts) > 1 { source = parts[1] } var val string switch source { case "query": val = x.Request.URL.Query().Get(key) case "header": val = x.Request.Header.Get(key) case "ctx": if v, ok := x.Get(key).(string); ok { val = v } default: // path val = x.PathParams.Get(key) } if val == "" { return "", fmt.Errorf("param %s not found in %s", key, source) } return code[:start] + val + code[end+1:], nil } // extractToken 辅助函数 func extractToken(x *vigo.X) string { auth := x.Request.Header.Get("Authorization") if auth != "" { if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { return auth[7:] } } return x.Request.URL.Query().Get("access_token") }