Skip to content

Commit 08d1cbf

Browse files
committed
[monarch] implement __len__ for MeshTrait
Pull Request resolved: #729 straightforward enough ghstack-source-id: 300256522 @exported-using-ghexport Differential Revision: [D79483634](https://our.internmc.facebook.com/intern/diff/D79483634/)
1 parent 59caae0 commit 08d1cbf

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,6 @@ def items(self) -> Iterable[Tuple[Point, R]]:
444444
def __iter__(self) -> Iterator[Tuple[Point, R]]:
445445
return iter(self.items())
446446

447-
def __len__(self) -> int:
448-
return len(self._shape)
449-
450447
def __repr__(self) -> str:
451448
return f"ValueMesh({self._shape})"
452449

python/monarch/_src/actor/shape.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,5 +224,8 @@ def size(self, dim: Union[None, str, Sequence[str]] = None) -> int:
224224
def sizes(self) -> dict[str, int]:
225225
return dict(zip(self._labels, self._ndslice.sizes))
226226

227+
def __len__(self) -> int:
228+
return len(self._ndslice)
229+
227230

228231
__all__ = ["NDSlice", "Shape", "MeshTrait"]

python/tests/test_mesh_trait.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Iterable
8+
9+
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape, Slice
10+
11+
12+
class Mesh(MeshTrait):
13+
"""
14+
A simple implementor of MeshTrait.
15+
"""
16+
17+
def __init__(self, shape: Shape, values: list[int]) -> None:
18+
self._shape = shape
19+
self._values = values
20+
21+
def _new_with_shape(self, shape: Shape) -> "Mesh":
22+
return Mesh(shape, self._values)
23+
24+
@property
25+
def _ndslice(self) -> NDSlice:
26+
return self._shape.ndslice
27+
28+
@property
29+
def _labels(self) -> Iterable[str]:
30+
return self._shape.labels
31+
32+
33+
def test_len() -> None:
34+
s = Slice(offset=0, sizes=[2, 3], strides=[3, 1])
35+
shape = Shape(["label0", "label1"], s)
36+
37+
mesh = Mesh(shape, [1, 2, 3, 4, 5, 6])
38+
assert 6 == len(mesh)

python/tests/test_python_actors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,9 @@ def test_ported_actor():
598598
proc_mesh = local_proc_mesh(gpus=1).get()
599599
a = proc_mesh.spawn("port_actor", PortedActor).get()
600600
assert 5 == a.add.call_one(2).get()
601+
602+
603+
def test_mesh_len():
604+
proc_mesh = local_proc_mesh(gpus=12).get()
605+
s = proc_mesh.spawn("sync_actor", SyncActor).get()
606+
assert 12 == len(s)

0 commit comments

Comments
 (0)