Skip to content

Commit ec0fa31

Browse files
casteryhfacebook-github-bot
authored andcommitted
fix multiple tests. with latest monarch and rdma, ts.put and ts.get only works inside an actor (#75)
Summary: Pull Request resolved: #75 Reviewed By: LucasLLC, amirafzali Differential Revision: D86156904
1 parent a19c7cd commit ec0fa31

File tree

9 files changed

+139
-76
lines changed

9 files changed

+139
-76
lines changed

tests/test_keys.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import pytest
1212
import torch
13-
1413
import torchstore as ts
1514
from monarch.actor import Actor, current_rank, endpoint
1615
from pytest_unordered import unordered
@@ -21,24 +20,36 @@
2120
@pytest.mark.asyncio
2221
async def test_keys_basic():
2322
"""Test basic put/get functionality"""
23+
24+
class TestActor(Actor):
25+
@endpoint
26+
async def test(self) -> Exception or None:
27+
try:
28+
await ts.put("", torch.tensor([1, 2, 3]))
29+
await ts.put(".x", torch.tensor([1, 2, 3]))
30+
await ts.put("v0.x", torch.tensor([1, 2, 3]))
31+
await ts.put("v0.y", torch.tensor([4, 5, 6]))
32+
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
33+
await ts.put("v1.x", torch.tensor([7, 8, 9]))
34+
await ts.put("v1.y", torch.tensor([10, 11, 12]))
35+
36+
assert await ts.keys() == unordered(
37+
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
38+
)
39+
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
40+
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
41+
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
42+
assert await ts.keys("") == unordered(["", ".x"])
43+
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
44+
except Exception as e:
45+
return e
46+
2447
await ts.initialize()
2548

26-
await ts.put("", torch.tensor([1, 2, 3]))
27-
await ts.put(".x", torch.tensor([1, 2, 3]))
28-
await ts.put("v0.x", torch.tensor([1, 2, 3]))
29-
await ts.put("v0.y", torch.tensor([4, 5, 6]))
30-
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
31-
await ts.put("v1.x", torch.tensor([7, 8, 9]))
32-
await ts.put("v1.y", torch.tensor([10, 11, 12]))
49+
actor = await spawn_actors(1, TestActor, "actor_0")
50+
err = await actor.test.call_one()
3351

34-
assert await ts.keys() == unordered(
35-
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
36-
)
37-
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
38-
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
39-
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
40-
assert await ts.keys("") == unordered(["", ".x"])
41-
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
52+
assert err is None
4253

4354
await ts.shutdown()
4455

tests/test_large_tensors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import pytest
1212
import torch
13-
1413
import torchstore as ts
1514
from monarch.actor import Actor, endpoint
1615
from torchstore.logging import init_logging

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import pytest
1414
import torch
15-
1615
import torchstore as ts
1716
from monarch.actor import Actor, current_rank, endpoint
1817
from torch.distributed.device_mesh import init_device_mesh

tests/test_resharding_basic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
from typing import List, Tuple, Union
1212

1313
import pytest
14-
1514
import torch
16-
1715
import torchstore as ts
18-
1916
from torch.distributed._tensor import Replicate, Shard
2017
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
2118
from torchstore.utils import get_local_tensor, spawn_actors

tests/test_resharding_ext.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
from logging import getLogger
88

99
import pytest
10-
1110
from torch.distributed._tensor import Shard
1211

1312
from .test_resharding_basic import _test_resharding
14-
1513
from .utils import main, transport_plus_strategy_params
1614

1715
logger = getLogger(__name__)

tests/test_state_dict.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
from typing import Union
1313

1414
import pytest
15-
1615
import torch
1716
import torch.distributed.checkpoint as dcp
1817
import torch.nn as nn
19-
2018
import torchstore as ts
21-
2219
from monarch.actor import Actor, current_rank, endpoint
2320
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
2421
from torch.distributed.checkpoint.state_dict import (
@@ -279,7 +276,6 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
279276
flattened_state_dict_2
280277
), f"{flattened_state_dict_1.keys()=}\n{flattened_state_dict_2.keys()=}"
281278
for key in flattened_state_dict_1:
282-
283279
assert key in flattened_state_dict_2
284280
if isinstance(flattened_state_dict_1[key], torch.Tensor):
285281
t1, t2 = flattened_state_dict_1[key], flattened_state_dict_2[key]

tests/test_store.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
from logging import getLogger
99

1010
import pytest
11-
1211
import torch
13-
1412
import torchstore as ts
15-
1613
from monarch.actor import Actor, current_rank, endpoint
17-
1814
from torchstore.logging import init_logging
1915
from torchstore.utils import spawn_actors
2016

@@ -303,19 +299,31 @@ async def get(self, key):
303299
@pytest.mark.asyncio
304300
async def test_key_miss():
305301
"""Test the behavior of get() when the key is missing."""
306-
await ts.initialize()
307302

308-
key = "foo"
309-
value = torch.tensor([1, 2, 3])
310-
await ts.put(key, value)
303+
class TestActor(Actor):
304+
@endpoint
305+
async def test(self) -> Exception or None:
306+
try:
307+
key = "foo"
308+
value = torch.tensor([1, 2, 3])
309+
await ts.put(key, value)
310+
311+
# Get the value back
312+
retrieved_value = await ts.get(key)
313+
assert torch.equal(value, retrieved_value)
314+
315+
# Get a missing key
316+
with pytest.raises(KeyError):
317+
await ts.get("bar")
318+
except Exception as e:
319+
return e
320+
321+
await ts.initialize()
311322

312-
# Get the value back
313-
retrieved_value = await ts.get(key)
314-
assert torch.equal(value, retrieved_value)
323+
actor = await spawn_actors(1, TestActor, "actor_0")
324+
err = await actor.test.call_one()
315325

316-
# Get a missing key
317-
with pytest.raises(KeyError):
318-
await ts.get("bar")
326+
assert err is None
319327

320328
await ts.shutdown()
321329

0 commit comments

Comments
 (0)