You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/auth/auth.go

782 lines
20 KiB
Go

4 months ago
//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-02-14 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"context"
"errors"
4 months ago
"fmt"
"strconv"
4 months ago
"strings"
"time"
4 months ago
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/jwt"
4 months ago
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
pub "github.com/veypi/vigo/contrib/auth"
"github.com/veypi/vigo/contrib/event"
"github.com/veypi/vigo/logv"
"gorm.io/gorm"
4 months ago
)
const (
// CtxKeyUserID 用户ID上下文键
CtxKeyUserID = "auth:user_id"
// CtxKeySessionID 会话ID上下文键
CtxKeySessionID = "auth:session_id"
// 权限等级
LevelNone = 0
LevelCreate = 1 // 001 创建 (检查奇数层)
LevelRead = 2 // 010 读取 (检查偶数层)
LevelWrite = 4 // 100 写入 (检查偶数层)
LevelReadWrite = 6 // 110 读写 (检查偶数层)
LevelAdmin = 7 // 111 管理员 (完全控制)
)
// ctxKeyTokenParsed 标记token是否已解析请求级别避免重复解析
const ctxKeyTokenParsed = "_token_parsed"
// PermFunc 权限检查函数类型
type PermFunc = pub.PermFunc
4 months ago
// Provider 是 auth.Provider 的别名,用于实现端
type Provider = pub.Provider
// Factory 全局 Auth 工厂
4 months ago
var Factory = &authFactory{
apps: make(map[string]Provider),
4 months ago
}
var _ pub.Provider = &vbaseProvider{}
func init() {
// 注册权限初始化回调
event.Add("vb.init.auth", Factory.init)
}
4 months ago
type authFactory struct {
apps map[string]Provider
4 months ago
}
// New 创建权限 Provider 实例
func (f *authFactory) New(scope string) Provider {
if p, exists := f.apps[scope]; exists {
return p
4 months ago
}
p := &vbaseProvider{
scope: scope,
roleDefs: make(map[string]roleDefinition),
}
f.apps[scope] = p
return p
4 months ago
}
func (f *authFactory) init() error {
for appKey, p := range f.apps {
if vp, ok := p.(*vbaseProvider); ok {
if err := vp.init(); err != nil {
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
}
4 months ago
}
}
return nil
}
// roleDefinition 角色定义
type roleDefinition struct {
code string
name string
policies []string // 格式: "permissionID:level"
}
// vbaseProvider 实现 Provider 接口
type vbaseProvider struct {
scope string
roleDefs map[string]roleDefinition
}
// ========== Provider 接口实现 ==========
func (a *vbaseProvider) UserID(x *vigo.X) string {
// 1. 检查是否已解析过(无论成功与否,避免重复解析)
if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed {
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
return uid
}
return ""
4 months ago
}
// 2. 惰性解析:从请求中提取 token
tokenStr := extractToken(x)
if tokenStr == "" {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 3. 解析并验证 token
claims, err := jwt.ParseToken(tokenStr)
if err != nil {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 确保是 access token
if !jwt.IsAccessToken(claims) {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 验证 access token 对应的 session允许当前版本或上一版本防止多 tab 并发刷新互踢)
if !ValidateAccessSession(claims.SessionID, claims.UserID, claims.Version) {
x.Set(ctxKeyTokenParsed, true)
return ""
}
// 5. 设置到上下文中,供后续调用使用
x.Set(CtxKeyUserID, claims.UserID)
x.Set(CtxKeySessionID, claims.SessionID)
x.Set(ctxKeyTokenParsed, true)
return claims.UserID
4 months ago
}
// Grant 授予权限
func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error {
if err := validatePermission(permissionID, level); err != nil {
return err
}
// 检查是否存在
var count int64
cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Count(&count)
4 months ago
if count > 0 {
// 更新等级
return cfg.DB().Model(&models.Permission{}).
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Update("level", level).Error
4 months ago
}
4 months ago
// 创建
perm := models.Permission{
Scope: a.scope,
UserID: &userID,
PermissionID: permissionID,
Level: level,
4 months ago
}
return cfg.DB().Create(&perm).Error
4 months ago
}
// Revoke 撤销权限
func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error {
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
Delete(&models.Permission{}).Error
}
4 months ago
// GrantRole 授予角色
func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error {
4 months ago
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
4 months ago
}
var count int64
cfg.DB().Model(&models.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, role.ID).
Count(&count)
4 months ago
if count > 0 {
return nil // 已经有该角色
4 months ago
}
userRole := models.UserRole{
UserID: userID,
RoleID: role.ID,
4 months ago
}
return cfg.DB().Create(&userRole).Error
}
4 months ago
// RevokeRole 撤销角色
func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error {
var role models.Role
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
return err
}
4 months ago
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
Delete(&models.UserRole{}).Error
}
// Check 检查权限
func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool {
if err := validatePermission(permissionID, level); err != nil {
panic(err)
}
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
perms, err := a.getUserPermissions(userID)
if err != nil {
return false
}
// 2. 检查
return checkPermissionLevel(perms, permissionID, level)
}
// ListResources 查询用户在特定资源类型下的详细权限信息
func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
perms, err := a.getUserPermissions(userID)
if err != nil {
return nil, err
}
result := make(map[string]int)
prefix := resourceType + ":"
4 months ago
for _, p := range perms {
if p.PermissionID == "*" && p.Level == LevelAdmin {
continue
4 months ago
}
4 months ago
if strings.HasPrefix(p.PermissionID, prefix) {
suffix := p.PermissionID[len(prefix):]
parts := strings.Split(suffix, ":")
if len(parts) > 0 {
instanceID := parts[0]
level := p.Level
// If permission is deeper, assume LevelRead for parent unless explicit
if len(parts) > 1 {
level = LevelRead
}
4 months ago
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
result[instanceID] = level
}
}
}
}
return result, nil
}
// ListUsers 查询特定资源的所有协作者及其权限
func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
parents := getAllParents(permissionID)
4 months ago
var perms []models.Permission
db := cfg.DB().Where("scope = ?", a.scope)
4 months ago
conditions := []string{"permission_id = ?"}
args := []interface{}{permissionID}
if len(parents) > 0 {
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
args = append(args, parents, LevelAdmin)
4 months ago
}
conditions = append(conditions, "(permission_id = ? AND level = ?)")
args = append(args, "*", LevelAdmin)
query := db.Where(strings.Join(conditions, " OR "), args...)
4 months ago
if err := query.Find(&perms).Error; err != nil {
return nil, err
4 months ago
}
result := make(map[string]int)
4 months ago
for _, p := range perms {
if p.UserID != nil {
uid := *p.UserID
if l, ok := result[uid]; !ok || p.Level > l {
result[uid] = p.Level
}
}
if p.RoleID != nil {
var userRoles []models.UserRole
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
for _, ur := range userRoles {
if l, ok := result[ur.UserID]; !ok || p.Level > l {
result[ur.UserID] = p.Level
}
}
}
}
return result, nil
4 months ago
}
func getAllParents(permID string) []string {
parts := strings.Split(permID, ":")
var parents []string
for i := 1; i < len(parts); i++ {
parents = append(parents, strings.Join(parts[:i], ":"))
4 months ago
}
return parents
}
4 months ago
// AddRole 添加角色定义
func (a *vbaseProvider) AddRole(code, name string, policies ...string) error {
a.roleDefs[code] = roleDefinition{
code: code,
name: name,
policies: policies,
4 months ago
}
return nil
}
// init 初始化角色到数据库
func (a *vbaseProvider) init() error {
db := cfg.DB()
for code, def := range a.roleDefs {
// 1. 确保角色存在
var role models.Role
err := db.Where("code = ?", code).First(&role).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
role = models.Role{
Code: code,
Name: def.name,
IsSystem: true,
Status: 1,
}
if err := db.Create(&role).Error; err != nil {
return err
}
} else {
return err
}
4 months ago
}
// 2. 同步角色权限 (Diff Sync)
// ID格式: scoped:permission:level (用于唯一标识和 Diff Sync)
var targetIDs []string
// 获取该角色当前scope下的所有权限ID用于快速比对
var existingIDs []string
if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil {
return err
4 months ago
}
existingMap := make(map[string]bool)
for _, id := range existingIDs {
existingMap[id] = true
}
4 months ago
for _, policy := range def.policies {
// policy 格式: "permissionID:level"
parts := strings.Split(policy, ":")
if len(parts) < 2 {
continue
}
levelStr := parts[len(parts)-1]
permID := strings.Join(parts[:len(parts)-1], ":")
var level int
fmt.Sscanf(levelStr, "%d", &level)
// 生成确定性 ID: scoped:permission:level
id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level)
targetIDs = append(targetIDs, id)
// 检查是否存在
if !existingMap[id] {
// 不存在,创建新权限
newPerm := models.Permission{
Scope: a.scope,
RoleID: &role.ID,
PermissionID: permID,
Level: level,
}
newPerm.ID = id
if err := db.Create(&newPerm).Error; err != nil {
return err
}
}
}
// 3. 清理不再需要的权限
if len(targetIDs) > 0 {
if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs).
Delete(&models.Permission{}).Error; err != nil {
return err
}
} else {
// 如果没有策略,删除所有
if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope).
Delete(&models.Permission{}).Error; err != nil {
return err
}
}
4 months ago
}
return nil
}
4 months ago
// ========== 内部辅助方法 ==========
4 months ago
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission
func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) {
var perms []models.Permission
db := cfg.DB()
4 months ago
// 1. 直接权限
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
return nil, err
4 months ago
}
// 2. 角色权限
// 查用户角色
// UserRole 关联的是 RoleID
// Role 表已经没有 Scope所以这里查出用户拥有的所有角色ID
4 months ago
var roleIDs []string
if err := db.Model(&models.UserRole{}).
Where("user_id = ?", userID).
Pluck("role_id", &roleIDs).Error; err != nil {
4 months ago
return nil, err
}
if len(roleIDs) > 0 {
var rolePerms []models.Permission
// 查询这些角色在当前 scope 下拥有的权限
if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil {
4 months ago
return nil, err
}
perms = append(perms, rolePerms...)
}
4 months ago
return perms, nil
}
// checkPermissionLevel 核心鉴权逻辑
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
for _, p := range perms {
// 1. 管理员特权 (Level 7 且是父级或同级)
if p.Level == LevelAdmin {
// 如果拥有的权限是 target 的前缀,或者是 *
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
return true
}
4 months ago
}
// 2. 普通权限匹配
if p.Level >= requiredLevel {
if p.PermissionID == targetPermID {
return true
}
}
4 months ago
}
return false
4 months ago
}
// parsePermissionID 解析动态权限ID
func parsePermissionID(x *vigo.X, code string) (string, error) {
// 简单实现,支持 {key}, {key@query}, {key@header}
start := strings.Index(code, "{")
if start == -1 {
return code, nil
4 months ago
}
end := strings.Index(code, "}")
if end == -1 || end < start {
return "", fmt.Errorf("invalid permission format")
4 months ago
}
raw := code[start+1 : end]
parts := strings.Split(raw, "@")
key := parts[0]
source := "path"
if len(parts) > 1 {
source = parts[1]
4 months ago
}
var val string
switch source {
case "query":
val = x.Request.URL.Query().Get(key)
case "header":
val = x.Request.Header.Get(key)
case "ctx":
if v, ok := x.Get(key).(string); ok {
val = v
}
default: // path
val = x.PathParams.Get(key)
4 months ago
}
if val == "" {
return "", fmt.Errorf("param %s not found in %s", key, source)
}
return code[:start] + val + code[end+1:], nil
}
// extractToken 从请求中提取 token优先级: Cookie > Authorization Header > Query
func extractToken(x *vigo.X) string {
// 1. Cookie (HttpOnly浏览器自动携带)
if c, err := x.Request.Cookie(cfg.Global.JWT.CookiePrefix + "access"); err == nil && c.Value != "" {
return c.Value
}
// 2. Authorization Header
auth := x.Request.Header.Get("Authorization")
if auth != "" && len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
return auth[7:]
}
// 3. Query 参数
return x.Request.URL.Query().Get("access_token")
}
// ========== Session 管理 ==========
func sessionKey(sid string) string {
return fmt.Sprintf("vb:session:%s", sid)
}
func userSessionsKey(uid string) string {
return fmt.Sprintf("vb:user_sessions:%s", uid)
}
// CreateSession 创建登录会话DB + Redis
func CreateSession(userID, deviceInfo, ip string, expiresAt time.Time) (*models.Session, error) {
session := &models.Session{
UserID: userID,
Version: 1,
DeviceInfo: deviceInfo,
IP: ip,
ExpiresAt: expiresAt,
}
if err := cfg.DB().Create(session).Error; err != nil {
return nil, err
}
// 写 Redis 缓存
fillSessionCache(session)
return session, nil
}
// GetCurrentSessionID 从请求上下文获取当前 session ID
func GetCurrentSessionID(x *vigo.X) string {
if sid, ok := x.Get(CtxKeySessionID).(string); ok {
return sid
}
return ""
}
// ValidateAccessSession 验证 access token 对应的 session允许当前版本或上一版本
func ValidateAccessSession(sessionID, userID string, tokenVer int64) bool {
// 先查 Redis
rds := cfg.Redis()
if rds != nil {
revoked, err := rds.HGet(context.Background(), sessionKey(sessionID), "revoked").Result()
if err == nil {
verStr, _ := rds.HGet(context.Background(), sessionKey(sessionID), "version").Result()
suid, _ := rds.HGet(context.Background(), sessionKey(sessionID), "user_id").Result()
if revoked == "true" || suid != userID {
return false
}
ver, err := strconv.ParseInt(verStr, 10, 64)
if err != nil {
return false
}
return tokenVer == ver || tokenVer == ver-1
}
}
// Redis 未命中,查 DB
var session models.Session
if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil {
return false
}
if session.Revoked {
return false
}
// 回填 Redis
fillSessionCache(&session)
return tokenVer == session.Version || tokenVer == session.Version-1
}
// ValidateRefreshSession 验证 refresh token 并旋转版本号,返回新版本
func ValidateRefreshSession(userID, sessionID string, tokenVer int64) (int64, error) {
var session models.Session
if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil {
return 0, fmt.Errorf("session not found")
}
if session.Revoked {
return 0, fmt.Errorf("session revoked")
}
if tokenVer != session.Version {
return 0, fmt.Errorf("refresh token version mismatch: expected %d, got %d", session.Version, tokenVer)
}
// 版本 +1同时延长 session 过期时间
newVer := session.Version + 1
newExpiresAt := time.Now().Add(cfg.Global.JWT.RefreshExpiry)
if err := cfg.DB().Model(&session).Updates(map[string]interface{}{
"version": newVer,
"expires_at": newExpiresAt,
}).Error; err != nil {
return 0, err
}
// 更新 Redis
rds := cfg.Redis()
if rds != nil {
ttl := time.Until(newExpiresAt)
if ttl > 0 {
pipe := rds.Pipeline()
pipe.HSet(context.Background(), sessionKey(sessionID), "version", newVer)
pipe.Expire(context.Background(), sessionKey(sessionID), ttl)
pipe.Expire(context.Background(), userSessionsKey(userID), ttl)
pipe.Exec(context.Background())
}
}
return newVer, nil
}
// RevokeSession 撤销指定会话
func RevokeSession(userID, sessionID string) error {
now := time.Now()
res := cfg.DB().Model(&models.Session{}).Where("id = ? AND user_id = ?", sessionID, userID).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
})
if res.Error != nil {
return res.Error
}
// 更新 Redis
rds := cfg.Redis()
if rds != nil {
if err := rds.HSet(context.Background(), sessionKey(sessionID), "revoked", "true").Err(); err != nil {
logv.Warn().Msgf("RevokeSession: redis HSet failed: %v", err)
}
if err := rds.SRem(context.Background(), userSessionsKey(userID), sessionID).Err(); err != nil {
logv.Warn().Msgf("RevokeSession: redis SRem failed: %v", err)
}
}
return nil
}
// RevokeOtherSessions 撤销用户除当前外的其他所有会话
func RevokeOtherSessions(userID, currentSessionID string) error {
now := time.Now()
if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND id != ? AND revoked = ?", userID, currentSessionID, false).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
}).Error; err != nil {
return err
}
rds := cfg.Redis()
if rds != nil {
members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result()
if err == nil {
for _, sid := range members {
if sid != currentSessionID {
rds.HSet(context.Background(), sessionKey(sid), "revoked", "true")
rds.SRem(context.Background(), userSessionsKey(userID), sid)
}
}
} else {
logv.Warn().Msgf("RevokeOtherSessions: redis SMembers failed: %v", err)
}
}
return nil
}
// RevokeAllSessions 撤销用户所有会话
func RevokeAllSessions(userID string) error {
now := time.Now()
if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND revoked = ?", userID, false).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": now,
}).Error; err != nil {
return err
}
// 清理 Redis
rds := cfg.Redis()
if rds != nil {
members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result()
if err == nil {
for _, sid := range members {
rds.Del(context.Background(), sessionKey(sid))
}
} else {
logv.Warn().Msgf("RevokeAllSessions: redis SMembers failed: %v", err)
}
if err := rds.Del(context.Background(), userSessionsKey(userID)).Err(); err != nil {
logv.Warn().Msgf("RevokeAllSessions: redis Del failed: %v", err)
}
}
return nil
}
// ListSessions 列出用户所有活跃会话
func ListSessions(userID string) ([]models.Session, error) {
var sessions []models.Session
if err := cfg.DB().Where("user_id = ? AND revoked = ? AND expires_at > ?", userID, false, time.Now()).Order("created_at DESC").Find(&sessions).Error; err != nil {
return nil, err
}
return sessions, nil
}
// fillSessionCache 回填 Redis 缓存
func fillSessionCache(session *models.Session) {
rds := cfg.Redis()
if rds == nil {
return
}
ttl := time.Until(session.ExpiresAt)
if ttl <= 0 {
return
}
revoked := "false"
if session.Revoked {
revoked = "true"
}
pipe := rds.Pipeline()
pipe.HSet(context.Background(), sessionKey(session.ID),
"version", session.Version,
"revoked", revoked,
"user_id", session.UserID,
)
pipe.Expire(context.Background(), sessionKey(session.ID), ttl)
pipe.SAdd(context.Background(), userSessionsKey(session.UserID), session.ID)
pipe.Expire(context.Background(), userSessionsKey(session.UserID), ttl)
if _, err := pipe.Exec(context.Background()); err != nil {
logv.Warn().Msgf("fillSessionCache: redis pipeline failed: %v", err)
}
}
func validatePermission(code string, level int) error {
if code == "*" {
if level != LevelAdmin {
return fmt.Errorf("wildcard * requires LevelAdmin")
}
return nil
}
parts := strings.Split(code, ":")
depth := len(parts)
if level == LevelCreate {
if depth%2 == 0 {
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
}
} else {
// Level 2, 4, 6, 7
if depth%2 != 0 {
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
}
}
return nil
}