You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/api/auth/thirdparty.go

782 lines
20 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
baseauth "github.com/veypi/vbase/auth"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/crypto"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
"github.com/veypi/vigo"
"gorm.io/gorm"
)
// ProviderInfo 第三方登录提供商信息
type ProviderInfo struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Icon string `json:"icon"`
Enabled bool `json:"enabled"`
}
// listProviders 获取支持的第三方登录提供商列表
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
providers := []ProviderInfo{
{Name: "google", DisplayName: "Google", Icon: "google", Enabled: cfg.Config.Providers.Google.Enabled},
{Name: "github", DisplayName: "GitHub", Icon: "github", Enabled: cfg.Config.Providers.GitHub.Enabled},
{Name: "wechat", DisplayName: "微信", Icon: "wechat", Enabled: cfg.Config.Providers.WeChat.Enabled},
}
return providers, nil
}
// AuthorizeRequest 第三方登录授权请求
type AuthorizeRequest struct {
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
}
// AuthorizeResponse 授权响应
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
}
// authorizeThirdParty 获取第三方登录授权URL
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
provider := strings.ToLower(req.Provider)
// 生成state
state := generateState()
// 存储state到缓存用于验证回调
stateData := map[string]any{
"provider": provider,
"redirect": req.Redirect,
"bind_mode": req.BindMode,
"created_at": time.Now().Unix(),
}
// 如果是绑定模式,需要当前用户登录
if req.BindMode {
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized.WithString("login required for bind mode")
}
stateData["user_id"] = userID
}
// 保存state数据到缓存10分钟有效
if err := saveState(state, stateData); err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 构建授权URL
authURL, err := buildAuthURL(provider, state)
if err != nil {
return nil, err
}
return &AuthorizeResponse{
AuthURL: authURL,
State: state,
}, nil
}
// CallbackRequest 第三方登录回调请求
type CallbackRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
Code string `json:"code" src:"query" desc:"授权码"`
State string `json:"state" src:"query" desc:"状态值"`
Error string `json:"error" src:"query" desc:"错误信息"`
}
// CallbackResponse 回调响应
type CallbackResponse struct {
// 绑定模式
BindMode bool `json:"bind_mode,omitempty"`
Bound bool `json:"bound,omitempty"`
// 登录模式
NeedBind bool `json:"need_bind,omitempty"`
Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
// 登录成功
*AuthResponse
}
// callbackThirdParty 处理第三方登录回调
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
if req.Error != "" {
return nil, vigo.ErrArgInvalid.WithString("oauth error: " + req.Error)
}
if req.Code == "" || req.State == "" {
return nil, vigo.ErrArgInvalid.WithString("missing code or state")
}
// 验证state
stateData, err := verifyState(req.State)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired state")
}
provider := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool)
// 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定模式:将第三方身份绑定到当前用户
if bindMode {
userID := stateData["user_id"].(string)
if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err
}
return &CallbackResponse{
BindMode: true,
Bound: true,
}, nil
}
// 登录模式:查找是否已绑定
identity, err := findIdentity(provider, userInfo.ID)
if err != nil {
return nil, err
}
// 如果已绑定,直接登录
if identity != nil {
return loginByIdentity(x, identity)
}
// 未绑定,返回需要绑定的信息
// 生成临时token用于后续绑定
tempToken, err := generateTempBindToken(provider, userInfo)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &CallbackResponse{
NeedBind: true,
Provider: provider,
ProviderID: userInfo.ID,
AuthURL: "/auth/bind?token=" + tempToken,
}, nil
}
// BindRequest 绑定请求
type BindRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
}
// bindThirdParty 绑定第三方账号到已有账号
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired token")
}
// 查找用户
var user models.User
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrNotAuthorized.WithString("invalid credentials")
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrNotAuthorized.WithString("invalid credentials")
}
// 检查用户状态
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 绑定身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, &user)
}
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
type BindWithRegisterRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名"`
Email string `json:"email" src:"json" desc:"邮箱"`
Phone string `json:"phone" src:"json" desc:"手机号"`
}
// bindWithRegister 绑定并创建新账号
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired token")
}
// 检查用户名是否已存在
var count int64
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("username already exists")
}
// 检查邮箱是否已存在
if req.Email != "" {
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("email already exists")
}
}
// 创建用户(随机密码,需要后续设置)
randomPassword := generateRandomPassword(16)
hashedPassword, _ := crypto.HashPassword(randomPassword, cfg.Config.Security.BcryptCost)
var email *string
if req.Email != "" {
email = &req.Email
}
var phone *string
if req.Phone != "" {
phone = &req.Phone
}
user := &models.User{
Username: req.Username,
Password: hashedPassword,
Email: email,
Phone: phone,
Nickname: userInfo.Name,
Avatar: userInfo.Avatar,
Status: models.UserStatusActive,
}
if user.Nickname == "" {
user.Nickname = req.Username
}
if err := cfg.DB().Create(user).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 授予默认角色 "user"
if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil {
// 记录错误但允许流程继续
}
// 绑定第三方身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, user)
}
// UnbindRequest 解绑请求
type UnbindRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
}
// unbindThirdParty 解除第三方账号绑定
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
userID := getCurrentUserID(x)
if userID == "" {
return vigo.ErrNotAuthorized
}
// 删除绑定关系
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
return vigo.ErrInternalServer.WithError(err)
}
return nil
}
// BindingInfo 绑定信息
type BindingInfo struct {
Provider string `json:"provider"`
ProviderName string `json:"provider_name"`
Avatar string `json:"avatar"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
}
// listBindings 获取当前用户的第三方绑定列表
func listBindings(x *vigo.X) ([]BindingInfo, error) {
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized
}
var identities []models.Identity
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
result := make([]BindingInfo, 0, len(identities))
for _, id := range identities {
result = append(result, BindingInfo{
Provider: id.Provider,
ProviderName: id.ProviderName,
Avatar: id.Avatar,
Email: id.Email,
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
// ==================== 辅助函数 ====================
// ThirdPartyUserInfo 第三方用户信息
type ThirdPartyUserInfo struct {
Provider string
ID string
Name string
Email string
Avatar string
Raw map[string]any
}
func generateState() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func saveState(state string, data map[string]any) error {
key := "oauth:state:" + state
return cache.SetObject(key, data, 10*time.Minute)
}
func verifyState(state string) (map[string]any, error) {
key := "oauth:state:" + state
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return data, nil
}
func buildAuthURL(provider, state string) (string, error) {
var pc cfg.OAuthProviderConfig
switch provider {
case "google":
pc = cfg.Config.Providers.Google
case "github":
pc = cfg.Config.Providers.GitHub
case "wechat":
pc = cfg.Config.Providers.WeChat
default:
return "", vigo.ErrArgInvalid.WithString("unsupported provider: " + provider)
}
if !pc.Enabled {
return "", vigo.ErrArgInvalid.WithString("provider not enabled: " + provider)
}
params := url.Values{}
params.Set("client_id", pc.ClientID)
params.Set("redirect_uri", getRedirectURI(provider))
params.Set("response_type", "code")
params.Set("state", state)
params.Set("scope", strings.Join(pc.Scopes, " "))
// 特殊处理
if provider == "wechat" {
params.Set("appid", pc.ClientID)
}
return pc.AuthURL + "?" + params.Encode(), nil
}
func getRedirectURI(provider string) string {
return "/auth/callback/" + provider
}
func exchangeAndGetUserInfo(provider, code string) (*ThirdPartyUserInfo, error) {
switch provider {
case "google":
return exchangeGoogle(code)
case "github":
return exchangeGitHub(code)
case "wechat":
return exchangeWeChat(code)
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
}
func exchangeGoogle(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.Google
// 交换access_token
data := url.Values{}
data.Set("code", code)
data.Set("client_id", pc.ClientID)
data.Set("client_secret", pc.ClientSecret)
data.Set("redirect_uri", getRedirectURI("google"))
data.Set("grant_type", "authorization_code")
resp, err := http.Post(pc.TokenURL, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
Error string `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
}
// 获取用户信息
req, _ := http.NewRequest("GET", pc.UserInfoURL, nil)
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
userResp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
Picture string `json:"picture"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
return &ThirdPartyUserInfo{
Provider: "google",
ID: userData.Sub,
Name: userData.Name,
Email: userData.Email,
Avatar: userData.Picture,
}, nil
}
func exchangeGitHub(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.GitHub
// 交换access_token
data := url.Values{}
data.Set("code", code)
data.Set("client_id", pc.ClientID)
data.Set("client_secret", pc.ClientSecret)
data.Set("redirect_uri", getRedirectURI("github"))
req, _ := http.NewRequest("POST", pc.TokenURL, strings.NewReader(data.Encode()))
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
Error string `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
}
// 获取用户信息
req, _ = http.NewRequest("GET", pc.UserInfoURL, nil)
req.Header.Set("Authorization", "token "+tokenResp.AccessToken)
userResp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
ID int `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
return &ThirdPartyUserInfo{
Provider: "github",
ID: fmt.Sprintf("%d", userData.ID),
Name: userData.Name,
Email: userData.Email,
Avatar: userData.AvatarURL,
}, nil
}
func exchangeWeChat(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.WeChat
// 交换access_token
urlStr := fmt.Sprintf("%s?appid=%s&secret=%s&code=%s&grant_type=authorization_code",
pc.TokenURL, pc.ClientID, pc.ClientSecret, code)
resp, err := http.Get(urlStr)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
OpenID string `json:"openid"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat error: %s", tokenResp.ErrMsg)
}
// 获取用户信息
userInfoURL := fmt.Sprintf("%s?access_token=%s&openid=%s",
pc.UserInfoURL, tokenResp.AccessToken, tokenResp.OpenID)
userResp, err := http.Get(userInfoURL)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
HeadImg string `json:"headimgurl"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
if userData.ErrCode != 0 {
return nil, fmt.Errorf("wechat error: %s", userData.ErrMsg)
}
return &ThirdPartyUserInfo{
Provider: "wechat",
ID: tokenResp.UnionID,
Name: userData.Nickname,
Avatar: userData.HeadImg,
}, nil
}
func findIdentity(provider, providerUID string) (*models.Identity, error) {
var identity models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &identity, nil
}
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
// 检查是否已绑定
var existing models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
// 已绑定到其他账号
if existing.UserID != userID {
return vigo.ErrForbidden.WithString("this account is already bound to another user")
}
// 已绑定到当前账号,更新信息
existing.ProviderName = userInfo.Name
existing.Avatar = userInfo.Avatar
existing.Email = userInfo.Email
return cfg.DB().Save(&existing).Error
}
// 创建新绑定
identity := &models.Identity{
UserID: userID,
Provider: provider,
ProviderUID: userInfo.ID,
ProviderName: userInfo.Name,
Avatar: userInfo.Avatar,
Email: userInfo.Email,
}
return cfg.DB().Create(identity).Error
}
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
// 查找用户
var user models.User
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
return nil, vigo.ErrNotFound.WithString("user not found")
}
if user.Status != models.UserStatusActive {
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 生成token
authResp, err := generateAuthResponse(x, &user)
if err != nil {
return nil, err
}
return &CallbackResponse{
AuthResponse: authResp,
}, nil
}
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
func() string {
if user.Email != nil {
return *user.Email
}
return ""
}(),
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
// 生成临时token用于后续绑定
token := generateRandomToken(32)
key := "oauth:bind:" + token
data := map[string]any{
"provider": provider,
"provider_id": userInfo.ID,
"name": userInfo.Name,
"email": userInfo.Email,
"avatar": userInfo.Avatar,
"raw": userInfo.Raw,
}
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
return "", err
}
return token, nil
}
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
// 验证临时token并返回用户信息
key := "oauth:bind:" + token
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return &ThirdPartyUserInfo{
Provider: data["provider"].(string),
ID: data["provider_id"].(string),
Name: data["name"].(string),
Email: data["email"].(string),
Avatar: data["avatar"].(string),
Raw: data["raw"].(map[string]any),
}, nil
}
func generateRandomPassword(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, length)
rand.Read(b)
for i := range b {
b[i] = charset[b[i]%byte(len(charset))]
}
return string(b)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}