You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/libs/cache/cache.go

381 lines
8.6 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/veypi/vbase/cfg"
)
var (
Client *redis.Client
Ctx = context.Background()
)
// Init 初始化Redis连接
func Init() error {
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return nil
}
Client = redis.NewClient(&redis.Options{
Addr: cfg.Global.Redis.Addr,
Password: cfg.Global.Redis.Password,
DB: cfg.Global.Redis.DB,
})
if err := Client.Ping(Ctx).Err(); err != nil {
return fmt.Errorf("failed to connect redis: %w", err)
}
return nil
}
// IsEnabled 是否启用缓存
func IsEnabled() bool {
return cfg.Global.Redis.Addr != "" && cfg.Global.Redis.Addr != "memory" && Client != nil
}
// Get 获取字符串值
func Get(key string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
}
return Client.Get(Ctx, key).Result()
}
// GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error {
if !IsEnabled() {
return fmt.Errorf("redis not enabled")
}
data, err := Client.Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
return Client.Set(Ctx, key, value, expiration).Err()
}
// SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return Client.Set(Ctx, key, data, expiration).Err()
}
// Delete 删除key
func Delete(keys ...string) error {
if !IsEnabled() {
return nil
}
return Client.Del(Ctx, keys...).Err()
}
// Exists 检查key是否存在
func Exists(keys ...string) (int64, error) {
if !IsEnabled() {
return 0, nil
}
return Client.Exists(Ctx, keys...).Result()
}
// Expire 设置过期时间
func Expire(key string, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
return Client.Expire(Ctx, key, expiration).Err()
}
// TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) {
if !IsEnabled() {
return 0, nil
}
return Client.TTL(Ctx, key).Result()
}
// Incr 自增
func Incr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.Incr(Ctx, key).Result()
}
// IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.IncrBy(Ctx, key, value).Result()
}
// Decr 自减
func Decr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
}
return Client.Decr(Ctx, key).Result()
}
// HSet 设置hash字段
func HSet(key string, values ...interface{}) error {
if !IsEnabled() {
return nil
}
return Client.HSet(Ctx, key, values...).Err()
}
// HGet 获取hash字段
func HGet(key, field string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
}
return Client.HGet(Ctx, key, field).Result()
}
// HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) {
if !IsEnabled() {
return nil, fmt.Errorf("redis not enabled")
}
return Client.HGetAll(Ctx, key).Result()
}
// HDel 删除hash字段
func HDel(key string, fields ...string) error {
if !IsEnabled() {
return nil
}
return Client.HDel(Ctx, key, fields...).Err()
}
// SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if !IsEnabled() {
return false, fmt.Errorf("redis not enabled")
}
return Client.SetNX(Ctx, key, value, expiration).Result()
}
// ==================== 权限缓存相关 ====================
// PermKey 生成权限缓存key
func PermKey(userID, orgID, resource, action string) string {
if orgID == "" {
return fmt.Sprintf("perm:%s:%s:%s", userID, resource, action)
}
return fmt.Sprintf("perm:%s:%s:%s:%s", userID, orgID, resource, action)
}
// SetPermission 缓存权限结果
func SetPermission(userID, orgID, resource, action string, allowed bool, expiration time.Duration) error {
key := PermKey(userID, orgID, resource, action)
value := "deny"
if allowed {
value = "allow"
}
return Set(key, value, expiration)
}
// GetPermission 获取缓存的权限结果
func GetPermission(userID, orgID, resource, action string) (allowed bool, cached bool, err error) {
key := PermKey(userID, orgID, resource, action)
value, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, false, nil
}
return false, false, err
}
return value == "allow", true, nil
}
// DeletePermission 删除权限缓存
func DeletePermission(userID, orgID, resource, action string) error {
key := PermKey(userID, orgID, resource, action)
return Delete(key)
}
// DeleteUserPermissions 删除用户的所有权限缓存
func DeleteUserPermissions(userID string) error {
if !IsEnabled() {
return nil
}
pattern := fmt.Sprintf("perm:%s:*", userID)
return deleteByPattern(pattern)
}
// DeleteOrgPermissions 删除组织的所有权限缓存
func DeleteOrgPermissions(orgID string) error {
if !IsEnabled() {
return nil
}
pattern := fmt.Sprintf("perm:*:%s:*", orgID)
return deleteByPattern(pattern)
}
// deleteByPattern 根据pattern删除key
func deleteByPattern(pattern string) error {
iter := Client.Scan(Ctx, 0, pattern, 0).Iterator()
var keys []string
for iter.Next(Ctx) {
keys = append(keys, iter.Val())
if len(keys) >= 100 {
if err := Delete(keys...); err != nil {
return err
}
keys = keys[:0]
}
}
if err := iter.Err(); err != nil {
return err
}
if len(keys) > 0 {
return Delete(keys...)
}
return nil
}
// ==================== 用户/组织缓存 ====================
// UserKey 用户缓存key
func UserKey(userID string) string {
return fmt.Sprintf("user:%s", userID)
}
// OrgKey 组织缓存key
func OrgKey(orgID string) string {
return fmt.Sprintf("org:%s", orgID)
}
// OrgMemberKey 组织成员缓存key
func OrgMemberKey(orgID, userID string) string {
return fmt.Sprintf("org:%s:member:%s", orgID, userID)
}
// ==================== Token黑名单 ====================
// TokenBlacklistKey Token黑名单key
func TokenBlacklistKey(jti string) string {
return fmt.Sprintf("token:revoked:%s", jti)
}
// BlacklistToken 将Token加入黑名单
func BlacklistToken(jti string, expiration time.Duration) error {
key := TokenBlacklistKey(jti)
return Set(key, "1", expiration)
}
// IsTokenBlacklisted 检查Token是否在黑名单中
func IsTokenBlacklisted(jti string) (bool, error) {
key := TokenBlacklistKey(jti)
_, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
return true, nil
}
// ==================== 限流缓存 ====================
// RateLimitKey 限流key
func RateLimitKey(identifier, path string) string {
return fmt.Sprintf("ratelimit:%s:%s", identifier, path)
}
// IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Incr(key)
if err != nil {
return 0, err
}
// 第一次设置过期时间
if count == 1 {
Expire(key, window)
}
return count, nil
}
// GetRateLimit 获取当前限流计数
func GetRateLimit(identifier, path string) (int64, error) {
key := RateLimitKey(identifier, path)
count, err := Get(key)
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
var result int64
fmt.Sscanf(count, "%d", &result)
return result, nil
}
// ==================== 验证码缓存 ====================
// CaptchaKey 验证码key
func CaptchaKey(captchaID string) string {
return fmt.Sprintf("captcha:%s", captchaID)
}
// SetCaptcha 存储验证码
func SetCaptcha(captchaID, code string, expiration time.Duration) error {
key := CaptchaKey(captchaID)
return Set(key, code, expiration)
}
// VerifyCaptcha 验证验证码(验证后删除)
func VerifyCaptcha(captchaID, code string) (bool, error) {
key := CaptchaKey(captchaID)
storedCode, err := Get(key)
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
// 验证后删除
Delete(key)
return storedCode == code, nil
}
// ==================== OAuth缓存 ====================
// OAuthCodeKey OAuth授权码key
func OAuthCodeKey(code string) string {
return fmt.Sprintf("oauth:code:%s", code)
}
// OAuthStateKey OAuth state key
func OAuthStateKey(state string) string {
return fmt.Sprintf("oauth:state:%s", state)
}