You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/api/auth/thirdparty.go

762 lines
20 KiB
Go

//
// Copyright (C) 2024 veypi <i@veypi.com>
// 2025-03-04 16:08:06
// Distributed under terms of the MIT license.
//
1 week ago
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/veypi/vbase/cfg"
"github.com/veypi/vbase/libs/cache"
"github.com/veypi/vbase/libs/crypto"
"github.com/veypi/vbase/libs/jwt"
"github.com/veypi/vbase/models"
1 week ago
"github.com/veypi/vigo"
"gorm.io/gorm"
)
// ProviderInfo 第三方登录提供商信息
type ProviderInfo struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Icon string `json:"icon"`
Enabled bool `json:"enabled"`
}
// listProviders 获取支持的第三方登录提供商列表
func listProviders(x *vigo.X) ([]ProviderInfo, error) {
1 week ago
providers := []ProviderInfo{
{Name: "google", DisplayName: "Google", Icon: "google", Enabled: cfg.Config.Providers.Google.Enabled},
{Name: "github", DisplayName: "GitHub", Icon: "github", Enabled: cfg.Config.Providers.GitHub.Enabled},
{Name: "wechat", DisplayName: "微信", Icon: "wechat", Enabled: cfg.Config.Providers.WeChat.Enabled},
1 week ago
}
return providers, nil
}
// AuthorizeRequest 第三方登录授权请求
type AuthorizeRequest struct {
Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"`
Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"`
BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"`
1 week ago
}
// AuthorizeResponse 授权响应
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
}
// authorizeThirdParty 获取第三方登录授权URL
func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) {
1 week ago
provider := strings.ToLower(req.Provider)
// 生成state
state := generateState()
// 存储state到缓存用于验证回调
stateData := map[string]any{
"provider": provider,
"redirect": req.Redirect,
"bind_mode": req.BindMode,
1 week ago
"created_at": time.Now().Unix(),
}
// 如果是绑定模式,需要当前用户登录
if req.BindMode {
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized.WithString("login required for bind mode")
}
stateData["user_id"] = userID
}
// 保存state数据到缓存10分钟有效
if err := saveState(state, stateData); err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 构建授权URL
authURL, err := buildAuthURL(provider, state)
if err != nil {
return nil, err
}
return &AuthorizeResponse{
AuthURL: authURL,
State: state,
}, nil
}
// CallbackRequest 第三方登录回调请求
type CallbackRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
Code string `json:"code" src:"query" desc:"授权码"`
State string `json:"state" src:"query" desc:"状态值"`
Error string `json:"error" src:"query" desc:"错误信息"`
}
// CallbackResponse 回调响应
type CallbackResponse struct {
// 绑定模式
BindMode bool `json:"bind_mode,omitempty"`
Bound bool `json:"bound,omitempty"`
// 登录模式
NeedBind bool `json:"need_bind,omitempty"`
Provider string `json:"provider,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
1 week ago
// 登录成功
*AuthResponse
}
// callbackThirdParty 处理第三方登录回调
func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) {
1 week ago
if req.Error != "" {
return nil, vigo.ErrArgInvalid.WithString("oauth error: " + req.Error)
}
if req.Code == "" || req.State == "" {
return nil, vigo.ErrArgInvalid.WithString("missing code or state")
}
// 验证state
stateData, err := verifyState(req.State)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired state")
}
provider := stateData["provider"].(string)
bindMode := stateData["bind_mode"].(bool)
// 交换access_token并获取用户信息
userInfo, err := exchangeAndGetUserInfo(provider, req.Code)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定模式:将第三方身份绑定到当前用户
if bindMode {
userID := stateData["user_id"].(string)
if err := bindIdentity(userID, provider, userInfo); err != nil {
return nil, err
}
return &CallbackResponse{
BindMode: true,
Bound: true,
}, nil
}
// 登录模式:查找是否已绑定
identity, err := findIdentity(provider, userInfo.ID)
if err != nil {
return nil, err
}
// 如果已绑定,直接登录
if identity != nil {
return loginByIdentity(x, identity)
}
// 未绑定,返回需要绑定的信息
// 生成临时token用于后续绑定
tempToken, err := generateTempBindToken(provider, userInfo)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &CallbackResponse{
NeedBind: true,
Provider: provider,
ProviderID: userInfo.ID,
AuthURL: "/auth/bind?token=" + tempToken,
}, nil
}
// BindRequest 绑定请求
type BindRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"`
Password string `json:"password" src:"json" desc:"密码"`
}
// bindThirdParty 绑定第三方账号到已有账号
func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) {
1 week ago
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired token")
}
// 查找用户
var user models.User
query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username)
1 week ago
if err := query.First(&user).Error; err != nil {
return nil, vigo.ErrNotAuthorized.WithString("invalid credentials")
}
// 验证密码
if !crypto.VerifyPassword(req.Password, user.Password) {
return nil, vigo.ErrNotAuthorized.WithString("invalid credentials")
}
// 检查用户状态
if user.Status != models.UserStatusActive {
1 week ago
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 绑定身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, &user)
}
// BindWithRegisterRequest 绑定并注册新账号(可选功能)
type BindWithRegisterRequest struct {
TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"`
Username string `json:"username" src:"json" desc:"用户名"`
Email string `json:"email" src:"json" desc:"邮箱"`
Phone string `json:"phone" src:"json" desc:"手机号"`
}
// bindWithRegister 绑定并创建新账号
func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) {
1 week ago
// 验证临时token
userInfo, err := verifyTempBindToken(req.TempToken)
if err != nil {
return nil, vigo.ErrArgInvalid.WithString("invalid or expired token")
}
// 检查用户名是否已存在
var count int64
cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count)
1 week ago
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("username already exists")
}
// 检查邮箱是否已存在
if req.Email != "" {
cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count)
1 week ago
if count > 0 {
return nil, vigo.ErrArgInvalid.WithString("email already exists")
}
}
// 创建用户(随机密码,需要后续设置)
randomPassword := generateRandomPassword(16)
hashedPassword, _ := crypto.HashPassword(randomPassword, cfg.Config.Security.BcryptCost)
1 week ago
user := &models.User{
1 week ago
Username: req.Username,
Password: hashedPassword,
Email: req.Email,
Phone: req.Phone,
Nickname: userInfo.Name,
Avatar: userInfo.Avatar,
Status: models.UserStatusActive,
1 week ago
}
if user.Nickname == "" {
user.Nickname = req.Username
}
if err := cfg.DB().Create(user).Error; err != nil {
1 week ago
return nil, vigo.ErrInternalServer.WithError(err)
}
// 绑定第三方身份
if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil {
return nil, err
}
// 生成登录token
return generateAuthResponse(x, user)
}
// UnbindRequest 解绑请求
type UnbindRequest struct {
Provider string `json:"provider" src:"path" desc:"提供商"`
}
// unbindThirdParty 解除第三方账号绑定
func unbindThirdParty(x *vigo.X, req *UnbindRequest) error {
1 week ago
userID := getCurrentUserID(x)
if userID == "" {
return vigo.ErrNotAuthorized
}
// 删除绑定关系
if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil {
1 week ago
return vigo.ErrInternalServer.WithError(err)
}
return nil
}
// BindingInfo 绑定信息
type BindingInfo struct {
Provider string `json:"provider"`
ProviderName string `json:"provider_name"`
Avatar string `json:"avatar"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
}
// listBindings 获取当前用户的第三方绑定列表
func listBindings(x *vigo.X) ([]BindingInfo, error) {
1 week ago
userID := getCurrentUserID(x)
if userID == "" {
return nil, vigo.ErrNotAuthorized
}
var identities []models.Identity
if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil {
1 week ago
return nil, vigo.ErrInternalServer.WithError(err)
}
result := make([]BindingInfo, 0, len(identities))
for _, id := range identities {
result = append(result, BindingInfo{
Provider: id.Provider,
ProviderName: id.ProviderName,
Avatar: id.Avatar,
Email: id.Email,
CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
// ==================== 辅助函数 ====================
// ThirdPartyUserInfo 第三方用户信息
type ThirdPartyUserInfo struct {
Provider string
ID string
Name string
Email string
Avatar string
Raw map[string]any
}
func generateState() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func saveState(state string, data map[string]any) error {
key := "oauth:state:" + state
return cache.SetObject(key, data, 10*time.Minute)
}
func verifyState(state string) (map[string]any, error) {
key := "oauth:state:" + state
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return data, nil
}
func buildAuthURL(provider, state string) (string, error) {
var pc cfg.OAuthProviderConfig
1 week ago
switch provider {
case "google":
pc = cfg.Config.Providers.Google
1 week ago
case "github":
pc = cfg.Config.Providers.GitHub
1 week ago
case "wechat":
pc = cfg.Config.Providers.WeChat
1 week ago
default:
return "", vigo.ErrArgInvalid.WithString("unsupported provider: " + provider)
}
if !pc.Enabled {
1 week ago
return "", vigo.ErrArgInvalid.WithString("provider not enabled: " + provider)
}
params := url.Values{}
params.Set("client_id", pc.ClientID)
1 week ago
params.Set("redirect_uri", getRedirectURI(provider))
params.Set("response_type", "code")
params.Set("state", state)
params.Set("scope", strings.Join(pc.Scopes, " "))
1 week ago
// 特殊处理
if provider == "wechat" {
params.Set("appid", pc.ClientID)
1 week ago
}
return pc.AuthURL + "?" + params.Encode(), nil
1 week ago
}
func getRedirectURI(provider string) string {
return "/auth/callback/" + provider
1 week ago
}
func exchangeAndGetUserInfo(provider, code string) (*ThirdPartyUserInfo, error) {
switch provider {
case "google":
return exchangeGoogle(code)
case "github":
return exchangeGitHub(code)
case "wechat":
return exchangeWeChat(code)
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
}
func exchangeGoogle(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.Google
1 week ago
// 交换access_token
data := url.Values{}
data.Set("code", code)
data.Set("client_id", pc.ClientID)
data.Set("client_secret", pc.ClientSecret)
1 week ago
data.Set("redirect_uri", getRedirectURI("google"))
data.Set("grant_type", "authorization_code")
resp, err := http.Post(pc.TokenURL, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
1 week ago
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
Error string `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
}
// 获取用户信息
req, _ := http.NewRequest("GET", pc.UserInfoURL, nil)
1 week ago
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
userResp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
Picture string `json:"picture"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
return &ThirdPartyUserInfo{
Provider: "google",
ID: userData.Sub,
Name: userData.Name,
Email: userData.Email,
Avatar: userData.Picture,
}, nil
}
func exchangeGitHub(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.GitHub
1 week ago
// 交换access_token
data := url.Values{}
data.Set("code", code)
data.Set("client_id", pc.ClientID)
data.Set("client_secret", pc.ClientSecret)
1 week ago
data.Set("redirect_uri", getRedirectURI("github"))
req, _ := http.NewRequest("POST", pc.TokenURL, strings.NewReader(data.Encode()))
1 week ago
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
Error string `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token error: %s", tokenResp.Error)
}
// 获取用户信息
req, _ = http.NewRequest("GET", pc.UserInfoURL, nil)
1 week ago
req.Header.Set("Authorization", "token "+tokenResp.AccessToken)
userResp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
ID int `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
1 week ago
AvatarURL string `json:"avatar_url"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
return &ThirdPartyUserInfo{
Provider: "github",
ID: fmt.Sprintf("%d", userData.ID),
Name: userData.Name,
Email: userData.Email,
Avatar: userData.AvatarURL,
}, nil
}
func exchangeWeChat(code string) (*ThirdPartyUserInfo, error) {
pc := cfg.Config.Providers.WeChat
1 week ago
// 交换access_token
urlStr := fmt.Sprintf("%s?appid=%s&secret=%s&code=%s&grant_type=authorization_code",
pc.TokenURL, pc.ClientID, pc.ClientSecret, code)
1 week ago
resp, err := http.Get(urlStr)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
OpenID string `json:"openid"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
1 week ago
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
if tokenResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat error: %s", tokenResp.ErrMsg)
}
// 获取用户信息
userInfoURL := fmt.Sprintf("%s?access_token=%s&openid=%s",
pc.UserInfoURL, tokenResp.AccessToken, tokenResp.OpenID)
1 week ago
userResp, err := http.Get(userInfoURL)
if err != nil {
return nil, err
}
defer userResp.Body.Close()
var userData struct {
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
HeadImg string `json:"headimgurl"`
UnionID string `json:"unionid"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil {
return nil, err
}
if userData.ErrCode != 0 {
return nil, fmt.Errorf("wechat error: %s", userData.ErrMsg)
}
return &ThirdPartyUserInfo{
Provider: "wechat",
ID: tokenResp.UnionID,
Name: userData.Nickname,
Avatar: userData.HeadImg,
}, nil
}
func findIdentity(provider, providerUID string) (*models.Identity, error) {
var identity models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil {
1 week ago
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &identity, nil
}
func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error {
// 检查是否已绑定
var existing models.Identity
if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil {
1 week ago
// 已绑定到其他账号
if existing.UserID != userID {
return vigo.ErrForbidden.WithString("this account is already bound to another user")
}
// 已绑定到当前账号,更新信息
existing.ProviderName = userInfo.Name
existing.Avatar = userInfo.Avatar
existing.Email = userInfo.Email
return cfg.DB().Save(&existing).Error
1 week ago
}
// 创建新绑定
identity := &models.Identity{
1 week ago
UserID: userID,
Provider: provider,
ProviderUID: userInfo.ID,
ProviderName: userInfo.Name,
Avatar: userInfo.Avatar,
Email: userInfo.Email,
}
return cfg.DB().Create(identity).Error
1 week ago
}
func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) {
1 week ago
// 查找用户
var user models.User
if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil {
1 week ago
return nil, vigo.ErrNotFound.WithString("user not found")
}
if user.Status != models.UserStatusActive {
1 week ago
return nil, vigo.ErrForbidden.WithString("user is disabled")
}
// 生成token
authResp, err := generateAuthResponse(x, &user)
if err != nil {
return nil, err
}
return &CallbackResponse{
AuthResponse: authResp,
}, nil
}
func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) {
1 week ago
// 获取用户的组织信息
orgs, err := getUserOrgs(user.ID)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
orgClaims := make([]jwt.OrgClaim, 0, len(orgs))
for _, org := range orgs {
orgClaims = append(orgClaims, jwt.OrgClaim{
OrgID: org.OrgID,
Code: org.Code,
Name: org.Name,
Roles: org.Roles,
Status: org.Status,
})
}
tokenPair, err := jwt.GenerateTokenPair(
user.ID,
user.Username,
user.Nickname,
user.Avatar,
user.Email,
orgClaims,
)
if err != nil {
return nil, vigo.ErrInternalServer.WithError(err)
}
return &AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: &UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Avatar: user.Avatar,
},
}, nil
}
func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) {
// 生成临时token用于后续绑定
token := generateRandomToken(32)
key := "oauth:bind:" + token
data := map[string]any{
"provider": provider,
"provider_id": userInfo.ID,
"name": userInfo.Name,
"email": userInfo.Email,
"avatar": userInfo.Avatar,
"raw": userInfo.Raw,
}
if err := cache.SetObject(key, data, 10*time.Minute); err != nil {
return "", err
}
return token, nil
}
func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) {
// 验证临时token并返回用户信息
key := "oauth:bind:" + token
var data map[string]any
if err := cache.GetObject(key, &data); err != nil {
return nil, err
}
// 验证后删除
cache.Delete(key)
return &ThirdPartyUserInfo{
Provider: data["provider"].(string),
ID: data["provider_id"].(string),
Name: data["name"].(string),
Email: data["email"].(string),
Avatar: data["avatar"].(string),
Raw: data["raw"].(map[string]any),
}, nil
}
func generateRandomPassword(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, length)
rand.Read(b)
for i := range b {
b[i] = charset[b[i]%byte(len(charset))]
}
return string(b)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return hex.EncodeToString(b)
}