@@ -42,6 +42,13 @@ def __init__(self, world_size):
4242 async def put (self , key , tensor ):
4343 await ts .put (key , tensor )
4444
45+ class TensorSliceGetActor (Actor ):
46+ """Actor for getting tensors."""
47+
48+ @endpoint
49+ async def get (self , key , tensor_slice_spec = None ):
50+ return await ts .get (key , tensor_slice_spec = tensor_slice_spec )
51+
4552 volume_world_size , strategy = strategy_params
4653 await ts .initialize (num_storage_volumes = volume_world_size , strategy = strategy )
4754
@@ -53,6 +60,8 @@ async def put(self, key, tensor):
5360 world_size = volume_world_size ,
5461 )
5562
63+ get_actor = await spawn_actors (1 , TensorSliceGetActor , "tensor_slice_get_actor" )
64+
5665 try :
5766 # Create a 100x100 tensor filled with sequential values 0-9999
5867 test_tensor = torch .arange (10000 ).reshape (100 , 100 ).float ()
@@ -63,7 +72,7 @@ async def put(self, key, tensor):
6372 await put_actor .put .call (key , test_tensor )
6473
6574 # Test full tensor retrieval using get actor mesh
66- retrieved_tensor = await ts .get (key )
75+ retrieved_tensor = await get_actor .get . call_one (key )
6776 assert torch .equal (test_tensor , retrieved_tensor )
6877
6978 # Test slice retrieval using get actor mesh
@@ -75,7 +84,7 @@ async def put(self, key, tensor):
7584 mesh_shape = (),
7685 )
7786
78- tensor_slice = await ts .get (key , tensor_slice_spec = tensor_slice_spec )
87+ tensor_slice = await get_actor .get . call_one (key , tensor_slice_spec = tensor_slice_spec )
7988 expected_slice = test_tensor [10 :15 , 20 :30 ]
8089 assert torch .equal (tensor_slice , expected_slice )
8190 assert tensor_slice .shape == (5 , 10 )
@@ -89,32 +98,43 @@ async def put(self, key, tensor):
8998@pytest .mark .asyncio
9099async def test_tensor_slice_inplace ():
91100 """Test tensor slice API with in-place operations"""
101+ class TestActor (Actor ):
102+ @endpoint
103+ async def test (self , test_tensor ) -> Exception or None :
104+ try :
105+ # Store a test tensor
106+ await ts .put ("inplace_test" , test_tensor )
107+
108+ # Test in-place retrieval with slice
109+ slice_spec = TensorSlice (
110+ offsets = (10 , 20 ),
111+ coordinates = (),
112+ global_shape = (100 , 200 ),
113+ local_shape = (30 , 40 ),
114+ mesh_shape = (),
115+ )
116+
117+ # Create pre-allocated buffer
118+ slice_buffer = torch .empty (30 , 40 )
119+ result = await ts .get (
120+ "inplace_test" , inplace_tensor = slice_buffer , tensor_slice_spec = slice_spec
121+ )
122+
123+ # Verify in-place operation
124+ assert result is slice_buffer
125+ expected_slice = test_tensor [10 :40 , 20 :60 ]
126+ assert torch .equal (slice_buffer , expected_slice )
127+ except Exception as e :
128+ return e
129+
92130 await ts .initialize (num_storage_volumes = 1 )
93131
94132 try :
95- # Store a test tensor
96133 test_tensor = torch .randn (100 , 200 )
97- await ts .put ("inplace_test" , test_tensor )
98-
99- # Test in-place retrieval with slice
100- slice_spec = TensorSlice (
101- offsets = (10 , 20 ),
102- coordinates = (),
103- global_shape = (100 , 200 ),
104- local_shape = (30 , 40 ),
105- mesh_shape = (),
106- )
107-
108- # Create pre-allocated buffer
109- slice_buffer = torch .empty (30 , 40 )
110- result = await ts .get (
111- "inplace_test" , inplace_tensor = slice_buffer , tensor_slice_spec = slice_spec
112- )
134+ actor = await spawn_actors (1 , TestActor , "actor_0" )
135+ err = await actor .test .call_one (test_tensor )
113136
114- # Verify in-place operation
115- assert result is slice_buffer
116- expected_slice = test_tensor [10 :40 , 20 :60 ]
117- assert torch .equal (slice_buffer , expected_slice )
137+ assert err is None
118138
119139 finally :
120140 await ts .shutdown ()
@@ -123,6 +143,11 @@ async def test_tensor_slice_inplace():
123143@pytest .mark .asyncio
124144async def test_put_dtensor_get_full_tensor ():
125145 """Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor"""
146+ class GetActor (Actor ):
147+ @endpoint
148+ async def get_tensor (self , key ):
149+ return await ts .get (key )
150+
126151 await ts .initialize (num_storage_volumes = 2 , strategy = ts .LocalRankStrategy ())
127152
128153 original_tensor = torch .arange (16 ).reshape (4 , 4 ).float ()
@@ -142,7 +167,8 @@ async def test_put_dtensor_get_full_tensor():
142167
143168 await put_mesh .do_put .call ()
144169
145- fetched_tensor = await ts .get ("test_key" )
170+ get_actor = await spawn_actors (1 , GetActor , "get_actor_0" )
171+ fetched_tensor = await get_actor .get_tensor .call_one ("test_key" )
146172 assert torch .equal (original_tensor , fetched_tensor )
147173
148174 finally :
@@ -167,6 +193,11 @@ async def test_dtensor_fetch_slice():
167193 """
168194 import tempfile
169195
196+ class GetActor (Actor ):
197+ @endpoint
198+ async def get_tensor (self , key , tensor_slice_spec = None ):
199+ return await ts .get (key , tensor_slice_spec = tensor_slice_spec )
200+
170201 # Use LocalRankStrategy with 2 storage volumes (no RDMA, no parametrization)
171202 os .environ ["TORCHSTORE_RDMA_ENABLED" ] = "0"
172203 os .environ ["LOCAL_RANK" ] = "0" # Required by LocalRankStrategy
@@ -195,6 +226,8 @@ async def test_dtensor_fetch_slice():
195226
196227 await put_mesh .do_put .call ()
197228
229+ get_actor = await spawn_actors (1 , GetActor , "get_actor_0" )
230+
198231 # Test 1: Cross-volume slice (spans both volumes)
199232 # Request rows 2-5 (spans volume boundary at row 4)
200233 cross_volume_slice = TensorSlice (
@@ -205,7 +238,7 @@ async def test_dtensor_fetch_slice():
205238 mesh_shape = (),
206239 )
207240
208- cross_volume_result = await ts . get (
241+ cross_volume_result = await get_actor . get_tensor . call_one (
209242 "test_key" , tensor_slice_spec = cross_volume_slice
210243 )
211244 expected_cross_volume = original_tensor [2 :6 , 1 :5 ]
@@ -223,7 +256,7 @@ async def test_dtensor_fetch_slice():
223256 mesh_shape = (),
224257 )
225258
226- single_volume_result = await ts . get (
259+ single_volume_result = await get_actor . get_tensor . call_one (
227260 "test_key" , tensor_slice_spec = single_volume_slice
228261 )
229262 expected_single_volume = original_tensor [1 :3 , 0 :3 ]
@@ -249,6 +282,19 @@ async def test_partial_put():
249282 because the DTensor is not fully committed (only rank 0's shard is stored).
250283 """
251284
285+ class TestActor (Actor ):
286+ @endpoint
287+ async def exists (self , key ):
288+ return await ts .exists (key )
289+
290+ @endpoint
291+ async def get (self , key ):
292+ try :
293+ result = await ts .get (key )
294+ return {"success" : True , "result" : result }
295+ except Exception as e :
296+ return {"success" : False , "error" : e , "error_str" : str (e )}
297+
252298 await ts .initialize (num_storage_volumes = 2 , strategy = ts .LocalRankStrategy ())
253299
254300 original_tensor = torch .arange (16 ).reshape (4 , 4 ).float ()
@@ -270,15 +316,18 @@ async def test_partial_put():
270316 # Execute the put - rank 0 will put, rank 1 will skip
271317 await put_mesh .do_put .call ()
272318
273- assert not await ts .exists ("test_key" )
319+ test_actor = await spawn_actors (1 , TestActor , "test_actor_0" )
320+
321+ assert not await test_actor .exists .call_one ("test_key" )
274322 # Try to get the tensor - should raise KeyError because only rank 0 has committed
275- with pytest .raises (KeyError ) as exc_info :
276- await ts .get ("test_key" )
323+ result = await test_actor .get .call_one ("test_key" )
324+
325+ assert not result ["success" ], "Expected get to fail but it succeeded"
326+ assert isinstance (result ["error" ], KeyError ), f"Expected KeyError but got { type (result ['error' ])} "
277327
278328 # Verify the error message mentions partial commit
279- assert "partially committed" in str (
280- exc_info .value
281- ), f"Error message should mention partial commit: { exc_info .value } "
329+ assert "partially committed" in result ["error_str" ], \
330+ f"Error message should mention partial commit: { result ['error_str' ]} "
282331
283332 finally :
284333 # Clean up process groups
0 commit comments