// // Copyright (C) 2024 veypi // 2025-03-04 16:08:06 // Distributed under terms of the MIT license. // package cache import ( "context" "encoding/json" "fmt" "sync" "time" "github.com/redis/go-redis/v9" "github.com/veypi/vbase/cfg" ) var ( Ctx = context.Background() // Memory store fallback memStore sync.Map ) type memItem struct { Value []byte Expiration int64 } // IsEnabled 是否启用Redis缓存 func IsEnabled() bool { if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" { return false } return cfg.Redis() != nil } // Get 获取字符串值 func Get(key string) (string, error) { if IsEnabled() { return cfg.Redis().Get(Ctx, key).Result() } // Memory fallback val, ok := memStore.Load(key) if !ok { return "", redis.Nil } item := val.(memItem) if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { memStore.Delete(key) return "", redis.Nil } return string(item.Value), nil } // GetObject 获取并反序列化对象 func GetObject(key string, dest interface{}) error { if IsEnabled() { data, err := cfg.Redis().Get(Ctx, key).Bytes() if err != nil { return err } return json.Unmarshal(data, dest) } // Memory fallback val, ok := memStore.Load(key) if !ok { return redis.Nil } item := val.(memItem) if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { memStore.Delete(key) return redis.Nil } return json.Unmarshal(item.Value, dest) } // Set 设置字符串值 func Set(key string, value string, expiration time.Duration) error { if IsEnabled() { return cfg.Redis().Set(Ctx, key, value, expiration).Err() } // Memory fallback exp := int64(0) if expiration > 0 { exp = time.Now().Add(expiration).UnixNano() } memStore.Store(key, memItem{ Value: []byte(value), Expiration: exp, }) return nil } // SetObject 序列化并设置对象 func SetObject(key string, value interface{}, expiration time.Duration) error { data, err := json.Marshal(value) if err != nil { return err } if IsEnabled() { return cfg.Redis().Set(Ctx, key, data, expiration).Err() } // Memory fallback exp := int64(0) if expiration > 0 { exp = time.Now().Add(expiration).UnixNano() } memStore.Store(key, memItem{ Value: data, Expiration: exp, }) return nil } // Delete 删除key func Delete(keys ...string) error { if IsEnabled() { return cfg.Redis().Del(Ctx, keys...).Err() } // Memory fallback for _, key := range keys { memStore.Delete(key) } return nil } // Exists 检查key是否存在 func Exists(keys ...string) (int64, error) { if IsEnabled() { return cfg.Redis().Exists(Ctx, keys...).Result() } // Memory fallback count := int64(0) for _, key := range keys { if val, ok := memStore.Load(key); ok { item := val.(memItem) if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration { count++ } else { memStore.Delete(key) } } } return count, nil } // Expire 设置过期时间 func Expire(key string, expiration time.Duration) error { if IsEnabled() { return cfg.Redis().Expire(Ctx, key, expiration).Err() } // Memory fallback if val, ok := memStore.Load(key); ok { item := val.(memItem) item.Expiration = time.Now().Add(expiration).UnixNano() memStore.Store(key, item) return nil } return nil // Key not found, ignore } // TTL 获取剩余过期时间 func TTL(key string) (time.Duration, error) { if IsEnabled() { return cfg.Redis().TTL(Ctx, key).Result() } // Memory fallback if val, ok := memStore.Load(key); ok { item := val.(memItem) if item.Expiration == 0 { return -1, nil // No expiration } ttl := time.Duration(item.Expiration - time.Now().UnixNano()) if ttl <= 0 { memStore.Delete(key) return -2, nil // Expired } return ttl, nil } return -2, nil // Not found } // Incr 自增 func Incr(key string) (int64, error) { if IsEnabled() { return cfg.Redis().Incr(Ctx, key).Result() } // Memory fallback not implemented for counters return 0, fmt.Errorf("memory cache: incr not implemented") } // IncrBy 增加指定值 func IncrBy(key string, value int64) (int64, error) { if IsEnabled() { return cfg.Redis().IncrBy(Ctx, key, value).Result() } return 0, fmt.Errorf("memory cache: incrby not implemented") } // Decr 自减 func Decr(key string) (int64, error) { if IsEnabled() { return cfg.Redis().Decr(Ctx, key).Result() } return 0, fmt.Errorf("memory cache: decr not implemented") } // HSet 设置hash字段 func HSet(key string, values ...interface{}) error { if IsEnabled() { return cfg.Redis().HSet(Ctx, key, values...).Err() } return fmt.Errorf("memory cache: hset not implemented") } // HGet 获取hash字段 func HGet(key, field string) (string, error) { if IsEnabled() { return cfg.Redis().HGet(Ctx, key, field).Result() } return "", fmt.Errorf("memory cache: hget not implemented") } // HGetAll 获取hash所有字段 func HGetAll(key string) (map[string]string, error) { if IsEnabled() { return cfg.Redis().HGetAll(Ctx, key).Result() } return nil, fmt.Errorf("memory cache: hgetall not implemented") } // HDel 删除hash字段 func HDel(key string, fields ...string) error { if IsEnabled() { return cfg.Redis().HDel(Ctx, key, fields...).Err() } return fmt.Errorf("memory cache: hdel not implemented") } // SetNX 仅当key不存在时才设置(用于分布式锁) func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) { if IsEnabled() { return cfg.Redis().SetNX(Ctx, key, value, expiration).Result() } // Simple memory implementation if _, ok := memStore.Load(key); ok { return false, nil } SetObject(key, value, expiration) return true, nil } // ==================== 用户缓存 ==================== // UserKey 用户缓存key func UserKey(userID string) string { return fmt.Sprintf("user:%s", userID) } // ==================== Token黑名单 ==================== // TokenBlacklistKey Token黑名单key func TokenBlacklistKey(jti string) string { return fmt.Sprintf("token:revoked:%s", jti) } // BlacklistToken 将Token加入黑名单 func BlacklistToken(jti string, expiration time.Duration) error { key := TokenBlacklistKey(jti) return Set(key, "1", expiration) } // IsTokenBlacklisted 检查Token是否在黑名单中 func IsTokenBlacklisted(jti string) (bool, error) { key := TokenBlacklistKey(jti) _, err := Get(key) if err != nil { if err == redis.Nil { return false, nil } return false, err } return true, nil } // ==================== 限流缓存 ==================== // RateLimitKey 限流key func RateLimitKey(identifier, path string) string { return fmt.Sprintf("ratelimit:%s:%s", identifier, path) } // IncrRateLimit 增加限流计数 func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) { key := RateLimitKey(identifier, path) // For memory cache, we need a special implementation or just skip rate limiting if !IsEnabled() { // Simple bypass for memory mode return 1, nil } count, err := Incr(key) if err != nil { return 0, err } // 第一次设置过期时间 if count == 1 { Expire(key, window) } return count, nil } // GetRateLimit 获取当前限流计数 func GetRateLimit(identifier, path string) (int64, error) { key := RateLimitKey(identifier, path) count, err := Get(key) if err != nil { if err == redis.Nil { return 0, nil } return 0, err } var result int64 fmt.Sscanf(count, "%d", &result) return result, nil } // ==================== 验证码缓存 ==================== // CaptchaKey 验证码key func CaptchaKey(captchaID string) string { return fmt.Sprintf("captcha:%s", captchaID) } // SetCaptcha 存储验证码 func SetCaptcha(captchaID, code string, expiration time.Duration) error { key := CaptchaKey(captchaID) return Set(key, code, expiration) } // VerifyCaptcha 验证验证码(验证后删除) func VerifyCaptcha(captchaID, code string) (bool, error) { key := CaptchaKey(captchaID) storedCode, err := Get(key) if err != nil { if err == redis.Nil { return false, nil } return false, err } // 验证后删除 Delete(key) return storedCode == code, nil } // ==================== OAuth缓存 ==================== // OAuthCodeKey OAuth授权码key func OAuthCodeKey(code string) string { return fmt.Sprintf("oauth:code:%s", code) } // OAuthStateKey OAuth state key func OAuthStateKey(state string) string { return fmt.Sprintf("oauth:state:%s", state) }