// // Copyright (C) 2024 veypi // 2025-03-04 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "crypto/rand" "encoding/hex" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" "github.com/veypi/vbase/libs/crypto" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" "gorm.io/gorm" ) // ProviderInfo 第三方登录提供商信息 type ProviderInfo struct { Name string `json:"name"` DisplayName string `json:"display_name"` Icon string `json:"icon"` Enabled bool `json:"enabled"` } // listProviders 获取支持的第三方登录提供商列表 func listProviders(x *vigo.X) ([]ProviderInfo, error) { providers := []ProviderInfo{ {Name: "google", DisplayName: "Google", Icon: "google", Enabled: cfg.Config.Providers.Google.Enabled}, {Name: "github", DisplayName: "GitHub", Icon: "github", Enabled: cfg.Config.Providers.GitHub.Enabled}, {Name: "wechat", DisplayName: "微信", Icon: "wechat", Enabled: cfg.Config.Providers.WeChat.Enabled}, } return providers, nil } // AuthorizeRequest 第三方登录授权请求 type AuthorizeRequest struct { Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"` Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"` BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"` } // AuthorizeResponse 授权响应 type AuthorizeResponse struct { AuthURL string `json:"auth_url"` State string `json:"state"` } // authorizeThirdParty 获取第三方登录授权URL func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) { provider := strings.ToLower(req.Provider) // 生成state state := generateState() // 存储state到缓存(用于验证回调) stateData := map[string]any{ "provider": provider, "redirect": req.Redirect, "bind_mode": req.BindMode, "created_at": time.Now().Unix(), } // 如果是绑定模式,需要当前用户登录 if req.BindMode { userID := getCurrentUserID(x) if userID == "" { return nil, vigo.ErrNotAuthorized.WithString("login required for bind mode") } stateData["user_id"] = userID } // 保存state数据到缓存(10分钟有效) if err := saveState(state, stateData); err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 构建授权URL authURL, err := buildAuthURL(provider, state) if err != nil { return nil, err } return &AuthorizeResponse{ AuthURL: authURL, State: state, }, nil } // CallbackRequest 第三方登录回调请求 type CallbackRequest struct { Provider string `json:"provider" src:"path" desc:"提供商"` Code string `json:"code" src:"query" desc:"授权码"` State string `json:"state" src:"query" desc:"状态值"` Error string `json:"error" src:"query" desc:"错误信息"` } // CallbackResponse 回调响应 type CallbackResponse struct { // 绑定模式 BindMode bool `json:"bind_mode,omitempty"` Bound bool `json:"bound,omitempty"` // 登录模式 NeedBind bool `json:"need_bind,omitempty"` Provider string `json:"provider,omitempty"` ProviderID string `json:"provider_id,omitempty"` AuthURL string `json:"auth_url,omitempty"` // 登录成功 *AuthResponse } // callbackThirdParty 处理第三方登录回调 func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) { if req.Error != "" { return nil, vigo.ErrArgInvalid.WithString("oauth error: " + req.Error) } if req.Code == "" || req.State == "" { return nil, vigo.ErrArgInvalid.WithString("missing code or state") } // 验证state stateData, err := verifyState(req.State) if err != nil { return nil, vigo.ErrArgInvalid.WithString("invalid or expired state") } provider := stateData["provider"].(string) bindMode := stateData["bind_mode"].(bool) // 交换access_token并获取用户信息 userInfo, err := exchangeAndGetUserInfo(provider, req.Code) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 绑定模式:将第三方身份绑定到当前用户 if bindMode { userID := stateData["user_id"].(string) if err := bindIdentity(userID, provider, userInfo); err != nil { return nil, err } return &CallbackResponse{ BindMode: true, Bound: true, }, nil } // 登录模式:查找是否已绑定 identity, err := findIdentity(provider, userInfo.ID) if err != nil { return nil, err } // 如果已绑定,直接登录 if identity != nil { return loginByIdentity(x, identity) } // 未绑定,返回需要绑定的信息 // 生成临时token用于后续绑定 tempToken, err := generateTempBindToken(provider, userInfo) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return &CallbackResponse{ NeedBind: true, Provider: provider, ProviderID: userInfo.ID, AuthURL: "/auth/bind?token=" + tempToken, }, nil } // BindRequest 绑定请求 type BindRequest struct { TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"` Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"` Password string `json:"password" src:"json" desc:"密码"` } // bindThirdParty 绑定第三方账号到已有账号 func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) { // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { return nil, vigo.ErrArgInvalid.WithString("invalid or expired token") } // 查找用户 var user models.User query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username) if err := query.First(&user).Error; err != nil { return nil, vigo.ErrNotAuthorized.WithString("invalid credentials") } // 验证密码 if !crypto.VerifyPassword(req.Password, user.Password) { return nil, vigo.ErrNotAuthorized.WithString("invalid credentials") } // 检查用户状态 if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 绑定身份 if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil { return nil, err } // 生成登录token return generateAuthResponse(x, &user) } // BindWithRegisterRequest 绑定并注册新账号(可选功能) type BindWithRegisterRequest struct { TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"` Username string `json:"username" src:"json" desc:"用户名"` Email string `json:"email" src:"json" desc:"邮箱"` Phone string `json:"phone" src:"json" desc:"手机号"` } // bindWithRegister 绑定并创建新账号 func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) { // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { return nil, vigo.ErrArgInvalid.WithString("invalid or expired token") } // 检查用户名是否已存在 var count int64 cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count) if count > 0 { return nil, vigo.ErrArgInvalid.WithString("username already exists") } // 检查邮箱是否已存在 if req.Email != "" { cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count) if count > 0 { return nil, vigo.ErrArgInvalid.WithString("email already exists") } } // 创建用户(随机密码,需要后续设置) randomPassword := generateRandomPassword(16) hashedPassword, _ := crypto.HashPassword(randomPassword, cfg.Config.Security.BcryptCost) user := &models.User{ Username: req.Username, Password: hashedPassword, Email: req.Email, Phone: req.Phone, Nickname: userInfo.Name, Avatar: userInfo.Avatar, Status: models.UserStatusActive, } if user.Nickname == "" { user.Nickname = req.Username } if err := cfg.DB().Create(user).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 绑定第三方身份 if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil { return nil, err } // 生成登录token return generateAuthResponse(x, user) } // UnbindRequest 解绑请求 type UnbindRequest struct { Provider string `json:"provider" src:"path" desc:"提供商"` } // unbindThirdParty 解除第三方账号绑定 func unbindThirdParty(x *vigo.X, req *UnbindRequest) error { userID := getCurrentUserID(x) if userID == "" { return vigo.ErrNotAuthorized } // 删除绑定关系 if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil { return vigo.ErrInternalServer.WithError(err) } return nil } // BindingInfo 绑定信息 type BindingInfo struct { Provider string `json:"provider"` ProviderName string `json:"provider_name"` Avatar string `json:"avatar"` Email string `json:"email"` CreatedAt string `json:"created_at"` } // listBindings 获取当前用户的第三方绑定列表 func listBindings(x *vigo.X) ([]BindingInfo, error) { userID := getCurrentUserID(x) if userID == "" { return nil, vigo.ErrNotAuthorized } var identities []models.Identity if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } result := make([]BindingInfo, 0, len(identities)) for _, id := range identities { result = append(result, BindingInfo{ Provider: id.Provider, ProviderName: id.ProviderName, Avatar: id.Avatar, Email: id.Email, CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"), }) } return result, nil } // ==================== 辅助函数 ==================== // ThirdPartyUserInfo 第三方用户信息 type ThirdPartyUserInfo struct { Provider string ID string Name string Email string Avatar string Raw map[string]any } func generateState() string { b := make([]byte, 16) rand.Read(b) return hex.EncodeToString(b) } func saveState(state string, data map[string]any) error { key := "oauth:state:" + state return cache.SetObject(key, data, 10*time.Minute) } func verifyState(state string) (map[string]any, error) { key := "oauth:state:" + state var data map[string]any if err := cache.GetObject(key, &data); err != nil { return nil, err } // 验证后删除 cache.Delete(key) return data, nil } func buildAuthURL(provider, state string) (string, error) { var pc cfg.OAuthProviderConfig switch provider { case "google": pc = cfg.Config.Providers.Google case "github": pc = cfg.Config.Providers.GitHub case "wechat": pc = cfg.Config.Providers.WeChat default: return "", vigo.ErrArgInvalid.WithString("unsupported provider: " + provider) } if !pc.Enabled { return "", vigo.ErrArgInvalid.WithString("provider not enabled: " + provider) } params := url.Values{} params.Set("client_id", pc.ClientID) params.Set("redirect_uri", getRedirectURI(provider)) params.Set("response_type", "code") params.Set("state", state) params.Set("scope", strings.Join(pc.Scopes, " ")) // 特殊处理 if provider == "wechat" { params.Set("appid", pc.ClientID) } return pc.AuthURL + "?" + params.Encode(), nil } func getRedirectURI(provider string) string { return "/auth/callback/" + provider } func exchangeAndGetUserInfo(provider, code string) (*ThirdPartyUserInfo, error) { switch provider { case "google": return exchangeGoogle(code) case "github": return exchangeGitHub(code) case "wechat": return exchangeWeChat(code) default: return nil, fmt.Errorf("unsupported provider: %s", provider) } } func exchangeGoogle(code string) (*ThirdPartyUserInfo, error) { pc := cfg.Config.Providers.Google // 交换access_token data := url.Values{} data.Set("code", code) data.Set("client_id", pc.ClientID) data.Set("client_secret", pc.ClientSecret) data.Set("redirect_uri", getRedirectURI("google")) data.Set("grant_type", "authorization_code") resp, err := http.Post(pc.TokenURL, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) if err != nil { return nil, err } defer resp.Body.Close() var tokenResp struct { AccessToken string `json:"access_token"` Error string `json:"error"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return nil, err } if tokenResp.Error != "" { return nil, fmt.Errorf("token error: %s", tokenResp.Error) } // 获取用户信息 req, _ := http.NewRequest("GET", pc.UserInfoURL, nil) req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) userResp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer userResp.Body.Close() var userData struct { Sub string `json:"sub"` Name string `json:"name"` Email string `json:"email"` Picture string `json:"picture"` } if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil { return nil, err } return &ThirdPartyUserInfo{ Provider: "google", ID: userData.Sub, Name: userData.Name, Email: userData.Email, Avatar: userData.Picture, }, nil } func exchangeGitHub(code string) (*ThirdPartyUserInfo, error) { pc := cfg.Config.Providers.GitHub // 交换access_token data := url.Values{} data.Set("code", code) data.Set("client_id", pc.ClientID) data.Set("client_secret", pc.ClientSecret) data.Set("redirect_uri", getRedirectURI("github")) req, _ := http.NewRequest("POST", pc.TokenURL, strings.NewReader(data.Encode())) req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() var tokenResp struct { AccessToken string `json:"access_token"` Error string `json:"error"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return nil, err } if tokenResp.Error != "" { return nil, fmt.Errorf("token error: %s", tokenResp.Error) } // 获取用户信息 req, _ = http.NewRequest("GET", pc.UserInfoURL, nil) req.Header.Set("Authorization", "token "+tokenResp.AccessToken) userResp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer userResp.Body.Close() var userData struct { ID int `json:"id"` Login string `json:"login"` Name string `json:"name"` Email string `json:"email"` AvatarURL string `json:"avatar_url"` } if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil { return nil, err } return &ThirdPartyUserInfo{ Provider: "github", ID: fmt.Sprintf("%d", userData.ID), Name: userData.Name, Email: userData.Email, Avatar: userData.AvatarURL, }, nil } func exchangeWeChat(code string) (*ThirdPartyUserInfo, error) { pc := cfg.Config.Providers.WeChat // 交换access_token urlStr := fmt.Sprintf("%s?appid=%s&secret=%s&code=%s&grant_type=authorization_code", pc.TokenURL, pc.ClientID, pc.ClientSecret, code) resp, err := http.Get(urlStr) if err != nil { return nil, err } defer resp.Body.Close() var tokenResp struct { AccessToken string `json:"access_token"` OpenID string `json:"openid"` UnionID string `json:"unionid"` ErrCode int `json:"errcode"` ErrMsg string `json:"errmsg"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return nil, err } if tokenResp.ErrCode != 0 { return nil, fmt.Errorf("wechat error: %s", tokenResp.ErrMsg) } // 获取用户信息 userInfoURL := fmt.Sprintf("%s?access_token=%s&openid=%s", pc.UserInfoURL, tokenResp.AccessToken, tokenResp.OpenID) userResp, err := http.Get(userInfoURL) if err != nil { return nil, err } defer userResp.Body.Close() var userData struct { OpenID string `json:"openid"` Nickname string `json:"nickname"` HeadImg string `json:"headimgurl"` UnionID string `json:"unionid"` ErrCode int `json:"errcode"` ErrMsg string `json:"errmsg"` } if err := json.NewDecoder(userResp.Body).Decode(&userData); err != nil { return nil, err } if userData.ErrCode != 0 { return nil, fmt.Errorf("wechat error: %s", userData.ErrMsg) } return &ThirdPartyUserInfo{ Provider: "wechat", ID: tokenResp.UnionID, Name: userData.Nickname, Avatar: userData.HeadImg, }, nil } func findIdentity(provider, providerUID string) (*models.Identity, error) { var identity models.Identity if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } return &identity, nil } func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error { // 检查是否已绑定 var existing models.Identity if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil { // 已绑定到其他账号 if existing.UserID != userID { return vigo.ErrForbidden.WithString("this account is already bound to another user") } // 已绑定到当前账号,更新信息 existing.ProviderName = userInfo.Name existing.Avatar = userInfo.Avatar existing.Email = userInfo.Email return cfg.DB().Save(&existing).Error } // 创建新绑定 identity := &models.Identity{ UserID: userID, Provider: provider, ProviderUID: userInfo.ID, ProviderName: userInfo.Name, Avatar: userInfo.Avatar, Email: userInfo.Email, } return cfg.DB().Create(identity).Error } func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) { // 查找用户 var user models.User if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil { return nil, vigo.ErrNotFound.WithString("user not found") } if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 生成token authResp, err := generateAuthResponse(x, &user) if err != nil { return nil, err } return &CallbackResponse{ AuthResponse: authResp, }, nil } func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) { // 获取用户的组织信息 orgs, err := getUserOrgs(user.ID) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } orgClaims := make([]jwt.OrgClaim, 0, len(orgs)) for _, org := range orgs { orgClaims = append(orgClaims, jwt.OrgClaim{ OrgID: org.OrgID, Code: org.Code, Name: org.Name, Roles: org.Roles, Status: org.Status, }) } tokenPair, err := jwt.GenerateTokenPair( user.ID, user.Username, user.Nickname, user.Avatar, user.Email, orgClaims, ) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return &AuthResponse{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, ExpiresIn: tokenPair.ExpiresIn, User: &UserInfo{ ID: user.ID, Username: user.Username, Nickname: user.Nickname, Email: user.Email, Avatar: user.Avatar, }, }, nil } func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) { // 生成临时token用于后续绑定 token := generateRandomToken(32) key := "oauth:bind:" + token data := map[string]any{ "provider": provider, "provider_id": userInfo.ID, "name": userInfo.Name, "email": userInfo.Email, "avatar": userInfo.Avatar, "raw": userInfo.Raw, } if err := cache.SetObject(key, data, 10*time.Minute); err != nil { return "", err } return token, nil } func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) { // 验证临时token并返回用户信息 key := "oauth:bind:" + token var data map[string]any if err := cache.GetObject(key, &data); err != nil { return nil, err } // 验证后删除 cache.Delete(key) return &ThirdPartyUserInfo{ Provider: data["provider"].(string), ID: data["provider_id"].(string), Name: data["name"].(string), Email: data["email"].(string), Avatar: data["avatar"].(string), Raw: data["raw"].(map[string]any), }, nil } func generateRandomPassword(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*" b := make([]byte, length) rand.Read(b) for i := range b { b[i] = charset[b[i]%byte(len(charset))] } return string(b) } func generateRandomToken(length int) string { b := make([]byte, length) rand.Read(b) return hex.EncodeToString(b) }