mirror of https://github.com/veypi/OneAuth.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
514 lines
12 KiB
Go
514 lines
12 KiB
Go
package auth
|
|
|
|
import (
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/veypi/vbase/internal/cache"
|
|
"github.com/veypi/vbase/internal/config"
|
|
"github.com/veypi/vbase/internal/model"
|
|
"github.com/veypi/vbase/internal/pkg/crypto"
|
|
"github.com/veypi/vbase/internal/pkg/jwt"
|
|
"github.com/veypi/vigo"
|
|
)
|
|
|
|
// LoginRequest 登录请求
|
|
type LoginRequest struct {
|
|
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
|
|
Password string `json:"password" src:"json" desc:"密码"`
|
|
CaptchaID string `json:"captcha_id,omitempty" src:"json" desc:"验证码ID"`
|
|
CaptchaCode string `json:"captcha_code,omitempty" src:"json" desc:"验证码"`
|
|
Remember bool `json:"remember,omitempty" src:"json" desc:"记住登录"`
|
|
}
|
|
|
|
// RegisterRequest 注册请求
|
|
type RegisterRequest struct {
|
|
Username string `json:"username" src:"json" desc:"用户名"`
|
|
Password string `json:"password" src:"json" desc:"密码"`
|
|
Email string `json:"email,omitempty" src:"json" desc:"邮箱"`
|
|
Phone string `json:"phone,omitempty" src:"json" desc:"手机号"`
|
|
Nickname string `json:"nickname,omitempty" src:"json" desc:"昵称"`
|
|
}
|
|
|
|
// AuthResponse 认证响应
|
|
type AuthResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
User *UserInfo `json:"user"`
|
|
}
|
|
|
|
// UserInfo 用户信息
|
|
type UserInfo struct {
|
|
ID string `json:"id"`
|
|
Username string `json:"username"`
|
|
Nickname string `json:"nickname"`
|
|
Email string `json:"email"`
|
|
Avatar string `json:"avatar"`
|
|
}
|
|
|
|
// Login 用户登录
|
|
func Login(x *vigo.X, req *LoginRequest) (*AuthResponse, error) {
|
|
// 查找用户
|
|
var user model.User
|
|
query := model.DB.Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
|
|
if err := query.First(&user).Error; err != nil {
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid username or password")
|
|
}
|
|
|
|
// 检查用户状态
|
|
if user.Status != model.UserStatusActive {
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
}
|
|
|
|
// 验证密码
|
|
if !crypto.VerifyPassword(req.Password, user.Password) {
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid username or password")
|
|
}
|
|
|
|
// 获取用户的组织信息
|
|
orgs, err := getUserOrgs(user.ID)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 生成token
|
|
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
|
|
for _, org := range orgs {
|
|
orgClaims = append(orgClaims, jwt.OrgClaim{
|
|
OrgID: org.OrgID,
|
|
Code: org.Code,
|
|
Name: org.Name,
|
|
Roles: org.Roles,
|
|
Status: org.Status,
|
|
})
|
|
}
|
|
|
|
tokenPair, err := jwt.GenerateTokenPair(
|
|
user.ID,
|
|
user.Username,
|
|
user.Nickname,
|
|
user.Avatar,
|
|
user.Email,
|
|
orgClaims,
|
|
)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 保存session
|
|
session := &model.Session{
|
|
UserID: user.ID,
|
|
TokenID: getJTI(tokenPair.AccessToken),
|
|
Type: "access",
|
|
DeviceInfo: x.Request.UserAgent(),
|
|
IP: x.GetRemoteIP(),
|
|
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
|
|
}
|
|
model.DB.Create(session)
|
|
|
|
// 更新最后登录时间
|
|
now := time.Now()
|
|
model.DB.Model(&user).Update("last_login_at", now)
|
|
|
|
return &AuthResponse{
|
|
AccessToken: tokenPair.AccessToken,
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
TokenType: tokenPair.TokenType,
|
|
ExpiresIn: tokenPair.ExpiresIn,
|
|
User: &UserInfo{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Nickname: user.Nickname,
|
|
Email: user.Email,
|
|
Avatar: user.Avatar,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Register 用户注册
|
|
func Register(x *vigo.X, req *RegisterRequest) (*AuthResponse, error) {
|
|
// 检查用户名是否已存在
|
|
var count int64
|
|
model.DB.Model(&model.User{}).Where("username = ?", req.Username).Count(&count)
|
|
if count > 0 {
|
|
return nil, vigo.ErrArgInvalid.WithString("username already exists")
|
|
}
|
|
|
|
// 检查邮箱是否已存在
|
|
if req.Email != "" {
|
|
model.DB.Model(&model.User{}).Where("email = ?", req.Email).Count(&count)
|
|
if count > 0 {
|
|
return nil, vigo.ErrArgInvalid.WithString("email already exists")
|
|
}
|
|
}
|
|
|
|
// 哈希密码
|
|
hashedPassword, err := crypto.HashPassword(req.Password, config.C.Security.BcryptCost)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 创建用户
|
|
user := &model.User{
|
|
Username: req.Username,
|
|
Password: hashedPassword,
|
|
Email: req.Email,
|
|
Phone: req.Phone,
|
|
Nickname: req.Nickname,
|
|
Status: model.UserStatusActive,
|
|
}
|
|
|
|
if user.Nickname == "" {
|
|
user.Nickname = user.Username
|
|
}
|
|
|
|
if err := model.DB.Create(user).Error; err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 生成token
|
|
tokenPair, err := jwt.GenerateTokenPair(
|
|
user.ID,
|
|
user.Username,
|
|
user.Nickname,
|
|
user.Avatar,
|
|
user.Email,
|
|
nil, // 新用户无组织
|
|
)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 保存session
|
|
session := &model.Session{
|
|
UserID: user.ID,
|
|
TokenID: getJTI(tokenPair.AccessToken),
|
|
Type: "access",
|
|
DeviceInfo: x.Request.UserAgent(),
|
|
IP: x.GetRemoteIP(),
|
|
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
|
|
}
|
|
model.DB.Create(session)
|
|
|
|
return &AuthResponse{
|
|
AccessToken: tokenPair.AccessToken,
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
TokenType: tokenPair.TokenType,
|
|
ExpiresIn: tokenPair.ExpiresIn,
|
|
User: &UserInfo{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Nickname: user.Nickname,
|
|
Email: user.Email,
|
|
Avatar: user.Avatar,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// RefreshRequest 刷新请求
|
|
type RefreshRequest struct {
|
|
RefreshToken string `json:"refresh_token" src:"json" desc:"刷新令牌"`
|
|
}
|
|
|
|
// Refresh 刷新Token
|
|
func Refresh(x *vigo.X, req *RefreshRequest) (*AuthResponse, error) {
|
|
// 解析refresh token
|
|
claims, err := jwt.ParseToken(req.RefreshToken)
|
|
if err != nil {
|
|
if err == jwt.ErrExpiredToken {
|
|
return nil, vigo.ErrNotAuthorized.WithString("refresh token expired")
|
|
}
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid refresh token")
|
|
}
|
|
|
|
if !jwt.IsRefreshToken(claims) {
|
|
return nil, vigo.ErrNotAuthorized.WithString("invalid token type")
|
|
}
|
|
|
|
// 查找用户
|
|
var user model.User
|
|
if err := model.DB.First(&user, "id = ?", claims.UserID).Error; err != nil {
|
|
return nil, vigo.ErrNotAuthorized.WithString("user not found")
|
|
}
|
|
|
|
if user.Status != model.UserStatusActive {
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
}
|
|
|
|
// 获取用户的组织信息
|
|
orgs, err := getUserOrgs(user.ID)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
|
|
for _, org := range orgs {
|
|
orgClaims = append(orgClaims, jwt.OrgClaim{
|
|
OrgID: org.OrgID,
|
|
Code: org.Code,
|
|
Name: org.Name,
|
|
Roles: org.Roles,
|
|
Status: org.Status,
|
|
})
|
|
}
|
|
|
|
// 生成新token
|
|
tokenPair, err := jwt.GenerateTokenPair(
|
|
user.ID,
|
|
user.Username,
|
|
user.Nickname,
|
|
user.Avatar,
|
|
user.Email,
|
|
orgClaims,
|
|
)
|
|
if err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 保存新session
|
|
session := &model.Session{
|
|
UserID: user.ID,
|
|
TokenID: getJTI(tokenPair.AccessToken),
|
|
Type: "access",
|
|
DeviceInfo: x.Request.UserAgent(),
|
|
IP: x.GetRemoteIP(),
|
|
ExpiresAt: time.Now().Add(config.C.JWT.AccessExpiry),
|
|
}
|
|
model.DB.Create(session)
|
|
|
|
return &AuthResponse{
|
|
AccessToken: tokenPair.AccessToken,
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
TokenType: tokenPair.TokenType,
|
|
ExpiresIn: tokenPair.ExpiresIn,
|
|
User: &UserInfo{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Nickname: user.Nickname,
|
|
Email: user.Email,
|
|
Avatar: user.Avatar,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Logout 用户登出
|
|
func Logout(x *vigo.X) error {
|
|
tokenString := extractTokenFromRequest(x.Request)
|
|
if tokenString == "" {
|
|
return nil
|
|
}
|
|
|
|
jti, err := jwt.GetJTI(tokenString)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// 加入黑名单
|
|
expiration, _ := jwt.GetExpiration(tokenString)
|
|
if cache.IsEnabled() {
|
|
ttl := time.Until(expiration)
|
|
if ttl > 0 {
|
|
cache.BlacklistToken(jti, ttl)
|
|
}
|
|
}
|
|
|
|
// 标记session为撤销
|
|
model.DB.Model(&model.Session{}).Where("token_id = ?", jti).Updates(map[string]interface{}{
|
|
"revoked": true,
|
|
"revoked_at": time.Now(),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// Me 获取当前用户信息
|
|
func Me(x *vigo.X) (*UserInfo, error) {
|
|
userID := getCurrentUserID(x)
|
|
if userID == "" {
|
|
return nil, vigo.ErrNotAuthorized
|
|
}
|
|
|
|
var user model.User
|
|
if err := model.DB.First(&user, "id = ?", userID).Error; err != nil {
|
|
return nil, vigo.ErrNotFound
|
|
}
|
|
|
|
return &UserInfo{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Nickname: user.Nickname,
|
|
Email: user.Email,
|
|
Avatar: user.Avatar,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateMeRequest 更新自己请求
|
|
type UpdateMeRequest struct {
|
|
Nickname *string `json:"nickname,omitempty" src:"json" desc:"昵称"`
|
|
Avatar *string `json:"avatar,omitempty" src:"json" desc:"头像"`
|
|
Email *string `json:"email,omitempty" src:"json" desc:"邮箱"`
|
|
}
|
|
|
|
// UpdateMe 更新当前用户信息
|
|
func UpdateMe(x *vigo.X, req *UpdateMeRequest) (*UserInfo, error) {
|
|
userID := getCurrentUserID(x)
|
|
if userID == "" {
|
|
return nil, vigo.ErrNotAuthorized
|
|
}
|
|
|
|
updates := make(map[string]interface{})
|
|
if req.Nickname != nil {
|
|
updates["nickname"] = *req.Nickname
|
|
}
|
|
if req.Avatar != nil {
|
|
updates["avatar"] = *req.Avatar
|
|
}
|
|
if req.Email != nil {
|
|
// 检查邮箱是否被其他用户使用
|
|
var count int64
|
|
model.DB.Model(&model.User{}).Where("email = ? AND id != ?", *req.Email, userID).Count(&count)
|
|
if count > 0 {
|
|
return nil, vigo.ErrArgInvalid.WithString("email already exists")
|
|
}
|
|
updates["email"] = *req.Email
|
|
}
|
|
|
|
if err := model.DB.Model(&model.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
return Me(x)
|
|
}
|
|
|
|
// ChangePasswordRequest 修改密码请求
|
|
type ChangePasswordRequest struct {
|
|
OldPassword string `json:"old_password" src:"json" desc:"旧密码"`
|
|
NewPassword string `json:"new_password" src:"json" desc:"新密码"`
|
|
}
|
|
|
|
// ChangePassword 修改密码
|
|
func ChangePassword(x *vigo.X, req *ChangePasswordRequest) error {
|
|
userID := getCurrentUserID(x)
|
|
if userID == "" {
|
|
return vigo.ErrNotAuthorized
|
|
}
|
|
|
|
var user model.User
|
|
if err := model.DB.First(&user, "id = ?", userID).Error; err != nil {
|
|
return vigo.ErrNotFound
|
|
}
|
|
|
|
// 验证旧密码
|
|
if !crypto.VerifyPassword(req.OldPassword, user.Password) {
|
|
return vigo.ErrArgInvalid.WithString("old password is incorrect")
|
|
}
|
|
|
|
// 哈希新密码
|
|
hashedPassword, err := crypto.HashPassword(req.NewPassword, config.C.Security.BcryptCost)
|
|
if err != nil {
|
|
return vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
// 更新密码
|
|
if err := model.DB.Model(&user).Update("password", hashedPassword).Error; err != nil {
|
|
return vigo.ErrInternalServer.WithError(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// helper functions
|
|
|
|
func getCurrentUserID(x *vigo.X) string {
|
|
if uid, ok := x.Get("current_user").(string); ok {
|
|
return uid
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func extractToken(r *vigo.X) string {
|
|
// 从Header获取
|
|
auth := r.Request.Header.Get("Authorization")
|
|
if auth != "" {
|
|
parts := make([]string, 0)
|
|
for _, p := range []string{auth} {
|
|
parts = append(parts, p)
|
|
}
|
|
// Simple check
|
|
if len(auth) > 7 && auth[:7] == "Bearer " {
|
|
return auth[7:]
|
|
}
|
|
}
|
|
|
|
// 从Query获取
|
|
return r.Request.URL.Query().Get("access_token")
|
|
}
|
|
|
|
func extractTokenFromRequest(r *http.Request) string {
|
|
// 从Header获取
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "" {
|
|
if len(auth) > 7 && auth[:7] == "Bearer " {
|
|
return auth[7:]
|
|
}
|
|
}
|
|
|
|
// 从Query获取
|
|
return r.URL.Query().Get("access_token")
|
|
}
|
|
|
|
func getJTI(token string) string {
|
|
jti, _ := jwt.GetJTI(token)
|
|
return jti
|
|
}
|
|
|
|
type userOrgInfo struct {
|
|
OrgID string
|
|
Code string
|
|
Name string
|
|
Roles []string
|
|
Status int
|
|
}
|
|
|
|
func getUserOrgs(userID string) ([]userOrgInfo, error) {
|
|
var members []model.OrgMember
|
|
if err := model.DB.Where("user_id = ? AND status = ?", userID, model.MemberStatusActive).Find(&members).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(members) == 0 {
|
|
return []userOrgInfo{}, nil
|
|
}
|
|
|
|
result := make([]userOrgInfo, 0, len(members))
|
|
for _, m := range members {
|
|
var org model.Org
|
|
if err := model.DB.First(&org, "id = ?", m.OrgID).Error; err != nil {
|
|
continue
|
|
}
|
|
|
|
// 解析角色ID
|
|
roles := parseRoles(m.RoleIDs)
|
|
|
|
result = append(result, userOrgInfo{
|
|
OrgID: m.OrgID,
|
|
Code: org.Code,
|
|
Name: org.Name,
|
|
Roles: roles,
|
|
Status: m.Status,
|
|
})
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func parseRoles(roleIDs string) []string {
|
|
if roleIDs == "" {
|
|
return []string{}
|
|
}
|
|
// 简单解析,实际可能需要更复杂的逻辑
|
|
return []string{}
|
|
}
|