diff --git a/shardy/dialect/mpmd/transforms/optimize/pipeline_schedule.cc b/shardy/dialect/mpmd/transforms/optimize/pipeline_schedule.cc index eef6b3b0..320d8a14 100644 --- a/shardy/dialect/mpmd/transforms/optimize/pipeline_schedule.cc +++ b/shardy/dialect/mpmd/transforms/optimize/pipeline_schedule.cc @@ -133,13 +133,17 @@ bool OneFOneBMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // This guarantees that the transpose counts and call counts of each fragment // are defined. SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2)); + if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) { + SDY_LOG(ERROR) << "Cannot schedule for 1F1B pipelining without stages."; + return false; + } int64_t call_counter_f1 = *TryToFindCallCounter(fragment1); int64_t call_counter_f2 = *TryToFindCallCounter(fragment2); int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1); int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2); const int num_meshes = GetNumMeshes(fragment1); - const int mesh_id = GetMeshIndex(fragment1); + const int stage_id = fragment1.getStageIdAttr().getInt(); // The following two conditions guarantee the forward and backward fragments // are interleaved in the steady state of the pipeline. @@ -148,14 +152,14 @@ bool OneFOneBMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // of microbatch 0 must be scheduled before the forward computation of // microbatch 4: 0 == 4 - 4 + 0. if (transpose_count_f1 == 1 && transpose_count_f2 == 0) { - return call_counter_f1 == call_counter_f2 - num_meshes + mesh_id; + return call_counter_f1 == call_counter_f2 - num_meshes + stage_id; } // Example: in mesh/stage 0 of pipeline of depth 4, the forward computation of // microbatch 5 must be scheduled before the backward computation of // microbatch 2: 5 == 2 + 4 - (0 + 1). if (transpose_count_f1 == 0 && transpose_count_f2 == 1) { - return call_counter_f1 == call_counter_f2 + num_meshes - (mesh_id + 1); + return call_counter_f1 == call_counter_f2 + num_meshes - (stage_id + 1); } // If the fragments have the same transpose count, guarantee that the @@ -197,9 +201,14 @@ bool GPipeMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // Requires: IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2). bool GPipeBut1F1BLastMeshHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { + if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) { + SDY_LOG(ERROR) + << "Cannot schedule for GPipeBut1F1BForLastMesh without stages."; + return false; + } const int num_meshes = GetNumMeshes(fragment1); - const int mesh_id = GetMeshIndex(fragment1); - if (mesh_id == num_meshes - 1) { + const int stage_id = fragment1.getStageIdAttr().getInt(); + if (stage_id == num_meshes - 1) { return OneFOneBMustHappenBefore(fragment1, fragment2); } return GPipeMustHappenBefore(fragment1, fragment2); @@ -219,13 +228,17 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // This guarantees that the transpose counts and call counts of each fragment // are defined. SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2)); + if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) { + SDY_LOG(ERROR) << "Cannot schedule for ZeroBubbleH1 without stages."; + return false; + } int64_t call_counter_f1 = *TryToFindCallCounter(fragment1); int64_t call_counter_f2 = *TryToFindCallCounter(fragment2); int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1); int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2); const int num_meshes = GetNumMeshes(fragment1); - const int mesh_id = GetMeshIndex(fragment1); + const int stage_id = fragment1.getStageIdAttr().getInt(); bool is_wgrad_f1 = IsSplitDropTransferred(fragment1); bool is_wgrad_f2 = IsSplitDropTransferred(fragment2); @@ -234,23 +247,23 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // are interleaved in the steady state of the pipeline. They are just like // 1F1B but specialized to actual back-propagation fragments. - // Clause 1: Ba(i) < F(i + num_meshes - mesh_id) + // Clause 1: Ba(i) < F(i + num_meshes - stage_id) if (transpose_count_f1 == 1 && !is_wgrad_f1 && transpose_count_f2 == 0) { - return call_counter_f1 == call_counter_f2 - num_meshes + mesh_id; + return call_counter_f1 == call_counter_f2 - num_meshes + stage_id; } - // Clause 2: F(i + num_meshes - mesh_id - 1) < Ba(i) + // Clause 2: F(i + num_meshes - stage_id - 1) < Ba(i) if (transpose_count_f1 == 0 && transpose_count_f2 == 1 && !is_wgrad_f2) { - return call_counter_f1 == call_counter_f2 + num_meshes - (mesh_id + 1); + return call_counter_f1 == call_counter_f2 + num_meshes - (stage_id + 1); } // The rest of the conditions position the parameter gradient fragments. // Clause 3: Bw(i) < F(i + num_meshes) // e.g. Bw(0) < F(4) above. - if (transpose_count_f1 == 1 && (is_wgrad_f1 || mesh_id == 0) && + if (transpose_count_f1 == 1 && (is_wgrad_f1 || stage_id == 0) && transpose_count_f2 == 0) { return call_counter_f2 - call_counter_f1 == num_meshes; } - // Clause 4: Ba(i + mesh_id) < Bw(i) + // Clause 4: Ba(i + stage_id) < Bw(i) // e.g. // mesh0: Ba(0) < Bw(0) // mesh1: Ba(1) < Bw(0) @@ -258,15 +271,15 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // mesh3: Ba(3) < Bw(0) if (transpose_count_f1 == 1 && !is_wgrad_f1 && transpose_count_f2 == 1 && is_wgrad_f2) { - return call_counter_f1 - call_counter_f2 == mesh_id; + return call_counter_f1 - call_counter_f2 == stage_id; } // This is just needed for transitively completing Clauses 3 and 2, needed for // the final phase where there may be no remaining forward to anchor to. - // Bw(i) < Ba(i + mesh_id + 1) + // Bw(i) < Ba(i + stage_id + 1) if (transpose_count_f1 == 1 && is_wgrad_f1 && transpose_count_f2 == 1 && !is_wgrad_f2) { - return call_counter_f2 - call_counter_f1 == mesh_id + 1; + return call_counter_f2 - call_counter_f1 == stage_id + 1; } return false; @@ -274,28 +287,32 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) { // A function to calculate, for a given mesh, how many forward microbatches // need to be streamed in, before we can schedule the first backward. -using InitFwdPerMeshFn = std::function; +using InitFwdPerStageFn = std::function; bool ZeroBubbleH2MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2, - InitFwdPerMeshFn init_fwd_per_mesh) { + InitFwdPerStageFn init_fwd_per_stage) { SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2)); + if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) { + SDY_LOG(ERROR) << "Cannot schedule for ZeroBubbleH2 without stages."; + return false; + } int64_t call_counter_f1 = *TryToFindCallCounter(fragment1); int64_t call_counter_f2 = *TryToFindCallCounter(fragment2); int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1); int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2); const int num_meshes = GetNumMeshes(fragment1); - const int mesh_id = GetMeshIndex(fragment1); + const int stage_id = fragment1.getStageIdAttr().getInt(); bool is_wgrad_f1 = IsSplitDropTransferred(fragment1); bool is_wgrad_f2 = IsSplitDropTransferred(fragment2); // How many fwd we are allowed to stream before entering steady state. - int init_fwd = init_fwd_per_mesh(mesh_id); + int init_fwd = init_fwd_per_stage(stage_id); // The ZeroBubbleH2 pipeline is diagonally symmetric (replacing forward with // backwards parameter gradient) so the following quantity is also part of the // schedule invariants below. - int complement_init_fwd = init_fwd_per_mesh(num_meshes - mesh_id - 1); + int complement_init_fwd = init_fwd_per_stage(num_meshes - stage_id - 1); // Initial phase. // Clause 1: F(i) <= B(_) for i < init_fwd. @@ -362,20 +379,22 @@ bool LatencyHidingZeroBubbleH2MustHappenBefore(float latency_stage_fraction, // The `init_fwds_per_mesh` returns the e_i in the diagram above, for // every mesh_i. This it the number of forward microbatches that can execute // before the first backwards microbatch can be executed on this mesh. - auto init_fwds_per_mesh = [num_meshes, latency_stage_fraction](int mesh_id) { + auto init_fwds_per_stage = [num_meshes, + latency_stage_fraction](int stage_id) { // The number of transfers from the beginning until the first backward - // fragment can execute on mesh_id, see the diagram above. We call this the + // fragment can execute on stage_id, see the diagram above. We call this the // "initial" path of the first microbatch in the pipeline. - float num_init_transfers = 2.0f * (num_meshes - mesh_id - 1); + float num_init_transfers = 2.0f * (num_meshes - stage_id - 1); // How much compute has happened in that initial first microbatch path, i.e. - // until the point where the first backward fragment can execute on mesh_id. - // The assumption that time(fwd) == time(bwd) (NB: this is just the backprop - // bwd) may need to be revisited for real use cases. - float num_init_compute = 2.0f * (num_meshes - mesh_id) - 1.0f; + // until the point where the first backward fragment can execute on + // stage_id. The assumption that time(fwd) == time(bwd) (NB: this is just + // the backprop bwd) may need to be revisited for real use cases. + float num_init_compute = 2.0f * (num_meshes - stage_id) - 1.0f; return std::floor(num_init_compute + num_init_transfers * latency_stage_fraction); }; - return ZeroBubbleH2MustHappenBefore(fragment1, fragment2, init_fwds_per_mesh); + return ZeroBubbleH2MustHappenBefore(fragment1, fragment2, + init_fwds_per_stage); } // Returns true if `fragment1` must happen before `fragment2` in a parallel @@ -408,6 +427,11 @@ bool ParallelPipelinesWithWrapAroundMustHappenBefore(FragmentOp fragment1, // This guarantees that the transpose counts and call counts of each fragment // are defined. SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2)); + if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) { + SDY_LOG(ERROR) << "Cannot schedule for ParallelPipelinesWithWrapAround " + "without stages."; + return false; + } // Only allowed for forwards for now. SDY_CHECK(IsForwardFragment(fragment1)); SDY_CHECK(IsForwardFragment(fragment2)); @@ -417,21 +441,18 @@ bool ParallelPipelinesWithWrapAroundMustHappenBefore(FragmentOp fragment1, SDY_CHECK_NE(call_counter_f1, call_counter_f2) << "Should not have duplicate call counter."; - int64_t mesh_num = 0; - SDY_CHECK(llvm::to_integer(fragment1.getMeshName().drop_until( - [](char c) { return llvm::isDigit(c); }), - mesh_num)); - // The entrypoint to mesh{i} is call_counter {i}, so this always happens + int64_t stage_id = fragment1.getStageIdAttr().getInt(); + // The entrypoint to stage{i} is call_counter {i}, so this always happens // before. - if (call_counter_f1 == mesh_num || call_counter_f2 == mesh_num) { - return call_counter_f1 == mesh_num; + if (call_counter_f1 == stage_id || call_counter_f2 == stage_id) { + return call_counter_f1 == stage_id; } - // `mesh_num` is the pivot. If both call_counters are on the same side of + // `stage_id` is the pivot. If both call_counters are on the same side of // the pivot, we flip the order. But if they are on different // sides, then we take the order as per normal. - if ((call_counter_f1 > mesh_num && call_counter_f2 > mesh_num) || - (call_counter_f1 < mesh_num && call_counter_f2 < mesh_num)) { + if ((call_counter_f1 > stage_id && call_counter_f2 > stage_id) || + (call_counter_f1 < stage_id && call_counter_f2 < stage_id)) { return call_counter_f1 > call_counter_f2; } return call_counter_f1 < call_counter_f2;