File tree Expand file tree Collapse file tree 3 files changed +21
-6
lines changed Expand file tree Collapse file tree 3 files changed +21
-6
lines changed Original file line number Diff line number Diff line change 2525
2626          sudo apt-get install -y protobuf-compiler 
2727
28+           pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 
2829          pip install .[dev] -v 
2930
3031          pip install -r docs/requirements.txt 
Original file line number Diff line number Diff line change 1010    init_device_mesh ,
1111    ProcessGroup  as  BaseProcessGroup ,
1212)
13+ from  torch .distributed ._mesh_layout  import  _MeshLayout 
1314from  torch .distributed .tensor .device_mesh  import  _mesh_resources 
1415
1516from  torchft .manager  import  Manager 
@@ -69,12 +70,20 @@ def __init__(
6970        self .replicate_dim_name : str  =  mesh_dim_names [replicate_dim ]
7071        self .parent  =  parent 
7172        self .flatten_meshes : Dict [str , DeviceMesh ] =  {}
73+         self ._flatten_mapping : Dict [str , "DeviceMesh" ] =  {}
7274        self ._device_type : str 
7375        if  mesh  is  not None :
7476            self ._device_type  =  mesh .device_type 
77+             mesh_tensor  =  (
78+                 mesh .detach ().to (dtype = torch .int ).contiguous ()
79+                 if  isinstance (mesh , torch .Tensor )
80+                 else  torch .tensor (mesh , device = "cpu" , dtype = torch .int )
81+             )
82+             self ._layout  =  _MeshLayout (mesh_tensor .size (), mesh_tensor .stride ())
7583        else :
7684            assert  parent  is  not None 
7785            self ._device_type  =  parent .device_type 
86+             self ._layout  =  parent ._layout 
7887        self ._flatten_mesh_list : tuple [DeviceMesh , ...] =  tuple ()
7988        self ._thread_id : Optional [int ] =  None 
8089        self ._hash : Optional [int ] =  None 
Original file line number Diff line number Diff line change @@ -1253,14 +1253,19 @@ def _assert_same_stream(self) -> None:
12531253    def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
12541254        self ._assert_same_stream ()
12551255
1256-         with  get_stream_context (self ._stream ):
1257-             self ._work .wait ()
1258-             self ._set_future_callback ()
1256+         try :
1257+             with  get_stream_context (self ._stream ):
1258+                 self ._work .wait ()
1259+                 self ._set_future_callback ()
12591260
1260-         with  get_stream_context (self ._stream ):
1261-             self ._managed_fut_tail .wait ()
1261+              with  get_stream_context (self ._stream ):
1262+                  self ._managed_fut_tail .wait ()
12621263
1263-         return  True 
1264+             return  True 
1265+         except  Exception  as  e :
1266+             self ._manager ._logger .exception (f"got exception waiting for work { e }  )
1267+             self ._manager .report_error (e )
1268+             return  False 
12641269
12651270    def  block_current_stream (self , timeout : Optional [timedelta ] =  None ) ->  None :
12661271        self ._assert_same_stream ()
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments