|
|
package middleware
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"time"
|
|
|
|
|
|
"github.com/veypi/vbase/internal/cache"
|
|
|
"github.com/veypi/vigo"
|
|
|
)
|
|
|
|
|
|
// RateLimiter 限流中间件
|
|
|
func RateLimiter(maxRequests int, window time.Duration) func(*vigo.X) (any, error) {
|
|
|
return func(x *vigo.X) (any, error) {
|
|
|
if !cache.IsEnabled() {
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
// 使用IP+路径作为标识
|
|
|
identifier := x.GetRemoteIP()
|
|
|
path := x.Request.URL.Path
|
|
|
|
|
|
count, err := cache.IncrRateLimit(identifier, path, window)
|
|
|
if err != nil {
|
|
|
x.Next() // 缓存失败时放行
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
if count > int64(maxRequests) {
|
|
|
x.ResponseWriter().Header().Set("Retry-After", fmt.Sprintf("%d", int(window.Seconds())))
|
|
|
return nil, vigo.NewError("rate limit exceeded").WithCode(http.StatusTooManyRequests)
|
|
|
}
|
|
|
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// RateLimiterByUser 基于用户的限流
|
|
|
func RateLimiterByUser(maxRequests int, window time.Duration) func(*vigo.X) (any, error) {
|
|
|
return func(x *vigo.X) (any, error) {
|
|
|
if !cache.IsEnabled() {
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
userID := CurrentUser(x)
|
|
|
if userID == "" {
|
|
|
// 未登录用户使用IP限流
|
|
|
_, err := RateLimiter(maxRequests, window)(x)
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
path := x.Request.URL.Path
|
|
|
count, err := cache.IncrRateLimit("user:"+userID, path, window)
|
|
|
if err != nil {
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
if count > int64(maxRequests) {
|
|
|
x.ResponseWriter().Header().Set("Retry-After", fmt.Sprintf("%d", int(window.Seconds())))
|
|
|
return nil, vigo.NewError("rate limit exceeded").WithCode(http.StatusTooManyRequests)
|
|
|
}
|
|
|
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// LoginRateLimit 登录限流(更严格)
|
|
|
func LoginRateLimit() func(*vigo.X) (any, error) {
|
|
|
return func(x *vigo.X) (any, error) {
|
|
|
if !cache.IsEnabled() {
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
identifier := x.GetRemoteIP()
|
|
|
key := "login_attempt:" + identifier
|
|
|
|
|
|
count, _ := cache.Incr(key)
|
|
|
if count == 1 {
|
|
|
cache.Expire(key, 15*time.Minute)
|
|
|
}
|
|
|
|
|
|
// 5分钟内超过5次尝试,需要验证码
|
|
|
if count >= 5 {
|
|
|
x.Set("require_captcha", true)
|
|
|
}
|
|
|
|
|
|
// 超过10次直接拒绝
|
|
|
if count >= 10 {
|
|
|
return nil, vigo.NewError("too many login attempts, please try again later").WithCode(http.StatusTooManyRequests)
|
|
|
}
|
|
|
|
|
|
x.Next()
|
|
|
return nil, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// ResetLoginAttempts 重置登录尝试次数
|
|
|
func ResetLoginAttempts(x *vigo.X) {
|
|
|
if !cache.IsEnabled() {
|
|
|
return
|
|
|
}
|
|
|
identifier := x.GetRemoteIP()
|
|
|
key := "login_attempt:" + identifier
|
|
|
cache.Delete(key)
|
|
|
}
|