@@ -140,15 +140,15 @@ def __init__(
140140
141141 if self ._distributed :
142142 logger .info ("Setting up distributed samplers." )
143- self .train_sampler_ = DistributedSampler (self . train_dataset_ )
144- self .test_sampler_ = DistributedSampler (self . test_dataset_ )
143+ self .train_sampler_ = DistributedSampler (train_dataset )
144+ self .test_sampler_ = DistributedSampler (test_dataset )
145145 else :
146146 self .train_sampler_ = None
147147 self .test_sampler_ = None
148148
149149 self .train_loader_ = DataLoader (
150150 train_dataset ,
151- shuffle = True ,
151+ shuffle = ( not self . _distributed ) ,
152152 collate_fn = data_collator ,
153153 batch_size = self ._batch_size ,
154154 sampler = self .train_sampler_ ,
@@ -157,7 +157,7 @@ def __init__(
157157 test_dataset ,
158158 collate_fn = data_collator ,
159159 batch_size = self ._batch_size ,
160- sampler = self .train_sampler_ ,
160+ sampler = self .test_sampler_ ,
161161 )
162162
163163 logger .info (f"Train loader steps: { len (self .train_loader_ )} " )
@@ -179,19 +179,19 @@ def __init__(
179179 trainable_params += p .numel ()
180180 logger .info (f"Trainable parameters: { trainable_params } " )
181181
182- self .model_ .train ( )
182+ self .model_ .to ( self . device_ )
183183 if self ._distributed :
184184 self .model_ = DDP (
185185 self .model_ ,
186186 device_ids = [self ._rank ] if self ._rank is not None else None ,
187187 output_device = self ._rank ,
188188 )
189- self .model_ .to (self .device_ )
190189 self .metric_ = evaluate .load ("seqeval" )
191190
192191 # DP
193192 logger .info (f"DP: { dp } " )
194193 if dp :
194+ self .model_ .train ()
195195 if not ModuleValidator .is_valid (self .model_ ):
196196 self .model_ = ModuleValidator .fix (self .model_ )
197197
@@ -625,13 +625,13 @@ def run(args):
625625 logger .info (f"Distributed process rank: { os .environ ['RANK' ]} " )
626626 logger .info (f"Distributed world size: { os .environ ['WORLD_SIZE' ]} " )
627627
628- if int (os .environ [ "WORLD_SIZE" ] ) > 1 and torch .cuda .is_available ():
628+ if int (os .environ . get ( "WORLD_SIZE" , "1" ) ) > 1 and torch .cuda .is_available ():
629629 dist .init_process_group (
630630 "nccl" ,
631631 rank = int (os .environ ["RANK" ]),
632- world_size = int (os .environ [ "WORLD_SIZE" ] ),
632+ world_size = int (os .environ . get ( "WORLD_SIZE" , "1" ) ),
633633 )
634- elif int (os .environ [ "WORLD_SIZE" ] ) > 1 :
634+ elif int (os .environ . get ( "WORLD_SIZE" , "1" ) ) > 1 :
635635 dist .init_process_group ("gloo" )
636636
637637 trainer = NERTrainer (
@@ -651,11 +651,12 @@ def run(args):
651651 iteration_num = args .iteration_num ,
652652 batch_size = args .batch_size ,
653653 device_id = int (os .environ ["RANK" ]),
654- distributed = int (os .environ ["WORLD_SIZE" ]) > 1 and torch .cuda .is_available (),
654+ distributed = int (os .environ .get ("WORLD_SIZE" , "1" )) > 1
655+ and torch .cuda .is_available (),
655656 )
656657 trainer .execute (args .checkpoint )
657658
658- if torch . cuda . is_available () or int (os .environ [ "WORLD_SIZE" ] ) > 1 :
659+ if int (os .environ . get ( "WORLD_SIZE" , "1" ) ) > 1 :
659660 dist .destroy_process_group ()
660661
661662
0 commit comments