You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

578 lines
15 KiB
Go

1 week ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
1 week ago
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/jwt"
1 week ago
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
1 week ago
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// PermFunc 权限检查函数类型
type PermFunc func(x *vigo.X) error
// Auth 权限管理接口
type Auth interface {
// ========== 上下文 ==========
// UserID 获取当前用户ID
UserID(x *vigo.X) string
// ========== 登录检查 ==========
// Login 检查用户是否登录
Login() PermFunc
// ========== 权限检查 ==========
// Perm 检查权限
// code: 权限码,支持动态解析
// - 固定写法: "org:orgA"
// - 动态解析: "org:{orgID}" 从 path 获取
// "org:{orgID@query}" 从 query 获取
// "org:{orgID@header}" 从 header 获取
// "org:{orgID@ctx}" 从 ctx 获取
// level: 需要的权限等级
Perm(code string, level int) PermFunc
// ========== 快捷方法 ==========
// PermCreate 检查创建权限 (level 1检查奇数层)
PermCreate(code string) PermFunc
// PermRead 检查读取权限 (level 2检查偶数层)
PermRead(code string) PermFunc
// PermWrite 检查更新权限 (level 4检查偶数层)
PermWrite(code string) PermFunc
// PermAdmin 检查管理员权限 (level 7检查偶数层)
PermAdmin(code string) PermFunc
// ========== 权限授予(业务调用) ==========
// Grant 授予权限
// 在创建资源、被授权等业务逻辑中调用
// permissionID: 权限码,如 "org:orgA"
// level: 权限等级
Grant(ctx context.Context, userID, permissionID string, level int) error
// Revoke 撤销权限
Revoke(ctx context.Context, userID, permissionID string) error
// ========== 权限查询 ==========
// Check 检查权限 不支持动态解析
// permissionID: 完整的权限码,如 "org:orgA"
Check(ctx context.Context, userID, permissionID string, level int) bool
// ListResources 查询用户在特定资源类型下的详细权限信息
// 用于解决 "查询我有权限的 org 列表" 等场景
// userID: 用户ID
// resourceType: 资源类型 (奇数层),如 "org" 或 "org:{orgID}:project"
// 返回: map[实例ID]权限等级 (如 {"orgA": 2, "orgB": 7})
ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error)
// ListUsers 查询特定资源的所有协作者及其权限
// 用于解决 "查看这个项目有哪些成员" 等场景
// permissionID: 资源实例权限码,如 "org:orgA"
// 返回: map[用户ID]权限等级 (如 {"user1": 2, "user2": 7})
ListUsers(ctx context.Context, permissionID string) (map[string]int, error)
// GrantRole 授予角色
GrantRole(ctx context.Context, userID, roleCode string) error
// RevokeRole 撤销角色
RevokeRole(ctx context.Context, userID, roleCode string) error
// AddRole 添加角色定义 (用于初始化)
AddRole(code, name string, policies ...string) error
1 week ago
}
// Factory 全局 Auth 工厂
1 week ago
var Factory = &authFactory{
apps: make(map[string]*appAuth),
}
// VBaseAuth vbase 自身的权限管理实例
var VBaseAuth = Factory.New("vb")
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
1 week ago
type authFactory struct {
apps map[string]*appAuth
1 week ago
}
// New 创建权限管理实例
func (f *authFactory) New(scope string) Auth {
if auth, exists := f.apps[scope]; exists {
return auth
}
1 week ago
auth := &appAuth{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = auth
1 week ago
return auth
}
func (f *authFactory) init() error {
1 week ago
for appKey, auth := range f.apps {
if err := auth.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// appAuth 实现 Auth 接口
1 week ago
type appAuth struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== 接口实现 ==========
func (a *appAuth) UserID(x *vigo.X) string {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
1 week ago
}
return ""
1 week ago
}
// Login 登录检查中间件
func (a *appAuth) Login() PermFunc {
return func(x *vigo.X) error {
// 1. 提取 token
tokenString := extractToken(x)
if tokenString == "" {
return vigo.ErrUnauthorized.WithString("missing token")
1 week ago
}
// 2. 解析 token
claims, err := jwt.ParseToken(tokenString)
if err != nil {
if err == jwt.ErrExpiredToken {
return vigo.ErrTokenExpired
}
return vigo.ErrTokenInvalid
}
// 3. 检查黑名单
if cache.IsEnabled() {
blacklisted, _ := cache.IsTokenBlacklisted(claims.ID)
if blacklisted {
return vigo.ErrUnauthorized.WithString("token has been revoked")
}
}
// 4. 设置 UserID
x.Set(CtxKeyUserID, claims.UserID)
return nil
}
1 week ago
}
func (a *appAuth) Perm(code string, level int) PermFunc {
1 week ago
return func(x *vigo.X) error {
userID := a.UserID(x)
1 week ago
if userID == "" {
// 尝试先运行 Login 逻辑
if err := a.Login()(x); err != nil {
return err
}
userID = a.UserID(x)
}
1 week ago
// 解析动态参数
permID, err := parsePermissionID(x, code)
if err != nil {
return vigo.ErrInvalidArg.WithError(err)
1 week ago
}
// 检查权限
if !a.Check(x.Context(), userID, permID, level) {
return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level))
1 week ago
}
return nil
}
}
func (a *appAuth) PermCreate(code string) PermFunc { return a.Perm(code, LevelCreate) }
func (a *appAuth) PermRead(code string) PermFunc { return a.Perm(code, LevelRead) }
func (a *appAuth) PermWrite(code string) PermFunc { return a.Perm(code, LevelWrite) }
func (a *appAuth) PermAdmin(code string) PermFunc { return a.Perm(code, LevelAdmin) }
1 week ago
// Grant 授予权限
func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error {
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
1 week ago
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
}
1 week ago
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
1 week ago
}
return cfg.DB().Create(&perm).Error
1 week ago
}
// Revoke 撤销权限
func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
1 week ago
// GrantRole 授予角色
func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error {
1 week ago
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
1 week ago
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
1 week ago
if count > 0 {
return nil // 已经有该角色
1 week ago
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
1 week ago
}
return cfg.DB().Create(&userRole).Error
}
1 week ago
// RevokeRole 撤销角色
func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
}
1 week ago
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool {
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
1 week ago
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
1 week ago
}
1 week ago
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
1 week ago
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
1 week ago
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
1 week ago
query := db.Where("permission_id = ?", permissionID)
if len(parents) > 0 {
query = query.Or("permission_id IN ? AND level = ?", parents, LevelAdmin)
1 week ago
}
query = query.Or("permission_id = ? AND level = ?", "*", LevelAdmin)
1 week ago
if err := query.Find(&perms).Error; err != nil {
return nil, err
1 week ago
}
result := make(map[string]int)
1 week ago
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
1 week ago
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
1 week ago
}
return parents
}
1 week ago
// AddRole 添加角色定义
func (a *appAuth) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
1 week ago
}
return nil
}
// init 初始化角色到数据库
func (a *appAuth) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ? AND scope = ?", code, a.scope).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Scope: a.scope,
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
1 week ago
}
// 2. 同步角色权限
// 简单起见,先清除旧的,再插入新的(生产环境可能需要更精细的 diff
// 但 Permission 表是 mixed 的,不能随便删。
// 这里我们需要根据 RoleID 删除该角色的所有权限
if err := db.Where("role_id = ?", role.ID).Delete(&models.Permission{}).Error; err != nil {
return err
1 week ago
}
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
// 最后一个部分是 level前面是 permissionID
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
perm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
if err := db.Create(&perm).Error; err != nil {
return err
}
}
1 week ago
}
return nil
}
1 week ago
// ========== 内部辅助方法 ==========
1 week ago
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
1 week ago
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
1 week ago
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleIDRole 表有 Scope
// 我们需要关联查询: UserRole -> Role (where scope=a.scope)
1 week ago
var roleIDs []string
if err := db.Table("user_roles").
Joins("JOIN roles ON roles.id = user_roles.role_id").
Where("user_roles.user_id = ? AND roles.scope = ?", userID, a.scope).
Pluck("user_roles.role_id", &roleIDs).Error; err != nil {
1 week ago
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
if err := db.Where("role_id IN ?", roleIDs).Find(&rolePerms).Error; err != nil {
1 week ago
return nil, err
}
perms = append(perms, rolePerms...)
}
1 week ago
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
1 week ago
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
1 week ago
}
return false
1 week ago
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
1 week ago
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
1 week ago
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
1 week ago
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
1 week ago
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}