You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

537 lines
13 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
"fmt"
"strings"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
pub "github.com/veypi/vigo/contrib/auth"
"github.com/veypi/vigo/contrib/event"
"gorm.io/gorm"
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// ctxKeyTokenParsed 标记token是否已解析请求级别避免重复解析
const ctxKeyTokenParsed = "_token_parsed"
// PermFunc 权限检查函数类型
type PermFunc = pub.PermFunc
// Provider 是 auth.Provider 的别名,用于实现端
type Provider = pub.Provider
// Factory 全局 Auth 工厂
var Factory = &authFactory{
apps: make(map[string]Provider),
}
var _ pub.Provider = &vbaseProvider{}
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
type authFactory struct {
apps map[string]Provider
}
// New 创建权限 Provider 实例
func (f *authFactory) New(scope string) Provider {
if p, exists := f.apps[scope]; exists {
return p
}
p := &vbaseProvider{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = p
return p
}
func (f *authFactory) init() error {
for appKey, p := range f.apps {
if vp, ok := p.(*vbaseProvider); ok {
if err := vp.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// vbaseProvider 实现 Provider 接口
type vbaseProvider struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== Provider 接口实现 ==========
func (a *vbaseProvider) UserID(x *vigo.X) string {
// 1. 检查是否已解析过(无论成功与否,避免重复解析)
if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
}
return ""
}
// 2. 惰性解析:从请求中提取 token
tokenStr := extractToken(x)
if tokenStr == "" {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 3. 解析并验证 token
claims, err := jwt.ParseToken(tokenStr)
if err != nil {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 确保是 access token
if !jwt.IsAccessToken(claims) {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 4. 设置到上下文中,供后续调用使用
x.Set(CtxKeyUserID, claims.UserID)
x.Set(ctxKeyTokenParsed, true)
return claims.UserID
}
// Grant 授予权限
func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error {
if err := validatePermission(permissionID, level); err != nil {
return err
}
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
}
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
}
return cfg.DB().Create(&perm).Error
}
// Revoke 撤销权限
func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
// GrantRole 授予角色
func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
if count > 0 {
return nil // 已经有该角色
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
}
return cfg.DB().Create(&userRole).Error
}
// RevokeRole 撤销角色
func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
}
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool {
if err := validatePermission(permissionID, level); err != nil {
panic(err)
}
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
}
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
conditions := []string{"permission_id = ?"}
args := []interface{}{permissionID}
if len(parents) > 0 {
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
args = append(args, parents, LevelAdmin)
}
conditions = append(conditions, "(permission_id = ? AND level = ?)")
args = append(args, "*", LevelAdmin)
query := db.Where(strings.Join(conditions, " OR "), args...)
if err := query.Find(&perms).Error; err != nil {
return nil, err
}
result := make(map[string]int)
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
}
return parents
}
// AddRole 添加角色定义
func (a *vbaseProvider) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
}
return nil
}
// init 初始化角色到数据库
func (a *vbaseProvider) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ?", code).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
}
// 2. 同步角色权限 (Diff Sync)
// ID格式: scoped:permission:level (用于唯一标识和 Diff Sync)
var targetIDs []string
// 获取该角色当前scope下的所有权限ID用于快速比对
var existingIDs []string
if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil {
return err
}
existingMap := make(map[string]bool)
for _, id := range existingIDs {
existingMap[id] = true
}
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
// 生成确定性 ID: scoped:permission:level
id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level)
targetIDs = append(targetIDs, id)
// 检查是否存在
if !existingMap[id] {
// 不存在,创建新权限
newPerm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
newPerm.ID = id
if err := db.Create(&newPerm).Error; err != nil {
return err
}
}
}
// 3. 清理不再需要的权限
if len(targetIDs) > 0 {
if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs).
Delete(&models.Permission{}).Error; err != nil {
return err
}
} else {
// 如果没有策略,删除所有
if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope).
Delete(&models.Permission{}).Error; err != nil {
return err
}
}
}
return nil
}
// ========== 内部辅助方法 ==========
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleID
// Role 表已经没有 Scope所以这里查出用户拥有的所有角色ID
var roleIDs []string
if err := db.Model(&models.UserRole{}).
Where("user_id = ?", userID).
Pluck("role_id", &roleIDs).Error; err != nil {
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
// 查询这些角色在当前 scope 下拥有的权限
if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil {
return nil, err
}
perms = append(perms, rolePerms...)
}
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
}
return false
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 辅助函数
func extractToken(x *vigo.X) string {
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
}
return x.Request.URL.Query().Get("access_token")
}
func validatePermission(code string, level int) error {
if code == "*" {
if level != LevelAdmin {
return fmt.Errorf("wildcard * requires LevelAdmin")
}
return nil
}
parts := strings.Split(code, ":")
depth := len(parts)
if level == LevelCreate {
if depth%2 == 0 {
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
}
} else {
// Level 2, 4, 6, 7
if depth%2 != 0 {
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
}
}
return nil
}