|
|
//
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
// 2025-02-14 16:08:06
|
|
|
// Distributed under terms of the MIT license.
|
|
|
//
|
|
|
|
|
|
package auth
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
"github.com/veypi/vbase/cfg"
|
|
|
"github.com/veypi/vbase/libs/jwt"
|
|
|
"github.com/veypi/vbase/models"
|
|
|
"github.com/veypi/vigo"
|
|
|
pub "github.com/veypi/vigo/contrib/auth"
|
|
|
"github.com/veypi/vigo/contrib/event"
|
|
|
"github.com/veypi/vigo/logv"
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
// CtxKeyUserID 用户ID上下文键
|
|
|
CtxKeyUserID = "auth:user_id"
|
|
|
// CtxKeySessionID 会话ID上下文键
|
|
|
CtxKeySessionID = "auth:session_id"
|
|
|
|
|
|
// 权限等级
|
|
|
LevelNone = 0
|
|
|
LevelCreate = 1 // 001 创建 (检查奇数层)
|
|
|
LevelRead = 2 // 010 读取 (检查偶数层)
|
|
|
LevelWrite = 4 // 100 写入 (检查偶数层)
|
|
|
LevelReadWrite = 6 // 110 读写 (检查偶数层)
|
|
|
LevelAdmin = 7 // 111 管理员 (完全控制)
|
|
|
)
|
|
|
|
|
|
// ctxKeyTokenParsed 标记token是否已解析(请求级别,避免重复解析)
|
|
|
const ctxKeyTokenParsed = "_token_parsed"
|
|
|
|
|
|
// PermFunc 权限检查函数类型
|
|
|
type PermFunc = pub.PermFunc
|
|
|
|
|
|
// Provider 是 auth.Provider 的别名,用于实现端
|
|
|
type Provider = pub.Provider
|
|
|
|
|
|
// Factory 全局 Auth 工厂
|
|
|
var Factory = &authFactory{
|
|
|
apps: make(map[string]Provider),
|
|
|
}
|
|
|
|
|
|
var _ pub.Provider = &vbaseProvider{}
|
|
|
|
|
|
func init() {
|
|
|
// 注册权限初始化回调
|
|
|
event.Add("vb.init.auth", Factory.init)
|
|
|
}
|
|
|
|
|
|
type authFactory struct {
|
|
|
apps map[string]Provider
|
|
|
}
|
|
|
|
|
|
// New 创建权限 Provider 实例
|
|
|
func (f *authFactory) New(scope string) Provider {
|
|
|
if p, exists := f.apps[scope]; exists {
|
|
|
return p
|
|
|
}
|
|
|
p := &vbaseProvider{
|
|
|
scope: scope,
|
|
|
roleDefs: make(map[string]roleDefinition),
|
|
|
}
|
|
|
f.apps[scope] = p
|
|
|
return p
|
|
|
}
|
|
|
|
|
|
func (f *authFactory) init() error {
|
|
|
for appKey, p := range f.apps {
|
|
|
if vp, ok := p.(*vbaseProvider); ok {
|
|
|
if err := vp.init(); err != nil {
|
|
|
return fmt.Errorf("failed to init auth for %s: %w", appKey, err)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// roleDefinition 角色定义
|
|
|
type roleDefinition struct {
|
|
|
code string
|
|
|
name string
|
|
|
policies []string // 格式: "permissionID:level"
|
|
|
}
|
|
|
|
|
|
// vbaseProvider 实现 Provider 接口
|
|
|
type vbaseProvider struct {
|
|
|
scope string
|
|
|
roleDefs map[string]roleDefinition
|
|
|
}
|
|
|
|
|
|
// ========== Provider 接口实现 ==========
|
|
|
|
|
|
func (a *vbaseProvider) UserID(x *vigo.X) string {
|
|
|
// 1. 检查是否已解析过(无论成功与否,避免重复解析)
|
|
|
if _, parsed := x.Get(ctxKeyTokenParsed).(bool); parsed {
|
|
|
if uid, ok := x.Get(CtxKeyUserID).(string); ok {
|
|
|
return uid
|
|
|
}
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 2. 惰性解析:从请求中提取 token
|
|
|
tokenStr := extractToken(x)
|
|
|
if tokenStr == "" {
|
|
|
x.Set(ctxKeyTokenParsed, true)
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 3. 解析并验证 token
|
|
|
claims, err := jwt.ParseToken(tokenStr)
|
|
|
if err != nil {
|
|
|
x.Set(ctxKeyTokenParsed, true)
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 确保是 access token
|
|
|
if !jwt.IsAccessToken(claims) {
|
|
|
x.Set(ctxKeyTokenParsed, true)
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 验证 access token 对应的 session(允许当前版本或上一版本,防止多 tab 并发刷新互踢)
|
|
|
if !ValidateAccessSession(claims.SessionID, claims.UserID, claims.Version) {
|
|
|
x.Set(ctxKeyTokenParsed, true)
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 5. 设置到上下文中,供后续调用使用
|
|
|
x.Set(CtxKeyUserID, claims.UserID)
|
|
|
x.Set(CtxKeySessionID, claims.SessionID)
|
|
|
x.Set(ctxKeyTokenParsed, true)
|
|
|
|
|
|
return claims.UserID
|
|
|
}
|
|
|
|
|
|
// Grant 授予权限
|
|
|
func (a *vbaseProvider) Grant(ctx context.Context, userID, permissionID string, level int) error {
|
|
|
if err := validatePermission(permissionID, level); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
// 检查是否存在
|
|
|
var count int64
|
|
|
cfg.DB().Model(&models.Permission{}).
|
|
|
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
|
|
|
Count(&count)
|
|
|
|
|
|
if count > 0 {
|
|
|
// 更新等级
|
|
|
return cfg.DB().Model(&models.Permission{}).
|
|
|
Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
|
|
|
Update("level", level).Error
|
|
|
}
|
|
|
|
|
|
// 创建
|
|
|
perm := models.Permission{
|
|
|
Scope: a.scope,
|
|
|
UserID: &userID,
|
|
|
PermissionID: permissionID,
|
|
|
Level: level,
|
|
|
}
|
|
|
return cfg.DB().Create(&perm).Error
|
|
|
}
|
|
|
|
|
|
// Revoke 撤销权限
|
|
|
func (a *vbaseProvider) Revoke(ctx context.Context, userID, permissionID string) error {
|
|
|
return cfg.DB().Where("user_id = ? AND permission_id = ? AND scope = ?", userID, permissionID, a.scope).
|
|
|
Delete(&models.Permission{}).Error
|
|
|
}
|
|
|
|
|
|
// GrantRole 授予角色
|
|
|
func (a *vbaseProvider) GrantRole(ctx context.Context, userID, roleCode string) error {
|
|
|
var role models.Role
|
|
|
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
var count int64
|
|
|
cfg.DB().Model(&models.UserRole{}).
|
|
|
Where("user_id = ? AND role_id = ?", userID, role.ID).
|
|
|
Count(&count)
|
|
|
if count > 0 {
|
|
|
return nil // 已经有该角色
|
|
|
}
|
|
|
|
|
|
userRole := models.UserRole{
|
|
|
UserID: userID,
|
|
|
RoleID: role.ID,
|
|
|
}
|
|
|
return cfg.DB().Create(&userRole).Error
|
|
|
}
|
|
|
|
|
|
// RevokeRole 撤销角色
|
|
|
func (a *vbaseProvider) RevokeRole(ctx context.Context, userID, roleCode string) error {
|
|
|
var role models.Role
|
|
|
if err := cfg.DB().Where("code = ?", roleCode).First(&role).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
return cfg.DB().Where("user_id = ? AND role_id = ?", userID, role.ID).
|
|
|
Delete(&models.UserRole{}).Error
|
|
|
}
|
|
|
|
|
|
// Check 检查权限
|
|
|
func (a *vbaseProvider) Check(ctx context.Context, userID, permissionID string, level int) bool {
|
|
|
if err := validatePermission(permissionID, level); err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
|
// 1. 获取用户在该作用域下的所有权限(直接权限 + 角色权限)
|
|
|
perms, err := a.getUserPermissions(userID)
|
|
|
if err != nil {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
// 2. 检查
|
|
|
return checkPermissionLevel(perms, permissionID, level)
|
|
|
}
|
|
|
|
|
|
// ListResources 查询用户在特定资源类型下的详细权限信息
|
|
|
func (a *vbaseProvider) ListResources(ctx context.Context, userID, resourceType string) (map[string]int, error) {
|
|
|
perms, err := a.getUserPermissions(userID)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
result := make(map[string]int)
|
|
|
prefix := resourceType + ":"
|
|
|
|
|
|
for _, p := range perms {
|
|
|
if p.PermissionID == "*" && p.Level == LevelAdmin {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
if strings.HasPrefix(p.PermissionID, prefix) {
|
|
|
suffix := p.PermissionID[len(prefix):]
|
|
|
parts := strings.Split(suffix, ":")
|
|
|
if len(parts) > 0 {
|
|
|
instanceID := parts[0]
|
|
|
level := p.Level
|
|
|
// If permission is deeper, assume LevelRead for parent unless explicit
|
|
|
if len(parts) > 1 {
|
|
|
level = LevelRead
|
|
|
}
|
|
|
|
|
|
if currentLevel, ok := result[instanceID]; !ok || level > currentLevel {
|
|
|
result[instanceID] = level
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// ListUsers 查询特定资源的所有协作者及其权限
|
|
|
func (a *vbaseProvider) ListUsers(ctx context.Context, permissionID string) (map[string]int, error) {
|
|
|
parents := getAllParents(permissionID)
|
|
|
|
|
|
var perms []models.Permission
|
|
|
db := cfg.DB().Where("scope = ?", a.scope)
|
|
|
|
|
|
conditions := []string{"permission_id = ?"}
|
|
|
args := []interface{}{permissionID}
|
|
|
|
|
|
if len(parents) > 0 {
|
|
|
conditions = append(conditions, "(permission_id IN ? AND level = ?)")
|
|
|
args = append(args, parents, LevelAdmin)
|
|
|
}
|
|
|
|
|
|
conditions = append(conditions, "(permission_id = ? AND level = ?)")
|
|
|
args = append(args, "*", LevelAdmin)
|
|
|
|
|
|
query := db.Where(strings.Join(conditions, " OR "), args...)
|
|
|
|
|
|
if err := query.Find(&perms).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
result := make(map[string]int)
|
|
|
|
|
|
for _, p := range perms {
|
|
|
if p.UserID != nil {
|
|
|
uid := *p.UserID
|
|
|
if l, ok := result[uid]; !ok || p.Level > l {
|
|
|
result[uid] = p.Level
|
|
|
}
|
|
|
}
|
|
|
if p.RoleID != nil {
|
|
|
var userRoles []models.UserRole
|
|
|
cfg.DB().Where("role_id = ?", *p.RoleID).Find(&userRoles)
|
|
|
for _, ur := range userRoles {
|
|
|
if l, ok := result[ur.UserID]; !ok || p.Level > l {
|
|
|
result[ur.UserID] = p.Level
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
func getAllParents(permID string) []string {
|
|
|
parts := strings.Split(permID, ":")
|
|
|
var parents []string
|
|
|
for i := 1; i < len(parts); i++ {
|
|
|
parents = append(parents, strings.Join(parts[:i], ":"))
|
|
|
}
|
|
|
return parents
|
|
|
}
|
|
|
|
|
|
// AddRole 添加角色定义
|
|
|
func (a *vbaseProvider) AddRole(code, name string, policies ...string) error {
|
|
|
a.roleDefs[code] = roleDefinition{
|
|
|
code: code,
|
|
|
name: name,
|
|
|
policies: policies,
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// init 初始化角色到数据库
|
|
|
func (a *vbaseProvider) init() error {
|
|
|
db := cfg.DB()
|
|
|
for code, def := range a.roleDefs {
|
|
|
// 1. 确保角色存在
|
|
|
var role models.Role
|
|
|
err := db.Where("code = ?", code).First(&role).Error
|
|
|
if err != nil {
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
role = models.Role{
|
|
|
Code: code,
|
|
|
Name: def.name,
|
|
|
IsSystem: true,
|
|
|
Status: 1,
|
|
|
}
|
|
|
if err := db.Create(&role).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
} else {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 2. 同步角色权限 (Diff Sync)
|
|
|
// ID格式: scoped:permission:level (用于唯一标识和 Diff Sync)
|
|
|
var targetIDs []string
|
|
|
|
|
|
// 获取该角色当前scope下的所有权限ID,用于快速比对
|
|
|
var existingIDs []string
|
|
|
if err := db.Model(&models.Permission{}).Where("role_id = ? AND scope = ?", role.ID, a.scope).Pluck("id", &existingIDs).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
existingMap := make(map[string]bool)
|
|
|
for _, id := range existingIDs {
|
|
|
existingMap[id] = true
|
|
|
}
|
|
|
|
|
|
for _, policy := range def.policies {
|
|
|
// policy 格式: "permissionID:level"
|
|
|
parts := strings.Split(policy, ":")
|
|
|
if len(parts) < 2 {
|
|
|
continue
|
|
|
}
|
|
|
levelStr := parts[len(parts)-1]
|
|
|
permID := strings.Join(parts[:len(parts)-1], ":")
|
|
|
var level int
|
|
|
fmt.Sscanf(levelStr, "%d", &level)
|
|
|
|
|
|
// 生成确定性 ID: scoped:permission:level
|
|
|
id := fmt.Sprintf("%s:%s:%d", a.scope, permID, level)
|
|
|
targetIDs = append(targetIDs, id)
|
|
|
|
|
|
// 检查是否存在
|
|
|
if !existingMap[id] {
|
|
|
// 不存在,创建新权限
|
|
|
newPerm := models.Permission{
|
|
|
Scope: a.scope,
|
|
|
RoleID: &role.ID,
|
|
|
PermissionID: permID,
|
|
|
Level: level,
|
|
|
}
|
|
|
newPerm.ID = id
|
|
|
if err := db.Create(&newPerm).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 3. 清理不再需要的权限
|
|
|
if len(targetIDs) > 0 {
|
|
|
if err := db.Unscoped().Where("role_id = ? AND scope = ? AND id NOT IN ?", role.ID, a.scope, targetIDs).
|
|
|
Delete(&models.Permission{}).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
} else {
|
|
|
// 如果没有策略,删除所有
|
|
|
if err := db.Unscoped().Where("role_id = ? AND scope = ?", role.ID, a.scope).
|
|
|
Delete(&models.Permission{}).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// ========== 内部辅助方法 ==========
|
|
|
|
|
|
// getUserPermissions 获取用户的所有权限(聚合 Role 和 Direct Permission)
|
|
|
func (a *vbaseProvider) getUserPermissions(userID string) ([]models.Permission, error) {
|
|
|
var perms []models.Permission
|
|
|
db := cfg.DB()
|
|
|
|
|
|
// 1. 直接权限
|
|
|
if err := db.Where("user_id = ? AND scope = ?", userID, a.scope).Find(&perms).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 2. 角色权限
|
|
|
// 查用户角色
|
|
|
// UserRole 关联的是 RoleID
|
|
|
// Role 表已经没有 Scope,所以这里查出用户拥有的所有角色ID
|
|
|
var roleIDs []string
|
|
|
if err := db.Model(&models.UserRole{}).
|
|
|
Where("user_id = ?", userID).
|
|
|
Pluck("role_id", &roleIDs).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
if len(roleIDs) > 0 {
|
|
|
var rolePerms []models.Permission
|
|
|
// 查询这些角色在当前 scope 下拥有的权限
|
|
|
if err := db.Where("role_id IN ? AND scope = ?", roleIDs, a.scope).Find(&rolePerms).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
perms = append(perms, rolePerms...)
|
|
|
}
|
|
|
|
|
|
return perms, nil
|
|
|
}
|
|
|
|
|
|
// checkPermissionLevel 核心鉴权逻辑
|
|
|
func checkPermissionLevel(perms []models.Permission, targetPermID string, requiredLevel int) bool {
|
|
|
for _, p := range perms {
|
|
|
// 1. 管理员特权 (Level 7 且是父级或同级)
|
|
|
if p.Level == LevelAdmin {
|
|
|
// 如果拥有的权限是 target 的前缀,或者是 *
|
|
|
if p.PermissionID == "*" || strings.HasPrefix(targetPermID, p.PermissionID) {
|
|
|
return true
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 2. 普通权限匹配
|
|
|
if p.Level >= requiredLevel {
|
|
|
if p.PermissionID == targetPermID {
|
|
|
return true
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
// parsePermissionID 解析动态权限ID
|
|
|
func parsePermissionID(x *vigo.X, code string) (string, error) {
|
|
|
// 简单实现,支持 {key}, {key@query}, {key@header}
|
|
|
start := strings.Index(code, "{")
|
|
|
if start == -1 {
|
|
|
return code, nil
|
|
|
}
|
|
|
end := strings.Index(code, "}")
|
|
|
if end == -1 || end < start {
|
|
|
return "", fmt.Errorf("invalid permission format")
|
|
|
}
|
|
|
|
|
|
raw := code[start+1 : end]
|
|
|
parts := strings.Split(raw, "@")
|
|
|
key := parts[0]
|
|
|
source := "path"
|
|
|
if len(parts) > 1 {
|
|
|
source = parts[1]
|
|
|
}
|
|
|
|
|
|
var val string
|
|
|
switch source {
|
|
|
case "query":
|
|
|
val = x.Request.URL.Query().Get(key)
|
|
|
case "header":
|
|
|
val = x.Request.Header.Get(key)
|
|
|
case "ctx":
|
|
|
if v, ok := x.Get(key).(string); ok {
|
|
|
val = v
|
|
|
}
|
|
|
default: // path
|
|
|
val = x.PathParams.Get(key)
|
|
|
}
|
|
|
|
|
|
if val == "" {
|
|
|
return "", fmt.Errorf("param %s not found in %s", key, source)
|
|
|
}
|
|
|
|
|
|
return code[:start] + val + code[end+1:], nil
|
|
|
}
|
|
|
|
|
|
// extractToken 从请求中提取 token,优先级: Cookie > Authorization Header > Query
|
|
|
func extractToken(x *vigo.X) string {
|
|
|
// 1. Cookie (HttpOnly,浏览器自动携带)
|
|
|
if c, err := x.Request.Cookie(cfg.Global.JWT.CookiePrefix + "access"); err == nil && c.Value != "" {
|
|
|
return c.Value
|
|
|
}
|
|
|
|
|
|
// 2. Authorization Header
|
|
|
auth := x.Request.Header.Get("Authorization")
|
|
|
if auth != "" && len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") {
|
|
|
return auth[7:]
|
|
|
}
|
|
|
|
|
|
// 3. Query 参数
|
|
|
return x.Request.URL.Query().Get("access_token")
|
|
|
}
|
|
|
|
|
|
// ========== Session 管理 ==========
|
|
|
|
|
|
func sessionKey(sid string) string {
|
|
|
return fmt.Sprintf("vb:session:%s", sid)
|
|
|
}
|
|
|
|
|
|
func userSessionsKey(uid string) string {
|
|
|
return fmt.Sprintf("vb:user_sessions:%s", uid)
|
|
|
}
|
|
|
|
|
|
// CreateSession 创建登录会话(DB + Redis)
|
|
|
func CreateSession(userID, deviceInfo, ip string, expiresAt time.Time) (*models.Session, error) {
|
|
|
session := &models.Session{
|
|
|
UserID: userID,
|
|
|
Version: 1,
|
|
|
DeviceInfo: deviceInfo,
|
|
|
IP: ip,
|
|
|
ExpiresAt: expiresAt,
|
|
|
}
|
|
|
if err := cfg.DB().Create(session).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 写 Redis 缓存
|
|
|
fillSessionCache(session)
|
|
|
return session, nil
|
|
|
}
|
|
|
|
|
|
// GetCurrentSessionID 从请求上下文获取当前 session ID
|
|
|
func GetCurrentSessionID(x *vigo.X) string {
|
|
|
if sid, ok := x.Get(CtxKeySessionID).(string); ok {
|
|
|
return sid
|
|
|
}
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// ValidateAccessSession 验证 access token 对应的 session(允许当前版本或上一版本)
|
|
|
func ValidateAccessSession(sessionID, userID string, tokenVer int64) bool {
|
|
|
// 先查 Redis
|
|
|
rds := cfg.Redis()
|
|
|
if rds != nil {
|
|
|
revoked, err := rds.HGet(context.Background(), sessionKey(sessionID), "revoked").Result()
|
|
|
if err == nil {
|
|
|
verStr, _ := rds.HGet(context.Background(), sessionKey(sessionID), "version").Result()
|
|
|
suid, _ := rds.HGet(context.Background(), sessionKey(sessionID), "user_id").Result()
|
|
|
if revoked == "true" || suid != userID {
|
|
|
return false
|
|
|
}
|
|
|
ver, err := strconv.ParseInt(verStr, 10, 64)
|
|
|
if err != nil {
|
|
|
return false
|
|
|
}
|
|
|
return tokenVer == ver || tokenVer == ver-1
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Redis 未命中,查 DB
|
|
|
var session models.Session
|
|
|
if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil {
|
|
|
return false
|
|
|
}
|
|
|
if session.Revoked {
|
|
|
return false
|
|
|
}
|
|
|
// 回填 Redis
|
|
|
fillSessionCache(&session)
|
|
|
return tokenVer == session.Version || tokenVer == session.Version-1
|
|
|
}
|
|
|
|
|
|
// ValidateRefreshSession 验证 refresh token 并旋转版本号,返回新版本
|
|
|
func ValidateRefreshSession(userID, sessionID string, tokenVer int64) (int64, error) {
|
|
|
var session models.Session
|
|
|
if err := cfg.DB().Where("id = ? AND user_id = ?", sessionID, userID).First(&session).Error; err != nil {
|
|
|
return 0, fmt.Errorf("session not found")
|
|
|
}
|
|
|
if session.Revoked {
|
|
|
return 0, fmt.Errorf("session revoked")
|
|
|
}
|
|
|
if tokenVer != session.Version {
|
|
|
return 0, fmt.Errorf("refresh token version mismatch: expected %d, got %d", session.Version, tokenVer)
|
|
|
}
|
|
|
|
|
|
// 版本 +1,同时延长 session 过期时间
|
|
|
newVer := session.Version + 1
|
|
|
newExpiresAt := time.Now().Add(cfg.Global.JWT.RefreshExpiry)
|
|
|
if err := cfg.DB().Model(&session).Updates(map[string]interface{}{
|
|
|
"version": newVer,
|
|
|
"expires_at": newExpiresAt,
|
|
|
}).Error; err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
// 更新 Redis
|
|
|
rds := cfg.Redis()
|
|
|
if rds != nil {
|
|
|
ttl := time.Until(newExpiresAt)
|
|
|
if ttl > 0 {
|
|
|
pipe := rds.Pipeline()
|
|
|
pipe.HSet(context.Background(), sessionKey(sessionID), "version", newVer)
|
|
|
pipe.Expire(context.Background(), sessionKey(sessionID), ttl)
|
|
|
pipe.Expire(context.Background(), userSessionsKey(userID), ttl)
|
|
|
pipe.Exec(context.Background())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return newVer, nil
|
|
|
}
|
|
|
|
|
|
// RevokeSession 撤销指定会话
|
|
|
func RevokeSession(userID, sessionID string) error {
|
|
|
now := time.Now()
|
|
|
res := cfg.DB().Model(&models.Session{}).Where("id = ? AND user_id = ?", sessionID, userID).Updates(map[string]interface{}{
|
|
|
"revoked": true,
|
|
|
"revoked_at": now,
|
|
|
})
|
|
|
if res.Error != nil {
|
|
|
return res.Error
|
|
|
}
|
|
|
|
|
|
// 更新 Redis
|
|
|
rds := cfg.Redis()
|
|
|
if rds != nil {
|
|
|
if err := rds.HSet(context.Background(), sessionKey(sessionID), "revoked", "true").Err(); err != nil {
|
|
|
logv.Warn().Msgf("RevokeSession: redis HSet failed: %v", err)
|
|
|
}
|
|
|
if err := rds.SRem(context.Background(), userSessionsKey(userID), sessionID).Err(); err != nil {
|
|
|
logv.Warn().Msgf("RevokeSession: redis SRem failed: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// RevokeOtherSessions 撤销用户除当前外的其他所有会话
|
|
|
func RevokeOtherSessions(userID, currentSessionID string) error {
|
|
|
now := time.Now()
|
|
|
if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND id != ? AND revoked = ?", userID, currentSessionID, false).Updates(map[string]interface{}{
|
|
|
"revoked": true,
|
|
|
"revoked_at": now,
|
|
|
}).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
rds := cfg.Redis()
|
|
|
if rds != nil {
|
|
|
members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result()
|
|
|
if err == nil {
|
|
|
for _, sid := range members {
|
|
|
if sid != currentSessionID {
|
|
|
rds.HSet(context.Background(), sessionKey(sid), "revoked", "true")
|
|
|
rds.SRem(context.Background(), userSessionsKey(userID), sid)
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
logv.Warn().Msgf("RevokeOtherSessions: redis SMembers failed: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// RevokeAllSessions 撤销用户所有会话
|
|
|
func RevokeAllSessions(userID string) error {
|
|
|
now := time.Now()
|
|
|
if err := cfg.DB().Model(&models.Session{}).Where("user_id = ? AND revoked = ?", userID, false).Updates(map[string]interface{}{
|
|
|
"revoked": true,
|
|
|
"revoked_at": now,
|
|
|
}).Error; err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// 清理 Redis
|
|
|
rds := cfg.Redis()
|
|
|
if rds != nil {
|
|
|
members, err := rds.SMembers(context.Background(), userSessionsKey(userID)).Result()
|
|
|
if err == nil {
|
|
|
for _, sid := range members {
|
|
|
rds.Del(context.Background(), sessionKey(sid))
|
|
|
}
|
|
|
} else {
|
|
|
logv.Warn().Msgf("RevokeAllSessions: redis SMembers failed: %v", err)
|
|
|
}
|
|
|
if err := rds.Del(context.Background(), userSessionsKey(userID)).Err(); err != nil {
|
|
|
logv.Warn().Msgf("RevokeAllSessions: redis Del failed: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// ListSessions 列出用户所有活跃会话
|
|
|
func ListSessions(userID string) ([]models.Session, error) {
|
|
|
var sessions []models.Session
|
|
|
if err := cfg.DB().Where("user_id = ? AND revoked = ? AND expires_at > ?", userID, false, time.Now()).Order("created_at DESC").Find(&sessions).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
return sessions, nil
|
|
|
}
|
|
|
|
|
|
// fillSessionCache 回填 Redis 缓存
|
|
|
func fillSessionCache(session *models.Session) {
|
|
|
rds := cfg.Redis()
|
|
|
if rds == nil {
|
|
|
return
|
|
|
}
|
|
|
ttl := time.Until(session.ExpiresAt)
|
|
|
if ttl <= 0 {
|
|
|
return
|
|
|
}
|
|
|
revoked := "false"
|
|
|
if session.Revoked {
|
|
|
revoked = "true"
|
|
|
}
|
|
|
pipe := rds.Pipeline()
|
|
|
pipe.HSet(context.Background(), sessionKey(session.ID),
|
|
|
"version", session.Version,
|
|
|
"revoked", revoked,
|
|
|
"user_id", session.UserID,
|
|
|
)
|
|
|
pipe.Expire(context.Background(), sessionKey(session.ID), ttl)
|
|
|
pipe.SAdd(context.Background(), userSessionsKey(session.UserID), session.ID)
|
|
|
pipe.Expire(context.Background(), userSessionsKey(session.UserID), ttl)
|
|
|
if _, err := pipe.Exec(context.Background()); err != nil {
|
|
|
logv.Warn().Msgf("fillSessionCache: redis pipeline failed: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func validatePermission(code string, level int) error {
|
|
|
if code == "*" {
|
|
|
if level != LevelAdmin {
|
|
|
return fmt.Errorf("wildcard * requires LevelAdmin")
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
parts := strings.Split(code, ":")
|
|
|
depth := len(parts)
|
|
|
|
|
|
if level == LevelCreate {
|
|
|
if depth%2 == 0 {
|
|
|
return fmt.Errorf("LevelCreate requires odd depth (resource type), got %d for %s", depth, code)
|
|
|
}
|
|
|
} else {
|
|
|
// Level 2, 4, 6, 7
|
|
|
if depth%2 != 0 {
|
|
|
return fmt.Errorf("Level %d requires even depth (resource instance), got %d for %s", level, depth, code)
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
}
|