You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

578 lines
15 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// PermFunc 权限检查函数类型
type PermFunc func(x *vigo.X) error
// Auth 权限管理接口
type Auth interface {
// ========== 上下文 ==========
// UserID 获取当前用户ID
UserID(x *vigo.X) string
// ========== 登录检查 ==========
// Login 检查用户是否登录
Login() PermFunc
// ========== 权限检查 ==========
// Perm 检查权限
// code: 权限码,支持动态解析
// - 固定写法: "org:orgA"
// - 动态解析: "org:{orgID}" 从 path 获取
// "org:{orgID@query}" 从 query 获取
// "org:{orgID@header}" 从 header 获取
// "org:{orgID@ctx}" 从 ctx 获取
// level: 需要的权限等级
Perm(code string, level int) PermFunc
// ========== 快捷方法 ==========
// PermCreate 检查创建权限 (level 1检查奇数层)
PermCreate(code string) PermFunc
// PermRead 检查读取权限 (level 2检查偶数层)
PermRead(code string) PermFunc
// PermWrite 检查更新权限 (level 4检查偶数层)
PermWrite(code string) PermFunc
// PermAdmin 检查管理员权限 (level 7检查偶数层)
PermAdmin(code string) PermFunc
// ========== 权限授予(业务调用) ==========
// Grant 授予权限
// 在创建资源、被授权等业务逻辑中调用
// permissionID: 权限码,如 "org:orgA"
// level: 权限等级
Grant(ctx context.Context, userID, permissionID string, level int) error
// Revoke 撤销权限
Revoke(ctx context.Context, userID, permissionID string) error
// ========== 权限查询 ==========
// Check 检查权限 不支持动态解析
// permissionID: 完整的权限码,如 "org:orgA"
Check(ctx context.Context, userID, permissionID string, level int) bool
// ListResources 查询用户在特定资源类型下的详细权限信息
// 用于解决 "查询我有权限的 org 列表" 等场景
// userID: 用户ID
// resourceType: 资源类型 (奇数层),如 "org" 或 "org:{orgID}:project"
// 返回: map[实例ID]权限等级 (如 {"orgA": 2, "orgB": 7})
ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error)
// ListUsers 查询特定资源的所有协作者及其权限
// 用于解决 "查看这个项目有哪些成员" 等场景
// permissionID: 资源实例权限码,如 "org:orgA"
// 返回: map[用户ID]权限等级 (如 {"user1": 2, "user2": 7})
ListUsers(ctx context.Context, permissionID string) (map[string]int, error)
// GrantRole 授予角色
GrantRole(ctx context.Context, userID, roleCode string) error
// RevokeRole 撤销角色
RevokeRole(ctx context.Context, userID, roleCode string) error
// AddRole 添加角色定义 (用于初始化)
AddRole(code, name string, policies ...string) error
}
// Factory 全局 Auth 工厂
var Factory = &authFactory{
apps: make(map[string]*appAuth),
}
// VBaseAuth vbase 自身的权限管理实例
var VBaseAuth = Factory.New("vb")
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
type authFactory struct {
apps map[string]*appAuth
}
// New 创建权限管理实例
func (f *authFactory) New(scope string) Auth {
if auth, exists := f.apps[scope]; exists {
return auth
}
auth := &appAuth{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = auth
return auth
}
func (f *authFactory) init() error {
for appKey, auth := range f.apps {
if err := auth.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// appAuth 实现 Auth 接口
type appAuth struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== 接口实现 ==========
func (a *appAuth) UserID(x *vigo.X) string {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
}
return ""
}
// Login 登录检查中间件
func (a *appAuth) Login() PermFunc {
return func(x *vigo.X) error {
// 1. 提取 token
tokenString := extractToken(x)
if tokenString == "" {
return vigo.ErrUnauthorized.WithString("missing token")
}
// 2. 解析 token
claims, err := jwt.ParseToken(tokenString)
if err != nil {
if err == jwt.ErrExpiredToken {
return vigo.ErrTokenExpired
}
return vigo.ErrTokenInvalid
}
// 3. 检查黑名单
if cache.IsEnabled() {
blacklisted, _ := cache.IsTokenBlacklisted(claims.ID)
if blacklisted {
return vigo.ErrUnauthorized.WithString("token has been revoked")
}
}
// 4. 设置 UserID
x.Set(CtxKeyUserID, claims.UserID)
return nil
}
}
func (a *appAuth) Perm(code string, level int) PermFunc {
return func(x *vigo.X) error {
userID := a.UserID(x)
if userID == "" {
// 尝试先运行 Login 逻辑
if err := a.Login()(x); err != nil {
return err
}
userID = a.UserID(x)
}
// 解析动态参数
permID, err := parsePermissionID(x, code)
if err != nil {
return vigo.ErrInvalidArg.WithError(err)
}
// 检查权限
if !a.Check(x.Context(), userID, permID, level) {
return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level))
}
return nil
}
}
func (a *appAuth) PermCreate(code string) PermFunc { return a.Perm(code, LevelCreate) }
func (a *appAuth) PermRead(code string) PermFunc { return a.Perm(code, LevelRead) }
func (a *appAuth) PermWrite(code string) PermFunc { return a.Perm(code, LevelWrite) }
func (a *appAuth) PermAdmin(code string) PermFunc { return a.Perm(code, LevelAdmin) }
// Grant 授予权限
func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error {
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
}
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
}
return cfg.DB().Create(&perm).Error
}
// Revoke 撤销权限
func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
// GrantRole 授予角色
func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
if count > 0 {
return nil // 已经有该角色
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
}
return cfg.DB().Create(&userRole).Error
}
// RevokeRole 撤销角色
func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
}
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool {
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
}
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
query := db.Where("permission_id = ?", permissionID)
if len(parents) > 0 {
query = query.Or("permission_id IN ? AND level = ?", parents, LevelAdmin)
}
query = query.Or("permission_id = ? AND level = ?", "*", LevelAdmin)
if err := query.Find(&perms).Error; err != nil {
return nil, err
}
result := make(map[string]int)
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
}
return parents
}
// AddRole 添加角色定义
func (a *appAuth) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
}
return nil
}
// init 初始化角色到数据库
func (a *appAuth) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ? AND scope = ?", code, a.scope).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Scope: a.scope,
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
}
// 2. 同步角色权限
// 简单起见,先清除旧的,再插入新的(生产环境可能需要更精细的 diff
// 但 Permission 表是 mixed 的,不能随便删。
// 这里我们需要根据 RoleID 删除该角色的所有权限
if err := db.Where("role_id = ?", role.ID).Delete(&models.Permission{}).Error; err != nil {
return err
}
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
// 最后一个部分是 level前面是 permissionID
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
perm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
if err := db.Create(&perm).Error; err != nil {
return err
}
}
}
return nil
}
// ========== 内部辅助方法 ==========
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleIDRole 表有 Scope
// 我们需要关联查询: UserRole -> Role (where scope=a.scope)
var roleIDs []string
if err := db.Table("user_roles").
Joins("JOIN roles ON roles.id = user_roles.role_id").
Where("user_roles.user_id = ? AND roles.scope = ?", userID, a.scope).
Pluck("user_roles.role_id", &roleIDs).Error; err != nil {
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
if err := db.Where("role_id IN ?", roleIDs).Find(&rolePerms).Error; err != nil {
return nil, err
}
perms = append(perms, rolePerms...)
}
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
}
return false
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}