@@ -62,19 +62,19 @@ def __init__(
6262 raise ValueError (
6363 "ManagedDeviceMesh doesn't support both mesh and parent are None."
6464 )
65- self .mesh = mesh
66- self .mesh_dim_names = mesh_dim_names
65+ self ._mesh = mesh
66+ self ._mesh_dim_names = mesh_dim_names
6767 self .replicate_pg = replicate_pg
6868 self .replicate_dim = replicate_dim
6969 self .replicate_dim_name : str = mesh_dim_names [replicate_dim ]
7070 self .parent = parent
7171 self .flatten_meshes : Dict [str , DeviceMesh ] = {}
72- self .device_type : str
72+ self ._device_type : str
7373 if mesh is not None :
74- self .device_type = mesh .device_type
74+ self ._device_type = mesh .device_type
7575 else :
7676 assert parent is not None
77- self .device_type = parent .device_type
77+ self ._device_type = parent .device_type
7878 self ._flatten_mesh_list : tuple [DeviceMesh , ...] = tuple ()
7979 self ._thread_id : Optional [int ] = None
8080 self ._hash : Optional [int ] = None
@@ -102,20 +102,20 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
102102 elif mesh_dim_names in self .flatten_meshes :
103103 res_submesh = self .flatten_meshes [mesh_dim_names ]
104104 else :
105- assert self .mesh is not None
106- res_submesh = self .mesh [mesh_dim_names ]
105+ assert self ._mesh is not None
106+ res_submesh = self ._mesh [mesh_dim_names ]
107107 else :
108108 assert isinstance (mesh_dim_names , tuple )
109109 if self .replicate_dim_name not in mesh_dim_names :
110- assert self .mesh is not None
111- res_submesh = self .mesh [mesh_dim_names ]
110+ assert self ._mesh is not None
111+ res_submesh = self ._mesh [mesh_dim_names ]
112112 else :
113113 mesh_dim_names_wo_replicate = tuple (
114114 n for n in mesh_dim_names if n != self .replicate_dim_name
115115 )
116- assert self .mesh is not None
116+ assert self ._mesh is not None
117117 res_submesh = ManagedDeviceMesh (
118- self .mesh [mesh_dim_names_wo_replicate ],
118+ self ._mesh [mesh_dim_names_wo_replicate ],
119119 mesh_dim_names ,
120120 self .replicate_pg ,
121121 mesh_dim_names .index (self .replicate_dim_name ),
@@ -125,7 +125,7 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
125125 # TODO: find a better way to do this that doesn't depend on device mesh
126126 # internals
127127 root = _mesh_resources .get_root_mesh (self )
128- _mesh_resources . child_to_root_mapping [ res_submesh ] = root
128+ res_submesh . _root_mesh = root
129129
130130 return res_submesh
131131
@@ -134,7 +134,7 @@ def _real_mesh_dim(self, mesh_dim: int) -> int:
134134
135135 def get_group (self , mesh_dim : Optional [Union [int , str ]] = None ) -> BaseProcessGroup :
136136 if isinstance (mesh_dim , str ):
137- dim = self .mesh_dim_names .index (mesh_dim )
137+ dim = self ._mesh_dim_names .index (mesh_dim )
138138 else :
139139 dim = 0 if mesh_dim is None else int (mesh_dim )
140140
@@ -143,8 +143,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
143143 elif dim == self .replicate_dim :
144144 return self .replicate_pg
145145 else :
146- assert self .mesh is not None
147- return self .mesh .get_group (self ._real_mesh_dim (dim ))
146+ assert self ._mesh is not None
147+ return self ._mesh .get_group (self ._real_mesh_dim (dim ))
148148
149149 def _flatten (
150150 self ,
@@ -168,64 +168,64 @@ def size(self, mesh_dim: Optional[int] = None) -> int:
168168 # This is possible during the initialization stage of training.
169169 replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
170170 if mesh_dim is None :
171- if self .mesh is None :
171+ if self ._mesh is None :
172172 return replicate_pg_size
173173 else :
174- assert self .mesh is not None
175- return self .mesh .size () * replicate_pg_size
174+ assert self ._mesh is not None
175+ return self ._mesh .size () * replicate_pg_size
176176 elif mesh_dim == self .replicate_dim :
177177 return replicate_pg_size
178178 else :
179- assert self .mesh is not None
180- return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
179+ assert self ._mesh is not None
180+ return self ._mesh .size (self ._real_mesh_dim (mesh_dim ))
181181
182182 @property
183183 def ndim (self ) -> int :
184- assert self .mesh is not None
185- return self .mesh .ndim + 1
184+ assert self ._mesh is not None
185+ return self ._mesh .ndim + 1
186186
187187 @property
188188 def shape (self ) -> tuple [int , ...]:
189- assert self .mesh is not None
190- ret : list [int ] = list (self .mesh .shape )
189+ assert self ._mesh is not None
190+ ret : list [int ] = list (self ._mesh .shape )
191191 ret .insert (self .replicate_dim , self .replicate_pg .size ())
192192 return tuple (ret )
193193
194194 def get_rank (self ) -> int :
195- assert self .mesh is not None
196- return self .mesh .get_rank ()
195+ assert self ._mesh is not None
196+ return self ._mesh .get_rank ()
197197
198198 def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
199199 if isinstance (mesh_dim , str ):
200- dim = self .mesh_dim_names .index (mesh_dim )
200+ dim = self ._mesh_dim_names .index (mesh_dim )
201201 else :
202202 dim = 0 if mesh_dim is None else int (mesh_dim )
203203
204204 if mesh_dim is None :
205- if self .mesh is None :
205+ if self ._mesh is None :
206206 return get_rank (self .replicate_pg )
207207
208208 assert self .replicate_dim == 0 , "replicate_dim must be the first one"
209- assert self .mesh is not None
210- other_dim_size = self .mesh .size ()
211- assert self .mesh is not None
212- other_dim_rank = self .mesh .get_local_rank ()
209+ assert self ._mesh is not None
210+ other_dim_size = self ._mesh .size ()
211+ assert self ._mesh is not None
212+ other_dim_rank = self ._mesh .get_local_rank ()
213213 replicate_pg_rank = get_rank (self .replicate_pg )
214214 return other_dim_size * replicate_pg_rank + other_dim_rank
215215 elif dim == self .replicate_dim :
216216 return get_rank (self .replicate_pg )
217217 else :
218- assert self .mesh is not None
219- return self .mesh .get_local_rank (self ._real_mesh_dim (dim ))
218+ assert self ._mesh is not None
219+ return self ._mesh .get_local_rank (self ._real_mesh_dim (dim ))
220220
221221 def get_coordinate (self ) -> Optional [list [int ]]:
222222 """
223223 Return the relative indices of this rank relative to all
224224 dimensions of the mesh. If this rank is not part of the mesh, return None.
225225 """
226- assert self .mesh is not None
226+ assert self ._mesh is not None
227227 coordinate = (
228- self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
228+ self ._mesh ._coordinate_on_dim if self ._mesh ._coordinate_on_dim else None
229229 )
230230 if not coordinate :
231231 return coordinate
@@ -239,20 +239,20 @@ def get_all_groups(self) -> list[BaseProcessGroup]:
239239 raise NotImplementedError
240240
241241 def __repr__ (self ) -> str :
242- return f"ManagedDeviceMesh(mesh={ self .mesh } )"
242+ return f"ManagedDeviceMesh(mesh={ self ._mesh } )"
243243
244244 def __hash__ (self ) -> int :
245245 # lazily compute hash
246246 if not self ._hash :
247247 self ._hash = hash (
248248 (
249- self .mesh ,
250- self .mesh_dim_names ,
249+ self ._mesh ,
250+ self ._mesh_dim_names ,
251251 self .replicate_pg ,
252252 self .replicate_dim ,
253253 self .replicate_dim_name ,
254254 self .parent ,
255- self .device_type ,
255+ self ._device_type ,
256256 )
257257 )
258258 return self ._hash
0 commit comments