diff --git a/api/oauth/init.go b/api/oauth/init.go index 508970f..066a1ae 100644 --- a/api/oauth/init.go +++ b/api/oauth/init.go @@ -21,9 +21,9 @@ func init() { // === OAuth 客户端管理(需要认证)=== clientRouter := Router.SubRouter("/clients") - clientRouter.Get("/", "OAuth客户端列表", auth.VBaseAuth.Perm("oauth:client:read"), listClients) - clientRouter.Post("/", "创建OAuth客户端", auth.VBaseAuth.Perm("oauth:client:create"), createClient) - clientRouter.Get("/{client_id}", "获取客户端详情", auth.VBaseAuth.Perm("oauth:client:read"), getClient) - clientRouter.Patch("/{client_id}", "更新OAuth客户端", auth.VBaseAuth.Perm("oauth:client:update"), updateClient) - clientRouter.Delete("/{client_id}", "删除OAuth客户端", auth.VBaseAuth.Perm("oauth:client:delete"), deleteClient) + clientRouter.Get("/", "OAuth客户端列表", auth.VBaseAuth.Perm("oauth-client:read"), listClients) + clientRouter.Post("/", "创建OAuth客户端", auth.VBaseAuth.Perm("oauth-client:create"), createClient) + clientRouter.Get("/{client_id}", "获取客户端详情", auth.VBaseAuth.Perm("oauth-client:read"), getClient) + clientRouter.Patch("/{client_id}", "更新OAuth客户端", auth.VBaseAuth.Perm("oauth-client:update"), updateClient) + clientRouter.Delete("/{client_id}", "删除OAuth客户端", auth.VBaseAuth.Perm("oauth-client:delete"), deleteClient) } diff --git a/auth/auth.go b/auth/auth.go index ba142a0..c12416d 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -9,6 +9,7 @@ package auth import ( "context" "fmt" + "regexp" "strings" "time" @@ -78,7 +79,7 @@ var ( {Code: "user", Name: "普通用户", Policies: []string{ "user:read", "user:update", "org:read", "org:create", - "oauth:client:read", "oauth:client:create", "oauth:client:update", "oauth:client:delete", + "oauth-client:read", "oauth-client:create", "oauth-client:update", "oauth-client:delete", }}, }, }) @@ -94,6 +95,13 @@ func (f *authFactory) New(appKey string, config models.AppConfig) Auth { return f.apps[appKey] } + // 验证默认角色中的权限格式 + for _, role := range config.DefaultRoles { + for _, policy := range role.Policies { + validatePermissionID(policy) + } + } + auth := &appAuth{ appKey: appKey, config: config, @@ -102,6 +110,28 @@ func (f *authFactory) New(appKey string, config models.AppConfig) Auth { return auth } +var ( + validResourceRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`) +) + +func validatePermissionID(permissionID string) { + if permissionID == "*:*" { + return + } + parts := strings.Split(permissionID, ":") + // 允许 app:resource:action 或 resource:action 格式 + // 如果是 app:resource:action,则 parts 长度为 3 + // 如果是 resource:action,则 parts 长度为 2 + if len(parts) != 2 && len(parts) != 3 { + panic(fmt.Sprintf("invalid permission format: %s, expected 'resource:action' or 'app:resource:action'", permissionID)) + } + + resource := parts[len(parts)-2] + if !validResourceRegex.MatchString(resource) { + panic(fmt.Sprintf("invalid resource identifier: %s, must start with letter and contain only letters, numbers, '-' or '_'", resource)) + } +} + // Init 初始化所有注册的权限配置 // - 检查不同 app 之间是否有冲突 // - 同步 Permission 到数据库 @@ -235,67 +265,78 @@ func (a *appAuth) initRole(roleDef models.RoleDefinition) error { // ========== 中间件实现 ========== func (a *appAuth) Perm(permissionID string) func(*vigo.X) error { + validatePermissionID(permissionID) return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrUnauthorized + return vigo.ErrNotAuthorized } orgID := getOrgID(x) - - ok, err := a.CheckPermission(x.Context(), userID, orgID, permissionID, "") - if err != nil { + if err := a.checkPermission(x.Context(), userID, orgID, permissionID, ""); err != nil { return err } - if !ok { - return vigo.ErrForbidden - } return nil } } func (a *appAuth) PermWithOwner(permissionID, ownerKey string) func(*vigo.X) error { + validatePermissionID(permissionID) return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrUnauthorized + return vigo.ErrNotAuthorized } orgID := getOrgID(x) - // 先检查是否有权限 - ok, err := a.CheckPermission(x.Context(), userID, orgID, permissionID, "") - if err != nil { + // 检查是否有基本权限 + if err := a.checkPermission(x.Context(), userID, orgID, permissionID, ""); err != nil { return err } - if !ok { - return vigo.ErrForbidden - } - // 检查是否是所有者或管理员 + // 获取资源所有者ID ownerID, _ := x.Get(ownerKey).(string) if ownerID == "" { ownerID = x.PathParams.Get(ownerKey) } + + // 如果是所有者,直接放行 if ownerID == userID { return nil } - // 检查是否是管理员 - isAdmin, _ := a.isAdmin(x.Context(), userID, orgID) - if isAdmin { + // 如果不是所有者,且拥有全局管理权限(如admin),也可以放行 + // 这里简化为再次检查是否有更高级别的权限,或者该权限本身隐含了管理权 + // 实际上,CheckPermission 已经检查了用户是否拥有该 permissionID + // 如果设计上 PermWithOwner 意味着 "所有者 OR 拥有该权限的管理员", + // 那么前面的 CheckPermission 已经保证了 "拥有该权限" + // 但通常 Owner 权限是针对特定资源的,而 CheckPermission 检查的是通用权限 + // 这里逻辑稍微有点混淆,通常 PermWithOwner 意思是: + // 1. 用户必须登录 + // 2. 如果用户是资源所有者,允许 + // 3. 如果用户不是所有者,必须拥有特定权限 (permissionID) + + // 修正逻辑: + if ownerID == userID { return nil } - return vigo.ErrForbidden + // 不是所有者,检查是否有权限 + if err := a.checkPermission(x.Context(), userID, orgID, permissionID, ""); err != nil { + return err + } + + return nil } } func (a *appAuth) PermOnResource(permissionID, resourceKey string) func(*vigo.X) error { + validatePermissionID(permissionID) return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrUnauthorized + return vigo.ErrNotAuthorized } orgID := getOrgID(x) @@ -307,63 +348,72 @@ func (a *appAuth) PermOnResource(permissionID, resourceKey string) func(*vigo.X) resourceID = x.Request.URL.Query().Get(resourceKey) } - // 如果没有获取到 resourceID,仍然进行检查 (resourceID="") - // 这意味着检查用户是否拥有该权限的一般访问权 (例如通过角色获得) - // 如果想要强制检查特定资源,调用方应该确保 resourceKey 能获取到值 - - ok, err := a.CheckPermission(x.Context(), userID, orgID, permissionID, resourceID) - if err != nil { + if err := a.checkPermission(x.Context(), userID, orgID, permissionID, resourceID); err != nil { return err } - if !ok { - return vigo.ErrForbidden - } return nil } } +// 内部辅助检查方法,返回 error 以便于统一处理错误响应 +func (a *appAuth) checkPermission(ctx context.Context, userID, orgID, permissionID, resourceID string) error { + ok, err := a.CheckPermission(ctx, userID, orgID, permissionID, resourceID) + if err != nil { + return vigo.ErrInternalServer.WithError(err) + } + if !ok { + return vigo.ErrForbidden + } + return nil +} + func (a *appAuth) PermAny(permissionIDs []string) func(*vigo.X) error { + for _, pid := range permissionIDs { + validatePermissionID(pid) + } return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrUnauthorized + return vigo.ErrNotAuthorized } orgID := getOrgID(x) - - for _, permID := range permissionIDs { - ok, err := a.CheckPermission(x.Context(), userID, orgID, permID, "") - if err != nil { - return err - } - if ok { + var lastErr error + for _, pid := range permissionIDs { + if err := a.checkPermission(x.Context(), userID, orgID, pid, ""); err == nil { return nil + } else { + lastErr = err } } + if lastErr != nil { + // 如果是 Forbidden 错误,返回 Forbidden + // 否则返回最后一个错误 + // 这里简单处理,如果所有都失败,返回 Forbidden + return vigo.ErrForbidden + } return vigo.ErrForbidden } } func (a *appAuth) PermAll(permissionIDs []string) func(*vigo.X) error { + for _, pid := range permissionIDs { + validatePermissionID(pid) + } return func(x *vigo.X) error { userID := getUserID(x) if userID == "" { - return vigo.ErrUnauthorized + return vigo.ErrNotAuthorized } orgID := getOrgID(x) - for _, permID := range permissionIDs { - ok, err := a.CheckPermission(x.Context(), userID, orgID, permID, "") - if err != nil { + for _, pid := range permissionIDs { + if err := a.checkPermission(x.Context(), userID, orgID, pid, ""); err != nil { return err } - if !ok { - return vigo.ErrForbidden - } } - return nil } }