You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

594 lines
14 KiB
Go

1 month ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
1 month ago
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/jwt"
1 month ago
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
pub "github.com/veypi/vigo/contrib/auth"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
1 month ago
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// PermFunc 权限检查函数类型
type PermFunc = pub.PermFunc
1 month ago
// Factory 全局 Auth 工厂
1 month ago
var Factory = &authFactory{
apps: make(map[string]*appAuth),
}
// VBaseAuth vbase 自身的权限管理实例
var VBaseAuth = Factory.New("vb")
var _ pub.Auth = &appAuth{}
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
1 month ago
type authFactory struct {
apps map[string]*appAuth
1 month ago
}
// New 创建权限管理实例
func (f *authFactory) New(scope string) pub.Auth {
if auth, exists := f.apps[scope]; exists {
return auth
}
1 month ago
auth := &appAuth{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = auth
1 month ago
return auth
}
func (f *authFactory) init() error {
1 month ago
for appKey, auth := range f.apps {
if err := auth.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// appAuth 实现 Auth 接口
1 month ago
type appAuth struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== 接口实现 ==========
func (a *appAuth) UserID(x *vigo.X) string {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
1 month ago
}
return ""
1 month ago
}
// Login 登录检查中间件
func (a *appAuth) Login() PermFunc {
return func(x *vigo.X) error {
// 1. 提取 token
tokenString := extractToken(x)
if tokenString == "" {
return vigo.ErrUnauthorized.WithString("missing token")
1 month ago
}
// 2. 解析 token
claims, err := jwt.ParseToken(tokenString)
if err != nil {
if err == jwt.ErrExpiredToken {
return vigo.ErrTokenExpired
}
return vigo.ErrTokenInvalid
}
// 3. 检查黑名单
if cache.IsEnabled() {
blacklisted, _ := cache.IsTokenBlacklisted(claims.ID)
if blacklisted {
return vigo.ErrUnauthorized.WithString("token has been revoked")
}
}
// 4. 设置 UserID
x.Set(CtxKeyUserID, claims.UserID)
return nil
}
1 month ago
}
func (a *appAuth) Perm(code string, level int) PermFunc {
1 month ago
return func(x *vigo.X) error {
userID := a.UserID(x)
1 month ago
if userID == "" {
// 尝试先运行 Login 逻辑
if err := a.Login()(x); err != nil {
return err
}
userID = a.UserID(x)
}
1 month ago
// 解析动态参数
permID, err := parsePermissionID(x, code)
if err != nil {
return vigo.ErrInvalidArg.WithError(err)
1 month ago
}
// 检查权限
if err := validatePermission(permID, level); err != nil {
panic(err)
}
if !a.Check(x.Context(), userID, permID, level) {
return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level))
1 month ago
}
return nil
}
}
func (a *appAuth) PermCreate(code string) PermFunc {
if err := validatePermission(code, LevelCreate); err != nil {
panic(err)
}
return a.Perm(code, LevelCreate)
}
func (a *appAuth) PermRead(code string) PermFunc {
if err := validatePermission(code, LevelRead); err != nil {
panic(err)
}
return a.Perm(code, LevelRead)
}
func (a *appAuth) PermWrite(code string) PermFunc {
if err := validatePermission(code, LevelWrite); err != nil {
panic(err)
}
return a.Perm(code, LevelWrite)
}
func (a *appAuth) PermAdmin(code string) PermFunc {
if err := validatePermission(code, LevelAdmin); err != nil {
panic(err)
}
return a.Perm(code, LevelAdmin)
}
1 month ago
// Grant 授予权限
func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error {
if err := validatePermission(permissionID, level); err != nil {
return err
}
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
1 month ago
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
}
1 month ago
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
1 month ago
}
return cfg.DB().Create(&perm).Error
1 month ago
}
// Revoke 撤销权限
func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
1 month ago
// GrantRole 授予角色
func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error {
1 month ago
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
1 month ago
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
1 month ago
if count > 0 {
return nil // 已经有该角色
1 month ago
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
1 month ago
}
return cfg.DB().Create(&userRole).Error
}
1 month ago
// RevokeRole 撤销角色
func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
}
1 month ago
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool {
if err := validatePermission(permissionID, level); err != nil {
panic(err)
}
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
1 month ago
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
1 month ago
}
1 month ago
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
1 month ago
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
1 month ago
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
1 month ago
conditions := []string{"permission_id = ?"}
args := []interface{}{permissionID}
if len(parents) > 0 {
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
args = append(args, parents, LevelAdmin)
1 month ago
}
conditions = append(conditions, "(permission_id = ? AND level = ?)")
args = append(args, "*", LevelAdmin)
query := db.Where(strings.Join(conditions, " OR "), args...)
1 month ago
if err := query.Find(&perms).Error; err != nil {
return nil, err
1 month ago
}
result := make(map[string]int)
1 month ago
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
1 month ago
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
1 month ago
}
return parents
}
1 month ago
// AddRole 添加角色定义
func (a *appAuth) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
1 month ago
}
return nil
}
// init 初始化角色到数据库
func (a *appAuth) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ?", code).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
1 month ago
}
// 2. 同步角色权限 (Diff Sync)
// ID格式: scoped:permission:level (用于唯一标识和 Diff Sync)
var targetIDs []string
// 获取该角色当前scope下的所有权限ID用于快速比对
var existingIDs []string
if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil {
return err
1 month ago
}
existingMap := make(map[string]bool)
for _, id := range existingIDs {
existingMap[id] = true
}
1 month ago
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
// 生成确定性 ID: scoped:permission:level
id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level)
targetIDs = append(targetIDs, id)
// 检查是否存在
if !existingMap[id] {
// 不存在,创建新权限
newPerm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
newPerm.ID = id
if err := db.Create(&newPerm).Error; err != nil {
return err
}
}
}
// 3. 清理不再需要的权限
if len(targetIDs) > 0 {
if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs).
Delete(&models.Permission{}).Error; err != nil {
return err
}
} else {
// 如果没有策略,删除所有
if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope).
Delete(&models.Permission{}).Error; err != nil {
return err
}
}
1 month ago
}
return nil
}
1 month ago
// ========== 内部辅助方法 ==========
1 month ago
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
1 month ago
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
1 month ago
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleID
// Role 表已经没有 Scope所以这里查出用户拥有的所有角色ID
1 month ago
var roleIDs []string
if err := db.Model(&models.UserRole{}).
Where("user_id = ?", userID).
Pluck("role_id", &roleIDs).Error; err != nil {
1 month ago
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
// 查询这些角色在当前 scope 下拥有的权限
if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil {
1 month ago
return nil, err
}
perms = append(perms, rolePerms...)
}
1 month ago
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
1 month ago
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
1 month ago
}
return false
1 month ago
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
1 month ago
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
1 month ago
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
1 month ago
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
1 month ago
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}
func validatePermission(code string, level int) error {
if code == "*" {
if level != LevelAdmin {
return fmt.Errorf("wildcard * requires LevelAdmin")
}
return nil
}
parts := strings.Split(code, ":")
depth := len(parts)
if level == LevelCreate {
if depth%2 == 0 {
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
}
} else {
// Level 2, 4, 6, 7
if depth%2 != 0 {
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
}
}
return nil
}