You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/internal/api/auth/handler.go

514 lines
12 KiB
Go

2 weeks ago
package auth
import (
"net/http"
"time"
"github.com/veypi/vbase/internal/cache"
"github.com/veypi/vbase/internal/config"
"github.com/veypi/vbase/internal/model"
"github.com/veypi/vbase/internal/pkg/crypto"
"github.com/veypi/vbase/internal/pkg/jwt"
"github.com/veypi/vigo"
)
// LoginRequest 登录请求
type LoginRequest struct {
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
CaptchaID string `json:"captcha_id,omitempty" src:"json" desc:"验证码ID"`
CaptchaCode string `json:"captcha_code,omitempty" src:"json" desc:"验证码"`
Remember bool `json:"remember,omitempty" src:"json" desc:"记住登录"`
}
// RegisterRequest 注册请求
type RegisterRequest struct {
Username string `json:"username" src:"json" desc:"用户名"`
Password string `json:"password" src:"json" desc:"密码"`
Email string `json:"email,omitempty" src:"json" desc:"邮箱"`
Phone string `json:"phone,omitempty" src:"json" desc:"手机号"`
Nickname string `json:"nickname,omitempty" src:"json" desc:"昵称"`
}
// AuthResponse 认证响应
type AuthResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
User *UserInfo `json:"user"`
}
// UserInfo 用户信息
type UserInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email string `json:"email"`
Avatar string `json:"avatar"`
}
// Login 用户登录
func Login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) {
// 查找用户
var user model.User
query := model.DB.Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrNotAuthorized.WithString("invalid username or password")
}
// 检查用户状态
if user.Status != model.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrNotAuthorized.WithString("invalid username or password")
}
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 生成token
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
user.Email,
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 保存session
session := &model.Session{
UserID: user.ID,
TokenID: getJTI(tokenPair.AccessToken),
Type: "access",
DeviceInfo: x.Request.UserAgent(),
IP: x.GetRemoteIP(),
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
}
model.DB.Create(session)
// 更新最后登录时间
now := time.Now()
model.DB.Model(&user).Update("last_login_at", now)
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
// Register 用户注册
func Register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) {
// 检查用户名是否已存在
var count int64
model.DB.Model(&model.User{}).Where("username = ?", req.Username).Count(&count)
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("username already exists")
}
// 检查邮箱是否已存在
if req.Email != "" {
model.DB.Model(&model.User{}).Where("email = ?", req.Email).Count(&count)
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("email already exists")
}
}
// 哈希密码
hashedPassword, err := crypto.HashPassword(req.Password, config.C.Security.BcryptCost)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 创建用户
user := &model.User{
Username: req.Username,
Password: hashedPassword,
Email: req.Email,
Phone: req.Phone,
Nickname: req.Nickname,
Status: model.UserStatusActive,
}
if user.Nickname == "" {
user.Nickname = user.Username
}
if err := model.DB.Create(user).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 生成token
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
user.Email,
nil, // 新用户无组织
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 保存session
session := &model.Session{
UserID: user.ID,
TokenID: getJTI(tokenPair.AccessToken),
Type: "access",
DeviceInfo: x.Request.UserAgent(),
IP: x.GetRemoteIP(),
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
}
model.DB.Create(session)
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
// RefreshRequest 刷新请求
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" src:"json" desc:"刷新令牌"`
}
// Refresh 刷新Token
func Refresh(x *vigo.X, req *RefreshRequest) (*AuthResponse, error) {
// 解析refresh token
claims, err := jwt.ParseToken(req.RefreshToken)
if err != nil {
if err == jwt.ErrExpiredToken {
return nil, vigo.ErrNotAuthorized.WithString("refresh token expired")
}
return nil, vigo.ErrNotAuthorized.WithString("invalid refresh token")
}
if !jwt.IsRefreshToken(claims) {
return nil, vigo.ErrNotAuthorized.WithString("invalid token type")
}
// 查找用户
var user model.User
if err := model.DB.First(&user, "id = ?", claims.UserID).Error; err != nil {
return nil, vigo.ErrNotAuthorized.WithString("user not found")
}
if user.Status != model.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
// 生成新token
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
user.Email,
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 保存新session
session := &model.Session{
UserID: user.ID,
TokenID: getJTI(tokenPair.AccessToken),
Type: "access",
DeviceInfo: x.Request.UserAgent(),
IP: x.GetRemoteIP(),
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
}
model.DB.Create(session)
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
// Logout 用户登出
func Logout(x *vigo.X) error {
tokenString := extractTokenFromRequest(x.Request)
if tokenString == "" {
return nil
}
jti, err := jwt.GetJTI(tokenString)
if err != nil {
return nil
}
// 加入黑名单
expiration, _ := jwt.GetExpiration(tokenString)
if cache.IsEnabled() {
ttl := time.Until(expiration)
if ttl > 0 {
cache.BlacklistToken(jti, ttl)
}
}
// 标记session为撤销
model.DB.Model(&model.Session{}).Where("token_id = ?", jti).Updates(map[string]interface{}{
"revoked": true,
"revoked_at": time.Now(),
})
return nil
}
// Me 获取当前用户信息
func Me(x *vigo.X) (*UserInfo, error) {
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized
}
var user model.User
if err := model.DB.First(&user, "id = ?", userID).Error; err != nil {
return nil, vigo.ErrNotFound
}
return &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
}, nil
}
// UpdateMeRequest 更新自己请求
type UpdateMeRequest struct {
Nickname *string `json:"nickname,omitempty" src:"json" desc:"昵称"`
Avatar *string `json:"avatar,omitempty" src:"json" desc:"头像"`
Email *string `json:"email,omitempty" src:"json" desc:"邮箱"`
}
// UpdateMe 更新当前用户信息
func UpdateMe(x *vigo.X, req *UpdateMeRequest) (*UserInfo, error) {
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized
}
updates := make(map[string]interface{})
if req.Nickname != nil {
updates["nickname"] = *req.Nickname
}
if req.Avatar != nil {
updates["avatar"] = *req.Avatar
}
if req.Email != nil {
// 检查邮箱是否被其他用户使用
var count int64
model.DB.Model(&model.User{}).Where("email = ? AND id != ?", *req.Email, userID).Count(&count)
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("email already exists")
}
updates["email"] = *req.Email
}
if err := model.DB.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return Me(x)
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" src:"json" desc:"旧密码"`
NewPassword string `json:"new_password" src:"json" desc:"新密码"`
}
// ChangePassword 修改密码
func ChangePassword(x *vigo.X, req *ChangePasswordRequest) error {
userID := getCurrentUserID(x)
if userID == "" {
return vigo.ErrNotAuthorized
}
var user model.User
if err := model.DB.First(&user, "id = ?", userID).Error; err != nil {
return vigo.ErrNotFound
}
// 验证旧密码
if !crypto.VerifyPassword(req.OldPassword, user.Password) {
return vigo.ErrArgInvalid.WithString("old password is incorrect")
}
// 哈希新密码
hashedPassword, err := crypto.HashPassword(req.NewPassword, config.C.Security.BcryptCost)
if err != nil {
return vigo.ErrInternalServer.WithError(err)
}
// 更新密码
if err := model.DB.Model(&user).Update("password", hashedPassword).Error; err != nil {
return vigo.ErrInternalServer.WithError(err)
}
return nil
}
// helper functions
func getCurrentUserID(x *vigo.X) string {
if uid, ok := x.Get("current_user").(string); ok {
return uid
}
return ""
}
func extractToken(r *vigo.X) string {
// 从Header获取
auth := r.Request.Header.Get("Authorization")
if auth != "" {
parts := make([]string, 0)
for _, p := range []string{auth} {
parts = append(parts, p)
}
// Simple check
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
}
// 从Query获取
return r.Request.URL.Query().Get("access_token")
}
func extractTokenFromRequest(r *http.Request) string {
// 从Header获取
auth := r.Header.Get("Authorization")
if auth != "" {
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
}
// 从Query获取
return r.URL.Query().Get("access_token")
}
func getJTI(token string) string {
jti, _ := jwt.GetJTI(token)
return jti
}
type userOrgInfo struct {
OrgID string
Code string
Name string
Roles []string
Status int
}
func getUserOrgs(userID string) ([]userOrgInfo, error) {
var members []model.OrgMember
if err := model.DB.Where("user_id = ? AND status = ?", userID, model.MemberStatusActive).Find(&members).Error; err != nil {
return nil, err
}
if len(members) == 0 {
return []userOrgInfo{}, nil
}
result := make([]userOrgInfo, 0, len(members))
for _, m := range members {
var org model.Org
if err := model.DB.First(&org, "id = ?", m.OrgID).Error; err != nil {
continue
}
// 解析角色ID
roles := parseRoles(m.RoleIDs)
result = append(result, userOrgInfo{
OrgID: m.OrgID,
Code: org.Code,
Name: org.Name,
Roles: roles,
Status: m.Status,
})
}
return result, nil
}
func parseRoles(roleIDs string) []string {
if roleIDs == "" {
return []string{}
}
// 简单解析,实际可能需要更复杂的逻辑
return []string{}
}