|
|
//
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
// 2025-03-04 16:08:06
|
|
|
// Distributed under terms of the MIT license.
|
|
|
//
|
|
|
|
|
|
package auth
|
|
|
|
|
|
import (
|
|
|
"crypto/rand"
|
|
|
"encoding/hex"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
baseauth "github.com/veypi/vbase/auth"
|
|
|
"github.com/veypi/vbase/cfg"
|
|
|
"github.com/veypi/vbase/libs/cache"
|
|
|
"github.com/veypi/vbase/libs/crypto"
|
|
|
"github.com/veypi/vbase/libs/jwt"
|
|
|
"github.com/veypi/vbase/models"
|
|
|
"github.com/veypi/vigo"
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
// ProviderInfo 第三方登录提供商信息
|
|
|
type ProviderInfo struct {
|
|
|
Name string `json:"name"`
|
|
|
DisplayName string `json:"display_name"`
|
|
|
Icon string `json:"icon"`
|
|
|
Enabled bool `json:"enabled"`
|
|
|
}
|
|
|
|
|
|
// listProviders 获取支持的第三方登录提供商列表
|
|
|
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
|
|
|
var dbProviders []models.OAuthProvider
|
|
|
if err := cfg.DB().Where("enabled = ?", true).Order("sort_order").Find(&dbProviders).Error; err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
providers := make([]ProviderInfo, 0, len(dbProviders))
|
|
|
for _, p := range dbProviders {
|
|
|
providers = append(providers, ProviderInfo{
|
|
|
Name: p.Code,
|
|
|
DisplayName: p.Name,
|
|
|
Icon: p.Icon,
|
|
|
Enabled: p.Enabled,
|
|
|
})
|
|
|
}
|
|
|
return providers, nil
|
|
|
}
|
|
|
|
|
|
// AuthorizeRequest 第三方登录授权请求
|
|
|
type AuthorizeRequest struct {
|
|
|
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
|
|
|
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
|
|
|
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
|
|
|
}
|
|
|
|
|
|
// AuthorizeResponse 授权响应
|
|
|
type AuthorizeResponse struct {
|
|
|
AuthURL string `json:"auth_url"`
|
|
|
State string `json:"state"`
|
|
|
}
|
|
|
|
|
|
// authorizeThirdParty 获取第三方登录授权URL
|
|
|
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
|
|
|
provider := strings.ToLower(req.Provider)
|
|
|
|
|
|
// 生成state
|
|
|
state := generateState()
|
|
|
|
|
|
// 存储state到缓存(用于验证回调)
|
|
|
stateData := map[string]any{
|
|
|
"provider": provider,
|
|
|
"redirect": req.Redirect,
|
|
|
"bind_mode": req.BindMode,
|
|
|
"created_at": time.Now().Unix(),
|
|
|
}
|
|
|
|
|
|
// 如果是绑定模式,需要当前用户登录
|
|
|
if req.BindMode {
|
|
|
userID := baseauth.VBaseAuth.UserID(x)
|
|
|
if userID == "" {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("login required for bind mode")
|
|
|
}
|
|
|
stateData["user_id"] = userID
|
|
|
}
|
|
|
|
|
|
// 保存state数据到缓存(10分钟有效)
|
|
|
if err := saveState(state, stateData); err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 构建授权URL
|
|
|
authURL, err := buildAuthURL(provider, state)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &AuthorizeResponse{
|
|
|
AuthURL: authURL,
|
|
|
State: state,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// CallbackRequest 第三方登录回调请求
|
|
|
type CallbackRequest struct {
|
|
|
Provider string `json:"provider" src:"path" desc:"提供商"`
|
|
|
Code string `json:"code" src:"query" desc:"授权码"`
|
|
|
State string `json:"state" src:"query" desc:"状态值"`
|
|
|
Error string `json:"error" src:"query" desc:"错误信息"`
|
|
|
}
|
|
|
|
|
|
// CallbackResponse 回调响应
|
|
|
type CallbackResponse struct {
|
|
|
// 绑定模式
|
|
|
BindMode bool `json:"bind_mode,omitempty"`
|
|
|
Bound bool `json:"bound,omitempty"`
|
|
|
|
|
|
// 登录模式
|
|
|
NeedBind bool `json:"need_bind,omitempty"`
|
|
|
Provider string `json:"provider,omitempty"`
|
|
|
ProviderID string `json:"provider_id,omitempty"`
|
|
|
AuthURL string `json:"auth_url,omitempty"`
|
|
|
|
|
|
// 登录成功
|
|
|
*AuthResponse
|
|
|
}
|
|
|
|
|
|
// callbackThirdParty 处理第三方登录回调
|
|
|
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
|
|
|
if req.Error != "" {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error)
|
|
|
}
|
|
|
|
|
|
if req.Code == "" || req.State == "" {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("missing code or state")
|
|
|
}
|
|
|
|
|
|
// 验证state
|
|
|
stateData, err := verifyState(req.State)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
|
|
|
}
|
|
|
|
|
|
provider := stateData["provider"].(string)
|
|
|
bindMode := stateData["bind_mode"].(bool)
|
|
|
|
|
|
// 交换access_token并获取用户信息
|
|
|
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 绑定模式:将第三方身份绑定到当前用户
|
|
|
if bindMode {
|
|
|
userID := stateData["user_id"].(string)
|
|
|
if err := bindIdentity(userID, provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
return &CallbackResponse{
|
|
|
BindMode: true,
|
|
|
Bound: true,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// 登录模式:查找是否已绑定
|
|
|
identity, err := findIdentity(provider, userInfo.ID)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 如果已绑定,直接登录
|
|
|
if identity != nil {
|
|
|
return loginByIdentity(x, identity)
|
|
|
}
|
|
|
|
|
|
// 未绑定,返回需要绑定的信息
|
|
|
// 生成临时token用于后续绑定
|
|
|
tempToken, err := generateTempBindToken(provider, userInfo)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return &CallbackResponse{
|
|
|
NeedBind: true,
|
|
|
Provider: provider,
|
|
|
ProviderID: userInfo.ID,
|
|
|
AuthURL: "/auth/bind?token=" + tempToken,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// BindRequest 绑定请求
|
|
|
type BindRequest struct {
|
|
|
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
|
|
|
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
|
|
|
Password string `json:"password" src:"json" desc:"密码"`
|
|
|
}
|
|
|
|
|
|
// bindThirdParty 绑定第三方账号到已有账号
|
|
|
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
|
|
|
// 验证临时token
|
|
|
userInfo, err := verifyTempBindToken(req.TempToken)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
|
|
|
}
|
|
|
|
|
|
// 查找用户
|
|
|
var user models.User
|
|
|
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
|
|
|
if err := query.First(&user).Error; err != nil {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
|
|
|
}
|
|
|
|
|
|
// 验证密码
|
|
|
if !crypto.VerifyPassword(req.Password, user.Password) {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
|
|
|
}
|
|
|
|
|
|
// 检查用户状态
|
|
|
if user.Status != models.UserStatusActive {
|
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
|
}
|
|
|
|
|
|
// 绑定身份
|
|
|
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 生成登录token
|
|
|
return generateAuthResponse(x, &user)
|
|
|
}
|
|
|
|
|
|
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
|
|
|
type BindWithRegisterRequest struct {
|
|
|
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
|
|
|
Username string `json:"username" src:"json" desc:"用户名"`
|
|
|
Email string `json:"email" src:"json" desc:"邮箱"`
|
|
|
Phone string `json:"phone" src:"json" desc:"手机号"`
|
|
|
}
|
|
|
|
|
|
// bindWithRegister 绑定并创建新账号
|
|
|
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
|
|
|
// 验证临时token
|
|
|
userInfo, err := verifyTempBindToken(req.TempToken)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
|
|
|
}
|
|
|
|
|
|
// 检查用户名是否已存在
|
|
|
var count int64
|
|
|
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
|
|
|
if count > 0 {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("username already exists")
|
|
|
}
|
|
|
|
|
|
// 检查邮箱是否已存在
|
|
|
if req.Email != "" {
|
|
|
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
|
|
|
if count > 0 {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("email already exists")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 创建用户(随机密码,需要后续设置)
|
|
|
randomPassword := generateRandomPassword(16)
|
|
|
hashedPassword, _ := crypto.HashPassword(randomPassword, 12)
|
|
|
|
|
|
var email *string
|
|
|
if req.Email != "" {
|
|
|
email = &req.Email
|
|
|
}
|
|
|
var phone *string
|
|
|
if req.Phone != "" {
|
|
|
phone = &req.Phone
|
|
|
}
|
|
|
|
|
|
user := &models.User{
|
|
|
Username: req.Username,
|
|
|
Password: hashedPassword,
|
|
|
Email: email,
|
|
|
Phone: phone,
|
|
|
Nickname: userInfo.Name,
|
|
|
Avatar: userInfo.Avatar,
|
|
|
Status: models.UserStatusActive,
|
|
|
}
|
|
|
|
|
|
if user.Nickname == "" {
|
|
|
user.Nickname = req.Username
|
|
|
}
|
|
|
|
|
|
if err := cfg.DB().Create(user).Error; err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 授予默认角色 "user"
|
|
|
if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil {
|
|
|
// 记录错误但允许流程继续
|
|
|
}
|
|
|
|
|
|
// 绑定第三方身份
|
|
|
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 生成登录token
|
|
|
return generateAuthResponse(x, user)
|
|
|
}
|
|
|
|
|
|
// UnbindRequest 解绑请求
|
|
|
type UnbindRequest struct {
|
|
|
Provider string `json:"provider" src:"path" desc:"提供商"`
|
|
|
}
|
|
|
|
|
|
// unbindThirdParty 解除第三方账号绑定
|
|
|
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
|
|
|
userID := baseauth.VBaseAuth.UserID(x)
|
|
|
if userID == "" {
|
|
|
return vigo.ErrUnauthorized
|
|
|
}
|
|
|
|
|
|
// 删除绑定关系
|
|
|
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
|
|
|
return vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// BindingInfo 绑定信息
|
|
|
type BindingInfo struct {
|
|
|
Provider string `json:"provider"`
|
|
|
ProviderName string `json:"provider_name"`
|
|
|
Avatar string `json:"avatar"`
|
|
|
Email string `json:"email"`
|
|
|
CreatedAt string `json:"created_at"`
|
|
|
}
|
|
|
|
|
|
// listBindings 获取当前用户的第三方绑定列表
|
|
|
func listBindings(x *vigo.X) ([]BindingInfo, error) {
|
|
|
userID := baseauth.VBaseAuth.UserID(x)
|
|
|
if userID == "" {
|
|
|
return nil, vigo.ErrUnauthorized
|
|
|
}
|
|
|
|
|
|
var identities []models.Identity
|
|
|
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
result := make([]BindingInfo, 0, len(identities))
|
|
|
for _, id := range identities {
|
|
|
result = append(result, BindingInfo{
|
|
|
Provider: id.Provider,
|
|
|
ProviderName: id.ProviderName,
|
|
|
Avatar: id.Avatar,
|
|
|
Email: id.Email,
|
|
|
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
|
|
|
})
|
|
|
}
|
|
|
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 辅助函数 ====================
|
|
|
|
|
|
// ThirdPartyUserInfo 第三方用户信息
|
|
|
type ThirdPartyUserInfo struct {
|
|
|
Provider string
|
|
|
ID string
|
|
|
Name string
|
|
|
Email string
|
|
|
Avatar string
|
|
|
Raw map[string]any
|
|
|
}
|
|
|
|
|
|
func generateState() string {
|
|
|
b := make([]byte, 16)
|
|
|
rand.Read(b)
|
|
|
return hex.EncodeToString(b)
|
|
|
}
|
|
|
|
|
|
func saveState(state string, data map[string]any) error {
|
|
|
key := "oauth:state:" + state
|
|
|
return cache.SetObject(key, data, 10*time.Minute)
|
|
|
}
|
|
|
|
|
|
func verifyState(state string) (map[string]any, error) {
|
|
|
key := "oauth:state:" + state
|
|
|
var data map[string]any
|
|
|
if err := cache.GetObject(key, &data); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
cache.Delete(key)
|
|
|
return data, nil
|
|
|
}
|
|
|
|
|
|
func buildAuthURL(providerCode, state string) (string, error) {
|
|
|
// 从数据库获取提供商配置
|
|
|
var provider models.OAuthProvider
|
|
|
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
|
|
|
return "", vigo.ErrInvalidArg.WithString("provider not found or not enabled: " + providerCode)
|
|
|
}
|
|
|
|
|
|
params := url.Values{}
|
|
|
|
|
|
// 特殊处理微信的 appid 参数
|
|
|
if appidParam, ok := provider.ExtraConfig["appid_param"]; ok && appidParam != "" {
|
|
|
params.Set(appidParam, provider.ClientID)
|
|
|
} else {
|
|
|
params.Set("client_id", provider.ClientID)
|
|
|
}
|
|
|
|
|
|
params.Set("redirect_uri", getRedirectURI(providerCode))
|
|
|
params.Set("response_type", "code")
|
|
|
params.Set("state", state)
|
|
|
if len(provider.Scopes) > 0 {
|
|
|
params.Set("scope", strings.Join(provider.Scopes, " "))
|
|
|
}
|
|
|
|
|
|
return provider.AuthURL + "?" + params.Encode(), nil
|
|
|
}
|
|
|
|
|
|
func getRedirectURI(provider string) string {
|
|
|
return "/auth/callback/" + provider
|
|
|
}
|
|
|
|
|
|
func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) {
|
|
|
// 从数据库获取提供商配置
|
|
|
var provider models.OAuthProvider
|
|
|
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
|
|
|
return nil, fmt.Errorf("provider not found: %s", providerCode)
|
|
|
}
|
|
|
|
|
|
// 统一处理 OAuth 流程
|
|
|
return exchangeGeneric(provider, code)
|
|
|
}
|
|
|
|
|
|
func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
|
|
|
// 1. 交换 code 获取 access_token
|
|
|
tokenURL := provider.TokenURL
|
|
|
redirectURI := provider.RedirectURI
|
|
|
if redirectURI == "" {
|
|
|
redirectURI = getRedirectURI(provider.Code)
|
|
|
}
|
|
|
|
|
|
// 解密 ClientSecret
|
|
|
clientSecret := provider.ClientSecret
|
|
|
if clientSecret != "" {
|
|
|
decrypted, err := cfg.Global.Key.Decrypt(clientSecret)
|
|
|
if err == nil {
|
|
|
clientSecret = decrypted
|
|
|
}
|
|
|
// 如果解密失败,可能是明文存储的旧数据,继续使用原值
|
|
|
}
|
|
|
|
|
|
// 构建请求参数
|
|
|
data := url.Values{}
|
|
|
data.Set("code", code)
|
|
|
data.Set("client_id", provider.ClientID)
|
|
|
data.Set("client_secret", clientSecret)
|
|
|
data.Set("redirect_uri", redirectURI)
|
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
|
|
// 发送 token 请求
|
|
|
var tokenResp struct {
|
|
|
AccessToken string `json:"access_token"`
|
|
|
TokenType string `json:"token_type"`
|
|
|
ExpiresIn int `json:"expires_in"`
|
|
|
Error string `json:"error"`
|
|
|
ErrorDesc string `json:"error_description"`
|
|
|
}
|
|
|
|
|
|
usePost := provider.ExtraConfig["use_post_token"] == "true"
|
|
|
var err error
|
|
|
if usePost {
|
|
|
err = postFormJSON(tokenURL, data, &tokenResp)
|
|
|
} else {
|
|
|
err = getJSON(tokenURL+"?"+data.Encode(), &tokenResp)
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("exchange token failed: %w", err)
|
|
|
}
|
|
|
|
|
|
if tokenResp.Error != "" {
|
|
|
return nil, fmt.Errorf("oauth error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc)
|
|
|
}
|
|
|
|
|
|
// 2. 获取用户信息
|
|
|
userInfoURL := provider.UserInfoURL
|
|
|
var rawUser map[string]any
|
|
|
if err := getJSONWithAuth(userInfoURL, tokenResp.AccessToken, tokenResp.TokenType, &rawUser); err != nil {
|
|
|
return nil, fmt.Errorf("get user info failed: %w", err)
|
|
|
}
|
|
|
|
|
|
// 3. 解析字段映射
|
|
|
userInfo := &ThirdPartyUserInfo{
|
|
|
Provider: provider.Code,
|
|
|
Raw: rawUser,
|
|
|
}
|
|
|
|
|
|
// 提取字段
|
|
|
userInfo.ID = extractField(rawUser, provider.UserIDPath, "id")
|
|
|
userInfo.Name = extractField(rawUser, provider.UserNamePath, "name")
|
|
|
userInfo.Email = extractField(rawUser, provider.UserEmailPath, "email")
|
|
|
userInfo.Avatar = extractField(rawUser, provider.UserAvatarPath, "avatar")
|
|
|
|
|
|
if userInfo.ID == "" {
|
|
|
return nil, fmt.Errorf("failed to extract user id from response")
|
|
|
}
|
|
|
|
|
|
return userInfo, nil
|
|
|
}
|
|
|
|
|
|
// extractField 从嵌套 map 中提取字段值
|
|
|
func extractField(data map[string]any, path, defaultKey string) string {
|
|
|
if path == "" {
|
|
|
path = defaultKey
|
|
|
}
|
|
|
if path == "" {
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 支持点号分隔的路径,如 "data.employee_id"
|
|
|
keys := strings.Split(path, ".")
|
|
|
current := data
|
|
|
|
|
|
for i, key := range keys {
|
|
|
if i == len(keys)-1 {
|
|
|
// 最后一个 key
|
|
|
if val, ok := current[key]; ok {
|
|
|
switch v := val.(type) {
|
|
|
case string:
|
|
|
return v
|
|
|
case float64:
|
|
|
return fmt.Sprintf("%.0f", v)
|
|
|
case int:
|
|
|
return fmt.Sprintf("%d", v)
|
|
|
default:
|
|
|
return fmt.Sprintf("%v", v)
|
|
|
}
|
|
|
}
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 继续深入
|
|
|
if next, ok := current[key].(map[string]any); ok {
|
|
|
current = next
|
|
|
} else {
|
|
|
return ""
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// postFormJSON 发送 POST form 请求并解析 JSON 响应
|
|
|
func postFormJSON(url string, data url.Values, result any) error {
|
|
|
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
return json.NewDecoder(resp.Body).Decode(result)
|
|
|
}
|
|
|
|
|
|
// getJSON 发送 GET 请求并解析 JSON 响应
|
|
|
func getJSON(url string, result any) error {
|
|
|
resp, err := http.Get(url)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
return json.NewDecoder(resp.Body).Decode(result)
|
|
|
}
|
|
|
|
|
|
// getJSONWithAuth 发送带认证的 GET 请求
|
|
|
func getJSONWithAuth(url, token, tokenType string, result any) error {
|
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// 设置认证头
|
|
|
if tokenType == "Bearer" || tokenType == "" {
|
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
} else {
|
|
|
// 如 GitHub 使用 token 类型
|
|
|
req.Header.Set("Authorization", tokenType+" "+token)
|
|
|
}
|
|
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
return json.NewDecoder(resp.Body).Decode(result)
|
|
|
}
|
|
|
|
|
|
func findIdentity(provider, providerUID string) (*models.Identity, error) {
|
|
|
var identity models.Identity
|
|
|
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
|
return nil, nil
|
|
|
}
|
|
|
return nil, err
|
|
|
}
|
|
|
return &identity, nil
|
|
|
}
|
|
|
|
|
|
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
|
|
|
// 检查是否已绑定
|
|
|
var existing models.Identity
|
|
|
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
|
|
|
// 已绑定到其他账号
|
|
|
if existing.UserID != userID {
|
|
|
return vigo.ErrForbidden.WithString("this account is already bound to another user")
|
|
|
}
|
|
|
// 已绑定到当前账号,更新信息
|
|
|
existing.ProviderName = userInfo.Name
|
|
|
existing.Avatar = userInfo.Avatar
|
|
|
existing.Email = userInfo.Email
|
|
|
return cfg.DB().Save(&existing).Error
|
|
|
}
|
|
|
|
|
|
// 创建新绑定
|
|
|
identity := &models.Identity{
|
|
|
UserID: userID,
|
|
|
Provider: provider,
|
|
|
ProviderUID: userInfo.ID,
|
|
|
ProviderName: userInfo.Name,
|
|
|
Avatar: userInfo.Avatar,
|
|
|
Email: userInfo.Email,
|
|
|
}
|
|
|
return cfg.DB().Create(identity).Error
|
|
|
}
|
|
|
|
|
|
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
|
|
|
// 查找用户
|
|
|
var user models.User
|
|
|
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
|
|
|
return nil, vigo.ErrNotFound.WithString("user not found")
|
|
|
}
|
|
|
|
|
|
if user.Status != models.UserStatusActive {
|
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
|
}
|
|
|
|
|
|
// 生成token
|
|
|
authResp, err := generateAuthResponse(x, &user)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &CallbackResponse{
|
|
|
AuthResponse: authResp,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
|
|
|
// 获取用户的组织信息
|
|
|
orgs, err := getUserOrgs(user.ID)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
|
|
|
for _, org := range orgs {
|
|
|
orgClaims = append(orgClaims, jwt.OrgClaim{
|
|
|
OrgID: org.OrgID,
|
|
|
Code: org.Code,
|
|
|
Name: org.Name,
|
|
|
Roles: org.Roles,
|
|
|
Status: org.Status,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
tokenPair, err := jwt.GenerateTokenPair(
|
|
|
user.ID,
|
|
|
user.Username,
|
|
|
user.Nickname,
|
|
|
user.Avatar,
|
|
|
func() string {
|
|
|
if user.Email != nil {
|
|
|
return *user.Email
|
|
|
}
|
|
|
return ""
|
|
|
}(),
|
|
|
orgClaims,
|
|
|
)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return &AuthResponse{
|
|
|
AccessToken: tokenPair.AccessToken,
|
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
|
TokenType: tokenPair.TokenType,
|
|
|
ExpiresIn: tokenPair.ExpiresIn,
|
|
|
User: &UserInfo{
|
|
|
ID: user.ID,
|
|
|
Username: user.Username,
|
|
|
Nickname: user.Nickname,
|
|
|
Email: user.Email,
|
|
|
Avatar: user.Avatar,
|
|
|
},
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
|
|
|
// 生成临时token用于后续绑定
|
|
|
token := generateRandomToken(32)
|
|
|
key := "oauth:bind:" + token
|
|
|
data := map[string]any{
|
|
|
"provider": provider,
|
|
|
"provider_id": userInfo.ID,
|
|
|
"name": userInfo.Name,
|
|
|
"email": userInfo.Email,
|
|
|
"avatar": userInfo.Avatar,
|
|
|
"raw": userInfo.Raw,
|
|
|
}
|
|
|
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
return token, nil
|
|
|
}
|
|
|
|
|
|
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
|
|
|
// 验证临时token并返回用户信息
|
|
|
key := "oauth:bind:" + token
|
|
|
var data map[string]any
|
|
|
if err := cache.GetObject(key, &data); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
cache.Delete(key)
|
|
|
return &ThirdPartyUserInfo{
|
|
|
Provider: data["provider"].(string),
|
|
|
ID: data["provider_id"].(string),
|
|
|
Name: data["name"].(string),
|
|
|
Email: data["email"].(string),
|
|
|
Avatar: data["avatar"].(string),
|
|
|
Raw: data["raw"].(map[string]any),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateRandomPassword(length int) string {
|
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
|
|
b := make([]byte, length)
|
|
|
rand.Read(b)
|
|
|
for i := range b {
|
|
|
b[i] = charset[b[i]%byte(len(charset))]
|
|
|
}
|
|
|
return string(b)
|
|
|
}
|
|
|
|
|
|
func generateRandomToken(length int) string {
|
|
|
b := make([]byte, length)
|
|
|
rand.Read(b)
|
|
|
return hex.EncodeToString(b)
|
|
|
}
|