feat(auth): Improve OAuth callback and add memory cache fallback

- Add random avatar generation for new users in register
    - Fix OAuth state parsing with type assertions and error handling
    - Add TempToken field to CallbackResponse for bind flow
    - Implement dynamic redirect URI resolution with X-Forwarded headers support
    - Add memory cache fallback when Redis is unavailable
    - Change default port from 4001 to 4000 in Makefile
master
veypi 3 weeks ago
parent 627439bc4d
commit 46f01afc9f

@ -6,4 +6,4 @@
# #
run: run:
@go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4001 @go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4000

@ -7,8 +7,11 @@
package auth package auth
import ( import (
"fmt"
"math/rand"
"regexp" "regexp"
"strings" "strings"
"time"
baseauth "github.com/veypi/vbase/auth" baseauth "github.com/veypi/vbase/auth"
"github.com/veypi/vbase/cfg" "github.com/veypi/vbase/cfg"
@ -135,6 +138,9 @@ func register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) {
user.Nickname = user.Username user.Nickname = user.Username
} }
// 生成随机头像
user.Avatar = fmt.Sprintf("https://public.veypi.com/img/avatar/%04d.jpg", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(220))
if err := cfg.DB().Create(user).Error; err != nil { if err := cfg.DB().Create(user).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err) return nil, vigo.ErrInternalServer.WithError(err)
} }

@ -96,7 +96,7 @@ func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse,
} }
// 构建授权URL // 构建授权URL
authURL, err := buildAuthURL(provider, state) authURL, err := buildAuthURL(x, provider, state)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -126,6 +126,7 @@ type CallbackResponse struct {
Provider string `json:"provider,omitempty"` Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"` ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"` AuthURL string `json:"auth_url,omitempty"`
TempToken string `json:"temp_token,omitempty"`
// 登录成功 // 登录成功
*AuthResponse *AuthResponse
@ -147,18 +148,28 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state") return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
} }
provider := stateData["provider"].(string) provider, ok := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool) if !ok {
return nil, vigo.ErrInvalidArg.WithString("invalid provider in state")
}
bindMode := false
if bm, ok := stateData["bind_mode"].(bool); ok {
bindMode = bm
}
// 交换access_token并获取用户信息 // 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code) userInfo, err := exchangeAndGetUserInfo(x, provider, req.Code)
if err != nil { if err != nil {
return nil, vigo.ErrInternalServer.WithError(err) return nil, vigo.ErrInternalServer.WithError(err)
} }
// 绑定模式:将第三方身份绑定到当前用户 // 绑定模式:将第三方身份绑定到当前用户
if bindMode { if bindMode {
userID := stateData["user_id"].(string) userID, ok := stateData["user_id"].(string)
if !ok {
return nil, vigo.ErrInvalidArg.WithString("invalid user_id in state")
}
if err := bindIdentity(userID, provider, userInfo); err != nil { if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err return nil, err
} }
@ -190,6 +201,7 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err
NeedBind: true, NeedBind: true,
Provider: provider, Provider: provider,
ProviderID: userInfo.ID, ProviderID: userInfo.ID,
TempToken: tempToken,
AuthURL: "/auth/bind?token=" + tempToken, AuthURL: "/auth/bind?token=" + tempToken,
}, nil }, nil
} }
@ -391,7 +403,7 @@ func verifyState(state string) (map[string]any, error) {
return data, nil return data, nil
} }
func buildAuthURL(providerCode, state string) (string, error) { func buildAuthURL(x *vigo.X, providerCode, state string) (string, error) {
// 从数据库获取提供商配置 // 从数据库获取提供商配置
var provider models.OAuthProvider var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
@ -407,7 +419,7 @@ func buildAuthURL(providerCode, state string) (string, error) {
params.Set("client_id", provider.ClientID) params.Set("client_id", provider.ClientID)
} }
params.Set("redirect_uri", getRedirectURI(providerCode)) params.Set("redirect_uri", resolveRedirectURI(x, provider))
params.Set("response_type", "code") params.Set("response_type", "code")
params.Set("state", state) params.Set("state", state)
if len(provider.Scopes) > 0 { if len(provider.Scopes) > 0 {
@ -417,11 +429,25 @@ func buildAuthURL(providerCode, state string) (string, error) {
return provider.AuthURL + "?" + params.Encode(), nil return provider.AuthURL + "?" + params.Encode(), nil
} }
func getRedirectURI(provider string) string { func resolveRedirectURI(x *vigo.X, provider models.OAuthProvider) string {
return "/auth/callback/" + provider if provider.RedirectURI != "" {
return provider.RedirectURI
}
scheme := "http"
if x.Request.TLS != nil || x.Request.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
host := x.Request.Host
if x.Request.Header.Get("X-Forwarded-Host") != "" {
host = x.Request.Header.Get("X-Forwarded-Host")
}
return fmt.Sprintf("%s://%s/callback/%s", scheme, host, provider.Code)
} }
func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) { func exchangeAndGetUserInfo(x *vigo.X, providerCode, code string) (*ThirdPartyUserInfo, error) {
// 从数据库获取提供商配置 // 从数据库获取提供商配置
var provider models.OAuthProvider var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
@ -429,16 +455,13 @@ func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, err
} }
// 统一处理 OAuth 流程 // 统一处理 OAuth 流程
return exchangeGeneric(provider, code) return exchangeGeneric(x, provider, code)
} }
func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) { func exchangeGeneric(x *vigo.X, provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
// 1. 交换 code 获取 access_token // 1. 交换 code 获取 access_token
tokenURL := provider.TokenURL tokenURL := provider.TokenURL
redirectURI := provider.RedirectURI redirectURI := resolveRedirectURI(x, provider)
if redirectURI == "" {
redirectURI = getRedirectURI(provider.Code)
}
// 解密 ClientSecret // 解密 ClientSecret
clientSecret := provider.ClientSecret clientSecret := provider.ClientSecret
@ -553,7 +576,14 @@ func extractField(data map[string]any, path, defaultKey string) string {
// postFormJSON 发送 POST form 请求并解析 JSON 响应 // postFormJSON 发送 POST form 请求并解析 JSON 响应
func postFormJSON(url string, data url.Values, result any) error { func postFormJSON(url string, data url.Values, result any) error {
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) req, err := http.NewRequest("POST", url, strings.NewReader(data.Encode()))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
@ -564,7 +594,13 @@ func postFormJSON(url string, data url.Values, result any) error {
// getJSON 发送 GET 请求并解析 JSON 响应 // getJSON 发送 GET 请求并解析 JSON 响应
func getJSON(url string, result any) error { func getJSON(url string, result any) error {
resp, err := http.Get(url) req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }

231
libs/cache/cache.go vendored

@ -10,6 +10,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@ -17,168 +18,248 @@ import (
) )
var ( var (
Client *redis.Client Ctx = context.Background()
Ctx = context.Background()
)
// Init 初始化Redis连接
func Init() error {
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return nil
}
Client = redis.NewClient(&redis.Options{ // Memory store fallback
Addr: cfg.Global.Redis.Addr, memStore sync.Map
Password: cfg.Global.Redis.Password, )
DB: cfg.Global.Redis.DB,
})
if err := Client.Ping(Ctx).Err(); err != nil {
return fmt.Errorf("failed to connect redis: %w", err)
}
return nil type memItem struct {
Value []byte
Expiration int64
} }
// IsEnabled 是否启用缓存 // IsEnabled 是否启用Redis缓存
func IsEnabled() bool { func IsEnabled() bool {
return cfg.Global.Redis.Addr != "" && cfg.Global.Redis.Addr != "memory" && Client != nil if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return false
}
return cfg.Redis() != nil
} }
// Get 获取字符串值 // Get 获取字符串值
func Get(key string) (string, error) { func Get(key string) (string, error) {
if !IsEnabled() { if IsEnabled() {
return "", fmt.Errorf("redis not enabled") return cfg.Redis().Get(Ctx, key).Result()
} }
return Client.Get(Ctx, key).Result()
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return "", redis.Nil
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return "", redis.Nil
}
return string(item.Value), nil
} }
// GetObject 获取并反序列化对象 // GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error { func GetObject(key string, dest interface{}) error {
if !IsEnabled() { if IsEnabled() {
return fmt.Errorf("redis not enabled") data, err := cfg.Redis().Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
} }
data, err := Client.Get(Ctx, key).Bytes()
if err != nil { // Memory fallback
return err val, ok := memStore.Load(key)
if !ok {
return redis.Nil
} }
return json.Unmarshal(data, dest) item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return redis.Nil
}
return json.Unmarshal(item.Value, dest)
} }
// Set 设置字符串值 // Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error { func Set(key string, value string, expiration time.Duration) error {
if !IsEnabled() { if IsEnabled() {
return nil return cfg.Redis().Set(Ctx, key, value, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
} }
return Client.Set(Ctx, key, value, expiration).Err() memStore.Store(key, memItem{
Value: []byte(value),
Expiration: exp,
})
return nil
} }
// SetObject 序列化并设置对象 // SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error { func SetObject(key string, value interface{}, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
data, err := json.Marshal(value) data, err := json.Marshal(value)
if err != nil { if err != nil {
return err return err
} }
return Client.Set(Ctx, key, data, expiration).Err()
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, data, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: data,
Expiration: exp,
})
return nil
} }
// Delete 删除key // Delete 删除key
func Delete(keys ...string) error { func Delete(keys ...string) error {
if !IsEnabled() { if IsEnabled() {
return nil return cfg.Redis().Del(Ctx, keys...).Err()
} }
return Client.Del(Ctx, keys...).Err()
// Memory fallback
for _, key := range keys {
memStore.Delete(key)
}
return nil
} }
// Exists 检查key是否存在 // Exists 检查key是否存在
func Exists(keys ...string) (int64, error) { func Exists(keys ...string) (int64, error) {
if !IsEnabled() { if IsEnabled() {
return 0, nil return cfg.Redis().Exists(Ctx, keys...).Result()
}
// Memory fallback
count := int64(0)
for _, key := range keys {
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration {
count++
} else {
memStore.Delete(key)
}
}
} }
return Client.Exists(Ctx, keys...).Result() return count, nil
} }
// Expire 设置过期时间 // Expire 设置过期时间
func Expire(key string, expiration time.Duration) error { func Expire(key string, expiration time.Duration) error {
if !IsEnabled() { if IsEnabled() {
return cfg.Redis().Expire(Ctx, key, expiration).Err()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
item.Expiration = time.Now().Add(expiration).UnixNano()
memStore.Store(key, item)
return nil return nil
} }
return Client.Expire(Ctx, key, expiration).Err() return nil // Key not found, ignore
} }
// TTL 获取剩余过期时间 // TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) { func TTL(key string) (time.Duration, error) {
if !IsEnabled() { if IsEnabled() {
return 0, nil return cfg.Redis().TTL(Ctx, key).Result()
} }
return Client.TTL(Ctx, key).Result()
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 {
return -1, nil // No expiration
}
ttl := time.Duration(item.Expiration - time.Now().UnixNano())
if ttl <= 0 {
memStore.Delete(key)
return -2, nil // Expired
}
return ttl, nil
}
return -2, nil // Not found
} }
// Incr 自增 // Incr 自增
func Incr(key string) (int64, error) { func Incr(key string) (int64, error) {
if !IsEnabled() { if IsEnabled() {
return 0, fmt.Errorf("redis not enabled") return cfg.Redis().Incr(Ctx, key).Result()
} }
return Client.Incr(Ctx, key).Result() // Memory fallback not implemented for counters
return 0, fmt.Errorf("memory cache: incr not implemented")
} }
// IncrBy 增加指定值 // IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) { func IncrBy(key string, value int64) (int64, error) {
if !IsEnabled() { if IsEnabled() {
return 0, fmt.Errorf("redis not enabled") return cfg.Redis().IncrBy(Ctx, key, value).Result()
} }
return Client.IncrBy(Ctx, key, value).Result() return 0, fmt.Errorf("memory cache: incrby not implemented")
} }
// Decr 自减 // Decr 自减
func Decr(key string) (int64, error) { func Decr(key string) (int64, error) {
if !IsEnabled() { if IsEnabled() {
return 0, fmt.Errorf("redis not enabled") return cfg.Redis().Decr(Ctx, key).Result()
} }
return Client.Decr(Ctx, key).Result() return 0, fmt.Errorf("memory cache: decr not implemented")
} }
// HSet 设置hash字段 // HSet 设置hash字段
func HSet(key string, values ...interface{}) error { func HSet(key string, values ...interface{}) error {
if !IsEnabled() { if IsEnabled() {
return nil return cfg.Redis().HSet(Ctx, key, values...).Err()
} }
return Client.HSet(Ctx, key, values...).Err() return fmt.Errorf("memory cache: hset not implemented")
} }
// HGet 获取hash字段 // HGet 获取hash字段
func HGet(key, field string) (string, error) { func HGet(key, field string) (string, error) {
if !IsEnabled() { if IsEnabled() {
return "", fmt.Errorf("redis not enabled") return cfg.Redis().HGet(Ctx, key, field).Result()
} }
return Client.HGet(Ctx, key, field).Result() return "", fmt.Errorf("memory cache: hget not implemented")
} }
// HGetAll 获取hash所有字段 // HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) { func HGetAll(key string) (map[string]string, error) {
if !IsEnabled() { if IsEnabled() {
return nil, fmt.Errorf("redis not enabled") return cfg.Redis().HGetAll(Ctx, key).Result()
} }
return Client.HGetAll(Ctx, key).Result() return nil, fmt.Errorf("memory cache: hgetall not implemented")
} }
// HDel 删除hash字段 // HDel 删除hash字段
func HDel(key string, fields ...string) error { func HDel(key string, fields ...string) error {
if !IsEnabled() { if IsEnabled() {
return nil return cfg.Redis().HDel(Ctx, key, fields...).Err()
} }
return Client.HDel(Ctx, key, fields...).Err() return fmt.Errorf("memory cache: hdel not implemented")
} }
// SetNX 仅当key不存在时才设置用于分布式锁 // SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) { func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if !IsEnabled() { if IsEnabled() {
return false, fmt.Errorf("redis not enabled") return cfg.Redis().SetNX(Ctx, key, value, expiration).Result()
}
// Simple memory implementation
if _, ok := memStore.Load(key); ok {
return false, nil
} }
return Client.SetNX(Ctx, key, value, expiration).Result() SetObject(key, value, expiration)
return true, nil
} }
// ==================== 用户缓存 ==================== // ==================== 用户缓存 ====================
@ -224,6 +305,12 @@ func RateLimitKey(identifier, path string) string {
// IncrRateLimit 增加限流计数 // IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) { func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path) key := RateLimitKey(identifier, path)
// For memory cache, we need a special implementation or just skip rate limiting
if !IsEnabled() {
// Simple bypass for memory mode
return 1, nil
}
count, err := Incr(key) count, err := Incr(key)
if err != nil { if err != nil {
return 0, err return 0, err

Loading…
Cancel
Save