package middleware import ( "fmt" "net/http" "time" "github.com/veypi/vbase/internal/cache" "github.com/veypi/vigo" ) // RateLimiter 限流中间件 func RateLimiter(maxRequests int, window time.Duration) func(*vigo.X) (any, error) { return func(x *vigo.X) (any, error) { if !cache.IsEnabled() { x.Next() return nil, nil } // 使用IP+路径作为标识 identifier := x.GetRemoteIP() path := x.Request.URL.Path count, err := cache.IncrRateLimit(identifier, path, window) if err != nil { x.Next() // 缓存失败时放行 return nil, nil } if count > int64(maxRequests) { x.ResponseWriter().Header().Set("Retry-After", fmt.Sprintf("%d", int(window.Seconds()))) return nil, vigo.NewError("rate limit exceeded").WithCode(http.StatusTooManyRequests) } x.Next() return nil, nil } } // RateLimiterByUser 基于用户的限流 func RateLimiterByUser(maxRequests int, window time.Duration) func(*vigo.X) (any, error) { return func(x *vigo.X) (any, error) { if !cache.IsEnabled() { x.Next() return nil, nil } userID := CurrentUser(x) if userID == "" { // 未登录用户使用IP限流 _, err := RateLimiter(maxRequests, window)(x) return nil, err } path := x.Request.URL.Path count, err := cache.IncrRateLimit("user:"+userID, path, window) if err != nil { x.Next() return nil, nil } if count > int64(maxRequests) { x.ResponseWriter().Header().Set("Retry-After", fmt.Sprintf("%d", int(window.Seconds()))) return nil, vigo.NewError("rate limit exceeded").WithCode(http.StatusTooManyRequests) } x.Next() return nil, nil } } // LoginRateLimit 登录限流(更严格) func LoginRateLimit() func(*vigo.X) (any, error) { return func(x *vigo.X) (any, error) { if !cache.IsEnabled() { x.Next() return nil, nil } identifier := x.GetRemoteIP() key := "login_attempt:" + identifier count, _ := cache.Incr(key) if count == 1 { cache.Expire(key, 15*time.Minute) } // 5分钟内超过5次尝试,需要验证码 if count >= 5 { x.Set("require_captcha", true) } // 超过10次直接拒绝 if count >= 10 { return nil, vigo.NewError("too many login attempts, please try again later").WithCode(http.StatusTooManyRequests) } x.Next() return nil, nil } } // ResetLoginAttempts 重置登录尝试次数 func ResetLoginAttempts(x *vigo.X) { if !cache.IsEnabled() { return } identifier := x.GetRemoteIP() key := "login_attempt:" + identifier cache.Delete(key) }