package cache import ( "context" "encoding/json" "fmt" "time" "github.com/redis/go-redis/v9" "github.com/veypi/vbase/internal/config" ) var ( Client *redis.Client Ctx = context.Background() ) // Init 初始化Redis连接 func Init() error { if !config.C.Redis.Enabled { return nil } Client = redis.NewClient(&redis.Options{ Addr: config.C.Redis.Addr, Password: config.C.Redis.Password, DB: config.C.Redis.DB, }) if err := Client.Ping(Ctx).Err(); err != nil { return fmt.Errorf("failed to connect redis: %w", err) } return nil } // IsEnabled 是否启用缓存 func IsEnabled() bool { return config.C.Redis.Enabled && Client != nil } // Get 获取字符串值 func Get(key string) (string, error) { if !IsEnabled() { return "", fmt.Errorf("redis not enabled") } return Client.Get(Ctx, key).Result() } // GetObject 获取并反序列化对象 func GetObject(key string, dest interface{}) error { if !IsEnabled() { return fmt.Errorf("redis not enabled") } data, err := Client.Get(Ctx, key).Bytes() if err != nil { return err } return json.Unmarshal(data, dest) } // Set 设置字符串值 func Set(key string, value string, expiration time.Duration) error { if !IsEnabled() { return nil } return Client.Set(Ctx, key, value, expiration).Err() } // SetObject 序列化并设置对象 func SetObject(key string, value interface{}, expiration time.Duration) error { if !IsEnabled() { return nil } data, err := json.Marshal(value) if err != nil { return err } return Client.Set(Ctx, key, data, expiration).Err() } // Delete 删除key func Delete(keys ...string) error { if !IsEnabled() { return nil } return Client.Del(Ctx, keys...).Err() } // Exists 检查key是否存在 func Exists(keys ...string) (int64, error) { if !IsEnabled() { return 0, nil } return Client.Exists(Ctx, keys...).Result() } // Expire 设置过期时间 func Expire(key string, expiration time.Duration) error { if !IsEnabled() { return nil } return Client.Expire(Ctx, key, expiration).Err() } // TTL 获取剩余过期时间 func TTL(key string) (time.Duration, error) { if !IsEnabled() { return 0, nil } return Client.TTL(Ctx, key).Result() } // Incr 自增 func Incr(key string) (int64, error) { if !IsEnabled() { return 0, fmt.Errorf("redis not enabled") } return Client.Incr(Ctx, key).Result() } // IncrBy 增加指定值 func IncrBy(key string, value int64) (int64, error) { if !IsEnabled() { return 0, fmt.Errorf("redis not enabled") } return Client.IncrBy(Ctx, key, value).Result() } // Decr 自减 func Decr(key string) (int64, error) { if !IsEnabled() { return 0, fmt.Errorf("redis not enabled") } return Client.Decr(Ctx, key).Result() } // HSet 设置hash字段 func HSet(key string, values ...interface{}) error { if !IsEnabled() { return nil } return Client.HSet(Ctx, key, values...).Err() } // HGet 获取hash字段 func HGet(key, field string) (string, error) { if !IsEnabled() { return "", fmt.Errorf("redis not enabled") } return Client.HGet(Ctx, key, field).Result() } // HGetAll 获取hash所有字段 func HGetAll(key string) (map[string]string, error) { if !IsEnabled() { return nil, fmt.Errorf("redis not enabled") } return Client.HGetAll(Ctx, key).Result() } // HDel 删除hash字段 func HDel(key string, fields ...string) error { if !IsEnabled() { return nil } return Client.HDel(Ctx, key, fields...).Err() } // SetNX 仅当key不存在时才设置(用于分布式锁) func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) { if !IsEnabled() { return false, fmt.Errorf("redis not enabled") } return Client.SetNX(Ctx, key, value, expiration).Result() } // ==================== 权限缓存相关 ==================== // PermKey 生成权限缓存key func PermKey(userID, orgID, resource, action string) string { if orgID == "" { return fmt.Sprintf("perm:%s:%s:%s", userID, resource, action) } return fmt.Sprintf("perm:%s:%s:%s:%s", userID, orgID, resource, action) } // SetPermission 缓存权限结果 func SetPermission(userID, orgID, resource, action string, allowed bool, expiration time.Duration) error { key := PermKey(userID, orgID, resource, action) value := "deny" if allowed { value = "allow" } return Set(key, value, expiration) } // GetPermission 获取缓存的权限结果 func GetPermission(userID, orgID, resource, action string) (allowed bool, cached bool, err error) { key := PermKey(userID, orgID, resource, action) value, err := Get(key) if err != nil { if err == redis.Nil { return false, false, nil } return false, false, err } return value == "allow", true, nil } // DeletePermission 删除权限缓存 func DeletePermission(userID, orgID, resource, action string) error { key := PermKey(userID, orgID, resource, action) return Delete(key) } // DeleteUserPermissions 删除用户的所有权限缓存 func DeleteUserPermissions(userID string) error { if !IsEnabled() { return nil } pattern := fmt.Sprintf("perm:%s:*", userID) return deleteByPattern(pattern) } // DeleteOrgPermissions 删除组织的所有权限缓存 func DeleteOrgPermissions(orgID string) error { if !IsEnabled() { return nil } pattern := fmt.Sprintf("perm:*:%s:*", orgID) return deleteByPattern(pattern) } // deleteByPattern 根据pattern删除key func deleteByPattern(pattern string) error { iter := Client.Scan(Ctx, 0, pattern, 0).Iterator() var keys []string for iter.Next(Ctx) { keys = append(keys, iter.Val()) if len(keys) >= 100 { if err := Delete(keys...); err != nil { return err } keys = keys[:0] } } if err := iter.Err(); err != nil { return err } if len(keys) > 0 { return Delete(keys...) } return nil } // ==================== 用户/组织缓存 ==================== // UserKey 用户缓存key func UserKey(userID string) string { return fmt.Sprintf("user:%s", userID) } // OrgKey 组织缓存key func OrgKey(orgID string) string { return fmt.Sprintf("org:%s", orgID) } // OrgMemberKey 组织成员缓存key func OrgMemberKey(orgID, userID string) string { return fmt.Sprintf("org:%s:member:%s", orgID, userID) } // ==================== Token黑名单 ==================== // TokenBlacklistKey Token黑名单key func TokenBlacklistKey(jti string) string { return fmt.Sprintf("token:revoked:%s", jti) } // BlacklistToken 将Token加入黑名单 func BlacklistToken(jti string, expiration time.Duration) error { key := TokenBlacklistKey(jti) return Set(key, "1", expiration) } // IsTokenBlacklisted 检查Token是否在黑名单中 func IsTokenBlacklisted(jti string) (bool, error) { key := TokenBlacklistKey(jti) _, err := Get(key) if err != nil { if err == redis.Nil { return false, nil } return false, err } return true, nil } // ==================== 限流缓存 ==================== // RateLimitKey 限流key func RateLimitKey(identifier, path string) string { return fmt.Sprintf("ratelimit:%s:%s", identifier, path) } // IncrRateLimit 增加限流计数 func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) { key := RateLimitKey(identifier, path) count, err := Incr(key) if err != nil { return 0, err } // 第一次设置过期时间 if count == 1 { Expire(key, window) } return count, nil } // GetRateLimit 获取当前限流计数 func GetRateLimit(identifier, path string) (int64, error) { key := RateLimitKey(identifier, path) count, err := Get(key) if err != nil { if err == redis.Nil { return 0, nil } return 0, err } var result int64 fmt.Sscanf(count, "%d", &result) return result, nil } // ==================== 验证码缓存 ==================== // CaptchaKey 验证码key func CaptchaKey(captchaID string) string { return fmt.Sprintf("captcha:%s", captchaID) } // SetCaptcha 存储验证码 func SetCaptcha(captchaID, code string, expiration time.Duration) error { key := CaptchaKey(captchaID) return Set(key, code, expiration) } // VerifyCaptcha 验证验证码(验证后删除) func VerifyCaptcha(captchaID, code string) (bool, error) { key := CaptchaKey(captchaID) storedCode, err := Get(key) if err != nil { if err == redis.Nil { return false, nil } return false, err } // 验证后删除 Delete(key) return storedCode == code, nil } // ==================== OAuth缓存 ==================== // OAuthCodeKey OAuth授权码key func OAuthCodeKey(code string) string { return fmt.Sprintf("oauth:code:%s", code) } // OAuthStateKey OAuth state key func OAuthStateKey(state string) string { return fmt.Sprintf("oauth:state:%s", state) }