Skip to content

Commit d397aa6

Browse files
authored
Sample PR to fix CI (#182)
stack-info: PR: #182, branch: xmfan/stack/11
1 parent 01a6538 commit d397aa6

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

.github/workflows/test_cuda.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ jobs:
3030
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3131
gpu-arch-version: ${{ matrix.gpu-arch-version }}
3232
submodules: recursive
33-
python-version: "3.12"
3433
script: |
34+
conda create --yes --quiet --name py312 python=3.12
35+
source $(conda info --base)/etc/profile.d/conda.sh
36+
conda activate py312
37+
3538
pip install --quiet -r requirements-test.txt
3639
# For some reason the spec above isnt working
3740
pip uninstall -y torch

.github/workflows/test_torchtitan.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ jobs:
3030
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3131
gpu-arch-version: ${{ matrix.gpu-arch-version }}
3232
submodules: recursive
33-
python-version: "3.12"
3433
script: |
34+
conda create --yes --quiet --name py312 python=3.12
35+
source $(conda info --base)/etc/profile.d/conda.sh
36+
conda activate py312
37+
3538
pip install --quiet -r requirements-test.txt
3639
# For some reason the spec above isnt working
3740
pip uninstall -y torch

autoparallel/optimize_sharding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ def build_sharding_metadata(self):
177177
assert (
178178
local_map_kwargs.get("in_grad_placements", None) is None
179179
), "Not yet implemented"
180-
assert (
181-
local_map_kwargs.get("device_mesh", None) is None
182-
), "Must be provided by Autoparallel"
183180
assert not user_kwargs
184181
# TODO: get rid of this when HOP can install as a subgraph
185182
assert "call_local_map" in str(

examples/example_local_map.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@
1515

1616
from autoparallel.api import AutoParallel
1717

18+
world_size = 256
19+
20+
fake_store = FakeStore()
21+
torch.distributed.init_process_group(
22+
"fake", store=fake_store, rank=0, world_size=world_size
23+
)
24+
mesh = torch.distributed.device_mesh.init_device_mesh(
25+
"cuda",
26+
(world_size // 32, 8, 4),
27+
mesh_dim_names=(
28+
"dp",
29+
"tp",
30+
"cp",
31+
),
32+
)
33+
assert mesh.ndim == 3, "Please also update local_map"
34+
1835

1936
def policy_fn(ctx, op, *args, **kwargs):
2037
if (
@@ -37,7 +54,7 @@ def policy_fn(ctx, op, *args, **kwargs):
3754
),
3855
redistribute_inputs=True,
3956
in_grad_placements=None,
40-
device_mesh=None,
57+
device_mesh=mesh,
4158
)
4259
def replicate_linear(w, x):
4360
return torch.matmul(x, w.t())
@@ -54,7 +71,7 @@ def replicate_linear(w, x):
5471
),
5572
redistribute_inputs=True,
5673
in_grad_placements=None,
57-
device_mesh=None,
74+
device_mesh=mesh,
5875
)
5976
def sharded_pointwise(x, scalar):
6077
return x + scalar, scalar
@@ -69,7 +86,7 @@ def sharded_pointwise(x, scalar):
6986
),
7087
redistribute_inputs=True,
7188
in_grad_placements=None,
72-
device_mesh=None,
89+
device_mesh=mesh,
7390
)
7491
def context_parallel_attention(query, key, value):
7592
out = nn.functional.scaled_dot_product_attention(
@@ -128,22 +145,6 @@ def forward(self, x):
128145
return o
129146

130147

131-
world_size = 256
132-
133-
fake_store = FakeStore()
134-
torch.distributed.init_process_group(
135-
"fake", store=fake_store, rank=0, world_size=world_size
136-
)
137-
mesh = torch.distributed.device_mesh.init_device_mesh(
138-
"cuda",
139-
(world_size // 32, 8, 4),
140-
mesh_dim_names=(
141-
"dp",
142-
"tp",
143-
"cp",
144-
),
145-
)
146-
147148
bs = 8 * mesh.shape[0]
148149
seq_len = 256
149150
nheads = 48

0 commit comments

Comments
 (0)