You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/libs/cache/cache.go

379 lines
8.3 KiB
Go

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
1 month ago
package cache
import (
"context"
"encoding/json"
"fmt"
"sync"
1 month ago
"time"
"github.com/redis/go-redis/v9"
"github.com/veypi/vbase/cfg"
1 month ago
)
var (
Ctx = context.Background()
1 month ago
// Memory store fallback
memStore sync.Map
)
1 month ago
type memItem struct {
Value []byte
Expiration int64
1 month ago
}
// IsEnabled 是否启用Redis缓存
1 month ago
func IsEnabled() bool {
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return false
}
return cfg.Redis() != nil
1 month ago
}
// Get 获取字符串值
func Get(key string) (string, error) {
if IsEnabled() {
return cfg.Redis().Get(Ctx, key).Result()
1 month ago
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return "", redis.Nil
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return "", redis.Nil
}
return string(item.Value), nil
1 month ago
}
// GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error {
if IsEnabled() {
data, err := cfg.Redis().Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
1 month ago
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return redis.Nil
1 month ago
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return redis.Nil
}
return json.Unmarshal(item.Value, dest)
1 month ago
}
// Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error {
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, value, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
1 month ago
}
memStore.Store(key, memItem{
Value: []byte(value),
Expiration: exp,
})
return nil
1 month ago
}
// SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, data, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: data,
Expiration: exp,
})
return nil
1 month ago
}
// Delete 删除key
func Delete(keys ...string) error {
if IsEnabled() {
return cfg.Redis().Del(Ctx, keys...).Err()
1 month ago
}
// Memory fallback
for _, key := range keys {
memStore.Delete(key)
}
return nil
1 month ago
}
// Exists 检查key是否存在
func Exists(keys ...string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Exists(Ctx, keys...).Result()
}
// Memory fallback
count := int64(0)
for _, key := range keys {
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration {
count++
} else {
memStore.Delete(key)
}
}
1 month ago
}
return count, nil
1 month ago
}
// Expire 设置过期时间
func Expire(key string, expiration time.Duration) error {
if IsEnabled() {
return cfg.Redis().Expire(Ctx, key, expiration).Err()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
item.Expiration = time.Now().Add(expiration).UnixNano()
memStore.Store(key, item)
1 month ago
return nil
}
return nil // Key not found, ignore
1 month ago
}
// TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) {
if IsEnabled() {
return cfg.Redis().TTL(Ctx, key).Result()
1 month ago
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 {
return -1, nil // No expiration
}
ttl := time.Duration(item.Expiration - time.Now().UnixNano())
if ttl <= 0 {
memStore.Delete(key)
return -2, nil // Expired
}
return ttl, nil
}
return -2, nil // Not found
1 month ago
}
// Incr 自增
func Incr(key string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Incr(Ctx, key).Result()
1 month ago
}
// Memory fallback not implemented for counters
return 0, fmt.Errorf("memory cache: incr not implemented")
1 month ago
}
// IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) {
if IsEnabled() {
return cfg.Redis().IncrBy(Ctx, key, value).Result()
1 month ago
}
return 0, fmt.Errorf("memory cache: incrby not implemented")
1 month ago
}
// Decr 自减
func Decr(key string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Decr(Ctx, key).Result()
1 month ago
}
return 0, fmt.Errorf("memory cache: decr not implemented")
1 month ago
}
// HSet 设置hash字段
func HSet(key string, values ...interface{}) error {
if IsEnabled() {
return cfg.Redis().HSet(Ctx, key, values...).Err()
1 month ago
}
return fmt.Errorf("memory cache: hset not implemented")
1 month ago
}
// HGet 获取hash字段
func HGet(key, field string) (string, error) {
if IsEnabled() {
return cfg.Redis().HGet(Ctx, key, field).Result()
1 month ago
}
return "", fmt.Errorf("memory cache: hget not implemented")
1 month ago
}
// HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) {
if IsEnabled() {
return cfg.Redis().HGetAll(Ctx, key).Result()
1 month ago
}
return nil, fmt.Errorf("memory cache: hgetall not implemented")
1 month ago
}
// HDel 删除hash字段
func HDel(key string, fields ...string) error {
if IsEnabled() {
return cfg.Redis().HDel(Ctx, key, fields...).Err()
1 month ago
}
return fmt.Errorf("memory cache: hdel not implemented")
1 month ago
}
// SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if IsEnabled() {
return cfg.Redis().SetNX(Ctx, key, value, expiration).Result()
}
// Simple memory implementation
if _, ok := memStore.Load(key); ok {
return false, nil
1 month ago
}
SetObject(key, value, expiration)
return true, nil
1 month ago
}
// ==================== 用户缓存 ====================
1 month ago
// UserKey 用户缓存key
func UserKey(userID string) string {
return fmt.Sprintf("user:%s", userID)
}
// ==================== Token黑名单 ====================
// TokenBlacklistKey Token黑名单key
func TokenBlacklistKey(jti string) string {
return fmt.Sprintf("token:revoked:%s", jti)
}
// BlacklistToken 将Token加入黑名单
func BlacklistToken(jti string, expiration time.Duration) error {
key := TokenBlacklistKey(jti)
return Set(key, "1", expiration)
}
// IsTokenBlacklisted 检查Token是否在黑名单中
func IsTokenBlacklisted(jti string) (bool, error) {
key := TokenBlacklistKey(jti)
_, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
return true, nil
}
// ==================== 限流缓存 ====================
// RateLimitKey 限流key
func RateLimitKey(identifier, path string) string {
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
}
// IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path)
// For memory cache, we need a special implementation or just skip rate limiting
if !IsEnabled() {
// Simple bypass for memory mode
return 1, nil
}
1 month ago
count, err := Incr(key)
if err != nil {
return 0, err
}
// 第一次设置过期时间
if count == 1 {
Expire(key, window)
}
return count, nil
}
// GetRateLimit 获取当前限流计数
func GetRateLimit(identifier, path string) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Get(key)
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
var result int64
fmt.Sscanf(count, "%d", &result)
return result, nil
}
// ==================== 验证码缓存 ====================
// CaptchaKey 验证码key
func CaptchaKey(captchaID string) string {
return fmt.Sprintf("captcha:%s", captchaID)
}
// SetCaptcha 存储验证码
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
key := CaptchaKey(captchaID)
return Set(key, code, expiration)
}
// VerifyCaptcha 验证验证码(验证后删除)
func VerifyCaptcha(captchaID, code string) (bool, error) {
key := CaptchaKey(captchaID)
storedCode, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
// 验证后删除
Delete(key)
return storedCode == code, nil
}
// ==================== OAuth缓存 ====================
// OAuthCodeKey OAuth授权码key
func OAuthCodeKey(code string) string {
return fmt.Sprintf("oauth:code:%s", code)
}
// OAuthStateKey OAuth state key
func OAuthStateKey(state string) string {
return fmt.Sprintf("oauth:state:%s", state)
}