// // Copyright (C) 2024 veypi // 2025-03-04 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "time" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" "github.com/veypi/vbase/libs/crypto" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" ) // LoginRequest 登录请求 type LoginRequest struct { Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"` Password string `json:"password" src:"json" desc:"密码"` CaptchaID string `json:"captcha_id,omitempty" src:"json" desc:"验证码ID"` CaptchaCode string `json:"captcha_code,omitempty" src:"json" desc:"验证码"` Remember bool `json:"remember,omitempty" src:"json" desc:"记住登录"` } // AuthResponse 认证响应 type AuthResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` User *UserInfo `json:"user"` } // UserInfo 用户信息 type UserInfo struct { ID string `json:"id"` Username string `json:"username"` Nickname string `json:"nickname"` Email *string `json:"email"` Avatar string `json:"avatar"` } // login 用户登录 func login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) { // 查找用户 var user models.User query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username) if err := query.First(&user).Error; err != nil { return nil, vigo.ErrUnauthorized.WithString("invalid username or password") } // 检查用户状态 if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 验证密码 if !crypto.VerifyPassword(req.Password, user.Password) { return nil, vigo.ErrUnauthorized.WithString("invalid username or password") } // 获取用户的组织信息 orgs, err := getUserOrgs(user.ID) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 生成token orgClaims := make([]jwt.OrgClaim, 0, len(orgs)) for _, org := range orgs { orgClaims = append(orgClaims, jwt.OrgClaim{ OrgID: org.OrgID, Code: org.Code, Name: org.Name, Roles: org.Roles, Status: org.Status, }) } emailStr := "" if user.Email != nil { emailStr = *user.Email } tokenPair, err := jwt.GenerateTokenPair( user.ID, user.Username, user.Nickname, user.Avatar, emailStr, orgClaims, ) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 保存session session := &models.Session{ UserID: user.ID, TokenID: getJTI(tokenPair.AccessToken), Type: "access", DeviceInfo: x.Request.UserAgent(), IP: x.GetRemoteIP(), ExpiresAt: time.Now().Add(cfg.Config.JWT.AccessExpiry), } cfg.DB().Create(session) // 更新最后登录时间 now := time.Now() cfg.DB().Model(&user).Update("last_login_at", now) return &AuthResponse{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, ExpiresIn: tokenPair.ExpiresIn, User: &UserInfo{ ID: user.ID, Username: user.Username, Nickname: user.Nickname, Email: user.Email, Avatar: user.Avatar, }, }, nil } // RefreshRequest 刷新请求 type RefreshRequest struct { RefreshToken string `json:"refresh_token" src:"json" desc:"刷新令牌"` } // refresh 刷新Token func refresh(x *vigo.X, req *RefreshRequest) (*AuthResponse, error) { // 解析refresh token claims, err := jwt.ParseToken(req.RefreshToken) if err != nil { if err == jwt.ErrExpiredToken { return nil, vigo.ErrUnauthorized.WithString("refresh token expired") } return nil, vigo.ErrUnauthorized.WithString("invalid refresh token") } if !jwt.IsRefreshToken(claims) { return nil, vigo.ErrUnauthorized.WithString("invalid token type") } // 查找用户 var user models.User if err := cfg.DB().First(&user, "id = ?", claims.UserID).Error; err != nil { return nil, vigo.ErrUnauthorized.WithString("user not found") } if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 获取用户的组织信息 orgs, err := getUserOrgs(user.ID) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } orgClaims := make([]jwt.OrgClaim, 0, len(orgs)) for _, org := range orgs { orgClaims = append(orgClaims, jwt.OrgClaim{ OrgID: org.OrgID, Code: org.Code, Name: org.Name, Roles: org.Roles, Status: org.Status, }) } // 生成新token emailStr := "" if user.Email != nil { emailStr = *user.Email } tokenPair, err := jwt.GenerateTokenPair( user.ID, user.Username, user.Nickname, user.Avatar, emailStr, orgClaims, ) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 保存新session session := &models.Session{ UserID: user.ID, TokenID: getJTI(tokenPair.AccessToken), Type: "access", DeviceInfo: x.Request.UserAgent(), IP: x.GetRemoteIP(), ExpiresAt: time.Now().Add(cfg.Config.JWT.AccessExpiry), } cfg.DB().Create(session) return &AuthResponse{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, ExpiresIn: tokenPair.ExpiresIn, User: &UserInfo{ ID: user.ID, Username: user.Username, Nickname: user.Nickname, Email: user.Email, Avatar: user.Avatar, }, }, nil } // logout 用户登出 func logout(x *vigo.X) error { tokenString := extractToken(x) if tokenString == "" { return nil } jti, err := jwt.GetJTI(tokenString) if err != nil { return nil } // 加入黑名单 expiration, _ := jwt.GetExpiration(tokenString) if cache.IsEnabled() { ttl := time.Until(expiration) if ttl > 0 { cache.BlacklistToken(jti, ttl) } } // 标记session为撤销 cfg.DB().Model(&models.Session{}).Where("token_id = ?", jti).Updates(map[string]any{ "revoked": true, "revoked_at": time.Now(), }) return nil } // helper functions func getJTI(token string) string { jti, _ := jwt.GetJTI(token) return jti } func extractToken(x *vigo.X) string { // 从Header获取 auth := x.Request.Header.Get("Authorization") if auth != "" { if len(auth) > 7 && auth[:7] == "Bearer " { return auth[7:] } } // 从Query获取 return x.Request.URL.Query().Get("access_token") } type userOrgInfo struct { OrgID string Code string Name string Roles []string Status int } func getUserOrgs(userID string) ([]userOrgInfo, error) { var members []models.OrgMember if err := cfg.DB().Where("user_id = ? AND status = ?", userID, models.MemberStatusActive).Find(&members).Error; err != nil { return nil, err } if len(members) == 0 { return []userOrgInfo{}, nil } result := make([]userOrgInfo, 0, len(members)) for _, m := range members { var org models.Org if err := cfg.DB().First(&org, "id = ?", m.OrgID).Error; err != nil { continue } // 解析角色ID roles := parseRoles(m.RoleIDs) result = append(result, userOrgInfo{ OrgID: m.OrgID, Code: org.Code, Name: org.Name, Roles: roles, Status: m.Status, }) } return result, nil } func parseRoles(roleIDs string) []string { if roleIDs == "" { return []string{} } // 简单解析,实际可能需要更复杂的逻辑 return []string{} }