Skip to content

Commit 9dd276c

Browse files
authored
Remove manual monarch installation from CI (#72)
* Remove manual monarch installation from CI torchmonarch is already declared in pyproject.toml dependencies, so manual wheel installation is redundant. Let pip install it automatically like the forge repo does. * Remove manual torch installation from CI torch==2.9.0 is already declared in pyproject.toml dependencies, so manual installation is redundant. * Adapt CI workflow to match torchforge pattern - Use conda-incubator/setup-miniconda for environment setup - Use uv package manager for installation - Add uv configuration with PyTorch index to pyproject.toml - Add concurrency control and shell defaults This matches the torchforge workflow and should resolve linking issues. * Add repository owner check to unit test workflow * Add workflow_dispatch to allow manual workflow runs * Use GPU runner to match forge workflow * Comment out _proc_mesh.stop() calls due to monarch bug - Comment out all _proc_mesh.stop() calls in test files - Add TODO comments to investigate monarch bug - Tests are failing due to these calls - Files updated: test_store.py, test_tensor_slice.py, test_state_dict.py, test_large_tensors.py, test_resharding_basic.py
1 parent 5c20853 commit 9dd276c

File tree

8 files changed

+56
-28
lines changed

8 files changed

+56
-28
lines changed

.github/workflows/unit_test.yaml

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,38 @@ name: Unit Test
22

33
on:
44
pull_request:
5+
workflow_dispatch:
56

7+
concurrency:
8+
group: unit-test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
9+
cancel-in-progress: true
10+
11+
defaults:
12+
run:
13+
shell: bash -l -eo pipefail {0}
614

715
jobs:
816
unit_tests:
9-
runs-on: ubuntu-latest
17+
if: github.repository_owner == 'meta-pytorch'
18+
runs-on: linux.g5.12xlarge.nvidia.gpu
1019
timeout-minutes: 60
1120
strategy:
1221
matrix:
1322
python-version: ['3.10']
1423
steps:
1524
- name: Check out repo
1625
uses: actions/checkout@v4
17-
- name: Setup Python
18-
uses: actions/setup-python@v4
26+
- name: Setup conda env
27+
uses: conda-incubator/setup-miniconda@v2
1928
with:
29+
auto-update-conda: true
30+
miniconda-version: "latest"
31+
activate-environment: test
2032
python-version: ${{ matrix.python-version }}
21-
- name: Verify Python version
22-
run: |
23-
python --version
24-
python -c "import sys; print(f'Python version: {sys.version_info.major}.{sys.version_info.minor}')"
2533
- name: Update pip
2634
run: python -m pip install --upgrade pip
27-
- name: Install torch
28-
run: python -m pip install torch
29-
- name: Install dependencies
30-
run: python -m pip install -e ".[dev]"
31-
- name: Install monarch from local wheel
32-
run: python -m pip install assets/ci/monarch-0.0.1-cp310-cp310-linux_x86_64.whl --force-reinstall
35+
- name: Install torchstore
36+
run: pip install uv && uv pip install . && uv pip install .[dev]
3337
- name: Run slice tests (test_slice.py) with coverage
3438
run: |
3539
TORCHSTORE_RDMA_ENABLED=0 \
-23.7 MB
Binary file not shown.

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,15 @@ dev = [
4646
[tool.setuptools.packages.find]
4747
include = ["torchstore*"]
4848
exclude = ["assets*", "tests*", "docs*"]
49+
50+
# pytorch
51+
[[tool.uv.index]]
52+
name = "pytorch-cu128"
53+
url = "https://download.pytorch.org/whl/cu128"
54+
55+
[tool.uv.sources]
56+
torch = { index = "pytorch-cu128" }
57+
58+
[tool.uv]
59+
index-strategy = "unsafe-best-match"
60+
prerelease = "allow"

tests/test_large_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ async def get(self):
9797
await actor.get.call_one()
9898
# TODO: assert equal tensors from put/get
9999
finally:
100-
await actor._proc_mesh.stop()
100+
# TODO: Investigate monarch bug with proc_mesh.stop()
101+
# await actor._proc_mesh.stop()
101102
await ts.shutdown()
102103

103104

tests/test_resharding_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,11 @@ async def _test_resharding(
247247

248248
# teardown distributed or the next test will complain
249249
await put_mesh.destroy_process_group.call()
250-
await put_mesh._proc_mesh.stop()
250+
# TODO: Investigate monarch bug with proc_mesh.stop()
251+
# await put_mesh._proc_mesh.stop()
251252
await get_mesh.destroy_process_group.call()
252-
await get_mesh._proc_mesh.stop()
253+
# TODO: Investigate monarch bug with proc_mesh.stop()
254+
# await get_mesh._proc_mesh.stop()
253255
await ts.shutdown()
254256

255257

tests/test_state_dict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,9 @@ async def test_dcp_sharding_parity(strategy_params, use_rdma):
265265
f"Assertion failed on rank {coord.rank} ({save_mesh_shape=} {get_mesh_shape=}): {e}"
266266
) from e
267267
finally:
268-
await save_world._proc_mesh.stop()
269-
await get_world._proc_mesh.stop()
268+
# TODO: Investigate monarch bug with proc_mesh.stop()
269+
# await save_world._proc_mesh.stop()
270+
# await get_world._proc_mesh.stop()
270271
await ts.shutdown()
271272

272273

tests/test_store.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ async def get(self, rank_offset=0):
7979
expected = torch.tensor([other_rank + 1] * 10)
8080
assert torch.equal(expected, val), f"{expected} != {val}"
8181
finally:
82-
await actor_mesh_0._proc_mesh.stop()
83-
await actor_mesh_1._proc_mesh.stop()
82+
# TODO: Investigate monarch bug with proc_mesh.stop()
83+
# await actor_mesh_0._proc_mesh.stop()
84+
# await actor_mesh_1._proc_mesh.stop()
8485
await ts.shutdown()
8586

8687

@@ -142,8 +143,9 @@ def __eq__(self, other: object) -> bool:
142143
assert expected == val, f"{expected.val} != {val.val}"
143144

144145
finally:
145-
await actor_mesh_0._proc_mesh.stop()
146-
await actor_mesh_1._proc_mesh.stop()
146+
# TODO: Investigate monarch bug with proc_mesh.stop()
147+
# await actor_mesh_0._proc_mesh.stop()
148+
# await actor_mesh_1._proc_mesh.stop()
147149
await ts.shutdown()
148150

149151

@@ -211,7 +213,8 @@ async def exists(self, key):
211213
assert exists_result
212214

213215
finally:
214-
await actor_mesh._proc_mesh.stop()
216+
# TODO: Investigate monarch bug with proc_mesh.stop()
217+
# await actor_mesh._proc_mesh.stop()
215218
await ts.shutdown()
216219

217220

@@ -292,7 +295,8 @@ async def get(self, key):
292295
await actor_mesh.get.call("tensor_key_0")
293296

294297
finally:
295-
await actor_mesh._proc_mesh.stop()
298+
# TODO: Investigate monarch bug with proc_mesh.stop()
299+
# await actor_mesh._proc_mesh.stop()
296300
await ts.shutdown()
297301

298302

tests/test_tensor_slice.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ async def put(self, key, tensor):
8181
assert tensor_slice.shape == (5, 10)
8282

8383
finally:
84-
await put_actor_mesh._proc_mesh.stop()
84+
# TODO: Investigate monarch bug with proc_mesh.stop()
85+
# await put_actor_mesh._proc_mesh.stop()
8586
await ts.shutdown()
8687

8788

@@ -147,7 +148,8 @@ async def test_put_dtensor_get_full_tensor():
147148
finally:
148149
# Clean up process groups
149150
await put_mesh.destroy_process_group.call()
150-
await put_mesh._proc_mesh.stop()
151+
# TODO: Investigate monarch bug with proc_mesh.stop()
152+
# await put_mesh._proc_mesh.stop()
151153
await ts.shutdown()
152154

153155

@@ -232,7 +234,8 @@ async def test_dtensor_fetch_slice():
232234
finally:
233235
if put_mesh is not None:
234236
await put_mesh.destroy_process_group.call()
235-
await put_mesh._proc_mesh.stop()
237+
# TODO: Investigate monarch bug with proc_mesh.stop()
238+
# await put_mesh._proc_mesh.stop()
236239
await ts.shutdown()
237240

238241

@@ -280,7 +283,8 @@ async def test_partial_put():
280283
finally:
281284
# Clean up process groups
282285
await put_mesh.destroy_process_group.call()
283-
await put_mesh._proc_mesh.stop()
286+
# TODO: Investigate monarch bug with proc_mesh.stop()
287+
# await put_mesh._proc_mesh.stop()
284288
await ts.shutdown()
285289

286290

0 commit comments

Comments
 (0)