You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/api/auth/thirdparty.go

771 lines
20 KiB
Go

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
1 week ago
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
1 week ago
baseauth "github.com/veypi/vbase/auth"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/crypto"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
1 week ago
"github.com/veypi/vigo"
"gorm.io/gorm"
)
// ProviderInfo 第三方登录提供商信息
type ProviderInfo struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Icon string `json:"icon"`
Enabled bool `json:"enabled"`
}
// listProviders 获取支持的第三方登录提供商列表
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
var dbProviders []models.OAuthProvider
if err := cfg.DB().Where("enabled = ?", true).Order("sort_order").Find(&dbProviders).Error; err != nil {
return nil, err
}
providers := make([]ProviderInfo, 0, len(dbProviders))
for _, p := range dbProviders {
providers = append(providers, ProviderInfo{
Name: p.Code,
DisplayName: p.Name,
Icon: p.Icon,
Enabled: p.Enabled,
})
1 week ago
}
return providers, nil
}
// AuthorizeRequest 第三方登录授权请求
type AuthorizeRequest struct {
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
1 week ago
}
// AuthorizeResponse 授权响应
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
}
// authorizeThirdParty 获取第三方登录授权URL
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
1 week ago
provider := strings.ToLower(req.Provider)
// 生成state
state := generateState()
// 存储state到缓存用于验证回调
stateData := map[string]any{
"provider": provider,
"redirect": req.Redirect,
"bind_mode": req.BindMode,
1 week ago
"created_at": time.Now().Unix(),
}
// 如果是绑定模式,需要当前用户登录
if req.BindMode {
userID := baseauth.VBaseAuth.UserID(x)
1 week ago
if userID == "" {
return nil, vigo.ErrUnauthorized.WithString("login required for bind mode")
1 week ago
}
stateData["user_id"] = userID
}
// 保存state数据到缓存10分钟有效
if err := saveState(state, stateData); err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 构建授权URL
authURL, err := buildAuthURL(provider, state)
if err != nil {
return nil, err
}
return &AuthorizeResponse{
AuthURL: authURL,
State: state,
}, nil
}
// CallbackRequest 第三方登录回调请求
type CallbackRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
Code string `json:"code" src:"query" desc:"授权码"`
State string `json:"state" src:"query" desc:"状态值"`
Error string `json:"error" src:"query" desc:"错误信息"`
}
// CallbackResponse 回调响应
type CallbackResponse struct {
// 绑定模式
BindMode bool `json:"bind_mode,omitempty"`
Bound bool `json:"bound,omitempty"`
// 登录模式
NeedBind bool `json:"need_bind,omitempty"`
Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
1 week ago
// 登录成功
*AuthResponse
}
// callbackThirdParty 处理第三方登录回调
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
1 week ago
if req.Error != "" {
return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error)
1 week ago
}
if req.Code == "" || req.State == "" {
return nil, vigo.ErrInvalidArg.WithString("missing code or state")
1 week ago
}
// 验证state
stateData, err := verifyState(req.State)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
1 week ago
}
provider := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool)
// 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定模式:将第三方身份绑定到当前用户
if bindMode {
userID := stateData["user_id"].(string)
if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err
}
return &CallbackResponse{
BindMode: true,
Bound: true,
}, nil
}
// 登录模式:查找是否已绑定
identity, err := findIdentity(provider, userInfo.ID)
if err != nil {
return nil, err
}
// 如果已绑定,直接登录
if identity != nil {
return loginByIdentity(x, identity)
}
// 未绑定,返回需要绑定的信息
// 生成临时token用于后续绑定
tempToken, err := generateTempBindToken(provider, userInfo)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &CallbackResponse{
NeedBind: true,
Provider: provider,
ProviderID: userInfo.ID,
AuthURL: "/auth/bind?token=" + tempToken,
}, nil
}
// BindRequest 绑定请求
type BindRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
}
// bindThirdParty 绑定第三方账号到已有账号
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
1 week ago
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
1 week ago
}
// 查找用户
var user models.User
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
1 week ago
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
1 week ago
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
1 week ago
}
// 检查用户状态
if user.Status != models.UserStatusActive {
1 week ago
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 绑定身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, &user)
}
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
type BindWithRegisterRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名"`
Email string `json:"email" src:"json" desc:"邮箱"`
Phone string `json:"phone" src:"json" desc:"手机号"`
}
// bindWithRegister 绑定并创建新账号
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
1 week ago
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
1 week ago
}
// 检查用户名是否已存在
var count int64
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
1 week ago
if count > 0 {
return nil, vigo.ErrInvalidArg.WithString("username already exists")
1 week ago
}
// 检查邮箱是否已存在
if req.Email != "" {
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
1 week ago
if count > 0 {
return nil, vigo.ErrInvalidArg.WithString("email already exists")
1 week ago
}
}
// 创建用户(随机密码,需要后续设置)
randomPassword := generateRandomPassword(16)
hashedPassword, _ := crypto.HashPassword(randomPassword, 12)
1 week ago
1 week ago
var email *string
if req.Email != "" {
email = &req.Email
}
var phone *string
if req.Phone != "" {
phone = &req.Phone
}
user := &models.User{
1 week ago
Username: req.Username,
Password: hashedPassword,
1 week ago
Email: email,
Phone: phone,
1 week ago
Nickname: userInfo.Name,
Avatar: userInfo.Avatar,
Status: models.UserStatusActive,
1 week ago
}
if user.Nickname == "" {
user.Nickname = req.Username
}
if err := cfg.DB().Create(user).Error; err != nil {
1 week ago
return nil, vigo.ErrInternalServer.WithError(err)
}
1 week ago
// 授予默认角色 "user"
if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil {
// 记录错误但允许流程继续
}
1 week ago
// 绑定第三方身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, user)
}
// UnbindRequest 解绑请求
type UnbindRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
}
// unbindThirdParty 解除第三方账号绑定
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
userID := baseauth.VBaseAuth.UserID(x)
1 week ago
if userID == "" {
return vigo.ErrUnauthorized
1 week ago
}
// 删除绑定关系
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
1 week ago
return vigo.ErrInternalServer.WithError(err)
}
return nil
}
// BindingInfo 绑定信息
type BindingInfo struct {
Provider string `json:"provider"`
ProviderName string `json:"provider_name"`
Avatar string `json:"avatar"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
}
// listBindings 获取当前用户的第三方绑定列表
func listBindings(x *vigo.X) ([]BindingInfo, error) {
userID := baseauth.VBaseAuth.UserID(x)
1 week ago
if userID == "" {
return nil, vigo.ErrUnauthorized
1 week ago
}
var identities []models.Identity
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
1 week ago
return nil, vigo.ErrInternalServer.WithError(err)
}
result := make([]BindingInfo, 0, len(identities))
for _, id := range identities {
result = append(result, BindingInfo{
Provider: id.Provider,
ProviderName: id.ProviderName,
Avatar: id.Avatar,
Email: id.Email,
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
// ==================== 辅助函数 ====================
// ThirdPartyUserInfo 第三方用户信息
type ThirdPartyUserInfo struct {
Provider string
ID string
Name string
Email string
Avatar string
Raw map[string]any
}
func generateState() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func saveState(state string, data map[string]any) error {
key := "oauth:state:" + state
return cache.SetObject(key, data, 10*time.Minute)
}
func verifyState(state string) (map[string]any, error) {
key := "oauth:state:" + state
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return data, nil
}
func buildAuthURL(providerCode, state string) (string, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
return "", vigo.ErrInvalidArg.WithString("provider not found or not enabled: " + providerCode)
1 week ago
}
params := url.Values{}
// 特殊处理微信的 appid 参数
if appidParam, ok := provider.ExtraConfig["appid_param"]; ok && appidParam != "" {
params.Set(appidParam, provider.ClientID)
} else {
params.Set("client_id", provider.ClientID)
1 week ago
}
params.Set("redirect_uri", getRedirectURI(providerCode))
1 week ago
params.Set("response_type", "code")
params.Set("state", state)
if len(provider.Scopes) > 0 {
params.Set("scope", strings.Join(provider.Scopes, " "))
1 week ago
}
return provider.AuthURL + "?" + params.Encode(), nil
1 week ago
}
func getRedirectURI(provider string) string {
return "/auth/callback/" + provider
1 week ago
}
func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
return nil, fmt.Errorf("provider not found: %s", providerCode)
1 week ago
}
// 统一处理 OAuth 流程
return exchangeGeneric(provider, code)
1 week ago
}
func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
// 1. 交换 code 获取 access_token
tokenURL := provider.TokenURL
redirectURI := provider.RedirectURI
if redirectURI == "" {
redirectURI = getRedirectURI(provider.Code)
}
1 week ago
// 解密 ClientSecret
clientSecret := provider.ClientSecret
if clientSecret != "" {
decrypted, err := cfg.Global.Key.Decrypt(clientSecret)
if err == nil {
clientSecret = decrypted
}
// 如果解密失败,可能是明文存储的旧数据,继续使用原值
}
// 构建请求参数
1 week ago
data := url.Values{}
data.Set("code", code)
data.Set("client_id", provider.ClientID)
data.Set("client_secret", clientSecret)
data.Set("redirect_uri", redirectURI)
1 week ago
data.Set("grant_type", "authorization_code")
// 发送 token 请求
1 week ago
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
1 week ago
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
1 week ago
}
usePost := provider.ExtraConfig["use_post_token"] == "true"
var err error
if usePost {
err = postFormJSON(tokenURL, data, &tokenResp)
} else {
err = getJSON(tokenURL+"?"+data.Encode(), &tokenResp)
}
1 week ago
if err != nil {
return nil, fmt.Errorf("exchange token failed: %w", err)
1 week ago
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("oauth error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc)
1 week ago
}
// 2. 获取用户信息
userInfoURL := provider.UserInfoURL
var rawUser map[string]any
if err := getJSONWithAuth(userInfoURL, tokenResp.AccessToken, tokenResp.TokenType, &rawUser); err != nil {
return nil, fmt.Errorf("get user info failed: %w", err)
1 week ago
}
// 3. 解析字段映射
userInfo := &ThirdPartyUserInfo{
Provider: provider.Code,
Raw: rawUser,
}
1 week ago
// 提取字段
userInfo.ID = extractField(rawUser, provider.UserIDPath, "id")
userInfo.Name = extractField(rawUser, provider.UserNamePath, "name")
userInfo.Email = extractField(rawUser, provider.UserEmailPath, "email")
userInfo.Avatar = extractField(rawUser, provider.UserAvatarPath, "avatar")
1 week ago
if userInfo.ID == "" {
return nil, fmt.Errorf("failed to extract user id from response")
}
1 week ago
return userInfo, nil
}
1 week ago
// extractField 从嵌套 map 中提取字段值
func extractField(data map[string]any, path, defaultKey string) string {
if path == "" {
path = defaultKey
}
if path == "" {
return ""
}
// 支持点号分隔的路径,如 "data.employee_id"
keys := strings.Split(path, ".")
current := data
for i, key := range keys {
if i == len(keys)-1 {
// 最后一个 key
if val, ok := current[key]; ok {
switch v := val.(type) {
case string:
return v
case float64:
return fmt.Sprintf("%.0f", v)
case int:
return fmt.Sprintf("%d", v)
default:
return fmt.Sprintf("%v", v)
}
}
return ""
}
1 week ago
// 继续深入
if next, ok := current[key].(map[string]any); ok {
current = next
} else {
return ""
}
1 week ago
}
return ""
}
1 week ago
// postFormJSON 发送 POST form 请求并解析 JSON 响应
func postFormJSON(url string, data url.Values, result any) error {
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
1 week ago
if err != nil {
return err
1 week ago
}
defer resp.Body.Close()
1 week ago
return json.NewDecoder(resp.Body).Decode(result)
1 week ago
}
// getJSON 发送 GET 请求并解析 JSON 响应
func getJSON(url string, result any) error {
resp, err := http.Get(url)
1 week ago
if err != nil {
return err
1 week ago
}
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(result)
}
1 week ago
// getJSONWithAuth 发送带认证的 GET 请求
func getJSONWithAuth(url, token, tokenType string, result any) error {
req, err := http.NewRequest("GET", url, nil)
1 week ago
if err != nil {
return err
1 week ago
}
// 设置认证头
if tokenType == "Bearer" || tokenType == "" {
req.Header.Set("Authorization", "Bearer "+token)
} else {
// 如 GitHub 使用 token 类型
req.Header.Set("Authorization", tokenType+" "+token)
1 week ago
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
1 week ago
}
defer resp.Body.Close()
1 week ago
return json.NewDecoder(resp.Body).Decode(result)
1 week ago
}
func findIdentity(provider, providerUID string) (*models.Identity, error) {
var identity models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
1 week ago
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &identity, nil
}
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
// 检查是否已绑定
var existing models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
1 week ago
// 已绑定到其他账号
if existing.UserID != userID {
return vigo.ErrForbidden.WithString("this account is already bound to another user")
}
// 已绑定到当前账号,更新信息
existing.ProviderName = userInfo.Name
existing.Avatar = userInfo.Avatar
existing.Email = userInfo.Email
return cfg.DB().Save(&existing).Error
1 week ago
}
// 创建新绑定
identity := &models.Identity{
1 week ago
UserID: userID,
Provider: provider,
ProviderUID: userInfo.ID,
ProviderName: userInfo.Name,
Avatar: userInfo.Avatar,
Email: userInfo.Email,
}
return cfg.DB().Create(identity).Error
1 week ago
}
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
1 week ago
// 查找用户
var user models.User
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
1 week ago
return nil, vigo.ErrNotFound.WithString("user not found")
}
if user.Status != models.UserStatusActive {
1 week ago
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 生成token
authResp, err := generateAuthResponse(x, &user)
if err != nil {
return nil, err
}
return &CallbackResponse{
AuthResponse: authResp,
}, nil
}
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
1 week ago
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
1 week ago
func() string {
if user.Email != nil {
return *user.Email
}
return ""
}(),
1 week ago
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
// 生成临时token用于后续绑定
token := generateRandomToken(32)
key := "oauth:bind:" + token
data := map[string]any{
"provider": provider,
"provider_id": userInfo.ID,
"name": userInfo.Name,
"email": userInfo.Email,
"avatar": userInfo.Avatar,
"raw": userInfo.Raw,
}
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
return "", err
}
return token, nil
}
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
// 验证临时token并返回用户信息
key := "oauth:bind:" + token
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return &ThirdPartyUserInfo{
Provider: data["provider"].(string),
ID: data["provider_id"].(string),
Name: data["name"].(string),
Email: data["email"].(string),
Avatar: data["avatar"].(string),
Raw: data["raw"].(map[string]any),
}, nil
}
func generateRandomPassword(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, length)
rand.Read(b)
for i := range b {
b[i] = charset[b[i]%byte(len(charset))]
}
return string(b)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}