You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/api/auth/login.go

319 lines
7.2 KiB
Go

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"time"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/crypto"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
)
// LoginRequest 登录请求
type LoginRequest struct {
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
CaptchaID string `json:"captcha_id,omitempty" src:"json" desc:"验证码ID"`
CaptchaCode string `json:"captcha_code,omitempty" src:"json" desc:"验证码"`
Remember bool `json:"remember,omitempty" src:"json" desc:"记住登录"`
}
// AuthResponse 认证响应
type AuthResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
User *UserInfo `json:"user"`
}
// UserInfo 用户信息
type UserInfo struct {
1 week ago
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email *string `json:"email"`
Avatar string `json:"avatar"`
}
// login 用户登录
func login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) {
// 查找用户
var user models.User
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrUnauthorized.WithString("invalid username or password")
}
// 检查用户状态
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrUnauthorized.WithString("invalid username or password")
}
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 生成token
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
1 week ago
emailStr := ""
if user.Email != nil {
emailStr = *user.Email
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
1 week ago
emailStr,
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 保存session
session := &models.Session{
UserID: user.ID,
TokenID: getJTI(tokenPair.AccessToken),
Type: "access",
DeviceInfo: x.Request.UserAgent(),
IP: x.GetRemoteIP(),
ExpiresAt: time.Now().Add(cfg.Config.JWT.AccessExpiry),
}
cfg.DB().Create(session)
// 更新最后登录时间
now := time.Now()
cfg.DB().Model(&user).Update("last_login_at", now)
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
// RefreshRequest 刷新请求
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" src:"json" desc:"刷新令牌"`
}
// refresh 刷新Token
func refresh(x *vigo.X, req *RefreshRequest) (*AuthResponse, error) {
// 解析refresh token
claims, err := jwt.ParseToken(req.RefreshToken)
if err != nil {
if err == jwt.ErrExpiredToken {
return nil, vigo.ErrTokenExpired
}
return nil, vigo.ErrTokenInvalid
}
if !jwt.IsRefreshToken(claims) {
return nil, vigo.ErrTokenInvalid
}
// 查找用户
var user models.User
if err := cfg.DB().First(&user, "id = ?", claims.UserID).Error; err != nil {
return nil, vigo.ErrTokenInvalid
}
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
// 生成新token
1 week ago
emailStr := ""
if user.Email != nil {
emailStr = *user.Email
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
1 week ago
emailStr,
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 保存新session
session := &models.Session{
UserID: user.ID,
TokenID: getJTI(tokenPair.AccessToken),
Type: "access",
DeviceInfo: x.Request.UserAgent(),
IP: x.GetRemoteIP(),
ExpiresAt: time.Now().Add(cfg.Config.JWT.AccessExpiry),
}
cfg.DB().Create(session)
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
// logout 用户登出
func logout(x *vigo.X) error {
tokenString := extractToken(x)
if tokenString == "" {
return nil
}
jti, err := jwt.GetJTI(tokenString)
if err != nil {
return nil
}
// 加入黑名单
expiration, _ := jwt.GetExpiration(tokenString)
if cache.IsEnabled() {
ttl := time.Until(expiration)
if ttl > 0 {
cache.BlacklistToken(jti, ttl)
}
}
// 标记session为撤销
cfg.DB().Model(&models.Session{}).Where("token_id = ?", jti).Updates(map[string]any{
"revoked": true,
"revoked_at": time.Now(),
})
return nil
}
// helper functions
func getJTI(token string) string {
jti, _ := jwt.GetJTI(token)
return jti
}
func extractToken(x *vigo.X) string {
// 从Header获取
auth := x.Request.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
}
// 从Query获取
return x.Request.URL.Query().Get("access_token")
}
type userOrgInfo struct {
OrgID string
Code string
Name string
Roles []string
Status int
}
func getUserOrgs(userID string) ([]userOrgInfo, error) {
var members []models.OrgMember
if err := cfg.DB().Where("user_id = ? AND status = ?", userID, models.MemberStatusActive).Find(&members).Error; err != nil {
return nil, err
}
if len(members) == 0 {
return []userOrgInfo{}, nil
}
result := make([]userOrgInfo, 0, len(members))
for _, m := range members {
var org models.Org
if err := cfg.DB().First(&org, "id = ?", m.OrgID).Error; err != nil {
continue
}
// 解析角色ID
roles := parseRoles(m.RoleIDs)
result = append(result, userOrgInfo{
OrgID: m.OrgID,
Code: org.Code,
Name: org.Name,
Roles: roles,
Status: m.Status,
})
}
return result, nil
}
func parseRoles(roleIDs string) []string {
if roleIDs == "" {
return []string{}
}
// 简单解析,实际可能需要更复杂的逻辑
return []string{}
}