|
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 | import torch |
13 | | - |
14 | 13 | import torchstore as ts |
15 | 14 | from monarch.actor import Actor, current_rank, endpoint |
16 | 15 | from pytest_unordered import unordered |
|
21 | 20 | @pytest.mark.asyncio |
22 | 21 | async def test_keys_basic(): |
23 | 22 | """Test basic put/get functionality""" |
| 23 | + |
| 24 | + class TestActor(Actor): |
| 25 | + @endpoint |
| 26 | + async def test(self) -> Exception or None: |
| 27 | + try: |
| 28 | + await ts.put("", torch.tensor([1, 2, 3])) |
| 29 | + await ts.put(".x", torch.tensor([1, 2, 3])) |
| 30 | + await ts.put("v0.x", torch.tensor([1, 2, 3])) |
| 31 | + await ts.put("v0.y", torch.tensor([4, 5, 6])) |
| 32 | + await ts.put("v0.x.z", torch.tensor([7, 8, 9])) |
| 33 | + await ts.put("v1.x", torch.tensor([7, 8, 9])) |
| 34 | + await ts.put("v1.y", torch.tensor([10, 11, 12])) |
| 35 | + |
| 36 | + assert await ts.keys() == unordered( |
| 37 | + ["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"] |
| 38 | + ) |
| 39 | + assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"]) |
| 40 | + assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"]) |
| 41 | + assert await ts.keys("v0.x.z") == unordered(["v0.x.z"]) |
| 42 | + assert await ts.keys("") == unordered(["", ".x"]) |
| 43 | + assert await ts.keys("v1") == unordered(["v1.x", "v1.y"]) |
| 44 | + except Exception as e: |
| 45 | + return e |
| 46 | + |
24 | 47 | await ts.initialize() |
25 | 48 |
|
26 | | - await ts.put("", torch.tensor([1, 2, 3])) |
27 | | - await ts.put(".x", torch.tensor([1, 2, 3])) |
28 | | - await ts.put("v0.x", torch.tensor([1, 2, 3])) |
29 | | - await ts.put("v0.y", torch.tensor([4, 5, 6])) |
30 | | - await ts.put("v0.x.z", torch.tensor([7, 8, 9])) |
31 | | - await ts.put("v1.x", torch.tensor([7, 8, 9])) |
32 | | - await ts.put("v1.y", torch.tensor([10, 11, 12])) |
| 49 | + actor = await spawn_actors(1, TestActor, "actor_0") |
| 50 | + err = await actor.test.call_one() |
33 | 51 |
|
34 | | - assert await ts.keys() == unordered( |
35 | | - ["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"] |
36 | | - ) |
37 | | - assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"]) |
38 | | - assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"]) |
39 | | - assert await ts.keys("v0.x.z") == unordered(["v0.x.z"]) |
40 | | - assert await ts.keys("") == unordered(["", ".x"]) |
41 | | - assert await ts.keys("v1") == unordered(["v1.x", "v1.y"]) |
| 52 | + assert err is None |
42 | 53 |
|
43 | 54 | await ts.shutdown() |
44 | 55 |
|
|
0 commit comments