feat(auth): Improve OAuth callback and add memory cache fallback

- Add random avatar generation for new users in register
    - Fix OAuth state parsing with type assertions and error handling
    - Add TempToken field to CallbackResponse for bind flow
    - Implement dynamic redirect URI resolution with X-Forwarded headers support
    - Add memory cache fallback when Redis is unavailable
    - Change default port from 4001 to 4000 in Makefile
master
veypi 3 weeks ago
parent 627439bc4d
commit 46f01afc9f

@ -6,4 +6,4 @@
#
run:
@go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4001
@go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4000

@ -7,8 +7,11 @@
package auth
import (
"fmt"
"math/rand"
"regexp"
"strings"
"time"
baseauth "github.com/veypi/vbase/auth"
"github.com/veypi/vbase/cfg"
@ -135,6 +138,9 @@ func register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) {
user.Nickname = user.Username
}
// 生成随机头像
user.Avatar = fmt.Sprintf("https://public.veypi.com/img/avatar/%04d.jpg", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(220))
if err := cfg.DB().Create(user).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}

@ -96,7 +96,7 @@ func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse,
}
// 构建授权URL
authURL, err := buildAuthURL(provider, state)
authURL, err := buildAuthURL(x, provider, state)
if err != nil {
return nil, err
}
@ -126,6 +126,7 @@ type CallbackResponse struct {
Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
TempToken string `json:"temp_token,omitempty"`
// 登录成功
*AuthResponse
@ -147,18 +148,28 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
}
provider := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool)
provider, ok := stateData["provider"].(string)
if !ok {
return nil, vigo.ErrInvalidArg.WithString("invalid provider in state")
}
bindMode := false
if bm, ok := stateData["bind_mode"].(bool); ok {
bindMode = bm
}
// 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
userInfo, err := exchangeAndGetUserInfo(x, provider, req.Code)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定模式:将第三方身份绑定到当前用户
if bindMode {
userID := stateData["user_id"].(string)
userID, ok := stateData["user_id"].(string)
if !ok {
return nil, vigo.ErrInvalidArg.WithString("invalid user_id in state")
}
if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err
}
@ -190,6 +201,7 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err
NeedBind: true,
Provider: provider,
ProviderID: userInfo.ID,
TempToken: tempToken,
AuthURL: "/auth/bind?token=" + tempToken,
}, nil
}
@ -391,7 +403,7 @@ func verifyState(state string) (map[string]any, error) {
return data, nil
}
func buildAuthURL(providerCode, state string) (string, error) {
func buildAuthURL(x *vigo.X, providerCode, state string) (string, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
@ -407,7 +419,7 @@ func buildAuthURL(providerCode, state string) (string, error) {
params.Set("client_id", provider.ClientID)
}
params.Set("redirect_uri", getRedirectURI(providerCode))
params.Set("redirect_uri", resolveRedirectURI(x, provider))
params.Set("response_type", "code")
params.Set("state", state)
if len(provider.Scopes) > 0 {
@ -417,11 +429,25 @@ func buildAuthURL(providerCode, state string) (string, error) {
return provider.AuthURL + "?" + params.Encode(), nil
}
func getRedirectURI(provider string) string {
return "/auth/callback/" + provider
func resolveRedirectURI(x *vigo.X, provider models.OAuthProvider) string {
if provider.RedirectURI != "" {
return provider.RedirectURI
}
scheme := "http"
if x.Request.TLS != nil || x.Request.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) {
host := x.Request.Host
if x.Request.Header.Get("X-Forwarded-Host") != "" {
host = x.Request.Header.Get("X-Forwarded-Host")
}
return fmt.Sprintf("%s://%s/callback/%s", scheme, host, provider.Code)
}
func exchangeAndGetUserInfo(x *vigo.X, providerCode, code string) (*ThirdPartyUserInfo, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
@ -429,16 +455,13 @@ func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, err
}
// 统一处理 OAuth 流程
return exchangeGeneric(provider, code)
return exchangeGeneric(x, provider, code)
}
func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
func exchangeGeneric(x *vigo.X, provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
// 1. 交换 code 获取 access_token
tokenURL := provider.TokenURL
redirectURI := provider.RedirectURI
if redirectURI == "" {
redirectURI = getRedirectURI(provider.Code)
}
redirectURI := resolveRedirectURI(x, provider)
// 解密 ClientSecret
clientSecret := provider.ClientSecret
@ -553,7 +576,14 @@ func extractField(data map[string]any, path, defaultKey string) string {
// postFormJSON 发送 POST form 请求并解析 JSON 响应
func postFormJSON(url string, data url.Values, result any) error {
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
req, err := http.NewRequest("POST", url, strings.NewReader(data.Encode()))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
@ -564,7 +594,13 @@ func postFormJSON(url string, data url.Values, result any) error {
// getJSON 发送 GET 请求并解析 JSON 响应
func getJSON(url string, result any) error {
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}

225
libs/cache/cache.go vendored

@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/redis/go-redis/v9"
@ -17,168 +18,248 @@ import (
)
var (
Client *redis.Client
Ctx = context.Background()
)
// Init 初始化Redis连接
func Init() error {
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return nil
}
Client = redis.NewClient(&redis.Options{
Addr: cfg.Global.Redis.Addr,
Password: cfg.Global.Redis.Password,
DB: cfg.Global.Redis.DB,
})
if err := Client.Ping(Ctx).Err(); err != nil {
return fmt.Errorf("failed to connect redis: %w", err)
}
// Memory store fallback
memStore sync.Map
)
return nil
type memItem struct {
Value []byte
Expiration int64
}
// IsEnabled 是否启用缓存
// IsEnabled 是否启用Redis缓存
func IsEnabled() bool {
return cfg.Global.Redis.Addr != "" && cfg.Global.Redis.Addr != "memory" && Client != nil
if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" {
return false
}
return cfg.Redis() != nil
}
// Get 获取字符串值
func Get(key string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().Get(Ctx, key).Result()
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return "", redis.Nil
}
return Client.Get(Ctx, key).Result()
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return "", redis.Nil
}
return string(item.Value), nil
}
// GetObject 获取并反序列化对象
func GetObject(key string, dest interface{}) error {
if !IsEnabled() {
return fmt.Errorf("redis not enabled")
}
data, err := Client.Get(Ctx, key).Bytes()
if IsEnabled() {
data, err := cfg.Redis().Get(Ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Memory fallback
val, ok := memStore.Load(key)
if !ok {
return redis.Nil
}
item := val.(memItem)
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
memStore.Delete(key)
return redis.Nil
}
return json.Unmarshal(item.Value, dest)
}
// Set 设置字符串值
func Set(key string, value string, expiration time.Duration) error {
if !IsEnabled() {
return nil
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, value, expiration).Err()
}
return Client.Set(Ctx, key, value, expiration).Err()
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: []byte(value),
Expiration: exp,
})
return nil
}
// SetObject 序列化并设置对象
func SetObject(key string, value interface{}, expiration time.Duration) error {
if !IsEnabled() {
return nil
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return Client.Set(Ctx, key, data, expiration).Err()
if IsEnabled() {
return cfg.Redis().Set(Ctx, key, data, expiration).Err()
}
// Memory fallback
exp := int64(0)
if expiration > 0 {
exp = time.Now().Add(expiration).UnixNano()
}
memStore.Store(key, memItem{
Value: data,
Expiration: exp,
})
return nil
}
// Delete 删除key
func Delete(keys ...string) error {
if !IsEnabled() {
return nil
if IsEnabled() {
return cfg.Redis().Del(Ctx, keys...).Err()
}
return Client.Del(Ctx, keys...).Err()
// Memory fallback
for _, key := range keys {
memStore.Delete(key)
}
return nil
}
// Exists 检查key是否存在
func Exists(keys ...string) (int64, error) {
if !IsEnabled() {
return 0, nil
if IsEnabled() {
return cfg.Redis().Exists(Ctx, keys...).Result()
}
// Memory fallback
count := int64(0)
for _, key := range keys {
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration {
count++
} else {
memStore.Delete(key)
}
}
return Client.Exists(Ctx, keys...).Result()
}
return count, nil
}
// Expire 设置过期时间
func Expire(key string, expiration time.Duration) error {
if !IsEnabled() {
if IsEnabled() {
return cfg.Redis().Expire(Ctx, key, expiration).Err()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
item.Expiration = time.Now().Add(expiration).UnixNano()
memStore.Store(key, item)
return nil
}
return Client.Expire(Ctx, key, expiration).Err()
return nil // Key not found, ignore
}
// TTL 获取剩余过期时间
func TTL(key string) (time.Duration, error) {
if !IsEnabled() {
return 0, nil
if IsEnabled() {
return cfg.Redis().TTL(Ctx, key).Result()
}
// Memory fallback
if val, ok := memStore.Load(key); ok {
item := val.(memItem)
if item.Expiration == 0 {
return -1, nil // No expiration
}
return Client.TTL(Ctx, key).Result()
ttl := time.Duration(item.Expiration - time.Now().UnixNano())
if ttl <= 0 {
memStore.Delete(key)
return -2, nil // Expired
}
return ttl, nil
}
return -2, nil // Not found
}
// Incr 自增
func Incr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().Incr(Ctx, key).Result()
}
return Client.Incr(Ctx, key).Result()
// Memory fallback not implemented for counters
return 0, fmt.Errorf("memory cache: incr not implemented")
}
// IncrBy 增加指定值
func IncrBy(key string, value int64) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().IncrBy(Ctx, key, value).Result()
}
return Client.IncrBy(Ctx, key, value).Result()
return 0, fmt.Errorf("memory cache: incrby not implemented")
}
// Decr 自减
func Decr(key string) (int64, error) {
if !IsEnabled() {
return 0, fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().Decr(Ctx, key).Result()
}
return Client.Decr(Ctx, key).Result()
return 0, fmt.Errorf("memory cache: decr not implemented")
}
// HSet 设置hash字段
func HSet(key string, values ...interface{}) error {
if !IsEnabled() {
return nil
if IsEnabled() {
return cfg.Redis().HSet(Ctx, key, values...).Err()
}
return Client.HSet(Ctx, key, values...).Err()
return fmt.Errorf("memory cache: hset not implemented")
}
// HGet 获取hash字段
func HGet(key, field string) (string, error) {
if !IsEnabled() {
return "", fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().HGet(Ctx, key, field).Result()
}
return Client.HGet(Ctx, key, field).Result()
return "", fmt.Errorf("memory cache: hget not implemented")
}
// HGetAll 获取hash所有字段
func HGetAll(key string) (map[string]string, error) {
if !IsEnabled() {
return nil, fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().HGetAll(Ctx, key).Result()
}
return Client.HGetAll(Ctx, key).Result()
return nil, fmt.Errorf("memory cache: hgetall not implemented")
}
// HDel 删除hash字段
func HDel(key string, fields ...string) error {
if !IsEnabled() {
return nil
if IsEnabled() {
return cfg.Redis().HDel(Ctx, key, fields...).Err()
}
return Client.HDel(Ctx, key, fields...).Err()
return fmt.Errorf("memory cache: hdel not implemented")
}
// SetNX 仅当key不存在时才设置用于分布式锁
func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
if !IsEnabled() {
return false, fmt.Errorf("redis not enabled")
if IsEnabled() {
return cfg.Redis().SetNX(Ctx, key, value, expiration).Result()
}
return Client.SetNX(Ctx, key, value, expiration).Result()
// Simple memory implementation
if _, ok := memStore.Load(key); ok {
return false, nil
}
SetObject(key, value, expiration)
return true, nil
}
// ==================== 用户缓存 ====================
@ -224,6 +305,12 @@ func RateLimitKey(identifier, path string) string {
// IncrRateLimit 增加限流计数
func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) {
key := RateLimitKey(identifier, path)
// For memory cache, we need a special implementation or just skip rate limiting
if !IsEnabled() {
// Simple bypass for memory mode
return 1, nil
}
count, err := Incr(key)
if err != nil {
return 0, err

Loading…
Cancel
Save