diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index ecc06a52..fa207c41 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -444,9 +444,6 @@ def items(self) -> Iterable[Tuple[Point, R]]: def __iter__(self) -> Iterator[Tuple[Point, R]]: return iter(self.items()) - def __len__(self) -> int: - return len(self._shape) - def __repr__(self) -> str: return f"ValueMesh({self._shape})" diff --git a/python/monarch/_src/actor/shape.py b/python/monarch/_src/actor/shape.py index 01762301..45eb86c9 100644 --- a/python/monarch/_src/actor/shape.py +++ b/python/monarch/_src/actor/shape.py @@ -224,5 +224,8 @@ def size(self, dim: Union[None, str, Sequence[str]] = None) -> int: def sizes(self) -> dict[str, int]: return dict(zip(self._labels, self._ndslice.sizes)) + def __len__(self) -> int: + return len(self._ndslice) + __all__ = ["NDSlice", "Shape", "MeshTrait"] diff --git a/python/tests/test_mesh_trait.py b/python/tests/test_mesh_trait.py new file mode 100644 index 00000000..74ed9cde --- /dev/null +++ b/python/tests/test_mesh_trait.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable + +from monarch._src.actor.shape import MeshTrait, NDSlice, Shape, Slice + + +class Mesh(MeshTrait): + """ + A simple implementor of MeshTrait. + """ + + def __init__(self, shape: Shape, values: list[int]) -> None: + self._shape = shape + self._values = values + + def _new_with_shape(self, shape: Shape) -> "Mesh": + return Mesh(shape, self._values) + + @property + def _ndslice(self) -> NDSlice: + return self._shape.ndslice + + @property + def _labels(self) -> Iterable[str]: + return self._shape.labels + + +def test_len() -> None: + s = Slice(offset=0, sizes=[2, 3], strides=[3, 1]) + shape = Shape(["label0", "label1"], s) + + mesh = Mesh(shape, [1, 2, 3, 4, 5, 6]) + assert 6 == len(mesh) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index cac04792..75120912 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -612,3 +612,8 @@ async def consume(): def test_python_task_tuple() -> None: PythonTask.from_coroutine(consume()).block_on() + +def test_mesh_len(): + proc_mesh = local_proc_mesh(gpus=12).get() + s = proc_mesh.spawn("sync_actor", SyncActor).get() + assert 12 == len(s)