|
|
//
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
// 2025-03-04 16:08:06
|
|
|
// Distributed under terms of the MIT license.
|
|
|
//
|
|
|
|
|
|
package auth
|
|
|
|
|
|
import (
|
|
|
"crypto/rand"
|
|
|
"encoding/hex"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
baseauth "github.com/veypi/vbase/auth"
|
|
|
"github.com/veypi/vbase/cfg"
|
|
|
"github.com/veypi/vbase/libs/cache"
|
|
|
"github.com/veypi/vbase/libs/crypto"
|
|
|
"github.com/veypi/vbase/libs/jwt"
|
|
|
"github.com/veypi/vbase/models"
|
|
|
"github.com/veypi/vigo"
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
// ProviderInfo 第三方登录提供商信息
|
|
|
type ProviderInfo struct {
|
|
|
Name string `json:"name"`
|
|
|
DisplayName string `json:"display_name"`
|
|
|
Icon string `json:"icon"`
|
|
|
Enabled bool `json:"enabled"`
|
|
|
}
|
|
|
|
|
|
// listProviders 获取支持的第三方登录提供商列表
|
|
|
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
|
|
|
providers := []ProviderInfo{
|
|
|
{Name: "google", DisplayName: "Google", Icon: "google", Enabled: cfg.Config.Providers.Google.Enabled},
|
|
|
{Name: "github", DisplayName: "GitHub", Icon: "github", Enabled: cfg.Config.Providers.GitHub.Enabled},
|
|
|
{Name: "wechat", DisplayName: "微信", Icon: "wechat", Enabled: cfg.Config.Providers.WeChat.Enabled},
|
|
|
}
|
|
|
return providers, nil
|
|
|
}
|
|
|
|
|
|
// AuthorizeRequest 第三方登录授权请求
|
|
|
type AuthorizeRequest struct {
|
|
|
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
|
|
|
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
|
|
|
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
|
|
|
}
|
|
|
|
|
|
// AuthorizeResponse 授权响应
|
|
|
type AuthorizeResponse struct {
|
|
|
AuthURL string `json:"auth_url"`
|
|
|
State string `json:"state"`
|
|
|
}
|
|
|
|
|
|
// authorizeThirdParty 获取第三方登录授权URL
|
|
|
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
|
|
|
provider := strings.ToLower(req.Provider)
|
|
|
|
|
|
// 生成state
|
|
|
state := generateState()
|
|
|
|
|
|
// 存储state到缓存(用于验证回调)
|
|
|
stateData := map[string]any{
|
|
|
"provider": provider,
|
|
|
"redirect": req.Redirect,
|
|
|
"bind_mode": req.BindMode,
|
|
|
"created_at": time.Now().Unix(),
|
|
|
}
|
|
|
|
|
|
// 如果是绑定模式,需要当前用户登录
|
|
|
if req.BindMode {
|
|
|
userID := getCurrentUserID(x)
|
|
|
if userID == "" {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("login required for bind mode")
|
|
|
}
|
|
|
stateData["user_id"] = userID
|
|
|
}
|
|
|
|
|
|
// 保存state数据到缓存(10分钟有效)
|
|
|
if err := saveState(state, stateData); err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 构建授权URL
|
|
|
authURL, err := buildAuthURL(provider, state)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &AuthorizeResponse{
|
|
|
AuthURL: authURL,
|
|
|
State: state,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// CallbackRequest 第三方登录回调请求
|
|
|
type CallbackRequest struct {
|
|
|
Provider string `json:"provider" src:"path" desc:"提供商"`
|
|
|
Code string `json:"code" src:"query" desc:"授权码"`
|
|
|
State string `json:"state" src:"query" desc:"状态值"`
|
|
|
Error string `json:"error" src:"query" desc:"错误信息"`
|
|
|
}
|
|
|
|
|
|
// CallbackResponse 回调响应
|
|
|
type CallbackResponse struct {
|
|
|
// 绑定模式
|
|
|
BindMode bool `json:"bind_mode,omitempty"`
|
|
|
Bound bool `json:"bound,omitempty"`
|
|
|
|
|
|
// 登录模式
|
|
|
NeedBind bool `json:"need_bind,omitempty"`
|
|
|
Provider string `json:"provider,omitempty"`
|
|
|
ProviderID string `json:"provider_id,omitempty"`
|
|
|
AuthURL string `json:"auth_url,omitempty"`
|
|
|
|
|
|
// 登录成功
|
|
|
*AuthResponse
|
|
|
}
|
|
|
|
|
|
// callbackThirdParty 处理第三方登录回调
|
|
|
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
|
|
|
if req.Error != "" {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error)
|
|
|
}
|
|
|
|
|
|
if req.Code == "" || req.State == "" {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("missing code or state")
|
|
|
}
|
|
|
|
|
|
// 验证state
|
|
|
stateData, err := verifyState(req.State)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
|
|
|
}
|
|
|
|
|
|
provider := stateData["provider"].(string)
|
|
|
bindMode := stateData["bind_mode"].(bool)
|
|
|
|
|
|
// 交换access_token并获取用户信息
|
|
|
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 绑定模式:将第三方身份绑定到当前用户
|
|
|
if bindMode {
|
|
|
userID := stateData["user_id"].(string)
|
|
|
if err := bindIdentity(userID, provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
return &CallbackResponse{
|
|
|
BindMode: true,
|
|
|
Bound: true,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// 登录模式:查找是否已绑定
|
|
|
identity, err := findIdentity(provider, userInfo.ID)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 如果已绑定,直接登录
|
|
|
if identity != nil {
|
|
|
return loginByIdentity(x, identity)
|
|
|
}
|
|
|
|
|
|
// 未绑定,返回需要绑定的信息
|
|
|
// 生成临时token用于后续绑定
|
|
|
tempToken, err := generateTempBindToken(provider, userInfo)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return &CallbackResponse{
|
|
|
NeedBind: true,
|
|
|
Provider: provider,
|
|
|
ProviderID: userInfo.ID,
|
|
|
AuthURL: "/auth/bind?token=" + tempToken,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
// BindRequest 绑定请求
|
|
|
type BindRequest struct {
|
|
|
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
|
|
|
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
|
|
|
Password string `json:"password" src:"json" desc:"密码"`
|
|
|
}
|
|
|
|
|
|
// bindThirdParty 绑定第三方账号到已有账号
|
|
|
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
|
|
|
// 验证临时token
|
|
|
userInfo, err := verifyTempBindToken(req.TempToken)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
|
|
|
}
|
|
|
|
|
|
// 查找用户
|
|
|
var user models.User
|
|
|
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
|
|
|
if err := query.First(&user).Error; err != nil {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
|
|
|
}
|
|
|
|
|
|
// 验证密码
|
|
|
if !crypto.VerifyPassword(req.Password, user.Password) {
|
|
|
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
|
|
|
}
|
|
|
|
|
|
// 检查用户状态
|
|
|
if user.Status != models.UserStatusActive {
|
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
|
}
|
|
|
|
|
|
// 绑定身份
|
|
|
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 生成登录token
|
|
|
return generateAuthResponse(x, &user)
|
|
|
}
|
|
|
|
|
|
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
|
|
|
type BindWithRegisterRequest struct {
|
|
|
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
|
|
|
Username string `json:"username" src:"json" desc:"用户名"`
|
|
|
Email string `json:"email" src:"json" desc:"邮箱"`
|
|
|
Phone string `json:"phone" src:"json" desc:"手机号"`
|
|
|
}
|
|
|
|
|
|
// bindWithRegister 绑定并创建新账号
|
|
|
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
|
|
|
// 验证临时token
|
|
|
userInfo, err := verifyTempBindToken(req.TempToken)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
|
|
|
}
|
|
|
|
|
|
// 检查用户名是否已存在
|
|
|
var count int64
|
|
|
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
|
|
|
if count > 0 {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("username already exists")
|
|
|
}
|
|
|
|
|
|
// 检查邮箱是否已存在
|
|
|
if req.Email != "" {
|
|
|
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
|
|
|
if count > 0 {
|
|
|
return nil, vigo.ErrInvalidArg.WithString("email already exists")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 创建用户(随机密码,需要后续设置)
|
|
|
randomPassword := generateRandomPassword(16)
|
|
|
hashedPassword, _ := crypto.HashPassword(randomPassword, cfg.Config.Security.BcryptCost)
|
|
|
|
|
|
var email *string
|
|
|
if req.Email != "" {
|
|
|
email = &req.Email
|
|
|
}
|
|
|
var phone *string
|
|
|
if req.Phone != "" {
|
|
|
phone = &req.Phone
|
|
|
}
|
|
|
|
|
|
user := &models.User{
|
|
|
Username: req.Username,
|
|
|
Password: hashedPassword,
|
|
|
Email: email,
|
|
|
Phone: phone,
|
|
|
Nickname: userInfo.Name,
|
|
|
Avatar: userInfo.Avatar,
|
|
|
Status: models.UserStatusActive,
|
|
|
}
|
|
|
|
|
|
if user.Nickname == "" {
|
|
|
user.Nickname = req.Username
|
|
|
}
|
|
|
|
|
|
if err := cfg.DB().Create(user).Error; err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
// 授予默认角色 "user"
|
|
|
if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil {
|
|
|
// 记录错误但允许流程继续
|
|
|
}
|
|
|
|
|
|
// 绑定第三方身份
|
|
|
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 生成登录token
|
|
|
return generateAuthResponse(x, user)
|
|
|
}
|
|
|
|
|
|
// UnbindRequest 解绑请求
|
|
|
type UnbindRequest struct {
|
|
|
Provider string `json:"provider" src:"path" desc:"提供商"`
|
|
|
}
|
|
|
|
|
|
// unbindThirdParty 解除第三方账号绑定
|
|
|
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
|
|
|
userID := getCurrentUserID(x)
|
|
|
if userID == "" {
|
|
|
return vigo.ErrUnauthorized
|
|
|
}
|
|
|
|
|
|
// 删除绑定关系
|
|
|
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
|
|
|
return vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// BindingInfo 绑定信息
|
|
|
type BindingInfo struct {
|
|
|
Provider string `json:"provider"`
|
|
|
ProviderName string `json:"provider_name"`
|
|
|
Avatar string `json:"avatar"`
|
|
|
Email string `json:"email"`
|
|
|
CreatedAt string `json:"created_at"`
|
|
|
}
|
|
|
|
|
|
// listBindings 获取当前用户的第三方绑定列表
|
|
|
func listBindings(x *vigo.X) ([]BindingInfo, error) {
|
|
|
userID := getCurrentUserID(x)
|
|
|
if userID == "" {
|
|
|
return nil, vigo.ErrUnauthorized
|
|
|
}
|
|
|
|
|
|
var identities []models.Identity
|
|
|
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
result := make([]BindingInfo, 0, len(identities))
|
|
|
for _, id := range identities {
|
|
|
result = append(result, BindingInfo{
|
|
|
Provider: id.Provider,
|
|
|
ProviderName: id.ProviderName,
|
|
|
Avatar: id.Avatar,
|
|
|
Email: id.Email,
|
|
|
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
|
|
|
})
|
|
|
}
|
|
|
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// ==================== 辅助函数 ====================
|
|
|
|
|
|
// ThirdPartyUserInfo 第三方用户信息
|
|
|
type ThirdPartyUserInfo struct {
|
|
|
Provider string
|
|
|
ID string
|
|
|
Name string
|
|
|
Email string
|
|
|
Avatar string
|
|
|
Raw map[string]any
|
|
|
}
|
|
|
|
|
|
func generateState() string {
|
|
|
b := make([]byte, 16)
|
|
|
rand.Read(b)
|
|
|
return hex.EncodeToString(b)
|
|
|
}
|
|
|
|
|
|
func saveState(state string, data map[string]any) error {
|
|
|
key := "oauth:state:" + state
|
|
|
return cache.SetObject(key, data, 10*time.Minute)
|
|
|
}
|
|
|
|
|
|
func verifyState(state string) (map[string]any, error) {
|
|
|
key := "oauth:state:" + state
|
|
|
var data map[string]any
|
|
|
if err := cache.GetObject(key, &data); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
cache.Delete(key)
|
|
|
return data, nil
|
|
|
}
|
|
|
|
|
|
func buildAuthURL(provider, state string) (string, error) {
|
|
|
var pc cfg.OAuthProviderConfig
|
|
|
switch provider {
|
|
|
case "google":
|
|
|
pc = cfg.Config.Providers.Google
|
|
|
case "github":
|
|
|
pc = cfg.Config.Providers.GitHub
|
|
|
case "wechat":
|
|
|
pc = cfg.Config.Providers.WeChat
|
|
|
default:
|
|
|
return "", vigo.ErrInvalidArg.WithString("unsupported provider: " + provider)
|
|
|
}
|
|
|
|
|
|
if !pc.Enabled {
|
|
|
return "", vigo.ErrInvalidArg.WithString("provider not enabled: " + provider)
|
|
|
}
|
|
|
|
|
|
params := url.Values{}
|
|
|
params.Set("client_id", pc.ClientID)
|
|
|
params.Set("redirect_uri", getRedirectURI(provider))
|
|
|
params.Set("response_type", "code")
|
|
|
params.Set("state", state)
|
|
|
params.Set("scope", strings.Join(pc.Scopes, " "))
|
|
|
|
|
|
// 特殊处理
|
|
|
if provider == "wechat" {
|
|
|
params.Set("appid", pc.ClientID)
|
|
|
}
|
|
|
|
|
|
return pc.AuthURL + "?" + params.Encode(), nil
|
|
|
}
|
|
|
|
|
|
func getRedirectURI(provider string) string {
|
|
|
return "/auth/callback/" + provider
|
|
|
}
|
|
|
|
|
|
func exchangeAndGetUserInfo(provider, code string) (*ThirdPartyUserInfo, error) {
|
|
|
switch provider {
|
|
|
case "google":
|
|
|
return exchangeGoogle(code)
|
|
|
case "github":
|
|
|
return exchangeGitHub(code)
|
|
|
case "wechat":
|
|
|
return exchangeWeChat(code)
|
|
|
default:
|
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func exchangeGoogle(code string) (*ThirdPartyUserInfo, error) {
|
|
|
pc := cfg.Config.Providers.Google
|
|
|
|
|
|
// 交换access_token
|
|
|
data := url.Values{}
|
|
|
data.Set("code", code)
|
|
|
data.Set("client_id", pc.ClientID)
|
|
|
data.Set("client_secret", pc.ClientSecret)
|
|
|
data.Set("redirect_uri", getRedirectURI("google"))
|
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
|
|
resp, err := http.Post(pc.TokenURL, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
var tokenResp struct {
|
|
|
AccessToken string `json:"access_token"`
|
|
|
Error string `json:"error"`
|
|
|
}
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if tokenResp.Error != "" {
|
|
|
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
|
|
|
}
|
|
|
|
|
|
// 获取用户信息
|
|
|
req, _ := http.NewRequest("GET", pc.UserInfoURL, nil)
|
|
|
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
|
|
|
|
userResp, err := http.DefaultClient.Do(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer userResp.Body.Close()
|
|
|
|
|
|
var userData struct {
|
|
|
Sub string `json:"sub"`
|
|
|
Name string `json:"name"`
|
|
|
Email string `json:"email"`
|
|
|
Picture string `json:"picture"`
|
|
|
}
|
|
|
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &ThirdPartyUserInfo{
|
|
|
Provider: "google",
|
|
|
ID: userData.Sub,
|
|
|
Name: userData.Name,
|
|
|
Email: userData.Email,
|
|
|
Avatar: userData.Picture,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func exchangeGitHub(code string) (*ThirdPartyUserInfo, error) {
|
|
|
pc := cfg.Config.Providers.GitHub
|
|
|
|
|
|
// 交换access_token
|
|
|
data := url.Values{}
|
|
|
data.Set("code", code)
|
|
|
data.Set("client_id", pc.ClientID)
|
|
|
data.Set("client_secret", pc.ClientSecret)
|
|
|
data.Set("redirect_uri", getRedirectURI("github"))
|
|
|
|
|
|
req, _ := http.NewRequest("POST", pc.TokenURL, strings.NewReader(data.Encode()))
|
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
var tokenResp struct {
|
|
|
AccessToken string `json:"access_token"`
|
|
|
Error string `json:"error"`
|
|
|
}
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if tokenResp.Error != "" {
|
|
|
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
|
|
|
}
|
|
|
|
|
|
// 获取用户信息
|
|
|
req, _ = http.NewRequest("GET", pc.UserInfoURL, nil)
|
|
|
req.Header.Set("Authorization", "token "+tokenResp.AccessToken)
|
|
|
|
|
|
userResp, err := http.DefaultClient.Do(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer userResp.Body.Close()
|
|
|
|
|
|
var userData struct {
|
|
|
ID int `json:"id"`
|
|
|
Login string `json:"login"`
|
|
|
Name string `json:"name"`
|
|
|
Email string `json:"email"`
|
|
|
AvatarURL string `json:"avatar_url"`
|
|
|
}
|
|
|
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &ThirdPartyUserInfo{
|
|
|
Provider: "github",
|
|
|
ID: fmt.Sprintf("%d", userData.ID),
|
|
|
Name: userData.Name,
|
|
|
Email: userData.Email,
|
|
|
Avatar: userData.AvatarURL,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func exchangeWeChat(code string) (*ThirdPartyUserInfo, error) {
|
|
|
pc := cfg.Config.Providers.WeChat
|
|
|
|
|
|
// 交换access_token
|
|
|
urlStr := fmt.Sprintf("%s?appid=%s&secret=%s&code=%s&grant_type=authorization_code",
|
|
|
pc.TokenURL, pc.ClientID, pc.ClientSecret, code)
|
|
|
|
|
|
resp, err := http.Get(urlStr)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
var tokenResp struct {
|
|
|
AccessToken string `json:"access_token"`
|
|
|
OpenID string `json:"openid"`
|
|
|
UnionID string `json:"unionid"`
|
|
|
ErrCode int `json:"errcode"`
|
|
|
ErrMsg string `json:"errmsg"`
|
|
|
}
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if tokenResp.ErrCode != 0 {
|
|
|
return nil, fmt.Errorf("wechat error: %s", tokenResp.ErrMsg)
|
|
|
}
|
|
|
|
|
|
// 获取用户信息
|
|
|
userInfoURL := fmt.Sprintf("%s?access_token=%s&openid=%s",
|
|
|
pc.UserInfoURL, tokenResp.AccessToken, tokenResp.OpenID)
|
|
|
|
|
|
userResp, err := http.Get(userInfoURL)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
defer userResp.Body.Close()
|
|
|
|
|
|
var userData struct {
|
|
|
OpenID string `json:"openid"`
|
|
|
Nickname string `json:"nickname"`
|
|
|
HeadImg string `json:"headimgurl"`
|
|
|
UnionID string `json:"unionid"`
|
|
|
ErrCode int `json:"errcode"`
|
|
|
ErrMsg string `json:"errmsg"`
|
|
|
}
|
|
|
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if userData.ErrCode != 0 {
|
|
|
return nil, fmt.Errorf("wechat error: %s", userData.ErrMsg)
|
|
|
}
|
|
|
|
|
|
return &ThirdPartyUserInfo{
|
|
|
Provider: "wechat",
|
|
|
ID: tokenResp.UnionID,
|
|
|
Name: userData.Nickname,
|
|
|
Avatar: userData.HeadImg,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func findIdentity(provider, providerUID string) (*models.Identity, error) {
|
|
|
var identity models.Identity
|
|
|
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
|
return nil, nil
|
|
|
}
|
|
|
return nil, err
|
|
|
}
|
|
|
return &identity, nil
|
|
|
}
|
|
|
|
|
|
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
|
|
|
// 检查是否已绑定
|
|
|
var existing models.Identity
|
|
|
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
|
|
|
// 已绑定到其他账号
|
|
|
if existing.UserID != userID {
|
|
|
return vigo.ErrForbidden.WithString("this account is already bound to another user")
|
|
|
}
|
|
|
// 已绑定到当前账号,更新信息
|
|
|
existing.ProviderName = userInfo.Name
|
|
|
existing.Avatar = userInfo.Avatar
|
|
|
existing.Email = userInfo.Email
|
|
|
return cfg.DB().Save(&existing).Error
|
|
|
}
|
|
|
|
|
|
// 创建新绑定
|
|
|
identity := &models.Identity{
|
|
|
UserID: userID,
|
|
|
Provider: provider,
|
|
|
ProviderUID: userInfo.ID,
|
|
|
ProviderName: userInfo.Name,
|
|
|
Avatar: userInfo.Avatar,
|
|
|
Email: userInfo.Email,
|
|
|
}
|
|
|
return cfg.DB().Create(identity).Error
|
|
|
}
|
|
|
|
|
|
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
|
|
|
// 查找用户
|
|
|
var user models.User
|
|
|
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
|
|
|
return nil, vigo.ErrNotFound.WithString("user not found")
|
|
|
}
|
|
|
|
|
|
if user.Status != models.UserStatusActive {
|
|
|
return nil, vigo.ErrForbidden.WithString("user is disabled")
|
|
|
}
|
|
|
|
|
|
// 生成token
|
|
|
authResp, err := generateAuthResponse(x, &user)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &CallbackResponse{
|
|
|
AuthResponse: authResp,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
|
|
|
// 获取用户的组织信息
|
|
|
orgs, err := getUserOrgs(user.ID)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
|
|
|
for _, org := range orgs {
|
|
|
orgClaims = append(orgClaims, jwt.OrgClaim{
|
|
|
OrgID: org.OrgID,
|
|
|
Code: org.Code,
|
|
|
Name: org.Name,
|
|
|
Roles: org.Roles,
|
|
|
Status: org.Status,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
tokenPair, err := jwt.GenerateTokenPair(
|
|
|
user.ID,
|
|
|
user.Username,
|
|
|
user.Nickname,
|
|
|
user.Avatar,
|
|
|
func() string {
|
|
|
if user.Email != nil {
|
|
|
return *user.Email
|
|
|
}
|
|
|
return ""
|
|
|
}(),
|
|
|
orgClaims,
|
|
|
)
|
|
|
if err != nil {
|
|
|
return nil, vigo.ErrInternalServer.WithError(err)
|
|
|
}
|
|
|
|
|
|
return &AuthResponse{
|
|
|
AccessToken: tokenPair.AccessToken,
|
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
|
TokenType: tokenPair.TokenType,
|
|
|
ExpiresIn: tokenPair.ExpiresIn,
|
|
|
User: &UserInfo{
|
|
|
ID: user.ID,
|
|
|
Username: user.Username,
|
|
|
Nickname: user.Nickname,
|
|
|
Email: user.Email,
|
|
|
Avatar: user.Avatar,
|
|
|
},
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
|
|
|
// 生成临时token用于后续绑定
|
|
|
token := generateRandomToken(32)
|
|
|
key := "oauth:bind:" + token
|
|
|
data := map[string]any{
|
|
|
"provider": provider,
|
|
|
"provider_id": userInfo.ID,
|
|
|
"name": userInfo.Name,
|
|
|
"email": userInfo.Email,
|
|
|
"avatar": userInfo.Avatar,
|
|
|
"raw": userInfo.Raw,
|
|
|
}
|
|
|
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
return token, nil
|
|
|
}
|
|
|
|
|
|
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
|
|
|
// 验证临时token并返回用户信息
|
|
|
key := "oauth:bind:" + token
|
|
|
var data map[string]any
|
|
|
if err := cache.GetObject(key, &data); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
// 验证后删除
|
|
|
cache.Delete(key)
|
|
|
return &ThirdPartyUserInfo{
|
|
|
Provider: data["provider"].(string),
|
|
|
ID: data["provider_id"].(string),
|
|
|
Name: data["name"].(string),
|
|
|
Email: data["email"].(string),
|
|
|
Avatar: data["avatar"].(string),
|
|
|
Raw: data["raw"].(map[string]any),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
func generateRandomPassword(length int) string {
|
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
|
|
b := make([]byte, length)
|
|
|
rand.Read(b)
|
|
|
for i := range b {
|
|
|
b[i] = charset[b[i]%byte(len(charset))]
|
|
|
}
|
|
|
return string(b)
|
|
|
}
|
|
|
|
|
|
func generateRandomToken(length int) string {
|
|
|
b := make([]byte, length)
|
|
|
rand.Read(b)
|
|
|
return hex.EncodeToString(b)
|
|
|
}
|