You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/libs/cache/cache.go

379 lines
8.3 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
package cache
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/redis/go-redis/v9"
"github.com/veypi/vbase/cfg"
)
var (
Ctx = context.Background()
// Memory store fallback
memStore sync.Map
)
type memItem struct {
Value []byte
Expiration int64
}
// IsEnabled 是否启用Redis缓存
func IsEnabled() bool {
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return false
}
return cfg.Redis() != nil
}
// Get 获取字符串值
func Get(key string) (string, error) {
if IsEnabled() {
return cfg.Redis().Get(Ctx, key).Result()
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return "", redis.Nil
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return "", redis.Nil
}
return string(item.Value), nil
}
// GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error {
if IsEnabled() {
data, err := cfg.Redis().Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return redis.Nil
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return redis.Nil
}
return json.Unmarshal(item.Value, dest)
}
// Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error {
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, value, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: []byte(value),
Expiration: exp,
})
return nil
}
// SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, data, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: data,
Expiration: exp,
})
return nil
}
// Delete 删除key
func Delete(keys ...string) error {
if IsEnabled() {
return cfg.Redis().Del(Ctx, keys...).Err()
}
// Memory fallback
for _, key := range keys {
memStore.Delete(key)
}
return nil
}
// Exists 检查key是否存在
func Exists(keys ...string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Exists(Ctx, keys...).Result()
}
// Memory fallback
count := int64(0)
for _, key := range keys {
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration {
count++
} else {
memStore.Delete(key)
}
}
}
return count, nil
}
// Expire 设置过期时间
func Expire(key string, expiration time.Duration) error {
if IsEnabled() {
return cfg.Redis().Expire(Ctx, key, expiration).Err()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
item.Expiration = time.Now().Add(expiration).UnixNano()
memStore.Store(key, item)
return nil
}
return nil // Key not found, ignore
}
// TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) {
if IsEnabled() {
return cfg.Redis().TTL(Ctx, key).Result()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 {
return -1, nil // No expiration
}
ttl := time.Duration(item.Expiration - time.Now().UnixNano())
if ttl <= 0 {
memStore.Delete(key)
return -2, nil // Expired
}
return ttl, nil
}
return -2, nil // Not found
}
// Incr 自增
func Incr(key string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Incr(Ctx, key).Result()
}
// Memory fallback not implemented for counters
return 0, fmt.Errorf("memory cache: incr not implemented")
}
// IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) {
if IsEnabled() {
return cfg.Redis().IncrBy(Ctx, key, value).Result()
}
return 0, fmt.Errorf("memory cache: incrby not implemented")
}
// Decr 自减
func Decr(key string) (int64, error) {
if IsEnabled() {
return cfg.Redis().Decr(Ctx, key).Result()
}
return 0, fmt.Errorf("memory cache: decr not implemented")
}
// HSet 设置hash字段
func HSet(key string, values ...interface{}) error {
if IsEnabled() {
return cfg.Redis().HSet(Ctx, key, values...).Err()
}
return fmt.Errorf("memory cache: hset not implemented")
}
// HGet 获取hash字段
func HGet(key, field string) (string, error) {
if IsEnabled() {
return cfg.Redis().HGet(Ctx, key, field).Result()
}
return "", fmt.Errorf("memory cache: hget not implemented")
}
// HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) {
if IsEnabled() {
return cfg.Redis().HGetAll(Ctx, key).Result()
}
return nil, fmt.Errorf("memory cache: hgetall not implemented")
}
// HDel 删除hash字段
func HDel(key string, fields ...string) error {
if IsEnabled() {
return cfg.Redis().HDel(Ctx, key, fields...).Err()
}
return fmt.Errorf("memory cache: hdel not implemented")
}
// SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if IsEnabled() {
return cfg.Redis().SetNX(Ctx, key, value, expiration).Result()
}
// Simple memory implementation
if _, ok := memStore.Load(key); ok {
return false, nil
}
SetObject(key, value, expiration)
return true, nil
}
// ==================== 用户缓存 ====================
// UserKey 用户缓存key
func UserKey(userID string) string {
return fmt.Sprintf("user:%s", userID)
}
// ==================== Token黑名单 ====================
// TokenBlacklistKey Token黑名单key
func TokenBlacklistKey(jti string) string {
return fmt.Sprintf("token:revoked:%s", jti)
}
// BlacklistToken 将Token加入黑名单
func BlacklistToken(jti string, expiration time.Duration) error {
key := TokenBlacklistKey(jti)
return Set(key, "1", expiration)
}
// IsTokenBlacklisted 检查Token是否在黑名单中
func IsTokenBlacklisted(jti string) (bool, error) {
key := TokenBlacklistKey(jti)
_, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
return true, nil
}
// ==================== 限流缓存 ====================
// RateLimitKey 限流key
func RateLimitKey(identifier, path string) string {
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
}
// IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path)
// For memory cache, we need a special implementation or just skip rate limiting
if !IsEnabled() {
// Simple bypass for memory mode
return 1, nil
}
count, err := Incr(key)
if err != nil {
return 0, err
}
// 第一次设置过期时间
if count == 1 {
Expire(key, window)
}
return count, nil
}
// GetRateLimit 获取当前限流计数
func GetRateLimit(identifier, path string) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Get(key)
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
var result int64
fmt.Sscanf(count, "%d", &result)
return result, nil
}
// ==================== 验证码缓存 ====================
// CaptchaKey 验证码key
func CaptchaKey(captchaID string) string {
return fmt.Sprintf("captcha:%s", captchaID)
}
// SetCaptcha 存储验证码
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
key := CaptchaKey(captchaID)
return Set(key, code, expiration)
}
// VerifyCaptcha 验证验证码(验证后删除)
func VerifyCaptcha(captchaID, code string) (bool, error) {
key := CaptchaKey(captchaID)
storedCode, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
// 验证后删除
Delete(key)
return storedCode == code, nil
}
// ==================== OAuth缓存 ====================
// OAuthCodeKey OAuth授权码key
func OAuthCodeKey(code string) string {
return fmt.Sprintf("oauth:code:%s", code)
}
// OAuthStateKey OAuth state key
func OAuthStateKey(state string) string {
return fmt.Sprintf("oauth:state:%s", state)
}