// // Copyright (C) 2024 veypi // 2025-03-04 16:08:06 // Distributed under terms of the MIT license. // package auth import ( "crypto/rand" "encoding/hex" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" baseauth "github.com/veypi/vbase/auth" "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" "github.com/veypi/vbase/libs/crypto" "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" "gorm.io/gorm" ) // ProviderInfo 第三方登录提供商信息 type ProviderInfo struct { Name string `json:"name"` DisplayName string `json:"display_name"` Icon string `json:"icon"` Enabled bool `json:"enabled"` } // listProviders 获取支持的第三方登录提供商列表 func listProviders(x *vigo.X) ([]ProviderInfo, error) { var dbProviders []models.OAuthProvider if err := cfg.DB().Where("enabled = ?", true).Order("sort_order").Find(&dbProviders).Error; err != nil { return nil, err } providers := make([]ProviderInfo, 0, len(dbProviders)) for _, p := range dbProviders { providers = append(providers, ProviderInfo{ Name: p.Code, DisplayName: p.Name, Icon: p.Icon, Enabled: p.Enabled, }) } return providers, nil } // AuthorizeRequest 第三方登录授权请求 type AuthorizeRequest struct { Provider string `json:"provider" src:"query" desc:"提供商: google/github/wechat"` Redirect string `json:"redirect" src:"query" desc:"登录成功后重定向地址"` BindMode bool `json:"bind_mode" src:"query" desc:"是否为绑定模式"` } // AuthorizeResponse 授权响应 type AuthorizeResponse struct { AuthURL string `json:"auth_url"` State string `json:"state"` } // authorizeThirdParty 获取第三方登录授权URL func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) { provider := strings.ToLower(req.Provider) // 生成state state := generateState() // 存储state到缓存(用于验证回调) stateData := map[string]any{ "provider": provider, "redirect": req.Redirect, "bind_mode": req.BindMode, "created_at": time.Now().Unix(), } // 如果是绑定模式,需要当前用户登录 if req.BindMode { userID := getCurrentUserID(x) if userID == "" { return nil, vigo.ErrUnauthorized.WithString("login required for bind mode") } stateData["user_id"] = userID } // 保存state数据到缓存(10分钟有效) if err := saveState(state, stateData); err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 构建授权URL authURL, err := buildAuthURL(provider, state) if err != nil { return nil, err } return &AuthorizeResponse{ AuthURL: authURL, State: state, }, nil } // CallbackRequest 第三方登录回调请求 type CallbackRequest struct { Provider string `json:"provider" src:"path" desc:"提供商"` Code string `json:"code" src:"query" desc:"授权码"` State string `json:"state" src:"query" desc:"状态值"` Error string `json:"error" src:"query" desc:"错误信息"` } // CallbackResponse 回调响应 type CallbackResponse struct { // 绑定模式 BindMode bool `json:"bind_mode,omitempty"` Bound bool `json:"bound,omitempty"` // 登录模式 NeedBind bool `json:"need_bind,omitempty"` Provider string `json:"provider,omitempty"` ProviderID string `json:"provider_id,omitempty"` AuthURL string `json:"auth_url,omitempty"` // 登录成功 *AuthResponse } // callbackThirdParty 处理第三方登录回调 func callbackThirdParty(x *vigo.X, req *CallbackRequest) (*CallbackResponse, error) { if req.Error != "" { return nil, vigo.ErrInvalidArg.WithString("oauth error: " + req.Error) } if req.Code == "" || req.State == "" { return nil, vigo.ErrInvalidArg.WithString("missing code or state") } // 验证state stateData, err := verifyState(req.State) if err != nil { return nil, vigo.ErrInvalidArg.WithString("invalid or expired state") } provider := stateData["provider"].(string) bindMode := stateData["bind_mode"].(bool) // 交换access_token并获取用户信息 userInfo, err := exchangeAndGetUserInfo(provider, req.Code) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 绑定模式:将第三方身份绑定到当前用户 if bindMode { userID := stateData["user_id"].(string) if err := bindIdentity(userID, provider, userInfo); err != nil { return nil, err } return &CallbackResponse{ BindMode: true, Bound: true, }, nil } // 登录模式:查找是否已绑定 identity, err := findIdentity(provider, userInfo.ID) if err != nil { return nil, err } // 如果已绑定,直接登录 if identity != nil { return loginByIdentity(x, identity) } // 未绑定,返回需要绑定的信息 // 生成临时token用于后续绑定 tempToken, err := generateTempBindToken(provider, userInfo) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return &CallbackResponse{ NeedBind: true, Provider: provider, ProviderID: userInfo.ID, AuthURL: "/auth/bind?token=" + tempToken, }, nil } // BindRequest 绑定请求 type BindRequest struct { TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"` Username string `json:"username" src:"json" desc:"用户名/邮箱/手机号"` Password string `json:"password" src:"json" desc:"密码"` } // bindThirdParty 绑定第三方账号到已有账号 func bindThirdParty(x *vigo.X, req *BindRequest) (*AuthResponse, error) { // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { return nil, vigo.ErrInvalidArg.WithString("invalid or expired token") } // 查找用户 var user models.User query := cfg.DB().Where("username = ? OR email = ? OR phone = ?", req.Username, req.Username, req.Username) if err := query.First(&user).Error; err != nil { return nil, vigo.ErrUnauthorized.WithString("invalid credentials") } // 验证密码 if !crypto.VerifyPassword(req.Password, user.Password) { return nil, vigo.ErrUnauthorized.WithString("invalid credentials") } // 检查用户状态 if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 绑定身份 if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil { return nil, err } // 生成登录token return generateAuthResponse(x, &user) } // BindWithRegisterRequest 绑定并注册新账号(可选功能) type BindWithRegisterRequest struct { TempToken string `json:"temp_token" src:"json" desc:"临时绑定令牌"` Username string `json:"username" src:"json" desc:"用户名"` Email string `json:"email" src:"json" desc:"邮箱"` Phone string `json:"phone" src:"json" desc:"手机号"` } // bindWithRegister 绑定并创建新账号 func bindWithRegister(x *vigo.X, req *BindWithRegisterRequest) (*AuthResponse, error) { // 验证临时token userInfo, err := verifyTempBindToken(req.TempToken) if err != nil { return nil, vigo.ErrInvalidArg.WithString("invalid or expired token") } // 检查用户名是否已存在 var count int64 cfg.DB().Model(&models.User{}).Where("username = ?", req.Username).Count(&count) if count > 0 { return nil, vigo.ErrInvalidArg.WithString("username already exists") } // 检查邮箱是否已存在 if req.Email != "" { cfg.DB().Model(&models.User{}).Where("email = ?", req.Email).Count(&count) if count > 0 { return nil, vigo.ErrInvalidArg.WithString("email already exists") } } // 创建用户(随机密码,需要后续设置) randomPassword := generateRandomPassword(16) hashedPassword, _ := crypto.HashPassword(randomPassword, 12) var email *string if req.Email != "" { email = &req.Email } var phone *string if req.Phone != "" { phone = &req.Phone } user := &models.User{ Username: req.Username, Password: hashedPassword, Email: email, Phone: phone, Nickname: userInfo.Name, Avatar: userInfo.Avatar, Status: models.UserStatusActive, } if user.Nickname == "" { user.Nickname = req.Username } if err := cfg.DB().Create(user).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } // 授予默认角色 "user" if err := baseauth.VBaseAuth.GrantRole(x.Context(), user.ID, "", "user"); err != nil { // 记录错误但允许流程继续 } // 绑定第三方身份 if err := bindIdentity(user.ID, userInfo.Provider, userInfo); err != nil { return nil, err } // 生成登录token return generateAuthResponse(x, user) } // UnbindRequest 解绑请求 type UnbindRequest struct { Provider string `json:"provider" src:"path" desc:"提供商"` } // unbindThirdParty 解除第三方账号绑定 func unbindThirdParty(x *vigo.X, req *UnbindRequest) error { userID := getCurrentUserID(x) if userID == "" { return vigo.ErrUnauthorized } // 删除绑定关系 if err := cfg.DB().Where("user_id = ? AND provider = ?", userID, req.Provider).Delete(&models.Identity{}).Error; err != nil { return vigo.ErrInternalServer.WithError(err) } return nil } // BindingInfo 绑定信息 type BindingInfo struct { Provider string `json:"provider"` ProviderName string `json:"provider_name"` Avatar string `json:"avatar"` Email string `json:"email"` CreatedAt string `json:"created_at"` } // listBindings 获取当前用户的第三方绑定列表 func listBindings(x *vigo.X) ([]BindingInfo, error) { userID := getCurrentUserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } var identities []models.Identity if err := cfg.DB().Where("user_id = ?", userID).Find(&identities).Error; err != nil { return nil, vigo.ErrInternalServer.WithError(err) } result := make([]BindingInfo, 0, len(identities)) for _, id := range identities { result = append(result, BindingInfo{ Provider: id.Provider, ProviderName: id.ProviderName, Avatar: id.Avatar, Email: id.Email, CreatedAt: id.CreatedAt.Format("2006-01-02 15:04:05"), }) } return result, nil } // ==================== 辅助函数 ==================== // ThirdPartyUserInfo 第三方用户信息 type ThirdPartyUserInfo struct { Provider string ID string Name string Email string Avatar string Raw map[string]any } func generateState() string { b := make([]byte, 16) rand.Read(b) return hex.EncodeToString(b) } func saveState(state string, data map[string]any) error { key := "oauth:state:" + state return cache.SetObject(key, data, 10*time.Minute) } func verifyState(state string) (map[string]any, error) { key := "oauth:state:" + state var data map[string]any if err := cache.GetObject(key, &data); err != nil { return nil, err } // 验证后删除 cache.Delete(key) return data, nil } func buildAuthURL(providerCode, state string) (string, error) { // 从数据库获取提供商配置 var provider models.OAuthProvider if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { return "", vigo.ErrInvalidArg.WithString("provider not found or not enabled: " + providerCode) } params := url.Values{} // 特殊处理微信的 appid 参数 if appidParam, ok := provider.ExtraConfig["appid_param"]; ok && appidParam != "" { params.Set(appidParam, provider.ClientID) } else { params.Set("client_id", provider.ClientID) } params.Set("redirect_uri", getRedirectURI(providerCode)) params.Set("response_type", "code") params.Set("state", state) if len(provider.Scopes) > 0 { params.Set("scope", strings.Join(provider.Scopes, " ")) } return provider.AuthURL + "?" + params.Encode(), nil } func getRedirectURI(provider string) string { return "/auth/callback/" + provider } func exchangeAndGetUserInfo(providerCode, code string) (*ThirdPartyUserInfo, error) { // 从数据库获取提供商配置 var provider models.OAuthProvider if err := cfg.DB().Where("code = ? AND enabled = ?", providerCode, true).First(&provider).Error; err != nil { return nil, fmt.Errorf("provider not found: %s", providerCode) } // 统一处理 OAuth 流程 return exchangeGeneric(provider, code) } func exchangeGeneric(provider models.OAuthProvider, code string) (*ThirdPartyUserInfo, error) { // 1. 交换 code 获取 access_token tokenURL := provider.TokenURL redirectURI := provider.RedirectURI if redirectURI == "" { redirectURI = getRedirectURI(provider.Code) } // 解密 ClientSecret clientSecret := provider.ClientSecret if clientSecret != "" { decrypted, err := cfg.Config.Key.Decrypt(clientSecret) if err == nil { clientSecret = decrypted } // 如果解密失败,可能是明文存储的旧数据,继续使用原值 } // 构建请求参数 data := url.Values{} data.Set("code", code) data.Set("client_id", provider.ClientID) data.Set("client_secret", clientSecret) data.Set("redirect_uri", redirectURI) data.Set("grant_type", "authorization_code") // 发送 token 请求 var tokenResp struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Error string `json:"error"` ErrorDesc string `json:"error_description"` } usePost := provider.ExtraConfig["use_post_token"] == "true" var err error if usePost { err = postFormJSON(tokenURL, data, &tokenResp) } else { err = getJSON(tokenURL+"?"+data.Encode(), &tokenResp) } if err != nil { return nil, fmt.Errorf("exchange token failed: %w", err) } if tokenResp.Error != "" { return nil, fmt.Errorf("oauth error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc) } // 2. 获取用户信息 userInfoURL := provider.UserInfoURL var rawUser map[string]any if err := getJSONWithAuth(userInfoURL, tokenResp.AccessToken, tokenResp.TokenType, &rawUser); err != nil { return nil, fmt.Errorf("get user info failed: %w", err) } // 3. 解析字段映射 userInfo := &ThirdPartyUserInfo{ Provider: provider.Code, Raw: rawUser, } // 提取字段 userInfo.ID = extractField(rawUser, provider.UserIDPath, "id") userInfo.Name = extractField(rawUser, provider.UserNamePath, "name") userInfo.Email = extractField(rawUser, provider.UserEmailPath, "email") userInfo.Avatar = extractField(rawUser, provider.UserAvatarPath, "avatar") if userInfo.ID == "" { return nil, fmt.Errorf("failed to extract user id from response") } return userInfo, nil } // extractField 从嵌套 map 中提取字段值 func extractField(data map[string]any, path, defaultKey string) string { if path == "" { path = defaultKey } if path == "" { return "" } // 支持点号分隔的路径,如 "data.employee_id" keys := strings.Split(path, ".") current := data for i, key := range keys { if i == len(keys)-1 { // 最后一个 key if val, ok := current[key]; ok { switch v := val.(type) { case string: return v case float64: return fmt.Sprintf("%.0f", v) case int: return fmt.Sprintf("%d", v) default: return fmt.Sprintf("%v", v) } } return "" } // 继续深入 if next, ok := current[key].(map[string]any); ok { current = next } else { return "" } } return "" } // postFormJSON 发送 POST form 请求并解析 JSON 响应 func postFormJSON(url string, data url.Values, result any) error { resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) if err != nil { return err } defer resp.Body.Close() return json.NewDecoder(resp.Body).Decode(result) } // getJSON 发送 GET 请求并解析 JSON 响应 func getJSON(url string, result any) error { resp, err := http.Get(url) if err != nil { return err } defer resp.Body.Close() return json.NewDecoder(resp.Body).Decode(result) } // getJSONWithAuth 发送带认证的 GET 请求 func getJSONWithAuth(url, token, tokenType string, result any) error { req, err := http.NewRequest("GET", url, nil) if err != nil { return err } // 设置认证头 if tokenType == "Bearer" || tokenType == "" { req.Header.Set("Authorization", "Bearer "+token) } else { // 如 GitHub 使用 token 类型 req.Header.Set("Authorization", tokenType+" "+token) } resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() return json.NewDecoder(resp.Body).Decode(result) } func findIdentity(provider, providerUID string) (*models.Identity, error) { var identity models.Identity if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, providerUID).First(&identity).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } return &identity, nil } func bindIdentity(userID, provider string, userInfo *ThirdPartyUserInfo) error { // 检查是否已绑定 var existing models.Identity if err := cfg.DB().Where("provider = ? AND provider_uid = ?", provider, userInfo.ID).First(&existing).Error; err == nil { // 已绑定到其他账号 if existing.UserID != userID { return vigo.ErrForbidden.WithString("this account is already bound to another user") } // 已绑定到当前账号,更新信息 existing.ProviderName = userInfo.Name existing.Avatar = userInfo.Avatar existing.Email = userInfo.Email return cfg.DB().Save(&existing).Error } // 创建新绑定 identity := &models.Identity{ UserID: userID, Provider: provider, ProviderUID: userInfo.ID, ProviderName: userInfo.Name, Avatar: userInfo.Avatar, Email: userInfo.Email, } return cfg.DB().Create(identity).Error } func loginByIdentity(x *vigo.X, identity *models.Identity) (*CallbackResponse, error) { // 查找用户 var user models.User if err := cfg.DB().First(&user, "id = ?", identity.UserID).Error; err != nil { return nil, vigo.ErrNotFound.WithString("user not found") } if user.Status != models.UserStatusActive { return nil, vigo.ErrForbidden.WithString("user is disabled") } // 生成token authResp, err := generateAuthResponse(x, &user) if err != nil { return nil, err } return &CallbackResponse{ AuthResponse: authResp, }, nil } func generateAuthResponse(x *vigo.X, user *models.User) (*AuthResponse, error) { // 获取用户的组织信息 orgs, err := getUserOrgs(user.ID) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } orgClaims := make([]jwt.OrgClaim, 0, len(orgs)) for _, org := range orgs { orgClaims = append(orgClaims, jwt.OrgClaim{ OrgID: org.OrgID, Code: org.Code, Name: org.Name, Roles: org.Roles, Status: org.Status, }) } tokenPair, err := jwt.GenerateTokenPair( user.ID, user.Username, user.Nickname, user.Avatar, func() string { if user.Email != nil { return *user.Email } return "" }(), orgClaims, ) if err != nil { return nil, vigo.ErrInternalServer.WithError(err) } return &AuthResponse{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, ExpiresIn: tokenPair.ExpiresIn, User: &UserInfo{ ID: user.ID, Username: user.Username, Nickname: user.Nickname, Email: user.Email, Avatar: user.Avatar, }, }, nil } func generateTempBindToken(provider string, userInfo *ThirdPartyUserInfo) (string, error) { // 生成临时token用于后续绑定 token := generateRandomToken(32) key := "oauth:bind:" + token data := map[string]any{ "provider": provider, "provider_id": userInfo.ID, "name": userInfo.Name, "email": userInfo.Email, "avatar": userInfo.Avatar, "raw": userInfo.Raw, } if err := cache.SetObject(key, data, 10*time.Minute); err != nil { return "", err } return token, nil } func verifyTempBindToken(token string) (*ThirdPartyUserInfo, error) { // 验证临时token并返回用户信息 key := "oauth:bind:" + token var data map[string]any if err := cache.GetObject(key, &data); err != nil { return nil, err } // 验证后删除 cache.Delete(key) return &ThirdPartyUserInfo{ Provider: data["provider"].(string), ID: data["provider_id"].(string), Name: data["name"].(string), Email: data["email"].(string), Avatar: data["avatar"].(string), Raw: data["raw"].(map[string]any), }, nil } func generateRandomPassword(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*" b := make([]byte, length) rand.Read(b) for i := range b { b[i] = charset[b[i]%byte(len(charset))] } return string(b) } func generateRandomToken(length int) string { b := make([]byte, length) rand.Read(b) return hex.EncodeToString(b) }