|
|
package cache
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"time"
|
|
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
"github.com/veypi/vbase/internal/config"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
Client *redis.Client
|
|
|
Ctx = context.Background()
|
|
|
)
|
|
|
|
|
|
// Init 初始化Redis连接
|
|
|
func Init() error {
|
|
|
if !config.C.Redis.Enabled {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
Client = redis.NewClient(&redis.Options{
|
|
|
Addr: config.C.Redis.Addr,
|
|
|
Password: config.C.Redis.Password,
|
|
|
DB: config.C.Redis.DB,
|
|
|
})
|
|
|
|
|
|
if err := Client.Ping(Ctx).Err(); err != nil {
|
|
|
return fmt.Errorf("failed to connect redis: %w", err)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// IsEnabled 是否启用缓存
|
|
|
func IsEnabled() bool {
|
|
|
return config.C.Redis.Enabled && Client != nil
|
|
|
}
|
|
|
|
|
|
// Get 获取字符串值
|
|
|
func Get(key string) (string, error) {
|
|
|
if !IsEnabled() {
|
|
|
return "", fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.Get(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// GetObject 获取并反序列化对象
|
|
|
func GetObject(key string, dest interface{}) error {
|
|
|
if !IsEnabled() {
|
|
|
return fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
data, err := Client.Get(Ctx, key).Bytes()
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
return json.Unmarshal(data, dest)
|
|
|
}
|
|
|
|
|
|
// Set 设置字符串值
|
|
|
func Set(key string, value string, expiration time.Duration) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
return Client.Set(Ctx, key, value, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// SetObject 序列化并设置对象
|
|
|
func SetObject(key string, value interface{}, expiration time.Duration) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
data, err := json.Marshal(value)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
return Client.Set(Ctx, key, data, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// Delete 删除key
|
|
|
func Delete(keys ...string) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
return Client.Del(Ctx, keys...).Err()
|
|
|
}
|
|
|
|
|
|
// Exists 检查key是否存在
|
|
|
func Exists(keys ...string) (int64, error) {
|
|
|
if !IsEnabled() {
|
|
|
return 0, nil
|
|
|
}
|
|
|
return Client.Exists(Ctx, keys...).Result()
|
|
|
}
|
|
|
|
|
|
// Expire 设置过期时间
|
|
|
func Expire(key string, expiration time.Duration) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
return Client.Expire(Ctx, key, expiration).Err()
|
|
|
}
|
|
|
|
|
|
// TTL 获取剩余过期时间
|
|
|
func TTL(key string) (time.Duration, error) {
|
|
|
if !IsEnabled() {
|
|
|
return 0, nil
|
|
|
}
|
|
|
return Client.TTL(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// Incr 自增
|
|
|
func Incr(key string) (int64, error) {
|
|
|
if !IsEnabled() {
|
|
|
return 0, fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.Incr(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// IncrBy 增加指定值
|
|
|
func IncrBy(key string, value int64) (int64, error) {
|
|
|
if !IsEnabled() {
|
|
|
return 0, fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.IncrBy(Ctx, key, value).Result()
|
|
|
}
|
|
|
|
|
|
// Decr 自减
|
|
|
func Decr(key string) (int64, error) {
|
|
|
if !IsEnabled() {
|
|
|
return 0, fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.Decr(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// HSet 设置hash字段
|
|
|
func HSet(key string, values ...interface{}) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
return Client.HSet(Ctx, key, values...).Err()
|
|
|
}
|
|
|
|
|
|
// HGet 获取hash字段
|
|
|
func HGet(key, field string) (string, error) {
|
|
|
if !IsEnabled() {
|
|
|
return "", fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.HGet(Ctx, key, field).Result()
|
|
|
}
|
|
|
|
|
|
// HGetAll 获取hash所有字段
|
|
|
func HGetAll(key string) (map[string]string, error) {
|
|
|
if !IsEnabled() {
|
|
|
return nil, fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.HGetAll(Ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
// HDel 删除hash字段
|
|
|
func HDel(key string, fields ...string) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
return Client.HDel(Ctx, key, fields...).Err()
|
|
|
}
|
|
|
|
|
|
// SetNX 仅当key不存在时才设置(用于分布式锁)
|
|
|
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
|
|
|
if !IsEnabled() {
|
|
|
return false, fmt.Errorf("redis not enabled")
|
|
|
}
|
|
|
return Client.SetNX(Ctx, key, value, expiration).Result()
|
|
|
}
|
|
|
|
|
|
// ==================== 权限缓存相关 ====================
|
|
|
|
|
|
// PermKey 生成权限缓存key
|
|
|
func PermKey(userID, orgID, resource, action string) string {
|
|
|
if orgID == "" {
|
|
|
return fmt.Sprintf("perm:%s:%s:%s", userID, resource, action)
|
|
|
}
|
|
|
return fmt.Sprintf("perm:%s:%s:%s:%s", userID, orgID, resource, action)
|
|
|
}
|
|
|
|
|
|
// SetPermission 缓存权限结果
|
|
|
func SetPermission(userID, orgID, resource, action string, allowed bool, expiration time.Duration) error {
|
|
|
key := PermKey(userID, orgID, resource, action)
|
|
|
value := "deny"
|
|
|
if allowed {
|
|
|
value = "allow"
|
|
|
}
|
|
|
return Set(key, value, expiration)
|
|
|
}
|
|
|
|
|
|
// GetPermission 获取缓存的权限结果
|
|
|
func GetPermission(userID, orgID, resource, action string) (allowed bool, cached bool, err error) {
|
|
|
key := PermKey(userID, orgID, resource, action)
|
|
|
value, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return false, false, nil
|
|
|
}
|
|
|
return false, false, err
|
|
|
}
|
|
|
return value == "allow", true, nil
|
|
|
}
|
|
|
|
|
|
// DeletePermission 删除权限缓存
|
|
|
func DeletePermission(userID, orgID, resource, action string) error {
|
|
|
key := PermKey(userID, orgID, resource, action)
|
|
|
return Delete(key)
|
|
|
}
|
|
|
|
|
|
// DeleteUserPermissions 删除用户的所有权限缓存
|
|
|
func DeleteUserPermissions(userID string) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
pattern := fmt.Sprintf("perm:%s:*", userID)
|
|
|
return deleteByPattern(pattern)
|
|
|
}
|
|
|
|
|
|
// DeleteOrgPermissions 删除组织的所有权限缓存
|
|
|
func DeleteOrgPermissions(orgID string) error {
|
|
|
if !IsEnabled() {
|
|
|
return nil
|
|
|
}
|
|
|
pattern := fmt.Sprintf("perm:*:%s:*", orgID)
|
|
|
return deleteByPattern(pattern)
|
|
|
}
|
|
|
|
|
|
// deleteByPattern 根据pattern删除key
|
|
|
func deleteByPattern(pattern string) error {
|
|
|
iter := Client.Scan(Ctx, 0, pattern, 0).Iterator()
|
|
|
var keys []string
|
|
|
for iter.Next(Ctx) {
|
|
|
keys = append(keys, iter.Val())
|
|
|
if len(keys) >= 100 {
|
|
|
if err := Delete(keys...); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
keys = keys[:0]
|
|
|
}
|
|
|
}
|
|
|
if err := iter.Err(); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
if len(keys) > 0 {
|
|
|
return Delete(keys...)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// ==================== 用户/组织缓存 ====================
|
|
|
|
|
|
// UserKey 用户缓存key
|
|
|
func UserKey(userID string) string {
|
|
|
return fmt.Sprintf("user:%s", userID)
|
|
|
}
|
|
|
|
|
|
// OrgKey 组织缓存key
|
|
|
func OrgKey(orgID string) string {
|
|
|
return fmt.Sprintf("org:%s", orgID)
|
|
|
}
|
|
|
|
|
|
// OrgMemberKey 组织成员缓存key
|
|
|
func OrgMemberKey(orgID, userID string) string {
|
|
|
return fmt.Sprintf("org:%s:member:%s", orgID, userID)
|
|
|
}
|
|
|
|
|
|
// ==================== Token黑名单 ====================
|
|
|
|
|
|
// TokenBlacklistKey Token黑名单key
|
|
|
func TokenBlacklistKey(jti string) string {
|
|
|
return fmt.Sprintf("token:revoked:%s", jti)
|
|
|
}
|
|
|
|
|
|
// BlacklistToken 将Token加入黑名单
|
|
|
func BlacklistToken(jti string, expiration time.Duration) error {
|
|
|
key := TokenBlacklistKey(jti)
|
|
|
return Set(key, "1", expiration)
|
|
|
}
|
|
|
|
|
|
// IsTokenBlacklisted 检查Token是否在黑名单中
|
|
|
func IsTokenBlacklisted(jti string) (bool, error) {
|
|
|
key := TokenBlacklistKey(jti)
|
|
|
_, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return false, nil
|
|
|
}
|
|
|
return false, err
|
|
|
}
|
|
|
return true, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 限流缓存 ====================
|
|
|
|
|
|
// RateLimitKey 限流key
|
|
|
func RateLimitKey(identifier, path string) string {
|
|
|
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
|
|
|
}
|
|
|
|
|
|
// IncrRateLimit 增加限流计数
|
|
|
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
|
|
|
key := RateLimitKey(identifier, path)
|
|
|
count, err := Incr(key)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
// 第一次设置过期时间
|
|
|
if count == 1 {
|
|
|
Expire(key, window)
|
|
|
}
|
|
|
return count, nil
|
|
|
}
|
|
|
|
|
|
// GetRateLimit 获取当前限流计数
|
|
|
func GetRateLimit(identifier, path string) (int64, error) {
|
|
|
key := RateLimitKey(identifier, path)
|
|
|
count, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return 0, nil
|
|
|
}
|
|
|
return 0, err
|
|
|
}
|
|
|
var result int64
|
|
|
fmt.Sscanf(count, "%d", &result)
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 验证码缓存 ====================
|
|
|
|
|
|
// CaptchaKey 验证码key
|
|
|
func CaptchaKey(captchaID string) string {
|
|
|
return fmt.Sprintf("captcha:%s", captchaID)
|
|
|
}
|
|
|
|
|
|
// SetCaptcha 存储验证码
|
|
|
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
|
|
|
key := CaptchaKey(captchaID)
|
|
|
return Set(key, code, expiration)
|
|
|
}
|
|
|
|
|
|
// VerifyCaptcha 验证验证码(验证后删除)
|
|
|
func VerifyCaptcha(captchaID, code string) (bool, error) {
|
|
|
key := CaptchaKey(captchaID)
|
|
|
storedCode, err := Get(key)
|
|
|
if err != nil {
|
|
|
if err == redis.Nil {
|
|
|
return false, nil
|
|
|
}
|
|
|
return false, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
Delete(key)
|
|
|
return storedCode == code, nil
|
|
|
}
|
|
|
|
|
|
// ==================== OAuth缓存 ====================
|
|
|
|
|
|
// OAuthCodeKey OAuth授权码key
|
|
|
func OAuthCodeKey(code string) string {
|
|
|
return fmt.Sprintf("oauth:code:%s", code)
|
|
|
}
|
|
|
|
|
|
// OAuthStateKey OAuth state key
|
|
|
func OAuthStateKey(state string) string {
|
|
|
return fmt.Sprintf("oauth:state:%s", state)
|
|
|
}
|