Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import pytest
import torch

import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint
from pytest_unordered import unordered
Expand All @@ -21,24 +20,36 @@
@pytest.mark.asyncio
async def test_keys_basic():
"""Test basic put/get functionality"""

class TestActor(Actor):
@endpoint
async def test(self) -> Exception or None:
try:
await ts.put("", torch.tensor([1, 2, 3]))
await ts.put(".x", torch.tensor([1, 2, 3]))
await ts.put("v0.x", torch.tensor([1, 2, 3]))
await ts.put("v0.y", torch.tensor([4, 5, 6]))
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
await ts.put("v1.x", torch.tensor([7, 8, 9]))
await ts.put("v1.y", torch.tensor([10, 11, 12]))

assert await ts.keys() == unordered(
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
)
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
assert await ts.keys("") == unordered(["", ".x"])
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
except Exception as e:
return e

await ts.initialize()

await ts.put("", torch.tensor([1, 2, 3]))
await ts.put(".x", torch.tensor([1, 2, 3]))
await ts.put("v0.x", torch.tensor([1, 2, 3]))
await ts.put("v0.y", torch.tensor([4, 5, 6]))
await ts.put("v0.x.z", torch.tensor([7, 8, 9]))
await ts.put("v1.x", torch.tensor([7, 8, 9]))
await ts.put("v1.y", torch.tensor([10, 11, 12]))
actor = await spawn_actors(1, TestActor, "actor_0")
err = await actor.test.call_one()

assert await ts.keys() == unordered(
["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"]
)
assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"])
assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"])
assert await ts.keys("v0.x.z") == unordered(["v0.x.z"])
assert await ts.keys("") == unordered(["", ".x"])
assert await ts.keys("v1") == unordered(["v1.x", "v1.y"])
assert err is None

await ts.shutdown()

Expand Down
1 change: 0 additions & 1 deletion tests/test_large_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import pytest
import torch

import torchstore as ts
from monarch.actor import Actor, endpoint
from torchstore.logging import init_logging
Expand Down
1 change: 0 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import pytest
import torch

import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.device_mesh import init_device_mesh
Expand Down
3 changes: 0 additions & 3 deletions tests/test_resharding_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
from typing import List, Tuple, Union

import pytest

import torch

import torchstore as ts

from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
from torchstore.utils import get_local_tensor, spawn_actors
Expand Down
14 changes: 12 additions & 2 deletions tests/test_resharding_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,28 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from logging import getLogger

import pytest

from torch.distributed._tensor import Shard

from .test_resharding_basic import _test_resharding

from .utils import main, transport_plus_strategy_params

logger = getLogger(__name__)


def slow_tests_enabled():
return os.environ.get("TORCHSTORE_ENABLE_SLOW_TESTS", "0") == "1"


requires_slow_tests_enabled = pytest.mark.skipif(
not slow_tests_enabled(),
reason="Slow tests are disabled by default, use TORCHSTORE_ENABLE_SLOW_TESTS=1 to enable them",
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.parametrize(
"put_mesh_shape,get_mesh_shape,put_sharding_dim,get_sharding_dim",
Expand Down Expand Up @@ -55,6 +64,7 @@ async def test_1d_resharding(
)


@requires_slow_tests_enabled
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from typing import Union

import pytest

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn

import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.state_dict import (
Expand Down Expand Up @@ -279,7 +276,6 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
flattened_state_dict_2
), f"{flattened_state_dict_1.keys()=}\n{flattened_state_dict_2.keys()=}"
for key in flattened_state_dict_1:

assert key in flattened_state_dict_2
if isinstance(flattened_state_dict_1[key], torch.Tensor):
t1, t2 = flattened_state_dict_1[key], flattened_state_dict_2[key]
Expand Down
36 changes: 22 additions & 14 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
from logging import getLogger

import pytest

import torch

import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint

from torchstore.logging import init_logging
from torchstore.utils import spawn_actors

Expand Down Expand Up @@ -303,19 +299,31 @@ async def get(self, key):
@pytest.mark.asyncio
async def test_key_miss():
"""Test the behavior of get() when the key is missing."""
await ts.initialize()

key = "foo"
value = torch.tensor([1, 2, 3])
await ts.put(key, value)
class TestActor(Actor):
@endpoint
async def test(self) -> Exception or None:
try:
key = "foo"
value = torch.tensor([1, 2, 3])
await ts.put(key, value)

# Get the value back
retrieved_value = await ts.get(key)
assert torch.equal(value, retrieved_value)

# Get a missing key
with pytest.raises(KeyError):
await ts.get("bar")
except Exception as e:
return e

await ts.initialize()

# Get the value back
retrieved_value = await ts.get(key)
assert torch.equal(value, retrieved_value)
actor = await spawn_actors(1, TestActor, "actor_0")
err = await actor.test.call_one()

# Get a missing key
with pytest.raises(KeyError):
await ts.get("bar")
assert err is None

await ts.shutdown()

Expand Down
Loading
Loading