|
1 | 1 | # 'shard' Dialect
|
2 | 2 |
|
3 |
| -This dialect contains a set of attributes, operations and interfaces that |
4 |
| -are useful for representing sharding of tensors and communication between |
5 |
| -devices. |
| 3 | +The 'shard' dialect defines a set of attributes, operations, and interfaces for |
| 4 | +working with tensor sharding and device communication. |
6 | 5 |
|
7 |
| -The Shard dialect was inspired by GSPMD (GSPMD: General and Scalable |
8 |
| -Parallelization for ML Computation Graphs). |
| 6 | +It’s inspired by [GSPMD](https://arxiv.org/abs/2205.05559) (*General and |
| 7 | +Scalable Parallelization for ML Computation Graphs*). |
9 | 8 |
|
10 |
| -It was originally introduced under the name 'mesh' but was later renamed |
11 |
| -to better reflect its purpose. |
| 9 | +Originally, the dialect was called `mesh`, but it was renamed to better reflect |
| 10 | +what it actually does. |
12 | 11 |
|
13 | 12 | [TOC]
|
14 | 13 |
|
15 | 14 | ## Collective Communication Operations
|
16 |
| -There are a number of operations in the Shard dialect to facilitate |
17 |
| -communication between devices in a grid. |
18 |
| -It is assumed that the user is familiar with collective operations. |
19 |
| -[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good |
20 |
| -explanation. |
21 |
| -The main addition is that the collectives in this dialect have grid |
22 |
| -semantics. |
23 |
| - |
24 |
| -### Device groups |
25 |
| -The operation attributes `grid` and `grid_axes` specifies a list of device grid |
26 |
| -axes that partition the devices into disjoint groups. |
27 |
| -The collective operation is performed between devices in the same group. |
28 |
| -Devices that have the same coordinates outside of axes `grid_axes` are in the |
29 |
| -same group. |
30 |
| -A group is described by its multi-index along the axes outside of `grid_axes`. |
31 |
| -For example if we have a device grid of size `2x3x4x5` and the partition grid |
32 |
| -axes list is `[0, 1]` then devices are partitioned into the groups |
33 |
| -`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`. |
34 |
| -The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`. |
35 |
| -Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. |
36 |
| -Device (1, 0, 2, 4) will be in another group. |
37 |
| -Some collective operations like all-to-all and all-gather care about the |
38 |
| -order of devices. |
39 |
| -The order of device in a device group is induced by the order of axes in |
40 |
| -`grid_axes`. |
41 |
| -The axes are ordered from outer to inner. |
42 |
| -If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede |
43 |
| -both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`. |
44 |
| - |
45 |
| -### In-group Device |
46 |
| -Some operations like `broadcast`, `scatter` and `send` specify devices in each |
47 |
| -device-group. |
48 |
| -These devices are represented with their multi-index over the grid axes that |
49 |
| -are not constant within a device group. |
50 |
| -These are the axes specified by `grid_axes` attribute. |
51 |
| - |
52 |
| -For Example on a 3D grid an operation with `grid_axes = [0, 2]` would specify |
53 |
| -an in-group device with `(i, j)`. Then for each group with index `g` on the |
54 |
| -second axis, the in-group device would be `(i, g, j)`. |
55 |
| -### Purity |
56 |
| -Collectives that involve the whole device group to perform a single operation |
57 |
| -are pure. The exceptions are `send` and `recv`. |
58 |
| - |
59 |
| -There is an assumption that the execution is SPMD. |
60 |
| -Not only that each process runs the same program, but that at the point of |
61 |
| -execution of a collective operation, all processes are in a coherent state. |
62 |
| -All compiler transformations must be consistent. |
63 |
| -Collective operations in the IR that may correspond to the same runtime |
64 |
| -collective operation must be transformed in a consistent manner. |
65 |
| -For example if a collective operation is optimized out, than it must also |
66 |
| -not appear in any path of execution on any process. |
67 |
| - |
68 |
| -Having the operations as `Pure` implies that if an interpreter is to execute |
69 |
| -the IR containing the `grid` collectives, all processes would execute the same |
70 |
| -line when they reach a pure collective operation. |
71 |
| -This requirement stems from the need to be compatible with general optimization |
72 |
| -passes like dead code and common sub-expression elimination. |
| 15 | + |
| 16 | +The 'shard' dialect includes several collective operations that help coordinate |
| 17 | +communication between devices arranged in a grid. |
| 18 | + |
| 19 | +If you’re not already familiar with collective operations, [this Wikipedia |
| 20 | +article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting |
| 21 | +point. |
| 22 | + |
| 23 | +Unlike traditional collectives that are defined in terms of message-passing |
| 24 | +between explicit buffers on each process, the collectives in this dialect work |
| 25 | +at a higher level. They’re defined in terms of how data moves across the |
| 26 | +dimensions of a tensor, and the participating processes are inferred from how |
| 27 | +the tensor is sharded - not specified manually. |
| 28 | + |
| 29 | +### Device Groups |
| 30 | + |
| 31 | +Each collective operation runs within a group of devices. You define groups |
| 32 | +using the `grid` and `grid_axes` attributes, which describe how to slice the |
| 33 | +full device grid into smaller groups. |
| 34 | + |
| 35 | +Devices that have the same coordinates *outside* the listed `grid_axes` belong |
| 36 | +to the same group. |
| 37 | + |
| 38 | +Example: Say your device grid is shaped `2×3×4×5`, and you set |
| 39 | +`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like: |
| 40 | + |
| 41 | +``` |
| 42 | +{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 } |
| 43 | +``` |
| 44 | + |
| 45 | +So the groups are identified by the coordinates `(k, m)`, and devices like |
| 46 | +`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)` |
| 47 | +is in a different group. |
| 48 | + |
| 49 | +For some collectives (like `all-to-all`), the order of devices in the group |
| 50 | +matters. The device order is based on the order of axes in `grid_axes`, from |
| 51 | +outermost to innermost. |
| 52 | + |
| 53 | +Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before |
| 54 | +`(i, 0, k, 1)` and `(i, 2, k, 0)`. |
| 55 | + |
| 56 | +### In-group Devices |
| 57 | + |
| 58 | +Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific |
| 59 | +device within each group. These in-group devices are identified using their |
| 60 | +coordinates over the axes listed in `grid_axes`. |
| 61 | + |
| 62 | +Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified |
| 63 | +as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full |
| 64 | +device index would be `(i, g, j)`. |
| 65 | + |
| 66 | +### Purity and Execution Model |
| 67 | + |
| 68 | +Collective operations involve all devices in a group (e.g. `all-gather`, |
| 69 | +`all-to-all`) and are considered pure. Operations like `send` and `recv` are not |
| 70 | +collective and are not pure. |
| 71 | + |
| 72 | +The execution model assumes SPMD (Single Program, Multiple Data): |
| 73 | + |
| 74 | +* Every process runs the same program. |
| 75 | +* At any collective operation, all processes are in sync. |
| 76 | + |
| 77 | +This means compiler optimizations must treat collective ops carefully. For |
| 78 | +example, if a collective is removed during optimization, it must be removed from |
| 79 | +*every* path and *every* process that would have participated - otherwise, you’ll |
| 80 | +get undefined behavior at runtime. |
| 81 | + |
| 82 | +Marking these ops as pure also helps with standard compiler passes like dead |
| 83 | +code elimination and common subexpression elimination. It ensures that when the |
| 84 | +program is executed, all devices hit the same line of code at the same time |
| 85 | +during collectives and so avoid dead-locks. |
73 | 86 |
|
74 | 87 | ## Operations
|
75 | 88 |
|
|
0 commit comments