Skip to content

Commit 64e2126

Browse files
casteryhfacebook-github-bot
authored andcommitted
fix multiple tests. with latest monarch and rdma, ts.put and ts.get only works inside an actor
Differential Revision: D86156904
1 parent 3a5f18e commit 64e2126

File tree

3 files changed

+129
-57
lines changed

3 files changed

+129
-57
lines changed

tests/test_keys.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,36 @@
2121
@pytest.mark.asyncio
2222
async def test_keys_basic():
2323
"""Test basic put/get functionality"""
24+
class TestActor(Actor):
25+
@endpoint
26+
async def test(self) -> Exception or None:
27+
try:
28+
await ts.put("", torch.tensor([1, 2, 3]))
29+
await ts.put(".x", torch.tensor([1, 2, 3]))
30+
await ts.put("v0.x", torch.tensor([1, 2, 3]))
31+
await ts.put("v0.y", torch.tensor([4, 5, 6]))
32+
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
33+
await ts.put("v1.x", torch.tensor([7, 8, 9]))
34+
await ts.put("v1.y", torch.tensor([10, 11, 12]))
35+
36+
assert await ts.keys() == unordered(
37+
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
38+
)
39+
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
40+
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
41+
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
42+
assert await ts.keys("") == unordered(["", ".x"])
43+
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
44+
except Exception as e:
45+
return e
46+
47+
2448
await ts.initialize()
2549

26-
await ts.put("", torch.tensor([1, 2, 3]))
27-
await ts.put(".x", torch.tensor([1, 2, 3]))
28-
await ts.put("v0.x", torch.tensor([1, 2, 3]))
29-
await ts.put("v0.y", torch.tensor([4, 5, 6]))
30-
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
31-
await ts.put("v1.x", torch.tensor([7, 8, 9]))
32-
await ts.put("v1.y", torch.tensor([10, 11, 12]))
50+
actor = await spawn_actors(1, TestActor, "actor_0")
51+
err = await actor.test.call_one()
3352

34-
assert await ts.keys() == unordered(
35-
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
36-
)
37-
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
38-
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
39-
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
40-
assert await ts.keys("") == unordered(["", ".x"])
41-
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
53+
assert err is None
4254

4355
await ts.shutdown()
4456

tests/test_store.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,30 @@ async def get(self, key):
303303
@pytest.mark.asyncio
304304
async def test_key_miss():
305305
"""Test the behavior of get() when the key is missing."""
306-
await ts.initialize()
306+
class TestActor(Actor):
307+
@endpoint
308+
async def test(self) -> Exception or None:
309+
try:
310+
key = "foo"
311+
value = torch.tensor([1, 2, 3])
312+
await ts.put(key, value)
313+
314+
# Get the value back
315+
retrieved_value = await ts.get(key)
316+
assert torch.equal(value, retrieved_value)
317+
318+
# Get a missing key
319+
with pytest.raises(KeyError):
320+
await ts.get("bar")
321+
except Exception as e:
322+
return e
307323

308-
key = "foo"
309-
value = torch.tensor([1, 2, 3])
310-
await ts.put(key, value)
324+
await ts.initialize()
311325

312-
# Get the value back
313-
retrieved_value = await ts.get(key)
314-
assert torch.equal(value, retrieved_value)
326+
actor = await spawn_actors(1, TestActor, "actor_0")
327+
err = await actor.test.call_one()
315328

316-
# Get a missing key
317-
with pytest.raises(KeyError):
318-
await ts.get("bar")
329+
assert err is None
319330

320331
await ts.shutdown()
321332

tests/test_tensor_slice.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def __init__(self, world_size):
4242
async def put(self, key, tensor):
4343
await ts.put(key, tensor)
4444

45+
class TensorSliceGetActor(Actor):
46+
"""Actor for getting tensors."""
47+
48+
@endpoint
49+
async def get(self, key, tensor_slice_spec=None):
50+
return await ts.get(key, tensor_slice_spec=tensor_slice_spec)
51+
4552
volume_world_size, strategy = strategy_params
4653
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)
4754

@@ -53,6 +60,8 @@ async def put(self, key, tensor):
5360
world_size=volume_world_size,
5461
)
5562

63+
get_actor = await spawn_actors(1, TensorSliceGetActor, "tensor_slice_get_actor")
64+
5665
try:
5766
# Create a 100x100 tensor filled with sequential values 0-9999
5867
test_tensor = torch.arange(10000).reshape(100, 100).float()
@@ -63,7 +72,7 @@ async def put(self, key, tensor):
6372
await put_actor.put.call(key, test_tensor)
6473

6574
# Test full tensor retrieval using get actor mesh
66-
retrieved_tensor = await ts.get(key)
75+
retrieved_tensor = await get_actor.get.call_one(key)
6776
assert torch.equal(test_tensor, retrieved_tensor)
6877

6978
# Test slice retrieval using get actor mesh
@@ -75,7 +84,7 @@ async def put(self, key, tensor):
7584
mesh_shape=(),
7685
)
7786

78-
tensor_slice = await ts.get(key, tensor_slice_spec=tensor_slice_spec)
87+
tensor_slice = await get_actor.get.call_one(key, tensor_slice_spec=tensor_slice_spec)
7988
expected_slice = test_tensor[10:15, 20:30]
8089
assert torch.equal(tensor_slice, expected_slice)
8190
assert tensor_slice.shape == (5, 10)
@@ -89,32 +98,43 @@ async def put(self, key, tensor):
8998
@pytest.mark.asyncio
9099
async def test_tensor_slice_inplace():
91100
"""Test tensor slice API with in-place operations"""
101+
class TestActor(Actor):
102+
@endpoint
103+
async def test(self, test_tensor) -> Exception or None:
104+
try:
105+
# Store a test tensor
106+
await ts.put("inplace_test", test_tensor)
107+
108+
# Test in-place retrieval with slice
109+
slice_spec = TensorSlice(
110+
offsets=(10, 20),
111+
coordinates=(),
112+
global_shape=(100, 200),
113+
local_shape=(30, 40),
114+
mesh_shape=(),
115+
)
116+
117+
# Create pre-allocated buffer
118+
slice_buffer = torch.empty(30, 40)
119+
result = await ts.get(
120+
"inplace_test", inplace_tensor=slice_buffer, tensor_slice_spec=slice_spec
121+
)
122+
123+
# Verify in-place operation
124+
assert result is slice_buffer
125+
expected_slice = test_tensor[10:40, 20:60]
126+
assert torch.equal(slice_buffer, expected_slice)
127+
except Exception as e:
128+
return e
129+
92130
await ts.initialize(num_storage_volumes=1)
93131

94132
try:
95-
# Store a test tensor
96133
test_tensor = torch.randn(100, 200)
97-
await ts.put("inplace_test", test_tensor)
98-
99-
# Test in-place retrieval with slice
100-
slice_spec = TensorSlice(
101-
offsets=(10, 20),
102-
coordinates=(),
103-
global_shape=(100, 200),
104-
local_shape=(30, 40),
105-
mesh_shape=(),
106-
)
107-
108-
# Create pre-allocated buffer
109-
slice_buffer = torch.empty(30, 40)
110-
result = await ts.get(
111-
"inplace_test", inplace_tensor=slice_buffer, tensor_slice_spec=slice_spec
112-
)
134+
actor = await spawn_actors(1, TestActor, "actor_0")
135+
err = await actor.test.call_one(test_tensor)
113136

114-
# Verify in-place operation
115-
assert result is slice_buffer
116-
expected_slice = test_tensor[10:40, 20:60]
117-
assert torch.equal(slice_buffer, expected_slice)
137+
assert err is None
118138

119139
finally:
120140
await ts.shutdown()
@@ -123,6 +143,11 @@ async def test_tensor_slice_inplace():
123143
@pytest.mark.asyncio
124144
async def test_put_dtensor_get_full_tensor():
125145
"""Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor"""
146+
class GetActor(Actor):
147+
@endpoint
148+
async def get_tensor(self, key):
149+
return await ts.get(key)
150+
126151
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())
127152

128153
original_tensor = torch.arange(16).reshape(4, 4).float()
@@ -142,7 +167,8 @@ async def test_put_dtensor_get_full_tensor():
142167

143168
await put_mesh.do_put.call()
144169

145-
fetched_tensor = await ts.get("test_key")
170+
get_actor = await spawn_actors(1, GetActor, "get_actor_0")
171+
fetched_tensor = await get_actor.get_tensor.call_one("test_key")
146172
assert torch.equal(original_tensor, fetched_tensor)
147173

148174
finally:
@@ -167,6 +193,11 @@ async def test_dtensor_fetch_slice():
167193
"""
168194
import tempfile
169195

196+
class GetActor(Actor):
197+
@endpoint
198+
async def get_tensor(self, key, tensor_slice_spec=None):
199+
return await ts.get(key, tensor_slice_spec=tensor_slice_spec)
200+
170201
# Use LocalRankStrategy with 2 storage volumes (no RDMA, no parametrization)
171202
os.environ["TORCHSTORE_RDMA_ENABLED"] = "0"
172203
os.environ["LOCAL_RANK"] = "0" # Required by LocalRankStrategy
@@ -195,6 +226,8 @@ async def test_dtensor_fetch_slice():
195226

196227
await put_mesh.do_put.call()
197228

229+
get_actor = await spawn_actors(1, GetActor, "get_actor_0")
230+
198231
# Test 1: Cross-volume slice (spans both volumes)
199232
# Request rows 2-5 (spans volume boundary at row 4)
200233
cross_volume_slice = TensorSlice(
@@ -205,7 +238,7 @@ async def test_dtensor_fetch_slice():
205238
mesh_shape=(),
206239
)
207240

208-
cross_volume_result = await ts.get(
241+
cross_volume_result = await get_actor.get_tensor.call_one(
209242
"test_key", tensor_slice_spec=cross_volume_slice
210243
)
211244
expected_cross_volume = original_tensor[2:6, 1:5]
@@ -223,7 +256,7 @@ async def test_dtensor_fetch_slice():
223256
mesh_shape=(),
224257
)
225258

226-
single_volume_result = await ts.get(
259+
single_volume_result = await get_actor.get_tensor.call_one(
227260
"test_key", tensor_slice_spec=single_volume_slice
228261
)
229262
expected_single_volume = original_tensor[1:3, 0:3]
@@ -249,6 +282,19 @@ async def test_partial_put():
249282
because the DTensor is not fully committed (only rank 0's shard is stored).
250283
"""
251284

285+
class TestActor(Actor):
286+
@endpoint
287+
async def exists(self, key):
288+
return await ts.exists(key)
289+
290+
@endpoint
291+
async def get(self, key):
292+
try:
293+
result = await ts.get(key)
294+
return {"success": True, "result": result}
295+
except Exception as e:
296+
return {"success": False, "error": e, "error_str": str(e)}
297+
252298
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())
253299

254300
original_tensor = torch.arange(16).reshape(4, 4).float()
@@ -270,15 +316,18 @@ async def test_partial_put():
270316
# Execute the put - rank 0 will put, rank 1 will skip
271317
await put_mesh.do_put.call()
272318

273-
assert not await ts.exists("test_key")
319+
test_actor = await spawn_actors(1, TestActor, "test_actor_0")
320+
321+
assert not await test_actor.exists.call_one("test_key")
274322
# Try to get the tensor - should raise KeyError because only rank 0 has committed
275-
with pytest.raises(KeyError) as exc_info:
276-
await ts.get("test_key")
323+
result = await test_actor.get.call_one("test_key")
324+
325+
assert not result["success"], "Expected get to fail but it succeeded"
326+
assert isinstance(result["error"], KeyError), f"Expected KeyError but got {type(result['error'])}"
277327

278328
# Verify the error message mentions partial commit
279-
assert "partially committed" in str(
280-
exc_info.value
281-
), f"Error message should mention partial commit: {exc_info.value}"
329+
assert "partially committed" in result["error_str"], \
330+
f"Error message should mention partial commit: {result['error_str']}"
282331

283332
finally:
284333
# Clean up process groups

0 commit comments

Comments
 (0)