You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/libs/cache/cache.go

381 lines
8.6 KiB
Go

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
2 weeks ago
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/veypi/vbase/cfg"
2 weeks ago
)
var (
Client *redis.Client
Ctx = context.Background()
)
// Init 初始化Redis连接
func Init() error {
if cfg.Config.Redis.Addr == "" || cfg.Config.Redis.Addr == "memory" {
2 weeks ago
return nil
}
Client = redis.NewClient(&redis.Options{
Addr: cfg.Config.Redis.Addr,
Password: cfg.Config.Redis.Password,
DB: cfg.Config.Redis.DB,
2 weeks ago
})
if err := Client.Ping(Ctx).Err(); err != nil {
return fmt.Errorf("failed to connect redis: %w", err)
}
return nil
}
// IsEnabled 是否启用缓存
func IsEnabled() bool {
return cfg.Config.Redis.Addr != "" && cfg.Config.Redis.Addr != "memory" && Client != nil
2 weeks ago
}
// Get 获取字符串值
func Get(key string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
}
return Client.Get(Ctx, key).Result()
}
// GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error {
if !IsEnabled() {
return fmt.Errorf("redis not enabled")
}
data, err := Client.Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
return Client.Set(Ctx, key, value, expiration).Err()
}
// SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return Client.Set(Ctx, key, data, expiration).Err()
}
// Delete 删除key
func Delete(keys ...string) error {
if !IsEnabled() {
return nil
}
return Client.Del(Ctx, keys...).Err()
}
// Exists 检查key是否存在
func Exists(keys ...string) (int64, error) {
if !IsEnabled() {
return 0, nil
}
return Client.Exists(Ctx, keys...).Result()
}
// Expire 设置过期时间
func Expire(key string, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
return Client.Expire(Ctx, key, expiration).Err()
}
// TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) {
if !IsEnabled() {
return 0, nil
}
return Client.TTL(Ctx, key).Result()
}
// Incr 自增
func Incr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.Incr(Ctx, key).Result()
}
// IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.IncrBy(Ctx, key, value).Result()
}
// Decr 自减
func Decr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.Decr(Ctx, key).Result()
}
// HSet 设置hash字段
func HSet(key string, values ...interface{}) error {
if !IsEnabled() {
return nil
}
return Client.HSet(Ctx, key, values...).Err()
}
// HGet 获取hash字段
func HGet(key, field string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
}
return Client.HGet(Ctx, key, field).Result()
}
// HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) {
if !IsEnabled() {
return nil, fmt.Errorf("redis not enabled")
}
return Client.HGetAll(Ctx, key).Result()
}
// HDel 删除hash字段
func HDel(key string, fields ...string) error {
if !IsEnabled() {
return nil
}
return Client.HDel(Ctx, key, fields...).Err()
}
// SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if !IsEnabled() {
return false, fmt.Errorf("redis not enabled")
}
return Client.SetNX(Ctx, key, value, expiration).Result()
}
// ==================== 权限缓存相关 ====================
// PermKey 生成权限缓存key
func PermKey(userID, orgID, resource, action string) string {
if orgID == "" {
return fmt.Sprintf("perm:%s:%s:%s", userID, resource, action)
}
return fmt.Sprintf("perm:%s:%s:%s:%s", userID, orgID, resource, action)
}
// SetPermission 缓存权限结果
func SetPermission(userID, orgID, resource, action string, allowed bool, expiration time.Duration) error {
key := PermKey(userID, orgID, resource, action)
value := "deny"
if allowed {
value = "allow"
}
return Set(key, value, expiration)
}
// GetPermission 获取缓存的权限结果
func GetPermission(userID, orgID, resource, action string) (allowed bool, cached bool, err error) {
key := PermKey(userID, orgID, resource, action)
value, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, false, nil
}
return false, false, err
}
return value == "allow", true, nil
}
// DeletePermission 删除权限缓存
func DeletePermission(userID, orgID, resource, action string) error {
key := PermKey(userID, orgID, resource, action)
return Delete(key)
}
// DeleteUserPermissions 删除用户的所有权限缓存
func DeleteUserPermissions(userID string) error {
if !IsEnabled() {
return nil
}
pattern := fmt.Sprintf("perm:%s:*", userID)
return deleteByPattern(pattern)
}
// DeleteOrgPermissions 删除组织的所有权限缓存
func DeleteOrgPermissions(orgID string) error {
if !IsEnabled() {
return nil
}
pattern := fmt.Sprintf("perm:*:%s:*", orgID)
return deleteByPattern(pattern)
}
// deleteByPattern 根据pattern删除key
func deleteByPattern(pattern string) error {
iter := Client.Scan(Ctx, 0, pattern, 0).Iterator()
var keys []string
for iter.Next(Ctx) {
keys = append(keys, iter.Val())
if len(keys) >= 100 {
if err := Delete(keys...); err != nil {
return err
}
keys = keys[:0]
}
}
if err := iter.Err(); err != nil {
return err
}
if len(keys) > 0 {
return Delete(keys...)
}
return nil
}
// ==================== 用户/组织缓存 ====================
// UserKey 用户缓存key
func UserKey(userID string) string {
return fmt.Sprintf("user:%s", userID)
}
// OrgKey 组织缓存key
func OrgKey(orgID string) string {
return fmt.Sprintf("org:%s", orgID)
}
// OrgMemberKey 组织成员缓存key
func OrgMemberKey(orgID, userID string) string {
return fmt.Sprintf("org:%s:member:%s", orgID, userID)
}
// ==================== Token黑名单 ====================
// TokenBlacklistKey Token黑名单key
func TokenBlacklistKey(jti string) string {
return fmt.Sprintf("token:revoked:%s", jti)
}
// BlacklistToken 将Token加入黑名单
func BlacklistToken(jti string, expiration time.Duration) error {
key := TokenBlacklistKey(jti)
return Set(key, "1", expiration)
}
// IsTokenBlacklisted 检查Token是否在黑名单中
func IsTokenBlacklisted(jti string) (bool, error) {
key := TokenBlacklistKey(jti)
_, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
return true, nil
}
// ==================== 限流缓存 ====================
// RateLimitKey 限流key
func RateLimitKey(identifier, path string) string {
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
}
// IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Incr(key)
if err != nil {
return 0, err
}
// 第一次设置过期时间
if count == 1 {
Expire(key, window)
}
return count, nil
}
// GetRateLimit 获取当前限流计数
func GetRateLimit(identifier, path string) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Get(key)
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
var result int64
fmt.Sscanf(count, "%d", &result)
return result, nil
}
// ==================== 验证码缓存 ====================
// CaptchaKey 验证码key
func CaptchaKey(captchaID string) string {
return fmt.Sprintf("captcha:%s", captchaID)
}
// SetCaptcha 存储验证码
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
key := CaptchaKey(captchaID)
return Set(key, code, expiration)
}
// VerifyCaptcha 验证验证码(验证后删除)
func VerifyCaptcha(captchaID, code string) (bool, error) {
key := CaptchaKey(captchaID)
storedCode, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
// 验证后删除
Delete(key)
return storedCode == code, nil
}
// ==================== OAuth缓存 ====================
// OAuthCodeKey OAuth授权码key
func OAuthCodeKey(code string) string {
return fmt.Sprintf("oauth:code:%s", code)
}
// OAuthStateKey OAuth state key
func OAuthStateKey(state string) string {
return fmt.Sprintf("oauth:state:%s", state)
}