From 8b2a1aba3b67c42d36bb95b8cad139030741f622 Mon Sep 17 00:00:00 2001 From: veypi Date: Sun, 15 Feb 2026 03:33:37 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=BB=9F=E4=B8=80API=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E7=B1=BB=E5=9E=8B=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/auth/login.go | 12 ++++++------ api/auth/me.go | 10 +++++----- api/auth/register.go | 27 +++++++++++++++++++++------ api/auth/thirdparty.go | 28 ++++++++++++++-------------- api/oauth/authorize.go | 4 ++-- api/oauth/client.go | 2 +- api/oauth/oidc.go | 2 +- api/oauth/token.go | 10 +++++----- api/org/create.go | 4 ++-- api/user/create.go | 4 ++-- api/user/patch.go | 2 +- auth/auth.go | 10 +++++----- auth/middleware.go | 8 ++++---- 13 files changed, 69 insertions(+), 54 deletions(-) diff --git a/api/auth/login.go b/api/auth/login.go index e73ea6d..8cbeacf 100644 --- a/api/auth/login.go +++ b/api/auth/login.go @@ -50,7 +50,7 @@ func login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) { var user models.User query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username) if err := query.First(&user).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid username or password") + return nil, vigo.ErrUnauthorized.WithString("invalid username or password") } // 检查用户状态 @@ -60,7 +60,7 @@ func login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) { // 验证密码 if !crypto.VerifyPassword(req.Password, user.Password) { - return nil, vigo.ErrNotAuthorized.WithString("invalid username or password") + return nil, vigo.ErrUnauthorized.WithString("invalid username or password") } // 获取用户的组织信息 @@ -139,19 +139,19 @@ func refresh(x *vigo.X, req *RefreshRequest) (*AuthResponse, error) { claims, err := jwt.ParseToken(req.RefreshToken) if err != nil { if err == jwt.ErrExpiredToken { - return nil, vigo.ErrNotAuthorized.WithString("refresh token expired") + return nil, vigo.ErrUnauthorized.WithString("refresh token expired") } - return nil, vigo.ErrNotAuthorized.WithString("invalid refresh token") + return nil, vigo.ErrUnauthorized.WithString("invalid refresh token") } if !jwt.IsRefreshToken(claims) { - return nil, vigo.ErrNotAuthorized.WithString("invalid token type") + return nil, vigo.ErrUnauthorized.WithString("invalid token type") } // 查找用户 var user models.User if err := cfg.DB().First(&user, "id = ?", claims.UserID).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("user not found") + return nil, vigo.ErrUnauthorized.WithString("user not found") } if user.Status != models.UserStatusActive { diff --git a/api/auth/me.go b/api/auth/me.go index 43a0297..b193ba7 100644 --- a/api/auth/me.go +++ b/api/auth/me.go @@ -17,7 +17,7 @@ import ( func me(x *vigo.X) (*UserInfo, error) { userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } var user models.User @@ -45,7 +45,7 @@ type UpdateMeRequest struct { func updateMe(x *vigo.X, req *UpdateMeRequest) (*UserInfo, error) { userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } updates := make(map[string]any) @@ -60,7 +60,7 @@ func updateMe(x *vigo.X, req *UpdateMeRequest) (*UserInfo, error) { var count int64 cfg.DB().Model(&models.User{}).Where("email = ? AND id != ?", *req.Email, userID).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("email already exists") + return nil, vigo.ErrInvalidArg.WithString("email already exists") } updates["email"] = *req.Email } @@ -82,7 +82,7 @@ type ChangePasswordRequest struct { func changePassword(x *vigo.X, req *ChangePasswordRequest) error { userID := getCurrentUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } var user models.User @@ -92,7 +92,7 @@ func changePassword(x *vigo.X, req *ChangePasswordRequest) error { // 验证旧密码 if !crypto.VerifyPassword(req.OldPassword, user.Password) { - return vigo.ErrArgInvalid.WithString("old password is incorrect") + return vigo.ErrInvalidArg.WithString("old password is incorrect") } // 哈希新密码 diff --git a/api/auth/register.go b/api/auth/register.go index 8572406..e2d4995 100644 --- a/api/auth/register.go +++ b/api/auth/register.go @@ -26,18 +26,28 @@ type RegisterRequest struct { // register 用户注册 func register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) { + // 检查是否是第一个用户(需要在创建用户之前检查) + var userCount int64 + if err := cfg.DB().Model(&models.User{}).Count(&userCount).Error; err != nil { + return nil, vigo.ErrInternalServer.WithError(err) + } + // 检查用户名是否已存在 var count int64 - cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count) + if err := cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count).Error; err != nil { + return nil, vigo.ErrInternalServer.WithError(err) + } if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("username already exists") + return nil, vigo.ErrInvalidArg.WithArgs("username already exists") } // 检查邮箱是否已存在 if req.Email != "" { - cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count) + if err := cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count).Error; err != nil { + return nil, vigo.ErrInternalServer.WithError(err) + } if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("email already exists") + return nil, vigo.ErrInvalidArg.WithArgs("email already exists") } } @@ -74,8 +84,13 @@ func register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) { return nil, vigo.ErrInternalServer.WithError(err) } - // 授予默认角色 "user" - if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil { + // 第一个用户授予 admin 角色,其他用户授予 user 角色 + roleCode := "user" + if userCount == 0 { + roleCode = "admin" + } + + if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", roleCode); err != nil { // 记录错误但允许注册继续,或者回滚 // 这里简单处理,继续流程,用户可能需要管理员手动授权 // 或者返回错误 diff --git a/api/auth/thirdparty.go b/api/auth/thirdparty.go index 99a1f56..2d43acb 100644 --- a/api/auth/thirdparty.go +++ b/api/auth/thirdparty.go @@ -76,7 +76,7 @@ func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, if req.BindMode { userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized.WithString("login required for bind mode") + return nil, vigo.ErrUnauthorized.WithString("login required for bind mode") } stateData["user_id"] = userID } @@ -125,17 +125,17 @@ type CallbackResponse struct { // callbackThirdParty 处理第三方登录回调 func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) { if req.Error != "" { - return nil, vigo.ErrArgInvalid.WithString("oauth error: " + req.Error) + return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error) } if req.Code == "" || req.State == "" { - return nil, vigo.ErrArgInvalid.WithString("missing code or state") + return nil, vigo.ErrInvalidArg.WithString("missing code or state") } // 验证state stateData, err := verifyState(req.State) if err != nil { - return nil, vigo.ErrArgInvalid.WithString("invalid or expired state") + return nil, vigo.ErrInvalidArg.WithString("invalid or expired state") } provider := stateData["provider"].(string) @@ -197,19 +197,19 @@ func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) { // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { - return nil, vigo.ErrArgInvalid.WithString("invalid or expired token") + return nil, vigo.ErrInvalidArg.WithString("invalid or expired token") } // 查找用户 var user models.User query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username) if err := query.First(&user).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid credentials") + return nil, vigo.ErrUnauthorized.WithString("invalid credentials") } // 验证密码 if !crypto.VerifyPassword(req.Password, user.Password) { - return nil, vigo.ErrNotAuthorized.WithString("invalid credentials") + return nil, vigo.ErrUnauthorized.WithString("invalid credentials") } // 检查用户状态 @@ -239,21 +239,21 @@ func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, e // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { - return nil, vigo.ErrArgInvalid.WithString("invalid or expired token") + return nil, vigo.ErrInvalidArg.WithString("invalid or expired token") } // 检查用户名是否已存在 var count int64 cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("username already exists") + return nil, vigo.ErrInvalidArg.WithString("username already exists") } // 检查邮箱是否已存在 if req.Email != "" { cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("email already exists") + return nil, vigo.ErrInvalidArg.WithString("email already exists") } } @@ -311,7 +311,7 @@ type UnbindRequest struct { func unbindThirdParty(x *vigo.X, req *UnbindRequest) error { userID := getCurrentUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } // 删除绑定关系 @@ -335,7 +335,7 @@ type BindingInfo struct { func listBindings(x *vigo.X) ([]BindingInfo, error) { userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } var identities []models.Identity @@ -401,11 +401,11 @@ func buildAuthURL(provider, state string) (string, error) { case "wechat": pc = cfg.Config.Providers.WeChat default: - return "", vigo.ErrArgInvalid.WithString("unsupported provider: " + provider) + return "", vigo.ErrInvalidArg.WithString("unsupported provider: " + provider) } if !pc.Enabled { - return "", vigo.ErrArgInvalid.WithString("provider not enabled: " + provider) + return "", vigo.ErrInvalidArg.WithString("provider not enabled: " + provider) } params := url.Values{} diff --git a/api/oauth/authorize.go b/api/oauth/authorize.go index c40bb2f..a07f03b 100644 --- a/api/oauth/authorize.go +++ b/api/oauth/authorize.go @@ -29,7 +29,7 @@ func authorize(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) { // 验证客户端 var client models.OAuthClient if err := cfg.DB().First(&client, "client_id = ?", req.ClientID).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid client") + return nil, vigo.ErrUnauthorized.WithString("invalid client") } if client.Status != models.OAuthClientStatusActive { @@ -39,7 +39,7 @@ func authorize(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) { // 获取当前用户 userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } // 生成授权码 diff --git a/api/oauth/client.go b/api/oauth/client.go index b0b14a4..399f6b4 100644 --- a/api/oauth/client.go +++ b/api/oauth/client.go @@ -67,7 +67,7 @@ type CreateClientResponse struct { func createClient(x *vigo.X, req *CreateClientRequest) (*CreateClientResponse, error) { ownerID := getCurrentUserID(x) if ownerID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } clientID := crypto.GenerateClientID() diff --git a/api/oauth/oidc.go b/api/oauth/oidc.go index 192ee23..df164d5 100644 --- a/api/oauth/oidc.go +++ b/api/oauth/oidc.go @@ -15,7 +15,7 @@ func userInfo(x *vigo.X) (map[string]any, error) { // 从token中解析用户ID userID := getCurrentUserID(x) if userID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } var user models.User diff --git a/api/oauth/token.go b/api/oauth/token.go index 603051a..25a4098 100644 --- a/api/oauth/token.go +++ b/api/oauth/token.go @@ -37,7 +37,7 @@ func token(x *vigo.X, req *TokenRequest) (*TokenResponse, error) { case "refresh_token": return handleRefreshToken(req) default: - return nil, vigo.ErrArgInvalid.WithString("unsupported grant type") + return nil, vigo.ErrInvalidArg.WithString("unsupported grant type") } } @@ -45,13 +45,13 @@ func handleAuthorizationCode(req *TokenRequest) (*TokenResponse, error) { // 验证客户端 var client models.OAuthClient if err := cfg.DB().First(&client, "client_id = ?", req.ClientID).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid client") + return nil, vigo.ErrUnauthorized.WithString("invalid client") } // 验证授权码 var authData map[string]any if err := cache.GetObject(cache.OAuthCodeKey(req.Code), &authData); err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid or expired code") + return nil, vigo.ErrUnauthorized.WithString("invalid or expired code") } // 删除已使用的授权码 @@ -91,11 +91,11 @@ func handleRefreshToken(req *TokenRequest) (*TokenResponse, error) { // 查找刷新令牌 var token models.OAuthToken if err := cfg.DB().First(&token, "refresh_token = ?", req.RefreshToken).Error; err != nil { - return nil, vigo.ErrNotAuthorized.WithString("invalid refresh token") + return nil, vigo.ErrUnauthorized.WithString("invalid refresh token") } if token.Revoked { - return nil, vigo.ErrNotAuthorized.WithString("token has been revoked") + return nil, vigo.ErrUnauthorized.WithString("token has been revoked") } // 生成新的访问令牌 diff --git a/api/org/create.go b/api/org/create.go index 232a54f..522b750 100644 --- a/api/org/create.go +++ b/api/org/create.go @@ -24,13 +24,13 @@ func create(x *vigo.X, req *CreateRequest) (*models.Org, error) { var count int64 cfg.DB().Model(&models.Org{}).Where("code = ?", req.Code).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("organization code already exists") + return nil, vigo.ErrInvalidArg.WithString("organization code already exists") } // 获取当前用户ID作为所有者 ownerID := getCurrentUserID(x) if ownerID == "" { - return nil, vigo.ErrNotAuthorized + return nil, vigo.ErrUnauthorized } org := &models.Org{ diff --git a/api/user/create.go b/api/user/create.go index e7db816..874dd6a 100644 --- a/api/user/create.go +++ b/api/user/create.go @@ -30,14 +30,14 @@ func create(x *vigo.X, req *CreateRequest) (*models.User, error) { var count int64 cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("username already exists") + return nil, vigo.ErrInvalidArg.WithString("username already exists") } // 检查邮箱是否已存在 if req.Email != "" { cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("email already exists") + return nil, vigo.ErrInvalidArg.WithString("email already exists") } } diff --git a/api/user/patch.go b/api/user/patch.go index 157559a..58869bc 100644 --- a/api/user/patch.go +++ b/api/user/patch.go @@ -37,7 +37,7 @@ func patch(x *vigo.X, req *PatchRequest) (*models.User, error) { var count int64 cfg.DB().Model(&models.User{}).Where("email = ? AND id != ?", *req.Email, req.UserID).Count(&count) if count > 0 { - return nil, vigo.ErrArgInvalid.WithString("email already exists") + return nil, vigo.ErrInvalidArg.WithString("email already exists") } updates["email"] = *req.Email } diff --git a/auth/auth.go b/auth/auth.go index d94b59c..b6b9335 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -270,7 +270,7 @@ func (a *appAuth) Perm(permissionID string) func(*vigo.X) error { return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } orgID := getOrgID(x) @@ -286,7 +286,7 @@ func (a *appAuth) PermWithOwner(permissionID, ownerKey string) func(*vigo.X) err return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } orgID := getOrgID(x) @@ -337,7 +337,7 @@ func (a *appAuth) PermOnResource(permissionID, resourceKey string) func(*vigo.X) return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } orgID := getOrgID(x) @@ -375,7 +375,7 @@ func (a *appAuth) PermAny(permissionIDs []string) func(*vigo.X) error { return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } orgID := getOrgID(x) @@ -405,7 +405,7 @@ func (a *appAuth) PermAll(permissionIDs []string) func(*vigo.X) error { return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrNotAuthorized + return vigo.ErrUnauthorized } orgID := getOrgID(x) diff --git a/auth/middleware.go b/auth/middleware.go index 623005e..69f39be 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -22,23 +22,23 @@ func AuthMiddleware() func(*vigo.X) error { // === 1. JWT 认证部分 === tokenString := extractToken(x) if tokenString == "" { - return vigo.ErrNotAuthorized.WithString("missing token") + return vigo.ErrUnauthorized.WithString("missing token") } // 解析token claims, err := jwt.ParseToken(tokenString) if err != nil { if err == jwt.ErrExpiredToken { - return vigo.ErrNotAuthorized.WithString("token expired") + return vigo.ErrUnauthorized.WithString("token expired") } - return vigo.ErrNotAuthorized.WithString("invalid token") + return vigo.ErrUnauthorized.WithString("invalid token") } // 检查token是否在黑名单中 if cache.IsEnabled() { blacklisted, _ := cache.IsTokenBlacklisted(claims.ID) if blacklisted { - return vigo.ErrNotAuthorized.WithString("token has been revoked") + return vigo.ErrUnauthorized.WithString("token has been revoked") } }