You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

567 lines
14 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
pub "github.com/veypi/vigo/contrib/auth"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// PermFunc 权限检查函数类型
type PermFunc = pub.PermFunc
// Factory 全局 Auth 工厂
var Factory = &authFactory{
apps: make(map[string]*appAuth),
}
// VBaseAuth vbase 自身的权限管理实例
var VBaseAuth = Factory.New("vb")
var _ pub.Auth = &appAuth{}
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
type authFactory struct {
apps map[string]*appAuth
}
// New 创建权限管理实例
func (f *authFactory) New(scope string) pub.Auth {
if auth, exists := f.apps[scope]; exists {
return auth
}
auth := &appAuth{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = auth
return auth
}
func (f *authFactory) init() error {
for appKey, auth := range f.apps {
if err := auth.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// appAuth 实现 Auth 接口
type appAuth struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== 接口实现 ==========
func (a *appAuth) UserID(x *vigo.X) string {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
}
return ""
}
// Login 登录检查中间件
func (a *appAuth) Login() PermFunc {
return func(x *vigo.X) error {
// 1. 提取 token
tokenString := extractToken(x)
if tokenString == "" {
return vigo.ErrUnauthorized.WithString("missing token")
}
// 2. 解析 token
claims, err := jwt.ParseToken(tokenString)
if err != nil {
if err == jwt.ErrExpiredToken {
return vigo.ErrTokenExpired
}
return vigo.ErrTokenInvalid
}
// 3. 检查黑名单
if cache.IsEnabled() {
blacklisted, _ := cache.IsTokenBlacklisted(claims.ID)
if blacklisted {
return vigo.ErrUnauthorized.WithString("token has been revoked")
}
}
// 4. 设置 UserID
x.Set(CtxKeyUserID, claims.UserID)
return nil
}
}
func (a *appAuth) Perm(code string, level int) PermFunc {
return func(x *vigo.X) error {
userID := a.UserID(x)
if userID == "" {
// 尝试先运行 Login 逻辑
if err := a.Login()(x); err != nil {
return err
}
userID = a.UserID(x)
}
// 解析动态参数
permID, err := parsePermissionID(x, code)
if err != nil {
return vigo.ErrInvalidArg.WithError(err)
}
// 检查权限
if err := validatePermission(permID, level); err != nil {
panic(err)
}
if !a.Check(x.Context(), userID, permID, level) {
return vigo.ErrNoPermission.WithString(fmt.Sprintf("requires permission: %s (level %d)", permID, level))
}
return nil
}
}
func (a *appAuth) PermCreate(code string) PermFunc {
if err := validatePermission(code, LevelCreate); err != nil {
panic(err)
}
return a.Perm(code, LevelCreate)
}
func (a *appAuth) PermRead(code string) PermFunc {
if err := validatePermission(code, LevelRead); err != nil {
panic(err)
}
return a.Perm(code, LevelRead)
}
func (a *appAuth) PermWrite(code string) PermFunc {
if err := validatePermission(code, LevelWrite); err != nil {
panic(err)
}
return a.Perm(code, LevelWrite)
}
func (a *appAuth) PermAdmin(code string) PermFunc {
if err := validatePermission(code, LevelAdmin); err != nil {
panic(err)
}
return a.Perm(code, LevelAdmin)
}
// Grant 授予权限
func (a *appAuth) Grant(ctx context.Context, userID, permissionID string, level int) error {
if err := validatePermission(permissionID, level); err != nil {
return err
}
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
}
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
}
return cfg.DB().Create(&perm).Error
}
// Revoke 撤销权限
func (a *appAuth) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
// GrantRole 授予角色
func (a *appAuth) GrantRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
if count > 0 {
return nil // 已经有该角色
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
}
return cfg.DB().Create(&userRole).Error
}
// RevokeRole 撤销角色
func (a *appAuth) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ? AND scope = ?", roleCode, a.scope).First(&role).Error; err != nil {
return err
}
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *appAuth) Check(ctx context.Context, userID, permissionID string, level int) bool {
if err := validatePermission(permissionID, level); err != nil {
panic(err)
}
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *appAuth) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
}
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *appAuth) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
conditions := []string{"permission_id = ?"}
args := []interface{}{permissionID}
if len(parents) > 0 {
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
args = append(args, parents, LevelAdmin)
}
conditions = append(conditions, "(permission_id = ? AND level = ?)")
args = append(args, "*", LevelAdmin)
query := db.Where(strings.Join(conditions, " OR "), args...)
if err := query.Find(&perms).Error; err != nil {
return nil, err
}
result := make(map[string]int)
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
}
return parents
}
// AddRole 添加角色定义
func (a *appAuth) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
}
return nil
}
// init 初始化角色到数据库
func (a *appAuth) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ? AND scope = ?", code, a.scope).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Scope: a.scope,
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
}
// 2. 同步角色权限
// 简单起见,先清除旧的,再插入新的(生产环境可能需要更精细的 diff
// 但 Permission 表是 mixed 的,不能随便删。
// 这里我们需要根据 RoleID 删除该角色的所有权限
if err := db.Where("role_id = ?", role.ID).Delete(&models.Permission{}).Error; err != nil {
return err
}
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
// 最后一个部分是 level前面是 permissionID
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
perm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
if err := db.Create(&perm).Error; err != nil {
return err
}
}
}
return nil
}
// ========== 内部辅助方法 ==========
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *appAuth) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleIDRole 表有 Scope
// 我们需要关联查询: UserRole -> Role (where scope=a.scope)
var roleIDs []string
if err := db.Table("user_roles").
Joins("JOIN roles ON roles.id = user_roles.role_id").
Where("user_roles.user_id = ? AND roles.scope = ?", userID, a.scope).
Pluck("user_roles.role_id", &roleIDs).Error; err != nil {
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
if err := db.Where("role_id IN ?", roleIDs).Find(&rolePerms).Error; err != nil {
return nil, err
}
perms = append(perms, rolePerms...)
}
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
}
return false
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}
func validatePermission(code string, level int) error {
if code == "*" {
if level != LevelAdmin {
return fmt.Errorf("wildcard * requires LevelAdmin")
}
return nil
}
parts := strings.Split(code, ":")
depth := len(parts)
if level == LevelCreate {
if depth%2 == 0 {
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
}
} else {
// Level 2, 4, 6, 7
if depth%2 != 0 {
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
}
}
return nil
}