You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

537 lines
13 KiB
Go

3 months ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
3 months ago
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/jwt"
3 months ago
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
pub "github.com/veypi/vigo/contrib/auth"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
3 months ago
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// ctxKeyTokenParsed 标记token是否已解析请求级别避免重复解析
const ctxKeyTokenParsed = "_token_parsed"
// PermFunc 权限检查函数类型
type PermFunc = pub.PermFunc
3 months ago
// Provider 是 auth.Provider 的别名,用于实现端
type Provider = pub.Provider
// Factory 全局 Auth 工厂
3 months ago
var Factory = &authFactory{
apps: make(map[string]Provider),
3 months ago
}
var _ pub.Provider = &vbaseProvider{}
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
3 months ago
type authFactory struct {
apps map[string]Provider
3 months ago
}
// New 创建权限 Provider 实例
func (f *authFactory) New(scope string) Provider {
if p, exists := f.apps[scope]; exists {
return p
3 months ago
}
p := &vbaseProvider{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = p
return p
3 months ago
}
func (f *authFactory) init() error {
for appKey, p := range f.apps {
if vp, ok := p.(*vbaseProvider); ok {
if err := vp.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
3 months ago
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// vbaseProvider 实现 Provider 接口
type vbaseProvider struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== Provider 接口实现 ==========
func (a *vbaseProvider) UserID(x *vigo.X) string {
// 1. 检查是否已解析过(无论成功与否,避免重复解析)
if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
}
return ""
3 months ago
}
// 2. 惰性解析:从请求中提取 token
tokenStr := extractToken(x)
if tokenStr == "" {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 3. 解析并验证 token
claims, err := jwt.ParseToken(tokenStr)
if err != nil {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 确保是 access token
if !jwt.IsAccessToken(claims) {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 4. 设置到上下文中,供后续调用使用
x.Set(CtxKeyUserID, claims.UserID)
x.Set(ctxKeyTokenParsed, true)
return claims.UserID
3 months ago
}
// Grant 授予权限
func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error {
if err := validatePermission(permissionID, level); err != nil {
return err
}
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
3 months ago
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
3 months ago
}
3 months ago
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
3 months ago
}
return cfg.DB().Create(&perm).Error
3 months ago
}
// Revoke 撤销权限
func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
3 months ago
// GrantRole 授予角色
func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error {
3 months ago
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
3 months ago
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
3 months ago
if count > 0 {
return nil // 已经有该角色
3 months ago
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
3 months ago
}
return cfg.DB().Create(&userRole).Error
}
3 months ago
// RevokeRole 撤销角色
func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
}
3 months ago
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool {
if err := validatePermission(permissionID, level); err != nil {
panic(err)
}
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
3 months ago
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
3 months ago
}
3 months ago
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
3 months ago
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
3 months ago
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
3 months ago
conditions := []string{"permission_id = ?"}
args := []interface{}{permissionID}
if len(parents) > 0 {
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
args = append(args, parents, LevelAdmin)
3 months ago
}
conditions = append(conditions, "(permission_id = ? AND level = ?)")
args = append(args, "*", LevelAdmin)
query := db.Where(strings.Join(conditions, " OR "), args...)
3 months ago
if err := query.Find(&perms).Error; err != nil {
return nil, err
3 months ago
}
result := make(map[string]int)
3 months ago
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
3 months ago
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
3 months ago
}
return parents
}
3 months ago
// AddRole 添加角色定义
func (a *vbaseProvider) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
3 months ago
}
return nil
}
// init 初始化角色到数据库
func (a *vbaseProvider) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ?", code).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
3 months ago
}
// 2. 同步角色权限 (Diff Sync)
// ID格式: scoped:permission:level (用于唯一标识和 Diff Sync)
var targetIDs []string
// 获取该角色当前scope下的所有权限ID用于快速比对
var existingIDs []string
if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil {
return err
3 months ago
}
existingMap := make(map[string]bool)
for _, id := range existingIDs {
existingMap[id] = true
}
3 months ago
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
// 生成确定性 ID: scoped:permission:level
id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level)
targetIDs = append(targetIDs, id)
// 检查是否存在
if !existingMap[id] {
// 不存在,创建新权限
newPerm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
newPerm.ID = id
if err := db.Create(&newPerm).Error; err != nil {
return err
}
}
}
// 3. 清理不再需要的权限
if len(targetIDs) > 0 {
if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs).
Delete(&models.Permission{}).Error; err != nil {
return err
}
} else {
// 如果没有策略,删除所有
if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope).
Delete(&models.Permission{}).Error; err != nil {
return err
}
}
3 months ago
}
return nil
}
3 months ago
// ========== 内部辅助方法 ==========
3 months ago
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
3 months ago
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
3 months ago
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleID
// Role 表已经没有 Scope所以这里查出用户拥有的所有角色ID
3 months ago
var roleIDs []string
if err := db.Model(&models.UserRole{}).
Where("user_id = ?", userID).
Pluck("role_id", &roleIDs).Error; err != nil {
3 months ago
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
// 查询这些角色在当前 scope 下拥有的权限
if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil {
3 months ago
return nil, err
}
perms = append(perms, rolePerms...)
}
3 months ago
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
3 months ago
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
3 months ago
}
return false
3 months ago
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
3 months ago
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
3 months ago
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
3 months ago
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
3 months ago
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}
func validatePermission(code string, level int) error {
if code == "*" {
if level != LevelAdmin {
return fmt.Errorf("wildcard * requires LevelAdmin")
}
return nil
}
parts := strings.Split(code, ":")
depth := len(parts)
if level == LevelCreate {
if depth%2 == 0 {
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
}
} else {
// Level 2, 4, 6, 7
if depth%2 != 0 {
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
}
}
return nil
}