Skip to content

Commit 6f504c5

Browse files
committed
improved shard doc
1 parent 03ffed9 commit 6f504c5

File tree

1 file changed

+77
-64
lines changed

1 file changed

+77
-64
lines changed

mlir/docs/Dialects/Shard.md

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,88 @@
11
# 'shard' Dialect
22

3-
This dialect contains a set of attributes, operations and interfaces that
4-
are useful for representing sharding of tensors and communication between
5-
devices.
3+
The 'shard' dialect defines a set of attributes, operations, and interfaces for
4+
working with tensor sharding and device communication.
65

7-
The Shard dialect was inspired by GSPMD (GSPMD: General and Scalable
8-
Parallelization for ML Computation Graphs).
6+
It’s inspired by [GSPMD](https://arxiv.org/abs/2205.05559) (*General and
7+
Scalable Parallelization for ML Computation Graphs*).
98

10-
It was originally introduced under the name 'mesh' but was later renamed
11-
to better reflect its purpose.
9+
Originally, the dialect was called `mesh`, but it was renamed to better reflect
10+
what it actually does.
1211

1312
[TOC]
1413

1514
## Collective Communication Operations
16-
There are a number of operations in the Shard dialect to facilitate
17-
communication between devices in a grid.
18-
It is assumed that the user is familiar with collective operations.
19-
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
20-
explanation.
21-
The main addition is that the collectives in this dialect have grid
22-
semantics.
23-
24-
### Device groups
25-
The operation attributes `grid` and `grid_axes` specifies a list of device grid
26-
axes that partition the devices into disjoint groups.
27-
The collective operation is performed between devices in the same group.
28-
Devices that have the same coordinates outside of axes `grid_axes` are in the
29-
same group.
30-
A group is described by its multi-index along the axes outside of `grid_axes`.
31-
For example if we have a device grid of size `2x3x4x5` and the partition grid
32-
axes list is `[0, 1]` then devices are partitioned into the groups
33-
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
34-
The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
35-
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
36-
Device (1, 0, 2, 4) will be in another group.
37-
Some collective operations like all-to-all and all-gather care about the
38-
order of devices.
39-
The order of device in a device group is induced by the order of axes in
40-
`grid_axes`.
41-
The axes are ordered from outer to inner.
42-
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
43-
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
44-
45-
### In-group Device
46-
Some operations like `broadcast`, `scatter` and `send` specify devices in each
47-
device-group.
48-
These devices are represented with their multi-index over the grid axes that
49-
are not constant within a device group.
50-
These are the axes specified by `grid_axes` attribute.
51-
52-
For Example on a 3D grid an operation with `grid_axes = [0, 2]` would specify
53-
an in-group device with `(i, j)`. Then for each group with index `g` on the
54-
second axis, the in-group device would be `(i, g, j)`.
55-
### Purity
56-
Collectives that involve the whole device group to perform a single operation
57-
are pure. The exceptions are `send` and `recv`.
58-
59-
There is an assumption that the execution is SPMD.
60-
Not only that each process runs the same program, but that at the point of
61-
execution of a collective operation, all processes are in a coherent state.
62-
All compiler transformations must be consistent.
63-
Collective operations in the IR that may correspond to the same runtime
64-
collective operation must be transformed in a consistent manner.
65-
For example if a collective operation is optimized out, than it must also
66-
not appear in any path of execution on any process.
67-
68-
Having the operations as `Pure` implies that if an interpreter is to execute
69-
the IR containing the `grid` collectives, all processes would execute the same
70-
line when they reach a pure collective operation.
71-
This requirement stems from the need to be compatible with general optimization
72-
passes like dead code and common sub-expression elimination.
15+
16+
The 'shard' dialect includes several collective operations that help coordinate
17+
communication between devices arranged in a grid.
18+
19+
If you’re not already familiar with collective operations, [this Wikipedia
20+
article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting
21+
point.
22+
23+
Unlike traditional collectives that are defined in terms of message-passing
24+
between explicit buffers on each process, the collectives in this dialect work
25+
at a higher level. They’re defined in terms of how data moves across the
26+
dimensions of a tensor, and the participating processes are inferred from how
27+
the tensor is sharded - not specified manually.
28+
29+
### Device Groups
30+
31+
Each collective operation runs within a group of devices. You define groups
32+
using the `grid` and `grid_axes` attributes, which describe how to slice the
33+
full device grid into smaller groups.
34+
35+
Devices that have the same coordinates *outside* the listed `grid_axes` belong
36+
to the same group.
37+
38+
Example: Say your device grid is shaped `2×3×4×5`, and you set
39+
`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:
40+
41+
```
42+
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
43+
```
44+
45+
So the groups are identified by the coordinates `(k, m)`, and devices like
46+
`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)`
47+
is in a different group.
48+
49+
For some collectives (like `all-to-all`), the order of devices in the group
50+
matters. The device order is based on the order of axes in `grid_axes`, from
51+
outermost to innermost.
52+
53+
Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before
54+
`(i, 0, k, 1)` and `(i, 2, k, 0)`.
55+
56+
### In-group Devices
57+
58+
Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific
59+
device within each group. These in-group devices are identified using their
60+
coordinates over the axes listed in `grid_axes`.
61+
62+
Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified
63+
as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full
64+
device index would be `(i, g, j)`.
65+
66+
### Purity and Execution Model
67+
68+
Collective operations involve all devices in a group (e.g. `all-gather`,
69+
`all-to-all`) and are considered pure. Operations like `send` and `recv` are not
70+
collective and are not pure.
71+
72+
The execution model assumes SPMD (Single Program, Multiple Data):
73+
74+
* Every process runs the same program.
75+
* At any collective operation, all processes are in sync.
76+
77+
This means compiler optimizations must treat collective ops carefully. For
78+
example, if a collective is removed during optimization, it must be removed from
79+
*every* path and *every* process that would have participated - otherwise, you’ll
80+
get undefined behavior at runtime.
81+
82+
Marking these ops as pure also helps with standard compiler passes like dead
83+
code elimination and common subexpression elimination. It ensures that when the
84+
program is executed, all devices hit the same line of code at the same time
85+
during collectives and so avoid dead-locks.
7386

7487
## Operations
7588

0 commit comments

Comments
 (0)