diff --git a/Makefile b/Makefile index 9e1300b..ad0acbe 100644 --- a/Makefile +++ b/Makefile @@ -6,4 +6,4 @@ # run: - @go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4001 + @go run ./cli/*.go -f ./cfg/dev.yml -l debug -p 4000 diff --git a/api/auth/register.go b/api/auth/register.go index 20cd79e..820b03d 100644 --- a/api/auth/register.go +++ b/api/auth/register.go @@ -7,8 +7,11 @@ package auth import ( + "fmt" + "math/rand" "regexp" "strings" + "time" baseauth "github.com/veypi/vbase/auth" "github.com/veypi/vbase/cfg" @@ -135,6 +138,9 @@ func register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) { user.Nickname = user.Username } + // 生成随机头像 + user.Avatar = fmt.Sprintf("https://public.veypi.com/img/avatar/%04d.jpg", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(220)) + if err := cfg.DB().Create(user).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } diff --git a/api/auth/thirdparty.go b/api/auth/thirdparty.go index a410502..5f819a2 100644 --- a/api/auth/thirdparty.go +++ b/api/auth/thirdparty.go @@ -96,7 +96,7 @@ func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, } // 构建授权URL - authURL, err := buildAuthURL(provider, state) + authURL, err := buildAuthURL(x, provider, state) if err != nil { return nil, err } @@ -126,6 +126,7 @@ type CallbackResponse struct { Provider string `json:"provider,omitempty"` ProviderID string `json:"provider_id,omitempty"` AuthURL string `json:"auth_url,omitempty"` + TempToken string `json:"temp_token,omitempty"` // 登录成功 *AuthResponse @@ -147,18 +148,28 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err return nil, vigo.ErrInvalidArg.WithString("invalid or expired state") } - provider := stateData["provider"].(string) - bindMode := stateData["bind_mode"].(bool) + provider, ok := stateData["provider"].(string) + if !ok { + return nil, vigo.ErrInvalidArg.WithString("invalid provider in state") + } + + bindMode := false + if bm, ok := stateData["bind_mode"].(bool); ok { + bindMode = bm + } // 交换access_token并获取用户信息 - userInfo, err := exchangeAndGetUserInfo(provider, req.Code) + userInfo, err := exchangeAndGetUserInfo(x, provider, req.Code) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 绑定模式:将第三方身份绑定到当前用户 if bindMode { - userID := stateData["user_id"].(string) + userID, ok := stateData["user_id"].(string) + if !ok { + return nil, vigo.ErrInvalidArg.WithString("invalid user_id in state") + } if err := bindIdentity(userID, provider, userInfo); err != nil { return nil, err } @@ -190,6 +201,7 @@ func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, err NeedBind: true, Provider: provider, ProviderID: userInfo.ID, + TempToken: tempToken, AuthURL: "/auth/bind?token=" + tempToken, }, nil } @@ -391,7 +403,7 @@ func verifyState(state string) (map[string]any, error) { return data, nil } -func buildAuthURL(providerCode, state string) (string, error) { +func buildAuthURL(x *vigo.X, providerCode, state string) (string, error) { // 从数据库获取提供商配置 var provider models.OAuthProvider if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { @@ -407,7 +419,7 @@ func buildAuthURL(providerCode, state string) (string, error) { params.Set("client_id", provider.ClientID) } - params.Set("redirect_uri", getRedirectURI(providerCode)) + params.Set("redirect_uri", resolveRedirectURI(x, provider)) params.Set("response_type", "code") params.Set("state", state) if len(provider.Scopes) > 0 { @@ -417,11 +429,25 @@ func buildAuthURL(providerCode, state string) (string, error) { return provider.AuthURL + "?" + params.Encode(), nil } -func getRedirectURI(provider string) string { - return "/auth/callback/" + provider +func resolveRedirectURI(x *vigo.X, provider models.OAuthProvider) string { + if provider.RedirectURI != "" { + return provider.RedirectURI + } + + scheme := "http" + if x.Request.TLS != nil || x.Request.Header.Get("X-Forwarded-Proto") == "https" { + scheme = "https" + } + + host := x.Request.Host + if x.Request.Header.Get("X-Forwarded-Host") != "" { + host = x.Request.Header.Get("X-Forwarded-Host") + } + + return fmt.Sprintf("%s://%s/callback/%s", scheme, host, provider.Code) } -func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) { +func exchangeAndGetUserInfo(x *vigo.X, providerCode, code string) (*ThirdPartyUserInfo, error) { // 从数据库获取提供商配置 var provider models.OAuthProvider if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { @@ -429,16 +455,13 @@ func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, err } // 统一处理 OAuth 流程 - return exchangeGeneric(provider, code) + return exchangeGeneric(x, provider, code) } -func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) { +func exchangeGeneric(x *vigo.X, provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) { // 1. 交换 code 获取 access_token tokenURL := provider.TokenURL - redirectURI := provider.RedirectURI - if redirectURI == "" { - redirectURI = getRedirectURI(provider.Code) - } + redirectURI := resolveRedirectURI(x, provider) // 解密 ClientSecret clientSecret := provider.ClientSecret @@ -553,7 +576,14 @@ func extractField(data map[string]any, path, defaultKey string) string { // postFormJSON 发送 POST form 请求并解析 JSON 响应 func postFormJSON(url string, data url.Values, result any) error { - resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) + req, err := http.NewRequest("POST", url, strings.NewReader(data.Encode())) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) if err != nil { return err } @@ -564,7 +594,13 @@ func postFormJSON(url string, data url.Values, result any) error { // getJSON 发送 GET 请求并解析 JSON 响应 func getJSON(url string, result any) error { - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) if err != nil { return err } diff --git a/libs/cache/cache.go b/libs/cache/cache.go index 93d1551..e8ec43e 100644 --- a/libs/cache/cache.go +++ b/libs/cache/cache.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "time" "github.com/redis/go-redis/v9" @@ -17,168 +18,248 @@ import ( ) var ( - Client *redis.Client - Ctx = context.Background() -) - -// Init 初始化Redis连接 -func Init() error { - if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" { - return nil - } + Ctx = context.Background() - Client = redis.NewClient(&redis.Options{ - Addr: cfg.Global.Redis.Addr, - Password: cfg.Global.Redis.Password, - DB: cfg.Global.Redis.DB, - }) - - if err := Client.Ping(Ctx).Err(); err != nil { - return fmt.Errorf("failed to connect redis: %w", err) - } + // Memory store fallback + memStore sync.Map +) - return nil +type memItem struct { + Value []byte + Expiration int64 } -// IsEnabled 是否启用缓存 +// IsEnabled 是否启用Redis缓存 func IsEnabled() bool { - return cfg.Global.Redis.Addr != "" && cfg.Global.Redis.Addr != "memory" && Client != nil + if cfg.Global.Redis.Addr == "" || cfg.Global.Redis.Addr == "memory" { + return false + } + return cfg.Redis() != nil } // Get 获取字符串值 func Get(key string) (string, error) { - if !IsEnabled() { - return "", fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().Get(Ctx, key).Result() } - return Client.Get(Ctx, key).Result() + + // Memory fallback + val, ok := memStore.Load(key) + if !ok { + return "", redis.Nil + } + item := val.(memItem) + if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { + memStore.Delete(key) + return "", redis.Nil + } + return string(item.Value), nil } // GetObject 获取并反序列化对象 func GetObject(key string, dest interface{}) error { - if !IsEnabled() { - return fmt.Errorf("redis not enabled") + if IsEnabled() { + data, err := cfg.Redis().Get(Ctx, key).Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, dest) } - data, err := Client.Get(Ctx, key).Bytes() - if err != nil { - return err + + // Memory fallback + val, ok := memStore.Load(key) + if !ok { + return redis.Nil } - return json.Unmarshal(data, dest) + item := val.(memItem) + if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { + memStore.Delete(key) + return redis.Nil + } + return json.Unmarshal(item.Value, dest) } // Set 设置字符串值 func Set(key string, value string, expiration time.Duration) error { - if !IsEnabled() { - return nil + if IsEnabled() { + return cfg.Redis().Set(Ctx, key, value, expiration).Err() + } + + // Memory fallback + exp := int64(0) + if expiration > 0 { + exp = time.Now().Add(expiration).UnixNano() } - return Client.Set(Ctx, key, value, expiration).Err() + memStore.Store(key, memItem{ + Value: []byte(value), + Expiration: exp, + }) + return nil } // SetObject 序列化并设置对象 func SetObject(key string, value interface{}, expiration time.Duration) error { - if !IsEnabled() { - return nil - } data, err := json.Marshal(value) if err != nil { return err } - return Client.Set(Ctx, key, data, expiration).Err() + + if IsEnabled() { + return cfg.Redis().Set(Ctx, key, data, expiration).Err() + } + + // Memory fallback + exp := int64(0) + if expiration > 0 { + exp = time.Now().Add(expiration).UnixNano() + } + memStore.Store(key, memItem{ + Value: data, + Expiration: exp, + }) + return nil } // Delete 删除key func Delete(keys ...string) error { - if !IsEnabled() { - return nil + if IsEnabled() { + return cfg.Redis().Del(Ctx, keys...).Err() } - return Client.Del(Ctx, keys...).Err() + + // Memory fallback + for _, key := range keys { + memStore.Delete(key) + } + return nil } // Exists 检查key是否存在 func Exists(keys ...string) (int64, error) { - if !IsEnabled() { - return 0, nil + if IsEnabled() { + return cfg.Redis().Exists(Ctx, keys...).Result() + } + + // Memory fallback + count := int64(0) + for _, key := range keys { + if val, ok := memStore.Load(key); ok { + item := val.(memItem) + if item.Expiration == 0 || time.Now().UnixNano() <= item.Expiration { + count++ + } else { + memStore.Delete(key) + } + } } - return Client.Exists(Ctx, keys...).Result() + return count, nil } // Expire 设置过期时间 func Expire(key string, expiration time.Duration) error { - if !IsEnabled() { + if IsEnabled() { + return cfg.Redis().Expire(Ctx, key, expiration).Err() + } + + // Memory fallback + if val, ok := memStore.Load(key); ok { + item := val.(memItem) + item.Expiration = time.Now().Add(expiration).UnixNano() + memStore.Store(key, item) return nil } - return Client.Expire(Ctx, key, expiration).Err() + return nil // Key not found, ignore } // TTL 获取剩余过期时间 func TTL(key string) (time.Duration, error) { - if !IsEnabled() { - return 0, nil + if IsEnabled() { + return cfg.Redis().TTL(Ctx, key).Result() } - return Client.TTL(Ctx, key).Result() + + // Memory fallback + if val, ok := memStore.Load(key); ok { + item := val.(memItem) + if item.Expiration == 0 { + return -1, nil // No expiration + } + ttl := time.Duration(item.Expiration - time.Now().UnixNano()) + if ttl <= 0 { + memStore.Delete(key) + return -2, nil // Expired + } + return ttl, nil + } + return -2, nil // Not found } // Incr 自增 func Incr(key string) (int64, error) { - if !IsEnabled() { - return 0, fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().Incr(Ctx, key).Result() } - return Client.Incr(Ctx, key).Result() + // Memory fallback not implemented for counters + return 0, fmt.Errorf("memory cache: incr not implemented") } // IncrBy 增加指定值 func IncrBy(key string, value int64) (int64, error) { - if !IsEnabled() { - return 0, fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().IncrBy(Ctx, key, value).Result() } - return Client.IncrBy(Ctx, key, value).Result() + return 0, fmt.Errorf("memory cache: incrby not implemented") } // Decr 自减 func Decr(key string) (int64, error) { - if !IsEnabled() { - return 0, fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().Decr(Ctx, key).Result() } - return Client.Decr(Ctx, key).Result() + return 0, fmt.Errorf("memory cache: decr not implemented") } // HSet 设置hash字段 func HSet(key string, values ...interface{}) error { - if !IsEnabled() { - return nil + if IsEnabled() { + return cfg.Redis().HSet(Ctx, key, values...).Err() } - return Client.HSet(Ctx, key, values...).Err() + return fmt.Errorf("memory cache: hset not implemented") } // HGet 获取hash字段 func HGet(key, field string) (string, error) { - if !IsEnabled() { - return "", fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().HGet(Ctx, key, field).Result() } - return Client.HGet(Ctx, key, field).Result() + return "", fmt.Errorf("memory cache: hget not implemented") } // HGetAll 获取hash所有字段 func HGetAll(key string) (map[string]string, error) { - if !IsEnabled() { - return nil, fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().HGetAll(Ctx, key).Result() } - return Client.HGetAll(Ctx, key).Result() + return nil, fmt.Errorf("memory cache: hgetall not implemented") } // HDel 删除hash字段 func HDel(key string, fields ...string) error { - if !IsEnabled() { - return nil + if IsEnabled() { + return cfg.Redis().HDel(Ctx, key, fields...).Err() } - return Client.HDel(Ctx, key, fields...).Err() + return fmt.Errorf("memory cache: hdel not implemented") } // SetNX 仅当key不存在时才设置(用于分布式锁) func SetNX(key string, value interface{}, expiration time.Duration) (bool, error) { - if !IsEnabled() { - return false, fmt.Errorf("redis not enabled") + if IsEnabled() { + return cfg.Redis().SetNX(Ctx, key, value, expiration).Result() + } + // Simple memory implementation + if _, ok := memStore.Load(key); ok { + return false, nil } - return Client.SetNX(Ctx, key, value, expiration).Result() + SetObject(key, value, expiration) + return true, nil } // ==================== 用户缓存 ==================== @@ -224,6 +305,12 @@ func RateLimitKey(identifier, path string) string { // IncrRateLimit 增加限流计数 func IncrRateLimit(identifier, path string, window time.Duration) (int64, error) { key := RateLimitKey(identifier, path) + // For memory cache, we need a special implementation or just skip rate limiting + if !IsEnabled() { + // Simple bypass for memory mode + return 1, nil + } + count, err := Incr(key) if err != nil { return 0, err