// // Copyright (C) 2024 veypi // 2025-02-14 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "context" "errors" "fmt" "strings" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" pub "github.com/veypi/vigo/contrib/auth" "github.com/veypi/vigo/contrib/event" "gorm.io/gorm" ) const ( // CtxKeyUserID 用户ID上下文键 CtxKeyUserID = "auth:user_id" // 权限等级 LevelNone = 0 LevelCreate = 1 // 001 创建 (检查奇数层) LevelRead = 2 // 010 读取 (检查偶数层) LevelWrite = 4 // 100 写入 (检查偶数层) LevelReadWrite = 6 // 110 读写 (检查偶数层) LevelAdmin = 7 // 111 管理员 (完全控制) ) // ctxKeyTokenParsed 标记token是否已解析(请求级别,避免重复解析) const ctxKeyTokenParsed = "_token_parsed" // PermFunc 权限检查函数类型 type PermFunc = pub.PermFunc // Provider 是 auth.Provider 的别名,用于实现端 type Provider = pub.Provider // Factory 全局 Auth 工厂 var Factory = &authFactory{ apps: make(map[string]Provider), } var _ pub.Provider = &vbaseProvider{} func init() { // 注册权限初始化回调 event.Add("vb.init.auth", Factory.init) } type authFactory struct { apps map[string]Provider } // New 创建权限 Provider 实例 func (f *authFactory) New(scope string) Provider { if p, exists := f.apps[scope]; exists { return p } p := &vbaseProvider{ scope: scope, roleDefs: make(map[string]roleDefinition), } f.apps[scope] = p return p } func (f *authFactory) init() error { for appKey, p := range f.apps { if vp, ok := p.(*vbaseProvider); ok { if err := vp.init(); err != nil { return fmt.Errorf("failed to init auth for %s: %w", appKey, err) } } } return nil } // roleDefinition 角色定义 type roleDefinition struct { code string name string policies []string // 格式: "permissionID:level" } // vbaseProvider 实现 Provider 接口 type vbaseProvider struct { scope string roleDefs map[string]roleDefinition } // ========== Provider 接口实现 ========== func (a *vbaseProvider) UserID(x *vigo.X) string { // 1. 检查是否已解析过(无论成功与否,避免重复解析) if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed { if uid, ok := x.Get(CtxKeyUserID).(string); ok { return uid } return "" } // 2. 惰性解析:从请求中提取 token tokenStr := extractToken(x) if tokenStr == "" { x.Set(ctxKeyTokenParsed, true) return "" } // 3. 解析并验证 token claims, err := jwt.ParseToken(tokenStr) if err != nil { x.Set(ctxKeyTokenParsed, true) return "" } // 确保是 access token if !jwt.IsAccessToken(claims) { x.Set(ctxKeyTokenParsed, true) return "" } // 4. 设置到上下文中,供后续调用使用 x.Set(CtxKeyUserID, claims.UserID) x.Set(ctxKeyTokenParsed, true) return claims.UserID } // Grant 授予权限 func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error { if err := validatePermission(permissionID, level); err != nil { return err } // 检查是否存在 var count int64 cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Count(&count) if count > 0 { // 更新等级 return cfg.DB().Model(&models.Permission{}). Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Update("level", level).Error } // 创建 perm := models.Permission{ Scope: a.scope, UserID: &userID, PermissionID: permissionID, Level: level, } return cfg.DB().Create(&perm).Error } // Revoke 撤销权限 func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error { return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope). Delete(&models.Permission{}).Error } // GrantRole 授予角色 func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } var count int64 cfg.DB().Model(&models.UserRole{}). Where("user_id = ? AND role_id = ?", userID, role.ID). Count(&count) if count > 0 { return nil // 已经有该角色 } userRole := models.UserRole{ UserID: userID, RoleID: role.ID, } return cfg.DB().Create(&userRole).Error } // RevokeRole 撤销角色 func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error { var role models.Role if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil { return err } return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID). Delete(&models.UserRole{}).Error } // Check 检查权限 func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool { if err := validatePermission(permissionID, level); err != nil { panic(err) } // 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限) perms, err := a.getUserPermissions(userID) if err != nil { return false } // 2. 检查 return checkPermissionLevel(perms, permissionID, level) } // ListResources 查询用户在特定资源类型下的详细权限信息 func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) { perms, err := a.getUserPermissions(userID) if err != nil { return nil, err } result := make(map[string]int) prefix := resourceType + ":" for _, p := range perms { if p.PermissionID == "*" && p.Level == LevelAdmin { continue } if strings.HasPrefix(p.PermissionID, prefix) { suffix := p.PermissionID[len(prefix):] parts := strings.Split(suffix, ":") if len(parts) > 0 { instanceID := parts[0] level := p.Level // If permission is deeper, assume LevelRead for parent unless explicit if len(parts) > 1 { level = LevelRead } if currentLevel, ok := result[instanceID]; !ok || level > currentLevel { result[instanceID] = level } } } } return result, nil } // ListUsers 查询特定资源的所有协作者及其权限 func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) { parents := getAllParents(permissionID) var perms []models.Permission db := cfg.DB().Where("scope = ?", a.scope) conditions := []string{"permission_id = ?"} args := []interface{}{permissionID} if len(parents) > 0 { conditions = append(conditions, "(permission_id IN ? AND level = ?)") args = append(args, parents, LevelAdmin) } conditions = append(conditions, "(permission_id = ? AND level = ?)") args = append(args, "*", LevelAdmin) query := db.Where(strings.Join(conditions, " OR "), args...) if err := query.Find(&perms).Error; err != nil { return nil, err } result := make(map[string]int) for _, p := range perms { if p.UserID != nil { uid := *p.UserID if l, ok := result[uid]; !ok || p.Level > l { result[uid] = p.Level } } if p.RoleID != nil { var userRoles []models.UserRole cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles) for _, ur := range userRoles { if l, ok := result[ur.UserID]; !ok || p.Level > l { result[ur.UserID] = p.Level } } } } return result, nil } func getAllParents(permID string) []string { parts := strings.Split(permID, ":") var parents []string for i := 1; i < len(parts); i++ { parents = append(parents, strings.Join(parts[:i], ":")) } return parents } // AddRole 添加角色定义 func (a *vbaseProvider) AddRole(code, name string, policies ...string) error { a.roleDefs[code] = roleDefinition{ code: code, name: name, policies: policies, } return nil } // init 初始化角色到数据库 func (a *vbaseProvider) init() error { db := cfg.DB() for code, def := range a.roleDefs { // 1. 确保角色存在 var role models.Role err := db.Where("code = ?", code).First(&role).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { role = models.Role{ Code: code, Name: def.name, IsSystem: true, Status: 1, } if err := db.Create(&role).Error; err != nil { return err } } else { return err } } // 2. 同步角色权限 (Diff Sync) // ID格式: scoped:permission:level (用于唯一标识和 Diff Sync) var targetIDs []string // 获取该角色当前scope下的所有权限ID,用于快速比对 var existingIDs []string if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil { return err } existingMap := make(map[string]bool) for _, id := range existingIDs { existingMap[id] = true } for _, policy := range def.policies { // policy 格式: "permissionID:level" parts := strings.Split(policy, ":") if len(parts) < 2 { continue } levelStr := parts[len(parts)-1] permID := strings.Join(parts[:len(parts)-1], ":") var level int fmt.Sscanf(levelStr, "%d", &level) // 生成确定性 ID: scoped:permission:level id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level) targetIDs = append(targetIDs, id) // 检查是否存在 if !existingMap[id] { // 不存在,创建新权限 newPerm := models.Permission{ Scope: a.scope, RoleID: &role.ID, PermissionID: permID, Level: level, } newPerm.ID = id if err := db.Create(&newPerm).Error; err != nil { return err } } } // 3. 清理不再需要的权限 if len(targetIDs) > 0 { if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs). Delete(&models.Permission{}).Error; err != nil { return err } } else { // 如果没有策略,删除所有 if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope). Delete(&models.Permission{}).Error; err != nil { return err } } } return nil } // ========== 内部辅助方法 ========== // getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission) func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) { var perms []models.Permission db := cfg.DB() // 1. 直接权限 if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil { return nil, err } // 2. 角色权限 // 查用户角色 // UserRole 关联的是 RoleID // Role 表已经没有 Scope,所以这里查出用户拥有的所有角色ID var roleIDs []string if err := db.Model(&models.UserRole{}). Where("user_id = ?", userID). Pluck("role_id", &roleIDs).Error; err != nil { return nil, err } if len(roleIDs) > 0 { var rolePerms []models.Permission // 查询这些角色在当前 scope 下拥有的权限 if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil { return nil, err } perms = append(perms, rolePerms...) } return perms, nil } // checkPermissionLevel 核心鉴权逻辑 func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool { for _, p := range perms { // 1. 管理员特权 (Level 7 且是父级或同级) if p.Level == LevelAdmin { // 如果拥有的权限是 target 的前缀,或者是 * if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) { return true } } // 2. 普通权限匹配 if p.Level >= requiredLevel { if p.PermissionID == targetPermID { return true } } } return false } // parsePermissionID 解析动态权限ID func parsePermissionID(x *vigo.X, code string) (string, error) { // 简单实现,支持 {key}, {key@query}, {key@header} start := strings.Index(code, "{") if start == -1 { return code, nil } end := strings.Index(code, "}") if end == -1 || end < start { return "", fmt.Errorf("invalid permission format") } raw := code[start+1 : end] parts := strings.Split(raw, "@") key := parts[0] source := "path" if len(parts) > 1 { source = parts[1] } var val string switch source { case "query": val = x.Request.URL.Query().Get(key) case "header": val = x.Request.Header.Get(key) case "ctx": if v, ok := x.Get(key).(string); ok { val = v } default: // path val = x.PathParams.Get(key) } if val == "" { return "", fmt.Errorf("param %s not found in %s", key, source) } return code[:start] + val + code[end+1:], nil } // extractToken 辅助函数 func extractToken(x *vigo.X) string { auth := x.Request.Header.Get("Authorization") if auth != "" { if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { return auth[7:] } } return x.Request.URL.Query().Get("access_token") } func validatePermission(code string, level int) error { if code == "*" { if level != LevelAdmin { return fmt.Errorf("wildcard * requires LevelAdmin") } return nil } parts := strings.Split(code, ":") depth := len(parts) if level == LevelCreate { if depth%2 == 0 { return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code) } } else { // Level 2, 4, 6, 7 if depth%2 != 0 { return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code) } } return nil }