Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 59 additions & 38 deletions shardy/dialect/mpmd/transforms/optimize/pipeline_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,17 @@ bool OneFOneBMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) {
// This guarantees that the transpose counts and call counts of each fragment
// are defined.
SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2));
if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) {
SDY_LOG(ERROR) << "Cannot schedule for 1F1B pipelining without stages.";
return false;
}
int64_t call_counter_f1 = *TryToFindCallCounter(fragment1);
int64_t call_counter_f2 = *TryToFindCallCounter(fragment2);
int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1);
int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2);

const int num_meshes = GetNumMeshes(fragment1);
const int mesh_id = GetMeshIndex(fragment1);
const int stage_id = fragment1.getStageIdAttr().getInt();

// The following two conditions guarantee the forward and backward fragments
// are interleaved in the steady state of the pipeline.
Expand All @@ -148,14 +152,14 @@ bool OneFOneBMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) {
// of microbatch 0 must be scheduled before the forward computation of
// microbatch 4: 0 == 4 - 4 + 0.
if (transpose_count_f1 == 1 && transpose_count_f2 == 0) {
return call_counter_f1 == call_counter_f2 - num_meshes + mesh_id;
return call_counter_f1 == call_counter_f2 - num_meshes + stage_id;
}

// Example: in mesh/stage 0 of pipeline of depth 4, the forward computation of
// microbatch 5 must be scheduled before the backward computation of
// microbatch 2: 5 == 2 + 4 - (0 + 1).
if (transpose_count_f1 == 0 && transpose_count_f2 == 1) {
return call_counter_f1 == call_counter_f2 + num_meshes - (mesh_id + 1);
return call_counter_f1 == call_counter_f2 + num_meshes - (stage_id + 1);
}

// If the fragments have the same transpose count, guarantee that the
Expand Down Expand Up @@ -197,9 +201,14 @@ bool GPipeMustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) {
// Requires: IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2).
bool GPipeBut1F1BLastMeshHappenBefore(FragmentOp fragment1,
FragmentOp fragment2) {
if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) {
SDY_LOG(ERROR)
<< "Cannot schedule for GPipeBut1F1BForLastMesh without stages.";
return false;
}
const int num_meshes = GetNumMeshes(fragment1);
const int mesh_id = GetMeshIndex(fragment1);
if (mesh_id == num_meshes - 1) {
const int stage_id = fragment1.getStageIdAttr().getInt();
if (stage_id == num_meshes - 1) {
return OneFOneBMustHappenBefore(fragment1, fragment2);
}
return GPipeMustHappenBefore(fragment1, fragment2);
Expand All @@ -219,13 +228,17 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) {
// This guarantees that the transpose counts and call counts of each fragment
// are defined.
SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2));
if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) {
SDY_LOG(ERROR) << "Cannot schedule for ZeroBubbleH1 without stages.";
return false;
}
int64_t call_counter_f1 = *TryToFindCallCounter(fragment1);
int64_t call_counter_f2 = *TryToFindCallCounter(fragment2);
int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1);
int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2);

const int num_meshes = GetNumMeshes(fragment1);
const int mesh_id = GetMeshIndex(fragment1);
const int stage_id = fragment1.getStageIdAttr().getInt();

bool is_wgrad_f1 = IsSplitDropTransferred(fragment1);
bool is_wgrad_f2 = IsSplitDropTransferred(fragment2);
Expand All @@ -234,68 +247,72 @@ bool ZeroBubbleH1MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2) {
// are interleaved in the steady state of the pipeline. They are just like
// 1F1B but specialized to actual back-propagation fragments.

// Clause 1: Ba(i) < F(i + num_meshes - mesh_id)
// Clause 1: Ba(i) < F(i + num_meshes - stage_id)
if (transpose_count_f1 == 1 && !is_wgrad_f1 && transpose_count_f2 == 0) {
return call_counter_f1 == call_counter_f2 - num_meshes + mesh_id;
return call_counter_f1 == call_counter_f2 - num_meshes + stage_id;
}
// Clause 2: F(i + num_meshes - mesh_id - 1) < Ba(i)
// Clause 2: F(i + num_meshes - stage_id - 1) < Ba(i)
if (transpose_count_f1 == 0 && transpose_count_f2 == 1 && !is_wgrad_f2) {
return call_counter_f1 == call_counter_f2 + num_meshes - (mesh_id + 1);
return call_counter_f1 == call_counter_f2 + num_meshes - (stage_id + 1);
}

// The rest of the conditions position the parameter gradient fragments.
// Clause 3: Bw(i) < F(i + num_meshes)
// e.g. Bw(0) < F(4) above.
if (transpose_count_f1 == 1 && (is_wgrad_f1 || mesh_id == 0) &&
if (transpose_count_f1 == 1 && (is_wgrad_f1 || stage_id == 0) &&
transpose_count_f2 == 0) {
return call_counter_f2 - call_counter_f1 == num_meshes;
}
// Clause 4: Ba(i + mesh_id) < Bw(i)
// Clause 4: Ba(i + stage_id) < Bw(i)
// e.g.
// mesh0: Ba(0) < Bw(0)
// mesh1: Ba(1) < Bw(0)
// mesh2: Ba(2) < Bw(0)
// mesh3: Ba(3) < Bw(0)
if (transpose_count_f1 == 1 && !is_wgrad_f1 && transpose_count_f2 == 1 &&
is_wgrad_f2) {
return call_counter_f1 - call_counter_f2 == mesh_id;
return call_counter_f1 - call_counter_f2 == stage_id;
}

// This is just needed for transitively completing Clauses 3 and 2, needed for
// the final phase where there may be no remaining forward to anchor to.
// Bw(i) < Ba(i + mesh_id + 1)
// Bw(i) < Ba(i + stage_id + 1)
if (transpose_count_f1 == 1 && is_wgrad_f1 && transpose_count_f2 == 1 &&
!is_wgrad_f2) {
return call_counter_f2 - call_counter_f1 == mesh_id + 1;
return call_counter_f2 - call_counter_f1 == stage_id + 1;
}

return false;
}

// A function to calculate, for a given mesh, how many forward microbatches
// need to be streamed in, before we can schedule the first backward.
using InitFwdPerMeshFn = std::function<int(int)>;
using InitFwdPerStageFn = std::function<int(int)>;

bool ZeroBubbleH2MustHappenBefore(FragmentOp fragment1, FragmentOp fragment2,
InitFwdPerMeshFn init_fwd_per_mesh) {
InitFwdPerStageFn init_fwd_per_stage) {
SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2));
if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) {
SDY_LOG(ERROR) << "Cannot schedule for ZeroBubbleH2 without stages.";
return false;
}
int64_t call_counter_f1 = *TryToFindCallCounter(fragment1);
int64_t call_counter_f2 = *TryToFindCallCounter(fragment2);
int64_t transpose_count_f1 = *TryToFindSingleTransposeCount(fragment1);
int64_t transpose_count_f2 = *TryToFindSingleTransposeCount(fragment2);

const int num_meshes = GetNumMeshes(fragment1);
const int mesh_id = GetMeshIndex(fragment1);
const int stage_id = fragment1.getStageIdAttr().getInt();

bool is_wgrad_f1 = IsSplitDropTransferred(fragment1);
bool is_wgrad_f2 = IsSplitDropTransferred(fragment2);

// How many fwd we are allowed to stream before entering steady state.
int init_fwd = init_fwd_per_mesh(mesh_id);
int init_fwd = init_fwd_per_stage(stage_id);
// The ZeroBubbleH2 pipeline is diagonally symmetric (replacing forward with
// backwards parameter gradient) so the following quantity is also part of the
// schedule invariants below.
int complement_init_fwd = init_fwd_per_mesh(num_meshes - mesh_id - 1);
int complement_init_fwd = init_fwd_per_stage(num_meshes - stage_id - 1);

// Initial phase.
// Clause 1: F(i) <= B(_) for i < init_fwd.
Expand Down Expand Up @@ -362,20 +379,22 @@ bool LatencyHidingZeroBubbleH2MustHappenBefore(float latency_stage_fraction,
// The `init_fwds_per_mesh` returns the e_i in the diagram above, for
// every mesh_i. This it the number of forward microbatches that can execute
// before the first backwards microbatch can be executed on this mesh.
auto init_fwds_per_mesh = [num_meshes, latency_stage_fraction](int mesh_id) {
auto init_fwds_per_stage = [num_meshes,
latency_stage_fraction](int stage_id) {
// The number of transfers from the beginning until the first backward
// fragment can execute on mesh_id, see the diagram above. We call this the
// fragment can execute on stage_id, see the diagram above. We call this the
// "initial" path of the first microbatch in the pipeline.
float num_init_transfers = 2.0f * (num_meshes - mesh_id - 1);
float num_init_transfers = 2.0f * (num_meshes - stage_id - 1);
// How much compute has happened in that initial first microbatch path, i.e.
// until the point where the first backward fragment can execute on mesh_id.
// The assumption that time(fwd) == time(bwd) (NB: this is just the backprop
// bwd) may need to be revisited for real use cases.
float num_init_compute = 2.0f * (num_meshes - mesh_id) - 1.0f;
// until the point where the first backward fragment can execute on
// stage_id. The assumption that time(fwd) == time(bwd) (NB: this is just
// the backprop bwd) may need to be revisited for real use cases.
float num_init_compute = 2.0f * (num_meshes - stage_id) - 1.0f;
return std::floor(num_init_compute +
num_init_transfers * latency_stage_fraction);
};
return ZeroBubbleH2MustHappenBefore(fragment1, fragment2, init_fwds_per_mesh);
return ZeroBubbleH2MustHappenBefore(fragment1, fragment2,
init_fwds_per_stage);
}

// Returns true if `fragment1` must happen before `fragment2` in a parallel
Expand Down Expand Up @@ -408,6 +427,11 @@ bool ParallelPipelinesWithWrapAroundMustHappenBefore(FragmentOp fragment1,
// This guarantees that the transpose counts and call counts of each fragment
// are defined.
SDY_CHECK(IsSchedulingUnit(fragment1) && IsSchedulingUnit(fragment2));
if (!fragment1.getStageIdAttr() || !fragment2.getStageIdAttr()) {
SDY_LOG(ERROR) << "Cannot schedule for ParallelPipelinesWithWrapAround "
"without stages.";
return false;
}
// Only allowed for forwards for now.
SDY_CHECK(IsForwardFragment(fragment1));
SDY_CHECK(IsForwardFragment(fragment2));
Expand All @@ -417,21 +441,18 @@ bool ParallelPipelinesWithWrapAroundMustHappenBefore(FragmentOp fragment1,
SDY_CHECK_NE(call_counter_f1, call_counter_f2)
<< "Should not have duplicate call counter.";

int64_t mesh_num = 0;
SDY_CHECK(llvm::to_integer(fragment1.getMeshName().drop_until(
[](char c) { return llvm::isDigit(c); }),
mesh_num));
// The entrypoint to mesh{i} is call_counter {i}, so this always happens
int64_t stage_id = fragment1.getStageIdAttr().getInt();
// The entrypoint to stage{i} is call_counter {i}, so this always happens
// before.
if (call_counter_f1 == mesh_num || call_counter_f2 == mesh_num) {
return call_counter_f1 == mesh_num;
if (call_counter_f1 == stage_id || call_counter_f2 == stage_id) {
return call_counter_f1 == stage_id;
}

// `mesh_num` is the pivot. If both call_counters are on the same side of
// `stage_id` is the pivot. If both call_counters are on the same side of
// the pivot, we flip the order. But if they are on different
// sides, then we take the order as per normal.
if ((call_counter_f1 > mesh_num && call_counter_f2 > mesh_num) ||
(call_counter_f1 < mesh_num && call_counter_f2 < mesh_num)) {
if ((call_counter_f1 > stage_id && call_counter_f2 > stage_id) ||
(call_counter_f1 < stage_id && call_counter_f2 < stage_id)) {
return call_counter_f1 > call_counter_f2;
}
return call_counter_f1 < call_counter_f2;
Expand Down
Loading