|
|
//
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
// 2025-03-04 16:08:06
|
|
|
// Distributed under terms of the MIT license.
|
|
|
//
|
|
|
|
|
|
package cache
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
"github.com/veypi/vbase/cfg"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
Ctx = context.Background()
|
|
|
|
|
|
// Memory store fallback
|
|
|
memStore sync.Map
|
|
|
)
|
|
|
|
|
|
type memItem struct {
|
|
|
Value []byte
|
|
|
Expiration int64
|
|
|
}
|
|
|
|
|
|
// IsEnabled 是否启用Redis缓存
|
|
|
func IsEnabled() bool {
|
|
|
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
|
|
|
return false
|
|
|
}
|
|
|
return cfg.Redis() != nil
|
|
|
}
|
|
|
|
|
|
// Get 获取字符串值
|
|
|
func Get(key string) (string, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Get(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
val, ok := memStore.Load(key)
|
|
|
if !ok {
|
|
|
return "", redis.Nil
|
|
|
}
|
|
|
item := val.(memItem)
|
|
|
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
|
|
|
memStore.Delete(key)
|
|
|
return "", redis.Nil
|
|
|
}
|
|
|
return string(item.Value), nil
|
|
|
}
|
|
|
|
|
|
// GetObject 获取并反序列化对象
|
|
|
func GetObject(key string, dest interface{}) error {
|
|
|
if IsEnabled() {
|
|
|
data, err := cfg.Redis().Get(Ctx, key).Bytes()
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
return json.Unmarshal(data, dest)
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
val, ok := memStore.Load(key)
|
|
|
if !ok {
|
|
|
return redis.Nil
|
|
|
}
|
|
|
item := val.(memItem)
|
|
|
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
|
|
|
memStore.Delete(key)
|
|
|
return redis.Nil
|
|
|
}
|
|
|
return json.Unmarshal(item.Value, dest)
|
|
|
}
|
|
|
|
|
|
// Set 设置字符串值
|
|
|
func Set(key string, value string, expiration time.Duration) error {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Set(Ctx, key, value, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
exp := int64(0)
|
|
|
if expiration > 0 {
|
|
|
exp = time.Now().Add(expiration).UnixNano()
|
|
|
}
|
|
|
memStore.Store(key, memItem{
|
|
|
Value: []byte(value),
|
|
|
Expiration: exp,
|
|
|
})
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// SetObject 序列化并设置对象
|
|
|
func SetObject(key string, value interface{}, expiration time.Duration) error {
|
|
|
data, err := json.Marshal(value)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Set(Ctx, key, data, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
exp := int64(0)
|
|
|
if expiration > 0 {
|
|
|
exp = time.Now().Add(expiration).UnixNano()
|
|
|
}
|
|
|
memStore.Store(key, memItem{
|
|
|
Value: data,
|
|
|
Expiration: exp,
|
|
|
})
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Delete 删除key
|
|
|
func Delete(keys ...string) error {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Del(Ctx, keys...).Err()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
for _, key := range keys {
|
|
|
memStore.Delete(key)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Exists 检查key是否存在
|
|
|
func Exists(keys ...string) (int64, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Exists(Ctx, keys...).Result()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
count := int64(0)
|
|
|
for _, key := range keys {
|
|
|
if val, ok := memStore.Load(key); ok {
|
|
|
item := val.(memItem)
|
|
|
if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration {
|
|
|
count++
|
|
|
} else {
|
|
|
memStore.Delete(key)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return count, nil
|
|
|
}
|
|
|
|
|
|
// Expire 设置过期时间
|
|
|
func Expire(key string, expiration time.Duration) error {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Expire(Ctx, key, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
if val, ok := memStore.Load(key); ok {
|
|
|
item := val.(memItem)
|
|
|
item.Expiration = time.Now().Add(expiration).UnixNano()
|
|
|
memStore.Store(key, item)
|
|
|
return nil
|
|
|
}
|
|
|
return nil // Key not found, ignore
|
|
|
}
|
|
|
|
|
|
// TTL 获取剩余过期时间
|
|
|
func TTL(key string) (time.Duration, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().TTL(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// Memory fallback
|
|
|
if val, ok := memStore.Load(key); ok {
|
|
|
item := val.(memItem)
|
|
|
if item.Expiration == 0 {
|
|
|
return -1, nil // No expiration
|
|
|
}
|
|
|
ttl := time.Duration(item.Expiration - time.Now().UnixNano())
|
|
|
if ttl <= 0 {
|
|
|
memStore.Delete(key)
|
|
|
return -2, nil // Expired
|
|
|
}
|
|
|
return ttl, nil
|
|
|
}
|
|
|
return -2, nil // Not found
|
|
|
}
|
|
|
|
|
|
// Incr 自增
|
|
|
func Incr(key string) (int64, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Incr(Ctx, key).Result()
|
|
|
}
|
|
|
// Memory fallback not implemented for counters
|
|
|
return 0, fmt.Errorf("memory cache: incr not implemented")
|
|
|
}
|
|
|
|
|
|
// IncrBy 增加指定值
|
|
|
func IncrBy(key string, value int64) (int64, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().IncrBy(Ctx, key, value).Result()
|
|
|
}
|
|
|
return 0, fmt.Errorf("memory cache: incrby not implemented")
|
|
|
}
|
|
|
|
|
|
// Decr 自减
|
|
|
func Decr(key string) (int64, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().Decr(Ctx, key).Result()
|
|
|
}
|
|
|
return 0, fmt.Errorf("memory cache: decr not implemented")
|
|
|
}
|
|
|
|
|
|
// HSet 设置hash字段
|
|
|
func HSet(key string, values ...interface{}) error {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().HSet(Ctx, key, values...).Err()
|
|
|
}
|
|
|
return fmt.Errorf("memory cache: hset not implemented")
|
|
|
}
|
|
|
|
|
|
// HGet 获取hash字段
|
|
|
func HGet(key, field string) (string, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().HGet(Ctx, key, field).Result()
|
|
|
}
|
|
|
return "", fmt.Errorf("memory cache: hget not implemented")
|
|
|
}
|
|
|
|
|
|
// HGetAll 获取hash所有字段
|
|
|
func HGetAll(key string) (map[string]string, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().HGetAll(Ctx, key).Result()
|
|
|
}
|
|
|
return nil, fmt.Errorf("memory cache: hgetall not implemented")
|
|
|
}
|
|
|
|
|
|
// HDel 删除hash字段
|
|
|
func HDel(key string, fields ...string) error {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().HDel(Ctx, key, fields...).Err()
|
|
|
}
|
|
|
return fmt.Errorf("memory cache: hdel not implemented")
|
|
|
}
|
|
|
|
|
|
// SetNX 仅当key不存在时才设置(用于分布式锁)
|
|
|
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
|
|
|
if IsEnabled() {
|
|
|
return cfg.Redis().SetNX(Ctx, key, value, expiration).Result()
|
|
|
}
|
|
|
// Simple memory implementation
|
|
|
if _, ok := memStore.Load(key); ok {
|
|
|
return false, nil
|
|
|
}
|
|
|
SetObject(key, value, expiration)
|
|
|
return true, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 用户缓存 ====================
|
|
|
|
|
|
// UserKey 用户缓存key
|
|
|
func UserKey(userID string) string {
|
|
|
return fmt.Sprintf("user:%s", userID)
|
|
|
}
|
|
|
|
|
|
// ==================== Token黑名单 ====================
|
|
|
|
|
|
// TokenBlacklistKey Token黑名单key
|
|
|
func TokenBlacklistKey(jti string) string {
|
|
|
return fmt.Sprintf("token:revoked:%s", jti)
|
|
|
}
|
|
|
|
|
|
// BlacklistToken 将Token加入黑名单
|
|
|
func BlacklistToken(jti string, expiration time.Duration) error {
|
|
|
key := TokenBlacklistKey(jti)
|
|
|
return Set(key, "1", expiration)
|
|
|
}
|
|
|
|
|
|
// IsTokenBlacklisted 检查Token是否在黑名单中
|
|
|
func IsTokenBlacklisted(jti string) (bool, error) {
|
|
|
key := TokenBlacklistKey(jti)
|
|
|
_, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return false, nil
|
|
|
}
|
|
|
return false, err
|
|
|
}
|
|
|
return true, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 限流缓存 ====================
|
|
|
|
|
|
// RateLimitKey 限流key
|
|
|
func RateLimitKey(identifier, path string) string {
|
|
|
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
|
|
|
}
|
|
|
|
|
|
// IncrRateLimit 增加限流计数
|
|
|
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
|
|
|
key := RateLimitKey(identifier, path)
|
|
|
// For memory cache, we need a special implementation or just skip rate limiting
|
|
|
if !IsEnabled() {
|
|
|
// Simple bypass for memory mode
|
|
|
return 1, nil
|
|
|
}
|
|
|
|
|
|
count, err := Incr(key)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
// 第一次设置过期时间
|
|
|
if count == 1 {
|
|
|
Expire(key, window)
|
|
|
}
|
|
|
return count, nil
|
|
|
}
|
|
|
|
|
|
// GetRateLimit 获取当前限流计数
|
|
|
func GetRateLimit(identifier, path string) (int64, error) {
|
|
|
key := RateLimitKey(identifier, path)
|
|
|
count, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return 0, nil
|
|
|
}
|
|
|
return 0, err
|
|
|
}
|
|
|
var result int64
|
|
|
fmt.Sscanf(count, "%d", &result)
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 验证码缓存 ====================
|
|
|
|
|
|
// CaptchaKey 验证码key
|
|
|
func CaptchaKey(captchaID string) string {
|
|
|
return fmt.Sprintf("captcha:%s", captchaID)
|
|
|
}
|
|
|
|
|
|
// SetCaptcha 存储验证码
|
|
|
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
|
|
|
key := CaptchaKey(captchaID)
|
|
|
return Set(key, code, expiration)
|
|
|
}
|
|
|
|
|
|
// VerifyCaptcha 验证验证码(验证后删除)
|
|
|
func VerifyCaptcha(captchaID, code string) (bool, error) {
|
|
|
key := CaptchaKey(captchaID)
|
|
|
storedCode, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return false, nil
|
|
|
}
|
|
|
return false, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
Delete(key)
|
|
|
return storedCode == code, nil
|
|
|
}
|
|
|
|
|
|
// ==================== OAuth缓存 ====================
|
|
|
|
|
|
// OAuthCodeKey OAuth授权码key
|
|
|
func OAuthCodeKey(code string) string {
|
|
|
return fmt.Sprintf("oauth:code:%s", code)
|
|
|
}
|
|
|
|
|
|
// OAuthStateKey OAuth state key
|
|
|
func OAuthStateKey(state string) string {
|
|
|
return fmt.Sprintf("oauth:state:%s", state)
|
|
|
}
|