diff --git a/sqle/api/controller/v1/pipeline.go b/sqle/api/controller/v1/pipeline.go index 9559290bd1..292fa05532 100644 --- a/sqle/api/controller/v1/pipeline.go +++ b/sqle/api/controller/v1/pipeline.go @@ -3,11 +3,11 @@ package v1 import ( "context" "fmt" - v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1" - "github.com/actiontech/sqle/sqle/errors" "net/http" "strconv" + "github.com/actiontech/sqle/sqle/errors" + "github.com/actiontech/sqle/sqle/api/controller" "github.com/actiontech/sqle/sqle/dms" "github.com/actiontech/sqle/sqle/server/pipeline" @@ -234,19 +234,10 @@ func GetPipelines(c echo.Context) error { if err != nil { return errors.New(errors.ConnectStorageError, fmt.Errorf("check get pipelines failed: %v", err)) } - userId := "" - if !userPermission.CanViewProject() { - userId = user.GetIDStr() - } - rangeDatasourceIds := make([]string, 0) - viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline) - if viewPipelinePermission != nil { - userId = "" - rangeDatasourceIds = viewPipelinePermission.RangeUids - } - // 4. 获取存储对象并查询流水线列表 + + // 3. 获取存储对象并查询流水线列表 var pipelineSvc pipeline.PipelineSvc - count, pipelineList, err := pipelineSvc.GetPipelineList(limit, offset, req.FuzzySearchNameDesc, projectUid, userId, rangeDatasourceIds) + count, pipelineList, err := pipelineSvc.GetPipelineListWithPermission(limit, offset, req.FuzzySearchNameDesc, projectUid, userPermission, user.GetIDStr()) if err != nil { return controller.JSONBaseErrorReq(c, err) } diff --git a/sqle/model/pipline.go b/sqle/model/pipline.go index 8c16050c0f..df434014da 100644 --- a/sqle/model/pipline.go +++ b/sqle/model/pipline.go @@ -101,22 +101,50 @@ func isValidAuditMethod(a string) bool { return false } -func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string) ([]*Pipeline, uint64, error) { +func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string, canViewAll bool) ([]*Pipeline, uint64, error) { var count int64 var pipelines []*Pipeline query := s.db.Model(&Pipeline{}).Where("project_uid = ?", projectID) - if userId != "" { - query = query.Where("create_user_id = ? OR create_user_id IS NULL", userId) - } + + // 1. 模糊搜索 if fuzzySearchContent != "" { query = query.Where("name LIKE ? OR description LIKE ?", "%"+fuzzySearchContent+"%", "%"+fuzzySearchContent+"%") } - if len(rangeDatasourceIds) > 0 { - query = query.Joins("JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id"). - Where("pipeline_nodes.instance_id IN (?)", rangeDatasourceIds). - Group("pipelines.id") + + // 2. 权限过滤 + if !canViewAll { + if len(rangeDatasourceIds) > 0 { + // 有数据源权限的用户可以看到: + // 1. 包含权限范围内数据源的流水线(通过LEFT JOIN匹配) + // 2. 自己创建的所有流水线 + // 3. 所有节点都是离线节点的流水线(通过NOT EXISTS检查) + query = query. + Joins("LEFT JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id"). + Where(` + pipeline_nodes.instance_id IN (?) OR + pipelines.create_user_id = ? OR + NOT EXISTS ( + SELECT 1 FROM pipeline_nodes pn2 + WHERE pn2.pipeline_id = pipelines.id + AND pn2.instance_id != 0 + )`, rangeDatasourceIds, userId). + Group("pipelines.id") // 去重,因为LEFT JOIN可能产生重复记录 + } else if userId != "" { + // 普通用户只能看到: + // 1. 自己创建的流水线 + // 2. 所有节点都是离线节点的流水线 + query = query.Where(` + create_user_id = ? OR + NOT EXISTS ( + SELECT 1 FROM pipeline_nodes pn + WHERE pn.pipeline_id = pipelines.id + AND pn.instance_id != 0 + )`, userId) + } } + // canViewAll = true 时不添加任何过滤条件 + // 3. 统计和分页查询 err := query.Count(&count).Error if err != nil { return pipelines, uint64(count), errors.New(errors.ConnectStorageError, err) @@ -169,6 +197,27 @@ func (s *Storage) GetPipelineNodesByInstanceId(instanceID uint64) ([]*PipelineNo return nodes, nil } +// GetPipelineNodesInBatch 批量获取多个流水线的节点 +func (s *Storage) GetPipelineNodesInBatch(pipelineIDs []uint) (map[uint][]*PipelineNode, error) { + if len(pipelineIDs) == 0 { + return make(map[uint][]*PipelineNode), nil + } + + var nodes []*PipelineNode + err := s.db.Model(PipelineNode{}).Where("pipeline_id IN (?)", pipelineIDs).Find(&nodes).Error + if err != nil { + return nil, errors.New(errors.ConnectStorageError, err) + } + + // 按pipeline_id分组 + nodeMap := make(map[uint][]*PipelineNode) + for _, node := range nodes { + nodeMap[node.PipelineID] = append(nodeMap[node.PipelineID], node) + } + + return nodeMap, nil +} + func (s *Storage) CreatePipeline(pipeline *Pipeline, nodes []*PipelineNode) error { return s.Tx(func(txDB *gorm.DB) error { // 4.1 保存 Pipeline 到数据库 diff --git a/sqle/server/pipeline/pipeline.go b/sqle/server/pipeline/pipeline.go index 111e087f1b..da6b64a6d4 100644 --- a/sqle/server/pipeline/pipeline.go +++ b/sqle/server/pipeline/pipeline.go @@ -9,6 +9,7 @@ import ( "github.com/actiontech/sqle/sqle/errors" + v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1" dmsCommonJwt "github.com/actiontech/dms/pkg/dms-common/api/jwt" "github.com/actiontech/sqle/sqle/api/controller" scannerCmd "github.com/actiontech/sqle/sqle/cmd/scannerd/command" @@ -235,19 +236,59 @@ func (svc PipelineSvc) GetPipeline(projectUID string, pipelineID uint) (*Pipelin return svc.toPipeline(modelPipeline, modelPiplineNodes), nil } -func (svc PipelineSvc) GetPipelineList(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userId string, rangeDatasourceIds []string) (count uint64, pipelines []*Pipeline, err error) { +// GetPipelineListWithPermission 根据用户权限获取流水线列表 +func (svc PipelineSvc) GetPipelineListWithPermission(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userPermission *dms.UserPermission, userId string) (count uint64, pipelines []*Pipeline, err error) { s := model.GetStorage() - modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, userId, rangeDatasourceIds) + + // 根据用户权限确定查询参数 + var queryUserId string + var rangeDatasourceIds []string + var canViewAll bool + + // 权限判断逻辑 + if userPermission.IsAdmin() || userPermission.IsProjectAdmin() { + // 超级管理员或项目管理员:可以查看所有流水线 + canViewAll = true + } else if viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline); viewPipelinePermission != nil { + // 拥有"查看流水线"权限的普通用户:可以查看指定数据源相关的流水线 + 自己创建的所有流水线 + queryUserId = userId + rangeDatasourceIds = viewPipelinePermission.RangeUids + canViewAll = false + } else { + // 普通用户:只能查看自己创建的流水线 + queryUserId = userId + rangeDatasourceIds = nil + canViewAll = false + } + + // 执行数据库查询 + modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, queryUserId, rangeDatasourceIds, canViewAll) if err != nil { return 0, nil, err } + + // 转换为服务层对象 pipelines = make([]*Pipeline, 0, len(modelPipelines)) + if len(modelPipelines) == 0 { + return count, pipelines, nil + } + + // 收集所有pipeline ID + pipelineIDs := make([]uint, 0, len(modelPipelines)) + for _, mp := range modelPipelines { + pipelineIDs = append(pipelineIDs, mp.ID) + } + + // 批量获取所有节点 + nodesMap, err := s.GetPipelineNodesInBatch(pipelineIDs) + if err != nil { + return 0, nil, err + } + + // 组装结果 for _, modelPipeline := range modelPipelines { - modelPiplineNodes, err := s.GetPipelineNodes(modelPipeline.ID) - if err != nil { - return 0, nil, err - } - pipelines = append(pipelines, svc.toPipeline(modelPipeline, modelPiplineNodes)) + nodes := nodesMap[modelPipeline.ID] + pipelines = append(pipelines, svc.toPipeline(modelPipeline, nodes)) } return count, pipelines, nil }