// // Copyright (C) 2024 veypi // 2025-02-14 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "context" "errors" "fmt" "strings" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" pub "github.com/veypi/vigo/contrib/auth" "github.com/veypi/vigo/contrib/event" "gorm.io/gorm" ) const ( // CtxKeyUserID 用户ID上下文键 CtxKeyUserID = "auth:user_id" // 权限等级 LevelNone = 0 LevelCreate = 1 // 001 创建 (检查奇数层) LevelRead = 2 // 010 读取 (检查偶数层) LevelWrite = 4 // 100 写入 (检查偶数层) LevelReadWrite = 6 // 110 读写 (检查偶数层) LevelAdmin = 7 // 111 管理员 (完全控制) ) // PermFunc 权限检查函数类型 type PermFunc = pub.PermFunc // Factory 全局 Auth 工厂 var Factory = &authFactory{ apps: make(map[string]*appAuth), } // VBaseAuth vbase 自身的权限管理实例 var VBaseAuth = Factory.New("vb") var _ pub.Auth = &appAuth{} func init() { // 注册权限初始化回调 event.Add("vb.init.auth", Factory.init) } type authFactory struct { apps map[string]*appAuth } // New 创建权限管理实例 func (f *authFactory) New(scope string) pub.Auth { if auth, exists := f.apps[scope]; exists { return auth } auth := &appAuth{ scope: scope, roleDefs: make(map[string]roleDefinition), } f.apps[scope] = auth return auth } func (f *authFactory) init() error { for appKey, auth := range f.apps { if err := auth.init(); err != nil { return fmt.Errorf("failed to init auth for %s: %w", appKey, err) } } return nil } // roleDefinition 角色定义 type roleDefinition struct { code string name string policies []string // 格式: "permissionID:level" } // appAuth 实现 Auth 接口 type appAuth struct { scope string roleDefs map[string]roleDefinition } // ========== 接口实现 ========== func (a *appAuth) UserID(x *vigo.X) string { if uid, ok := x.Get(CtxKeyUserID).(string); ok { return uid } return "" } // Login 登录检查中间件 func (a *appAuth) Login() PermFunc { return func(x *vigo.X) error { // 1. 提取 token tokenString := extractToken(x) if tokenString == "" { return vigo.ErrUnauthorized.WithString("missing token") } // 2. 解析 token claims, err := jwt.ParseToken(tokenString) if err != nil { if err == jwt.ErrExpiredToken { return vigo.ErrTokenExpired } return vigo.ErrTokenInvalid } // 3. 检查黑名单 if cache.IsEnabled() { blacklisted, _ := cache.IsTokenBlacklisted(claims.ID) if blacklisted { return vigo.ErrUnauthorized.WithString("token has been revoked") } } // 4. 设置 UserID x.Set(CtxKeyUserID, claims.UserID) return nil } } func (a *appAuth) Perm(code string, level int) PermFunc { return func(x *vigo.X) error { userID := a.UserID(x) if userID == "" { // 尝试先运行 Login 逻辑 if err := a.Login()(x); err != nil { return err } userID = a.UserID(x) } // 解析动态参数 permID, err := parsePermissionID(x, code) if err != nil { return vigo.ErrInvalidArg.WithError(err) } // 检查权限 if err := validatePermission(permID, level); err != nil { panic(err) } if !a.Check(x.Context(), userID, permID, level) { return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level)) } return nil } } func (a *appAuth) PermCreate(code string) PermFunc { if err := validatePermission(code, LevelCreate); err != nil { panic(err) } return a.Perm(code, LevelCreate) } func (a *appAuth) PermRead(code string) PermFunc { if err := validatePermission(code, LevelRead); err != nil { panic(err) } return a.Perm(code, LevelRead) } func (a *appAuth) PermWrite(code string) PermFunc { if err := validatePermission(code, LevelWrite); err != nil { panic(err) } return a.Perm(code, LevelWrite) } func (a *appAuth) PermAdmin(code string) PermFunc { if err := validatePermission(code, LevelAdmin); err != nil { panic(err) } return a.Perm(code, LevelAdmin) } // Grant 授予权限 func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error { if err := validatePermission(permissionID, level); err != nil { return err } // 检查是否存在 var count int64 cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Count(&count) if count > 0 { // 更新等级 return cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Update("level", level).Error } // 创建 perm := models.Permission{ Scope: a.scope, UserID: &userID, PermissionID: permissionID, Level: level, } return cfg.DB().Create(&perm).Error } // Revoke 撤销权限 func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error { return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Delete(&models.Permission{}).Error } // GrantRole 授予角色 func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } var count int64 cfg.DB().Model(&models.UserRole{}). Where("user_id = ? AND role_id = ?", userID, role.ID). Count(&count) if count > 0 { return nil // 已经有该角色 } userRole := models.UserRole{ UserID: userID, RoleID: role.ID, } return cfg.DB().Create(&userRole).Error } // RevokeRole 撤销角色 func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID). Delete(&models.UserRole{}).Error } // Check 检查权限 func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool { if err := validatePermission(permissionID, level); err != nil { panic(err) } // 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限) perms, err := a.getUserPermissions(userID) if err != nil { return false } // 2. 检查 return checkPermissionLevel(perms, permissionID, level) } // ListResources 查询用户在特定资源类型下的详细权限信息 func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) { perms, err := a.getUserPermissions(userID) if err != nil { return nil, err } result := make(map[string]int) prefix := resourceType + ":" for _, p := range perms { if p.PermissionID == "*" && p.Level == LevelAdmin { continue } if strings.HasPrefix(p.PermissionID, prefix) { suffix := p.PermissionID[len(prefix):] parts := strings.Split(suffix, ":") if len(parts) > 0 { instanceID := parts[0] level := p.Level // If permission is deeper, assume LevelRead for parent unless explicit if len(parts) > 1 { level = LevelRead } if currentLevel, ok := result[instanceID]; !ok || level > currentLevel { result[instanceID] = level } } } } return result, nil } // ListUsers 查询特定资源的所有协作者及其权限 func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) { parents := getAllParents(permissionID) var perms []models.Permission db := cfg.DB().Where("scope = ?", a.scope) conditions := []string{"permission_id = ?"} args := []interface{}{permissionID} if len(parents) > 0 { conditions = append(conditions, "(permission_id IN ? AND level = ?)") args = append(args, parents, LevelAdmin) } conditions = append(conditions, "(permission_id = ? AND level = ?)") args = append(args, "*", LevelAdmin) query := db.Where(strings.Join(conditions, " OR "), args...) if err := query.Find(&perms).Error; err != nil { return nil, err } result := make(map[string]int) for _, p := range perms { if p.UserID != nil { uid := *p.UserID if l, ok := result[uid]; !ok || p.Level > l { result[uid] = p.Level } } if p.RoleID != nil { var userRoles []models.UserRole cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles) for _, ur := range userRoles { if l, ok := result[ur.UserID]; !ok || p.Level > l { result[ur.UserID] = p.Level } } } } return result, nil } func getAllParents(permID string) []string { parts := strings.Split(permID, ":") var parents []string for i := 1; i < len(parts); i++ { parents = append(parents, strings.Join(parts[:i], ":")) } return parents } // AddRole 添加角色定义 func (a *appAuth) AddRole(code, name string, policies ...string) error { a.roleDefs[code] = roleDefinition{ code: code, name: name, policies: policies, } return nil } // init 初始化角色到数据库 func (a *appAuth) init() error { db := cfg.DB() for code, def := range a.roleDefs { // 1. 确保角色存在 var role models.Role err := db.Where("code = ?", code).First(&role).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { role = models.Role{ Code: code, Name: def.name, IsSystem: true, Status: 1, } if err := db.Create(&role).Error; err != nil { return err } } else { return err } } // 2. 同步角色权限 (Diff Sync) // ID格式: scoped:permission:level (用于唯一标识和 Diff Sync) var targetIDs []string // 获取该角色当前scope下的所有权限ID,用于快速比对 var existingIDs []string if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil { return err } existingMap := make(map[string]bool) for _, id := range existingIDs { existingMap[id] = true } for _, policy := range def.policies { // policy 格式: "permissionID:level" parts := strings.Split(policy, ":") if len(parts) < 2 { continue } levelStr := parts[len(parts)-1] permID := strings.Join(parts[:len(parts)-1], ":") var level int fmt.Sscanf(levelStr, "%d", &level) // 生成确定性 ID: scoped:permission:level id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level) targetIDs = append(targetIDs, id) // 检查是否存在 if !existingMap[id] { // 不存在,创建新权限 newPerm := models.Permission{ Scope: a.scope, RoleID: &role.ID, PermissionID: permID, Level: level, } newPerm.ID = id if err := db.Create(&newPerm).Error; err != nil { return err } } } // 3. 清理不再需要的权限 if len(targetIDs) > 0 { if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs). Delete(&models.Permission{}).Error; err != nil { return err } } else { // 如果没有策略,删除所有 if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope). Delete(&models.Permission{}).Error; err != nil { return err } } } return nil } // ========== 内部辅助方法 ========== // getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission) func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) { var perms []models.Permission db := cfg.DB() // 1. 直接权限 if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil { return nil, err } // 2. 角色权限 // 查用户角色 // UserRole 关联的是 RoleID // Role 表已经没有 Scope,所以这里查出用户拥有的所有角色ID var roleIDs []string if err := db.Model(&models.UserRole{}). Where("user_id = ?", userID). Pluck("role_id", &roleIDs).Error; err != nil { return nil, err } if len(roleIDs) > 0 { var rolePerms []models.Permission // 查询这些角色在当前 scope 下拥有的权限 if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil { return nil, err } perms = append(perms, rolePerms...) } return perms, nil } // checkPermissionLevel 核心鉴权逻辑 func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool { for _, p := range perms { // 1. 管理员特权 (Level 7 且是父级或同级) if p.Level == LevelAdmin { // 如果拥有的权限是 target 的前缀,或者是 * if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) { return true } } // 2. 普通权限匹配 if p.Level >= requiredLevel { if p.PermissionID == targetPermID { return true } } } return false } // parsePermissionID 解析动态权限ID func parsePermissionID(x *vigo.X, code string) (string, error) { // 简单实现,支持 {key}, {key@query}, {key@header} start := strings.Index(code, "{") if start == -1 { return code, nil } end := strings.Index(code, "}") if end == -1 || end < start { return "", fmt.Errorf("invalid permission format") } raw := code[start+1 : end] parts := strings.Split(raw, "@") key := parts[0] source := "path" if len(parts) > 1 { source = parts[1] } var val string switch source { case "query": val = x.Request.URL.Query().Get(key) case "header": val = x.Request.Header.Get(key) case "ctx": if v, ok := x.Get(key).(string); ok { val = v } default: // path val = x.PathParams.Get(key) } if val == "" { return "", fmt.Errorf("param %s not found in %s", key, source) } return code[:start] + val + code[end+1:], nil } // extractToken 辅助函数 func extractToken(x *vigo.X) string { auth := x.Request.Header.Get("Authorization") if auth != "" { if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { return auth[7:] } } return x.Request.URL.Query().Get("access_token") } func validatePermission(code string, level int) error { if code == "*" { if level != LevelAdmin { return fmt.Errorf("wildcard * requires LevelAdmin") } return nil } parts := strings.Split(code, ":") depth := len(parts) if level == LevelCreate { if depth%2 == 0 { return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code) } } else { // Level 2, 4, 6, 7 if depth%2 != 0 { return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code) } } return nil }