You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/api/auth/thirdparty.go

771 lines
20 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
baseauth "github.com/veypi/vbase/auth"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/crypto"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
"gorm.io/gorm"
)
// ProviderInfo 第三方登录提供商信息
type ProviderInfo struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Icon string `json:"icon"`
Enabled bool `json:"enabled"`
}
// listProviders 获取支持的第三方登录提供商列表
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
var dbProviders []models.OAuthProvider
if err := cfg.DB().Where("enabled = ?", true).Order("sort_order").Find(&dbProviders).Error; err != nil {
return nil, err
}
providers := make([]ProviderInfo, 0, len(dbProviders))
for _, p := range dbProviders {
providers = append(providers, ProviderInfo{
Name: p.Code,
DisplayName: p.Name,
Icon: p.Icon,
Enabled: p.Enabled,
})
}
return providers, nil
}
// AuthorizeRequest 第三方登录授权请求
type AuthorizeRequest struct {
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
}
// AuthorizeResponse 授权响应
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
}
// authorizeThirdParty 获取第三方登录授权URL
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
provider := strings.ToLower(req.Provider)
// 生成state
state := generateState()
// 存储state到缓存用于验证回调
stateData := map[string]any{
"provider": provider,
"redirect": req.Redirect,
"bind_mode": req.BindMode,
"created_at": time.Now().Unix(),
}
// 如果是绑定模式,需要当前用户登录
if req.BindMode {
userID := baseauth.GetUserID(x)
if userID == "" {
return nil, vigo.ErrUnauthorized.WithString("login required for bind mode")
}
stateData["user_id"] = userID
}
// 保存state数据到缓存10分钟有效
if err := saveState(state, stateData); err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 构建授权URL
authURL, err := buildAuthURL(provider, state)
if err != nil {
return nil, err
}
return &AuthorizeResponse{
AuthURL: authURL,
State: state,
}, nil
}
// CallbackRequest 第三方登录回调请求
type CallbackRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
Code string `json:"code" src:"query" desc:"授权码"`
State string `json:"state" src:"query" desc:"状态值"`
Error string `json:"error" src:"query" desc:"错误信息"`
}
// CallbackResponse 回调响应
type CallbackResponse struct {
// 绑定模式
BindMode bool `json:"bind_mode,omitempty"`
Bound bool `json:"bound,omitempty"`
// 登录模式
NeedBind bool `json:"need_bind,omitempty"`
Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
// 登录成功
*AuthResponse
}
// callbackThirdParty 处理第三方登录回调
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
if req.Error != "" {
return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error)
}
if req.Code == "" || req.State == "" {
return nil, vigo.ErrInvalidArg.WithString("missing code or state")
}
// 验证state
stateData, err := verifyState(req.State)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired state")
}
provider := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool)
// 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定模式:将第三方身份绑定到当前用户
if bindMode {
userID := stateData["user_id"].(string)
if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err
}
return &CallbackResponse{
BindMode: true,
Bound: true,
}, nil
}
// 登录模式:查找是否已绑定
identity, err := findIdentity(provider, userInfo.ID)
if err != nil {
return nil, err
}
// 如果已绑定,直接登录
if identity != nil {
return loginByIdentity(x, identity)
}
// 未绑定,返回需要绑定的信息
// 生成临时token用于后续绑定
tempToken, err := generateTempBindToken(provider, userInfo)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &CallbackResponse{
NeedBind: true,
Provider: provider,
ProviderID: userInfo.ID,
AuthURL: "/auth/bind?token=" + tempToken,
}, nil
}
// BindRequest 绑定请求
type BindRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
}
// bindThirdParty 绑定第三方账号到已有账号
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
}
// 查找用户
var user models.User
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrUnauthorized.WithString("invalid credentials")
}
// 检查用户状态
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 绑定身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, &user)
}
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
type BindWithRegisterRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名"`
Email string `json:"email" src:"json" desc:"邮箱"`
Phone string `json:"phone" src:"json" desc:"手机号"`
}
// bindWithRegister 绑定并创建新账号
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrInvalidArg.WithString("invalid or expired token")
}
// 检查用户名是否已存在
var count int64
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
if count > 0 {
return nil, vigo.ErrInvalidArg.WithString("username already exists")
}
// 检查邮箱是否已存在
if req.Email != "" {
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
if count > 0 {
return nil, vigo.ErrInvalidArg.WithString("email already exists")
}
}
// 创建用户(随机密码,需要后续设置)
randomPassword := generateRandomPassword(16)
hashedPassword, _ := crypto.HashPassword(randomPassword, 12)
var email *string
if req.Email != "" {
email = &req.Email
}
var phone *string
if req.Phone != "" {
phone = &req.Phone
}
user := &models.User{
Username: req.Username,
Password: hashedPassword,
Email: email,
Phone: phone,
Nickname: userInfo.Name,
Avatar: userInfo.Avatar,
Status: models.UserStatusActive,
}
if user.Nickname == "" {
user.Nickname = req.Username
}
if err := cfg.DB().Create(user).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 授予默认角色 "user"
if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil {
// 记录错误但允许流程继续
}
// 绑定第三方身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, user)
}
// UnbindRequest 解绑请求
type UnbindRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
}
// unbindThirdParty 解除第三方账号绑定
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
userID := baseauth.GetUserID(x)
if userID == "" {
return vigo.ErrUnauthorized
}
// 删除绑定关系
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
return vigo.ErrInternalServer.WithError(err)
}
return nil
}
// BindingInfo 绑定信息
type BindingInfo struct {
Provider string `json:"provider"`
ProviderName string `json:"provider_name"`
Avatar string `json:"avatar"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
}
// listBindings 获取当前用户的第三方绑定列表
func listBindings(x *vigo.X) ([]BindingInfo, error) {
userID := baseauth.GetUserID(x)
if userID == "" {
return nil, vigo.ErrUnauthorized
}
var identities []models.Identity
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
result := make([]BindingInfo, 0, len(identities))
for _, id := range identities {
result = append(result, BindingInfo{
Provider: id.Provider,
ProviderName: id.ProviderName,
Avatar: id.Avatar,
Email: id.Email,
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
// ==================== 辅助函数 ====================
// ThirdPartyUserInfo 第三方用户信息
type ThirdPartyUserInfo struct {
Provider string
ID string
Name string
Email string
Avatar string
Raw map[string]any
}
func generateState() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func saveState(state string, data map[string]any) error {
key := "oauth:state:" + state
return cache.SetObject(key, data, 10*time.Minute)
}
func verifyState(state string) (map[string]any, error) {
key := "oauth:state:" + state
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return data, nil
}
func buildAuthURL(providerCode, state string) (string, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
return "", vigo.ErrInvalidArg.WithString("provider not found or not enabled: " + providerCode)
}
params := url.Values{}
// 特殊处理微信的 appid 参数
if appidParam, ok := provider.ExtraConfig["appid_param"]; ok && appidParam != "" {
params.Set(appidParam, provider.ClientID)
} else {
params.Set("client_id", provider.ClientID)
}
params.Set("redirect_uri", getRedirectURI(providerCode))
params.Set("response_type", "code")
params.Set("state", state)
if len(provider.Scopes) > 0 {
params.Set("scope", strings.Join(provider.Scopes, " "))
}
return provider.AuthURL + "?" + params.Encode(), nil
}
func getRedirectURI(provider string) string {
return "/auth/callback/" + provider
}
func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) {
// 从数据库获取提供商配置
var provider models.OAuthProvider
if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil {
return nil, fmt.Errorf("provider not found: %s", providerCode)
}
// 统一处理 OAuth 流程
return exchangeGeneric(provider, code)
}
func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) {
// 1. 交换 code 获取 access_token
tokenURL := provider.TokenURL
redirectURI := provider.RedirectURI
if redirectURI == "" {
redirectURI = getRedirectURI(provider.Code)
}
// 解密 ClientSecret
clientSecret := provider.ClientSecret
if clientSecret != "" {
decrypted, err := cfg.Global.Key.Decrypt(clientSecret)
if err == nil {
clientSecret = decrypted
}
// 如果解密失败,可能是明文存储的旧数据,继续使用原值
}
// 构建请求参数
data := url.Values{}
data.Set("code", code)
data.Set("client_id", provider.ClientID)
data.Set("client_secret", clientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
// 发送 token 请求
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
usePost := provider.ExtraConfig["use_post_token"] == "true"
var err error
if usePost {
err = postFormJSON(tokenURL, data, &tokenResp)
} else {
err = getJSON(tokenURL+"?"+data.Encode(), &tokenResp)
}
if err != nil {
return nil, fmt.Errorf("exchange token failed: %w", err)
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("oauth error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc)
}
// 2. 获取用户信息
userInfoURL := provider.UserInfoURL
var rawUser map[string]any
if err := getJSONWithAuth(userInfoURL, tokenResp.AccessToken, tokenResp.TokenType, &rawUser); err != nil {
return nil, fmt.Errorf("get user info failed: %w", err)
}
// 3. 解析字段映射
userInfo := &ThirdPartyUserInfo{
Provider: provider.Code,
Raw: rawUser,
}
// 提取字段
userInfo.ID = extractField(rawUser, provider.UserIDPath, "id")
userInfo.Name = extractField(rawUser, provider.UserNamePath, "name")
userInfo.Email = extractField(rawUser, provider.UserEmailPath, "email")
userInfo.Avatar = extractField(rawUser, provider.UserAvatarPath, "avatar")
if userInfo.ID == "" {
return nil, fmt.Errorf("failed to extract user id from response")
}
return userInfo, nil
}
// extractField 从嵌套 map 中提取字段值
func extractField(data map[string]any, path, defaultKey string) string {
if path == "" {
path = defaultKey
}
if path == "" {
return ""
}
// 支持点号分隔的路径,如 "data.employee_id"
keys := strings.Split(path, ".")
current := data
for i, key := range keys {
if i == len(keys)-1 {
// 最后一个 key
if val, ok := current[key]; ok {
switch v := val.(type) {
case string:
return v
case float64:
return fmt.Sprintf("%.0f", v)
case int:
return fmt.Sprintf("%d", v)
default:
return fmt.Sprintf("%v", v)
}
}
return ""
}
// 继续深入
if next, ok := current[key].(map[string]any); ok {
current = next
} else {
return ""
}
}
return ""
}
// postFormJSON 发送 POST form 请求并解析 JSON 响应
func postFormJSON(url string, data url.Values, result any) error {
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
if err != nil {
return err
}
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(result)
}
// getJSON 发送 GET 请求并解析 JSON 响应
func getJSON(url string, result any) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(result)
}
// getJSONWithAuth 发送带认证的 GET 请求
func getJSONWithAuth(url, token, tokenType string, result any) error {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
// 设置认证头
if tokenType == "Bearer" || tokenType == "" {
req.Header.Set("Authorization", "Bearer "+token)
} else {
// 如 GitHub 使用 token 类型
req.Header.Set("Authorization", tokenType+" "+token)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(result)
}
func findIdentity(provider, providerUID string) (*models.Identity, error) {
var identity models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &identity, nil
}
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
// 检查是否已绑定
var existing models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
// 已绑定到其他账号
if existing.UserID != userID {
return vigo.ErrForbidden.WithString("this account is already bound to another user")
}
// 已绑定到当前账号,更新信息
existing.ProviderName = userInfo.Name
existing.Avatar = userInfo.Avatar
existing.Email = userInfo.Email
return cfg.DB().Save(&existing).Error
}
// 创建新绑定
identity := &models.Identity{
UserID: userID,
Provider: provider,
ProviderUID: userInfo.ID,
ProviderName: userInfo.Name,
Avatar: userInfo.Avatar,
Email: userInfo.Email,
}
return cfg.DB().Create(identity).Error
}
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
// 查找用户
var user models.User
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
return nil, vigo.ErrNotFound.WithString("user not found")
}
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 生成token
authResp, err := generateAuthResponse(x, &user)
if err != nil {
return nil, err
}
return &CallbackResponse{
AuthResponse: authResp,
}, nil
}
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
func() string {
if user.Email != nil {
return *user.Email
}
return ""
}(),
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
// 生成临时token用于后续绑定
token := generateRandomToken(32)
key := "oauth:bind:" + token
data := map[string]any{
"provider": provider,
"provider_id": userInfo.ID,
"name": userInfo.Name,
"email": userInfo.Email,
"avatar": userInfo.Avatar,
"raw": userInfo.Raw,
}
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
return "", err
}
return token, nil
}
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
// 验证临时token并返回用户信息
key := "oauth:bind:" + token
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return &ThirdPartyUserInfo{
Provider: data["provider"].(string),
ID: data["provider_id"].(string),
Name: data["name"].(string),
Email: data["email"].(string),
Avatar: data["avatar"].(string),
Raw: data["raw"].(map[string]any),
}, nil
}
func generateRandomPassword(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, length)
rand.Read(b)
for i := range b {
b[i] = charset[b[i]%byte(len(charset))]
}
return string(b)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}