From ec47bcc192603f636cdd6d948e5701258cbe44a2 Mon Sep 17 00:00:00 2001 From: veypi Date: Fri, 20 Feb 2026 03:20:34 +0800 Subject: [PATCH] refactor: Migrate auth to Vigo auth interface and simplify initialization - Replace GetUserID/GetOrgID with VBaseAuth.UserID/OrgID methods across all APIs - Integrate vigoauth.Auth interface into appAuth for standard auth methods - Move AuthMiddleware to PermLogin method in auth package - Add role management methods: GetRole, ListRoles, GrantRoles, RevokeRoles, ListUserRoles - Update ListUserPermissions and ListResourceUsers to return vigoauth types - Export Redis client in cfg package - Simplify app initialization by separating vigo.New in cli/main.go - Remove deprecated auth/middleware.go file --- api/auth/me.go | 14 +- api/auth/thirdparty.go | 6 +- api/init.go | 2 +- api/oauth/authorize.go | 2 +- api/oauth/client.go | 6 +- api/oauth/oidc.go | 2 +- api/org/create.go | 2 +- api/user/get.go | 4 +- api/user/patch.go | 4 +- auth/auth.go | 303 ++++++++++++++++++++++++++++------------- auth/middleware.go | 58 -------- cfg/cfg.go | 2 + cli/main.go | 5 +- init.go | 15 +- 14 files changed, 242 insertions(+), 183 deletions(-) delete mode 100644 auth/middleware.go diff --git a/api/auth/me.go b/api/auth/me.go index 968f23e..d445643 100644 --- a/api/auth/me.go +++ b/api/auth/me.go @@ -33,7 +33,7 @@ type UserInfoWithPerms struct { // me 获取当前用户信息 func me(x *vigo.X) (*UserInfoWithPerms, error) { - userID := baseAuth.GetUserID(x) + userID := baseAuth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } @@ -43,7 +43,7 @@ func me(x *vigo.X) (*UserInfoWithPerms, error) { return nil, vigo.ErrNotFound } - orgID := baseAuth.GetOrgID(x) + orgID := baseAuth.VBaseAuth.OrgID(x) // 获取用户权限列表 perms, err := baseAuth.VBaseAuth.ListUserPermissions(x.Context(), userID, orgID) @@ -51,11 +51,13 @@ func me(x *vigo.X) (*UserInfoWithPerms, error) { return nil, vigo.ErrInternalServer.WithError(err) } - // 转换权限格式 + // 转换权限格式 (Vigo 接口返回 vigoauth.UserPermission,需要转换回 PermissionID 格式) userPerms := make([]UserPermissionInfo, 0, len(perms)) for _, p := range perms { + // 从 Resource 构造 PermissionID,格式为 vb:resource:* + permissionID := "vb:" + p.Resource + ":*" userPerms = append(userPerms, UserPermissionInfo{ - PermissionID: p.PermissionID, + PermissionID: permissionID, ResourceID: p.ResourceID, }) } @@ -87,7 +89,7 @@ type UpdateMeRequest struct { // updateMe 更新当前用户信息 func updateMe(x *vigo.X, req *UpdateMeRequest) (*UserInfoWithPerms, error) { - userID := baseAuth.GetUserID(x) + userID := baseAuth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } @@ -124,7 +126,7 @@ type ChangePasswordRequest struct { // changePassword 修改密码 func changePassword(x *vigo.X, req *ChangePasswordRequest) error { - userID := baseAuth.GetUserID(x) + userID := baseAuth.VBaseAuth.UserID(x) if userID == "" { return vigo.ErrUnauthorized } diff --git a/api/auth/thirdparty.go b/api/auth/thirdparty.go index a1ee6c4..5597a40 100644 --- a/api/auth/thirdparty.go +++ b/api/auth/thirdparty.go @@ -83,7 +83,7 @@ func authorizeThirdParty(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, // 如果是绑定模式,需要当前用户登录 if req.BindMode { - userID := baseauth.GetUserID(x) + userID := baseauth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized.WithString("login required for bind mode") } @@ -318,7 +318,7 @@ type UnbindRequest struct { // unbindThirdParty 解除第三方账号绑定 func unbindThirdParty(x *vigo.X, req *UnbindRequest) error { - userID := baseauth.GetUserID(x) + userID := baseauth.VBaseAuth.UserID(x) if userID == "" { return vigo.ErrUnauthorized } @@ -342,7 +342,7 @@ type BindingInfo struct { // listBindings 获取当前用户的第三方绑定列表 func listBindings(x *vigo.X) ([]BindingInfo, error) { - userID := baseauth.GetUserID(x) + userID := baseauth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } diff --git a/api/init.go b/api/init.go index ecd6c55..eb0eb2b 100644 --- a/api/init.go +++ b/api/init.go @@ -23,7 +23,7 @@ var Router = vigo.NewRouter() func init() { // 注册全局中间件 - Router.Use(auth.AuthMiddleware()) + Router.Use(auth.VBaseAuth.PermLogin) Router.After(common.JsonResponse, common.JsonErrorResponse) // 子路由挂载 diff --git a/api/oauth/authorize.go b/api/oauth/authorize.go index 63c166f..5ef451f 100644 --- a/api/oauth/authorize.go +++ b/api/oauth/authorize.go @@ -40,7 +40,7 @@ func authorize(x *vigo.X, req *AuthorizeRequest) (*AuthorizeResponse, error) { } // 获取当前用户 - userID := auth.GetUserID(x) + userID := auth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } diff --git a/api/oauth/client.go b/api/oauth/client.go index 83886d2..a25a516 100644 --- a/api/oauth/client.go +++ b/api/oauth/client.go @@ -66,7 +66,7 @@ type CreateClientResponse struct { } func createClient(x *vigo.X, req *CreateClientRequest) (*CreateClientResponse, error) { - ownerID := auth.GetUserID(x) + ownerID := auth.VBaseAuth.UserID(x) if ownerID == "" { return nil, vigo.ErrUnauthorized } @@ -123,7 +123,7 @@ func updateClient(x *vigo.X, req *UpdateClientRequest) (*models.OAuthClient, err } // 检查权限:只有所有者或管理员可以修改 - currentUserID := auth.GetUserID(x) + currentUserID := auth.VBaseAuth.UserID(x) if currentUserID == "" { return nil, vigo.ErrUnauthorized } @@ -170,7 +170,7 @@ func deleteClient(x *vigo.X, req *DeleteClientRequest) error { } // 检查权限:只有所有者或管理员可以删除 - currentUserID := auth.GetUserID(x) + currentUserID := auth.VBaseAuth.UserID(x) if currentUserID == "" { return vigo.ErrUnauthorized } diff --git a/api/oauth/oidc.go b/api/oauth/oidc.go index 53dc2ae..3e39cf4 100644 --- a/api/oauth/oidc.go +++ b/api/oauth/oidc.go @@ -14,7 +14,7 @@ import ( // UserInfo OIDC用户信息 func userInfo(x *vigo.X) (map[string]any, error) { // 从token中解析用户ID - userID := auth.GetUserID(x) + userID := auth.VBaseAuth.UserID(x) if userID == "" { return nil, vigo.ErrUnauthorized } diff --git a/api/org/create.go b/api/org/create.go index f62a2f9..4ae94b2 100644 --- a/api/org/create.go +++ b/api/org/create.go @@ -28,7 +28,7 @@ func create(x *vigo.X, req *CreateRequest) (*models.Org, error) { } // 获取当前用户ID作为所有者 - ownerID := auth.GetUserID(x) + ownerID := auth.VBaseAuth.UserID(x) if ownerID == "" { return nil, vigo.ErrUnauthorized } diff --git a/api/user/get.go b/api/user/get.go index de479f1..2ce2dff 100644 --- a/api/user/get.go +++ b/api/user/get.go @@ -21,9 +21,9 @@ type GetRequest struct { // get 获取用户详情 func get(x *vigo.X, req *GetRequest) (*models.User, error) { // 手动鉴权: 只能查看自己的信息,或者是管理员 - uid := auth.GetUserID(x) + uid := auth.VBaseAuth.UserID(x) if uid != req.UserID { - if !auth.VBaseAuth.CheckPerm(x.Context(), uid, auth.GetOrgID(x), "user:read", "") { + if !auth.VBaseAuth.CheckPerm(x.Context(), uid, auth.VBaseAuth.OrgID(x), "user:read", "") { return nil, vigo.ErrForbidden } } diff --git a/api/user/patch.go b/api/user/patch.go index 5a9e177..b75f7f8 100644 --- a/api/user/patch.go +++ b/api/user/patch.go @@ -25,9 +25,9 @@ type PatchRequest struct { // patch 更新用户 func patch(x *vigo.X, req *PatchRequest) (*models.User, error) { // 手动鉴权: 只能修改自己的信息,或者是管理员 - uid := auth.GetUserID(x) + uid := auth.VBaseAuth.UserID(x) if uid != req.UserID { - if !auth.VBaseAuth.CheckPerm(x.Context(), uid, auth.GetOrgID(x), "user:update", "") { + if !auth.VBaseAuth.CheckPerm(x.Context(), uid, auth.VBaseAuth.OrgID(x), "user:update", "") { return nil, vigo.ErrForbidden } } diff --git a/auth/auth.go b/auth/auth.go index e1840c0..647bd70 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -16,8 +16,10 @@ import ( "github.com/veypi/vbase/cfg" "github.com/veypi/vbase/libs/cache" + "github.com/veypi/vbase/libs/jwt" "github.com/veypi/vbase/models" "github.com/veypi/vigo" + vigoauth "github.com/veypi/vigo/contrib/auth" "github.com/veypi/vigo/contrib/event" "gorm.io/gorm" ) @@ -36,79 +38,40 @@ const ( RoleCodeUser = "user" ) -// ========== 辅助函数 ========== -func GetUserID(x *vigo.X) string { - if userID, ok := x.Get(CtxKeyUserID).(string); ok { - return userID - } - return "" -} +// ========== Token 提取 ========== -func GetOrgID(x *vigo.X) string { - if orgID, ok := x.Get(CtxKeyOrgID).(string); ok { - return orgID +// extractToken 从 Header 或 Query 中提取 JWT token +func extractToken(x *vigo.X) string { + auth := x.Request.Header.Get("Authorization") + if auth != "" { + if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { + return auth[7:] + } } - return "" + return x.Request.URL.Query().Get("access_token") } -func GetOrgRoles(x *vigo.X) []string { - if roles, ok := x.Get(CtxKeyOrgRoles).([]string); ok { - return roles +// getOrgID 从请求中提取组织ID (Header/Query/Path) +func getOrgID(x *vigo.X) string { + orgID := x.Request.Header.Get("X-Org-ID") + if orgID == "" { + orgID = x.Request.URL.Query().Get("org_id") } - return nil + if orgID == "" { + orgID = x.PathParams.Get("org_id") + } + return orgID } -// Auth 权限管理接口 +// Auth 权限管理接口 (继承 Vigo auth.Auth 并扩展 LoadOrg) +// 注意:appAuth 同时实现了此接口和 vigoauth.Auth 接口 type Auth interface { - UserID(x *vigo.X) string - OrgID(x *vigo.X) string + vigoauth.Auth // 加载组织信息 (中间件/手动调用) LoadOrg(x *vigo.X) error - // ========== 中间件生成 ========== - // 基础权限检查 - Perm(permissionID string) func(*vigo.X) error - - // 特定资源权限检查 (自动从 Path/Query 获取资源ID) - PermOnResource(permissionID, resourceKey string) func(*vigo.X) error - - // 满足任一权限 - PermAny(permissionIDs ...string) func(*vigo.X) error - - // 满足所有权限 - PermAll(permissionIDs ...string) func(*vigo.X) error - - // ========== 角色管理 ========== - // 添加角色定义 - // policies 格式: "resource:action",例如 "user:read", "*:*" - AddRole(roleCode, roleName string, policies ...string) error - - // ========== 权限管理 ========== - // 授予角色 - GrantRole(ctx context.Context, userID, orgID, roleCode string) error - - // 撤销角色 - RevokeRole(ctx context.Context, userID, orgID, roleCode string) error - - // 授予特定资源权限 - GrantResourcePerm(ctx context.Context, userID, orgID, permissionID, resourceID string) error - - // 撤销特定资源权限 - RevokeResourcePerm(ctx context.Context, userID, orgID, permissionID, resourceID string) error - - // 撤销用户所有权限 - RevokeAll(ctx context.Context, userID, orgID string) error - - // ========== 权限查询 ========== - // 检查权限 + // 检查权限 (兼容旧接口) CheckPermission(ctx context.Context, userID, orgID, permissionID, resourceID string) bool - CheckPerm(ctx context.Context, userID, orgID, permissionID, resourceID string) bool - - // 列出用户权限 - ListUserPermissions(ctx context.Context, userID, orgID string) ([]models.UserPermissionResult, error) - - // 列出资源授权用户 - ListResourceUsers(ctx context.Context, orgID, permissionID, resourceID string) ([]models.ResourceUser, error) } // 全局 Auth 工厂 @@ -262,6 +225,64 @@ func (a *appAuth) AddRole(roleCode, roleName string, policies ...string) error { return nil } +// GetRole 获取角色定义 +func (a *appAuth) GetRole(roleCode string) (*vigoauth.Role, error) { + roleDef, exists := a.roleDefs[roleCode] + if !exists { + return nil, fmt.Errorf("role not found: %s", roleCode) + } + + // 从数据库获取完整角色信息 + var role models.Role + err := cfg.DB().Where("code = ? AND org_id IS NULL", roleCode).First(&role).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("role not found: %s", roleCode) + } + return nil, err + } + + // 转换策略为 Vigo 格式 + policies := make([]string, 0) + if rolePolicies, ok := a.policies[roleCode]; ok { + for _, p := range rolePolicies { + policies = append(policies, fmt.Sprintf("%s:%s", p[0], p[1])) + } + } + + return &vigoauth.Role{ + Code: roleDef.code, + Name: roleDef.name, + Policies: policies, + Description: roleDef.description, + }, nil +} + +// ListRoles 列出所有角色定义 +func (a *appAuth) ListRoles() ([]*vigoauth.Role, error) { + result := make([]*vigoauth.Role, 0) + for code, roleDef := range a.roleDefs { + if code == "_app_info" { + continue + } + + policies := make([]string, 0) + if rolePolicies, ok := a.policies[code]; ok { + for _, p := range rolePolicies { + policies = append(policies, fmt.Sprintf("%s:%s", p[0], p[1])) + } + } + + result = append(result, &vigoauth.Role{ + Code: roleDef.code, + Name: roleDef.name, + Policies: policies, + Description: roleDef.description, + }) + } + return result, nil +} + // init 初始化应用的权限配置 func (a *appAuth) init() error { // 1. 同步权限定义到数据库 @@ -427,29 +448,59 @@ func (a *appAuth) initRole(roleCode string) error { // ========== 中间件实现 ========== +// PermLogin JWT 认证中间件 +// 解析 token、验证黑名单、设置 CtxKeyUserID +func (a *appAuth) PermLogin(x *vigo.X) error { + // 1. 提取 token + tokenString := extractToken(x) + if tokenString == "" { + return vigo.ErrUnauthorized.WithString("missing token") + } + + // 2. 解析 token + claims, err := jwt.ParseToken(tokenString) + if err != nil { + if err == jwt.ErrExpiredToken { + return vigo.ErrTokenExpired + } + return vigo.ErrTokenInvalid + } + + // 3. 检查 token 黑名单 + if cache.IsEnabled() { + blacklisted, _ := cache.IsTokenBlacklisted(claims.ID) + if blacklisted { + return vigo.ErrUnauthorized.WithString("token has been revoked") + } + } + + // 4. 设置用户ID到上下文 + x.Set(CtxKeyUserID, claims.UserID) + return nil +} + func (a *appAuth) UserID(x *vigo.X) string { - return GetUserID(x) + if userID, ok := x.Get(CtxKeyUserID).(string); ok { + return userID + } + return "" } func (a *appAuth) OrgID(x *vigo.X) string { - return GetOrgID(x) + if orgID, ok := x.Get(CtxKeyOrgID).(string); ok { + return orgID + } + return "" } +// LoadOrg 加载组织信息 func (a *appAuth) LoadOrg(x *vigo.X) error { - orgID := x.Request.Header.Get("X-Org-ID") + orgID := getOrgID(x) if orgID == "" { - orgID = x.Request.URL.Query().Get("org_id") - } - if orgID == "" { - orgID = x.PathParams.Get("org_id") - } - - if orgID == "" { - // 没有指定组织 return vigo.ErrInvalidArg.WithString("missing org_id") } - userID := GetUserID(x) + userID := a.UserID(x) if userID == "" { return vigo.ErrUnauthorized } @@ -478,13 +529,12 @@ func (a *appAuth) LoadOrg(x *vigo.X) error { func (a *appAuth) Perm(permissionID string) func(*vigo.X) error { validatePermissionID(permissionID) return func(x *vigo.X) error { - userID := GetUserID(x) + userID := a.UserID(x) if userID == "" { return vigo.ErrUnauthorized } - orgID := GetOrgID(x) - if err := a.checkPermission(x.Context(), userID, orgID, permissionID, ""); err != nil { + if err := a.checkPermission(x.Context(), userID, "", permissionID, ""); err != nil { return err } return nil @@ -504,12 +554,12 @@ func (a *appAuth) Perm(permissionID string) func(*vigo.X) error { func (a *appAuth) PermOnResource(permissionID, resourceKey string) func(*vigo.X) error { validatePermissionID(permissionID) return func(x *vigo.X) error { - userID := GetUserID(x) + userID := a.UserID(x) if userID == "" { return vigo.ErrUnauthorized } - orgID := GetOrgID(x) + orgID := a.OrgID(x) // 尝试从 PathParams 获取 resourceID := x.PathParams.Get(resourceKey) @@ -542,11 +592,11 @@ func (a *appAuth) PermAny(permissionIDs ...string) func(*vigo.X) error { validatePermissionID(pid) } return func(x *vigo.X) error { - userID := GetUserID(x) + userID := a.UserID(x) if userID == "" { return vigo.ErrUnauthorized } - orgID := GetOrgID(x) + orgID := a.OrgID(x) for _, pid := range permissionIDs { if err := a.checkPermission(x.Context(), userID, orgID, pid, ""); err == nil { @@ -562,11 +612,11 @@ func (a *appAuth) PermAll(permissionIDs ...string) func(*vigo.X) error { validatePermissionID(pid) } return func(x *vigo.X) error { - userID := GetUserID(x) + userID := a.UserID(x) if userID == "" { return vigo.ErrUnauthorized } - orgID := GetOrgID(x) + orgID := a.OrgID(x) for _, pid := range permissionIDs { if err := a.checkPermission(x.Context(), userID, orgID, pid, ""); err != nil { @@ -634,6 +684,49 @@ func (a *appAuth) GrantRole(ctx context.Context, userID, orgID, roleCode string) return nil } +// GrantRoles 批量授予角色 +func (a *appAuth) GrantRoles(ctx context.Context, userID, orgID string, roleCodes ...string) error { + for _, roleCode := range roleCodes { + if err := a.GrantRole(ctx, userID, orgID, roleCode); err != nil { + return err + } + } + return nil +} + +// ListUserRoles 查询用户的角色列表 +func (a *appAuth) ListUserRoles(ctx context.Context, userID, orgID string) ([]string, error) { + var roleIDs []string + query := cfg.DB().Model(&models.UserRole{}). + Where("user_id = ?", userID) + + if orgID != "" { + query = query.Where("org_id = ? OR org_id IS NULL", orgID) + } else { + query = query.Where("org_id IS NULL") + } + + if err := query.Pluck("role_id", &roleIDs).Error; err != nil { + return nil, err + } + + if len(roleIDs) == 0 { + return []string{}, nil + } + + // 获取角色代码 + var roles []models.Role + if err := cfg.DB().Where("id IN ?", roleIDs).Pluck("code", &roles).Error; err != nil { + return nil, err + } + + codes := make([]string, 0, len(roles)) + for _, role := range roles { + codes = append(codes, role.Code) + } + return codes, nil +} + func (a *appAuth) RevokeRole(ctx context.Context, userID, orgID, roleCode string) error { var role models.Role // 优先查找组织特定角色 @@ -669,6 +762,16 @@ func (a *appAuth) RevokeRole(ctx context.Context, userID, orgID, roleCode string return nil } +// RevokeRoles 批量撤销角色 +func (a *appAuth) RevokeRoles(ctx context.Context, userID, orgID string, roleCodes ...string) error { + for _, roleCode := range roleCodes { + if err := a.RevokeRole(ctx, userID, orgID, roleCode); err != nil { + return err + } + } + return nil +} + func (a *appAuth) GrantResourcePerm(ctx context.Context, userID, orgID, permissionID, resourceID string) error { if strings.Count(permissionID, ":") == 1 { permissionID = fmt.Sprintf("%s:%s", a.scope, permissionID) @@ -855,8 +958,8 @@ func (a *appAuth) checkPermissionDB(ctx context.Context, userID, orgID, permissi return userPermCount > 0, nil } -func (a *appAuth) ListUserPermissions(ctx context.Context, userID, orgID string) ([]models.UserPermissionResult, error) { - result := make([]models.UserPermissionResult, 0) +func (a *appAuth) ListUserPermissions(ctx context.Context, userID, orgID string) ([]*vigoauth.UserPermission, error) { + result := make([]*vigoauth.UserPermission, 0) // 1. 获取用户角色对应的权限 var roleIDs []string @@ -876,10 +979,15 @@ func (a *appAuth) ListUserPermissions(ctx context.Context, userID, orgID string) } for _, permID := range permIDs { - result = append(result, models.UserPermissionResult{ - PermissionID: permID, - ResourceID: "*", - Actions: []string{"*"}, + parts := strings.Split(permID, ":") + resource := "*" + if len(parts) >= 2 { + resource = parts[len(parts)-2] + } + result = append(result, &vigoauth.UserPermission{ + Resource: resource, + ResourceID: "*", + Actions: []string{"*"}, }) } } @@ -893,18 +1001,23 @@ func (a *appAuth) ListUserPermissions(ctx context.Context, userID, orgID string) } for _, up := range userPerms { - result = append(result, models.UserPermissionResult{ - PermissionID: up.PermissionID, - ResourceID: up.ResourceID, - Actions: []string{"*"}, + parts := strings.Split(up.PermissionID, ":") + resource := "*" + if len(parts) >= 2 { + resource = parts[len(parts)-2] + } + result = append(result, &vigoauth.UserPermission{ + Resource: resource, + ResourceID: up.ResourceID, + Actions: []string{"*"}, }) } return result, nil } -func (a *appAuth) ListResourceUsers(ctx context.Context, orgID, permissionID, resourceID string) ([]models.ResourceUser, error) { - result := make([]models.ResourceUser, 0) +func (a *appAuth) ListResourceUsers(ctx context.Context, orgID, permissionID, resourceID string) ([]*vigoauth.ResourceUser, error) { + result := make([]*vigoauth.ResourceUser, 0) // 查询有该资源权限的用户 var userPerms []models.UserPermission @@ -923,7 +1036,7 @@ func (a *appAuth) ListResourceUsers(ctx context.Context, orgID, permissionID, re } for userID, actions := range userMap { - result = append(result, models.ResourceUser{ + result = append(result, &vigoauth.ResourceUser{ UserID: userID, Actions: actions, }) diff --git a/auth/middleware.go b/auth/middleware.go deleted file mode 100644 index 345d35d..0000000 --- a/auth/middleware.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2024 veypi -// 2025-03-04 16:08:06 -// Distributed under terms of the MIT license. - -package auth - -import ( - "strings" - - "github.com/veypi/vbase/libs/cache" - "github.com/veypi/vbase/libs/jwt" - "github.com/veypi/vigo" -) - -// AuthMiddleware 统一认证中间件 -// 仅处理 JWT 认证,设置 CtxKeyUserID -// 组织信息的加载需按需调用 Auth.LoadOrg(x) -func AuthMiddleware() func(*vigo.X) error { - return func(x *vigo.X) error { - // === 1. JWT 认证部分 === - tokenString := extractToken(x) - if tokenString == "" { - return vigo.ErrUnauthorized.WithString("missing token") - } - - // 解析token - claims, err := jwt.ParseToken(tokenString) - if err != nil { - if err == jwt.ErrExpiredToken { - return vigo.ErrTokenExpired - } - return vigo.ErrTokenInvalid - } - - // 检查token是否在黑名单中 - if cache.IsEnabled() { - blacklisted, _ := cache.IsTokenBlacklisted(claims.ID) - if blacklisted { - return vigo.ErrUnauthorized.WithString("token has been revoked") - } - } - - // 将用户信息存入上下文 - x.Set(CtxKeyUserID, claims.UserID) - - return nil - } -} - -func extractToken(x *vigo.X) string { - auth := x.Request.Header.Get("Authorization") - if auth != "" { - if len(auth) > 7 && strings.HasPrefix(auth, "Bearer ") { - return auth[7:] - } - } - return x.Request.URL.Query().Get("access_token") -} diff --git a/cfg/cfg.go b/cfg/cfg.go index 4be60b2..420f24f 100644 --- a/cfg/cfg.go +++ b/cfg/cfg.go @@ -75,3 +75,5 @@ var Global = &Options{ } var DB = Global.DB.Client + +var Redis = Global.Redis.Client diff --git a/cli/main.go b/cli/main.go index e49cb5e..ff35bba 100644 --- a/cli/main.go +++ b/cli/main.go @@ -9,8 +9,11 @@ package main import ( "github.com/veypi/vbase" + "github.com/veypi/vbase/cfg" + "github.com/veypi/vigo" ) func main() { - panic(vbase.App.Run()) + app := vigo.New("vbase", vbase.Router, cfg.Global, vbase.Init) + panic(app.Run()) } diff --git a/init.go b/init.go index a47a1e6..4c8ee98 100644 --- a/init.go +++ b/init.go @@ -19,12 +19,15 @@ import ( ) var Router = vigo.NewRouter() - var ( - Auth = auth.Factory - AuthMiddleware = auth.AuthMiddleware + NewAuth = auth.Factory.New + Config = cfg.Global ) +func Init() error { + return models.Migrate() +} + //go:embed ui var uifs embed.FS @@ -34,9 +37,3 @@ func init() { Router.Extend("vhtml", vhtml.Router) vhtml.WrapUI(Router, uifs) } - -var App = vigo.New("vbase", Router, cfg.Global, Init) - -func Init() error { - return models.Migrate() -}