diff --git a/tests/test_keys.py b/tests/test_keys.py index 0010814..62ec4bd 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -10,7 +10,6 @@ import pytest import torch - import torchstore as ts from monarch.actor import Actor, current_rank, endpoint from pytest_unordered import unordered @@ -21,24 +20,36 @@ @pytest.mark.asyncio async def test_keys_basic(): """Test basic put/get functionality""" + + class TestActor(Actor): + @endpoint + async def test(self) -> Exception or None: + try: + await ts.put("", torch.tensor([1, 2, 3])) + await ts.put(".x", torch.tensor([1, 2, 3])) + await ts.put("v0.x", torch.tensor([1, 2, 3])) + await ts.put("v0.y", torch.tensor([4, 5, 6])) + await ts.put("v0.x.z", torch.tensor([7, 8, 9])) + await ts.put("v1.x", torch.tensor([7, 8, 9])) + await ts.put("v1.y", torch.tensor([10, 11, 12])) + + assert await ts.keys() == unordered( + ["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"] + ) + assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"]) + assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"]) + assert await ts.keys("v0.x.z") == unordered(["v0.x.z"]) + assert await ts.keys("") == unordered(["", ".x"]) + assert await ts.keys("v1") == unordered(["v1.x", "v1.y"]) + except Exception as e: + return e + await ts.initialize() - await ts.put("", torch.tensor([1, 2, 3])) - await ts.put(".x", torch.tensor([1, 2, 3])) - await ts.put("v0.x", torch.tensor([1, 2, 3])) - await ts.put("v0.y", torch.tensor([4, 5, 6])) - await ts.put("v0.x.z", torch.tensor([7, 8, 9])) - await ts.put("v1.x", torch.tensor([7, 8, 9])) - await ts.put("v1.y", torch.tensor([10, 11, 12])) + actor = await spawn_actors(1, TestActor, "actor_0") + err = await actor.test.call_one() - assert await ts.keys() == unordered( - ["", ".x", "v0.x", "v0.y", "v0.x.z", "v1.x", "v1.y"] - ) - assert await ts.keys("v0") == unordered(["v0.x", "v0.y", "v0.x.z"]) - assert await ts.keys("v0.x") == unordered(["v0.x", "v0.x.z"]) - assert await ts.keys("v0.x.z") == unordered(["v0.x.z"]) - assert await ts.keys("") == unordered(["", ".x"]) - assert await ts.keys("v1") == unordered(["v1.x", "v1.y"]) + assert err is None await ts.shutdown() diff --git a/tests/test_large_tensors.py b/tests/test_large_tensors.py index 44eb87c..fdf8701 100644 --- a/tests/test_large_tensors.py +++ b/tests/test_large_tensors.py @@ -10,7 +10,6 @@ import pytest import torch - import torchstore as ts from monarch.actor import Actor, endpoint from torchstore.logging import init_logging diff --git a/tests/test_models.py b/tests/test_models.py index 6fe2b3a..8473f57 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,7 +12,6 @@ import pytest import torch - import torchstore as ts from monarch.actor import Actor, current_rank, endpoint from torch.distributed.device_mesh import init_device_mesh diff --git a/tests/test_resharding_basic.py b/tests/test_resharding_basic.py index 0872c56..2ff8b91 100644 --- a/tests/test_resharding_basic.py +++ b/tests/test_resharding_basic.py @@ -11,11 +11,8 @@ from typing import List, Tuple, Union import pytest - import torch - import torchstore as ts - from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset from torchstore.utils import get_local_tensor, spawn_actors diff --git a/tests/test_resharding_ext.py b/tests/test_resharding_ext.py index 68bfb25..576313f 100644 --- a/tests/test_resharding_ext.py +++ b/tests/test_resharding_ext.py @@ -4,19 +4,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os from logging import getLogger import pytest - from torch.distributed._tensor import Shard from .test_resharding_basic import _test_resharding - from .utils import main, transport_plus_strategy_params logger = getLogger(__name__) +def slow_tests_enabled(): + return os.environ.get("TORCHSTORE_ENABLE_SLOW_TESTS", "0") == "1" + + +requires_slow_tests_enabled = pytest.mark.skipif( + not slow_tests_enabled(), + reason="Slow tests are disabled by default, use TORCHSTORE_ENABLE_SLOW_TESTS=1 to enable them", +) + + @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.parametrize( "put_mesh_shape,get_mesh_shape,put_sharding_dim,get_sharding_dim", @@ -55,6 +64,7 @@ async def test_1d_resharding( ) +@requires_slow_tests_enabled @pytest.mark.parametrize(*transport_plus_strategy_params()) @pytest.mark.asyncio async def test_2d_to_2d_resharding(strategy_params, use_rdma): diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py index b3aa4e9..9639316 100644 --- a/tests/test_state_dict.py +++ b/tests/test_state_dict.py @@ -12,13 +12,10 @@ from typing import Union import pytest - import torch import torch.distributed.checkpoint as dcp import torch.nn as nn - import torchstore as ts - from monarch.actor import Actor, current_rank, endpoint from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torch.distributed.checkpoint.state_dict import ( @@ -279,7 +276,6 @@ def _assert_equal_state_dict(state_dict1, state_dict2): flattened_state_dict_2 ), f"{flattened_state_dict_1.keys()=}\n{flattened_state_dict_2.keys()=}" for key in flattened_state_dict_1: - assert key in flattened_state_dict_2 if isinstance(flattened_state_dict_1[key], torch.Tensor): t1, t2 = flattened_state_dict_1[key], flattened_state_dict_2[key] diff --git a/tests/test_store.py b/tests/test_store.py index 19342e5..3305597 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -8,13 +8,9 @@ from logging import getLogger import pytest - import torch - import torchstore as ts - from monarch.actor import Actor, current_rank, endpoint - from torchstore.logging import init_logging from torchstore.utils import spawn_actors @@ -303,19 +299,31 @@ async def get(self, key): @pytest.mark.asyncio async def test_key_miss(): """Test the behavior of get() when the key is missing.""" - await ts.initialize() - key = "foo" - value = torch.tensor([1, 2, 3]) - await ts.put(key, value) + class TestActor(Actor): + @endpoint + async def test(self) -> Exception or None: + try: + key = "foo" + value = torch.tensor([1, 2, 3]) + await ts.put(key, value) + + # Get the value back + retrieved_value = await ts.get(key) + assert torch.equal(value, retrieved_value) + + # Get a missing key + with pytest.raises(KeyError): + await ts.get("bar") + except Exception as e: + return e + + await ts.initialize() - # Get the value back - retrieved_value = await ts.get(key) - assert torch.equal(value, retrieved_value) + actor = await spawn_actors(1, TestActor, "actor_0") + err = await actor.test.call_one() - # Get a missing key - with pytest.raises(KeyError): - await ts.get("bar") + assert err is None await ts.shutdown() diff --git a/tests/test_tensor_slice.py b/tests/test_tensor_slice.py index 9093aa8..e4f7557 100644 --- a/tests/test_tensor_slice.py +++ b/tests/test_tensor_slice.py @@ -9,7 +9,6 @@ import pytest import torch - import torchstore as ts from monarch.actor import Actor, current_rank, endpoint @@ -42,6 +41,13 @@ def __init__(self, world_size): async def put(self, key, tensor): await ts.put(key, tensor) + class TensorSliceGetActor(Actor): + """Actor for getting tensors.""" + + @endpoint + async def get(self, key, tensor_slice_spec=None): + return await ts.get(key, tensor_slice_spec=tensor_slice_spec) + volume_world_size, strategy = strategy_params await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) @@ -53,6 +59,8 @@ async def put(self, key, tensor): world_size=volume_world_size, ) + get_actor = await spawn_actors(1, TensorSliceGetActor, "tensor_slice_get_actor") + try: # Create a 100x100 tensor filled with sequential values 0-9999 test_tensor = torch.arange(10000).reshape(100, 100).float() @@ -63,7 +71,7 @@ async def put(self, key, tensor): await put_actor.put.call(key, test_tensor) # Test full tensor retrieval using get actor mesh - retrieved_tensor = await ts.get(key) + retrieved_tensor = await get_actor.get.call_one(key) assert torch.equal(test_tensor, retrieved_tensor) # Test slice retrieval using get actor mesh @@ -75,7 +83,9 @@ async def put(self, key, tensor): mesh_shape=(), ) - tensor_slice = await ts.get(key, tensor_slice_spec=tensor_slice_spec) + tensor_slice = await get_actor.get.call_one( + key, tensor_slice_spec=tensor_slice_spec + ) expected_slice = test_tensor[10:15, 20:30] assert torch.equal(tensor_slice, expected_slice) assert tensor_slice.shape == (5, 10) @@ -89,32 +99,46 @@ async def put(self, key, tensor): @pytest.mark.asyncio async def test_tensor_slice_inplace(): """Test tensor slice API with in-place operations""" + + class TestActor(Actor): + @endpoint + async def test(self, test_tensor) -> Exception or None: + try: + # Store a test tensor + await ts.put("inplace_test", test_tensor) + + # Test in-place retrieval with slice + slice_spec = TensorSlice( + offsets=(10, 20), + coordinates=(), + global_shape=(100, 200), + local_shape=(30, 40), + mesh_shape=(), + ) + + # Create pre-allocated buffer + slice_buffer = torch.empty(30, 40) + result = await ts.get( + "inplace_test", + inplace_tensor=slice_buffer, + tensor_slice_spec=slice_spec, + ) + + # Verify in-place operation + assert result is slice_buffer + expected_slice = test_tensor[10:40, 20:60] + assert torch.equal(slice_buffer, expected_slice) + except Exception as e: + return e + await ts.initialize(num_storage_volumes=1) try: - # Store a test tensor test_tensor = torch.randn(100, 200) - await ts.put("inplace_test", test_tensor) - - # Test in-place retrieval with slice - slice_spec = TensorSlice( - offsets=(10, 20), - coordinates=(), - global_shape=(100, 200), - local_shape=(30, 40), - mesh_shape=(), - ) + actor = await spawn_actors(1, TestActor, "actor_0") + err = await actor.test.call_one(test_tensor) - # Create pre-allocated buffer - slice_buffer = torch.empty(30, 40) - result = await ts.get( - "inplace_test", inplace_tensor=slice_buffer, tensor_slice_spec=slice_spec - ) - - # Verify in-place operation - assert result is slice_buffer - expected_slice = test_tensor[10:40, 20:60] - assert torch.equal(slice_buffer, expected_slice) + assert err is None finally: await ts.shutdown() @@ -123,6 +147,12 @@ async def test_tensor_slice_inplace(): @pytest.mark.asyncio async def test_put_dtensor_get_full_tensor(): """Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor""" + + class GetActor(Actor): + @endpoint + async def get_tensor(self, key): + return await ts.get(key) + await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy()) original_tensor = torch.arange(16).reshape(4, 4).float() @@ -142,7 +172,8 @@ async def test_put_dtensor_get_full_tensor(): await put_mesh.do_put.call() - fetched_tensor = await ts.get("test_key") + get_actor = await spawn_actors(1, GetActor, "get_actor_0") + fetched_tensor = await get_actor.get_tensor.call_one("test_key") assert torch.equal(original_tensor, fetched_tensor) finally: @@ -167,6 +198,11 @@ async def test_dtensor_fetch_slice(): """ import tempfile + class GetActor(Actor): + @endpoint + async def get_tensor(self, key, tensor_slice_spec=None): + return await ts.get(key, tensor_slice_spec=tensor_slice_spec) + # Use LocalRankStrategy with 2 storage volumes (no RDMA, no parametrization) os.environ["TORCHSTORE_RDMA_ENABLED"] = "0" os.environ["LOCAL_RANK"] = "0" # Required by LocalRankStrategy @@ -195,6 +231,8 @@ async def test_dtensor_fetch_slice(): await put_mesh.do_put.call() + get_actor = await spawn_actors(1, GetActor, "get_actor_0") + # Test 1: Cross-volume slice (spans both volumes) # Request rows 2-5 (spans volume boundary at row 4) cross_volume_slice = TensorSlice( @@ -205,7 +243,7 @@ async def test_dtensor_fetch_slice(): mesh_shape=(), ) - cross_volume_result = await ts.get( + cross_volume_result = await get_actor.get_tensor.call_one( "test_key", tensor_slice_spec=cross_volume_slice ) expected_cross_volume = original_tensor[2:6, 1:5] @@ -223,7 +261,7 @@ async def test_dtensor_fetch_slice(): mesh_shape=(), ) - single_volume_result = await ts.get( + single_volume_result = await get_actor.get_tensor.call_one( "test_key", tensor_slice_spec=single_volume_slice ) expected_single_volume = original_tensor[1:3, 0:3] @@ -249,6 +287,19 @@ async def test_partial_put(): because the DTensor is not fully committed (only rank 0's shard is stored). """ + class TestActor(Actor): + @endpoint + async def exists(self, key): + return await ts.exists(key) + + @endpoint + async def get(self, key): + try: + result = await ts.get(key) + return {"success": True, "result": result} + except Exception as e: + return {"success": False, "error": e, "error_str": str(e)} + await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy()) original_tensor = torch.arange(16).reshape(4, 4).float() @@ -270,15 +321,21 @@ async def test_partial_put(): # Execute the put - rank 0 will put, rank 1 will skip await put_mesh.do_put.call() - assert not await ts.exists("test_key") + test_actor = await spawn_actors(1, TestActor, "test_actor_0") + + assert not await test_actor.exists.call_one("test_key") # Try to get the tensor - should raise KeyError because only rank 0 has committed - with pytest.raises(KeyError) as exc_info: - await ts.get("test_key") + result = await test_actor.get.call_one("test_key") + + assert not result["success"], "Expected get to fail but it succeeded" + assert isinstance( + result["error"], KeyError + ), f"Expected KeyError but got {type(result['error'])}" # Verify the error message mentions partial commit - assert "partially committed" in str( - exc_info.value - ), f"Error message should mention partial commit: {exc_info.value}" + assert ( + "partially committed" in result["error_str"] + ), f"Error message should mention partial commit: {result['error_str']}" finally: # Clean up process groups diff --git a/tests/test_utils.py b/tests/test_utils.py index 29bfa2e..bd02c89 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,9 +7,7 @@ from logging import getLogger import pytest - import torch - from torchstore.utils import assemble_tensor, get_local_tensor diff --git a/torchstore/controller.py b/torchstore/controller.py index af726cc..d6bf4fe 100644 --- a/torchstore/controller.py +++ b/torchstore/controller.py @@ -120,13 +120,13 @@ async def init( self.is_initialized = True @endpoint - def get_controller_strategy(self) -> TorchStoreStrategy: + async def get_controller_strategy(self) -> TorchStoreStrategy: self.assert_initialized() assert self.strategy is not None, "Strategy is not set" return self.strategy @endpoint - def locate_volumes( + async def locate_volumes( self, key: str, ) -> Dict[str, StorageInfo]: @@ -178,7 +178,9 @@ def locate_volumes( return volume_map @endpoint - def notify_put(self, key: str, request: Request, storage_volume_id: str) -> None: + async def notify_put( + self, key: str, request: Request, storage_volume_id: str + ) -> None: """Notify the controller that data has been stored in a storage volume. This should called after a successful put operation to @@ -220,13 +222,13 @@ async def teardown(self) -> None: self.num_storage_volumes = None @endpoint - def keys(self, prefix=None) -> List[str]: + async def keys(self, prefix=None) -> List[str]: if prefix is None: return list(self.keys_to_storage_volumes.keys()) return self.keys_to_storage_volumes.keys().filter_by_prefix(prefix) @endpoint - def notify_delete(self, key: str, storage_volume_id: str) -> None: + async def notify_delete(self, key: str, storage_volume_id: str) -> None: """ Notify the controller that deletion of data is initiated in a storage volume.