@@ -122,26 +122,102 @@ def loop(self) -> None:
122122 # receive data from producers
123123 for r in range (self .num_producers ):
124124 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
125- self .buffer .extend (
126- unbind_batch (
127- ray_broadcast_tensor_dict (
128- None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
129- )
130- )
125+ raw_batch = ray_broadcast_tensor_dict (
126+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
131127 )
132- while len (self .buffer ) >= self .dp_size * self .minibatch_size :
133- batches = self .buffer [
134- self .dp_rank * self .minibatch_size : (self .dp_rank + 1 ) * self .minibatch_size
128+ # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
129+ # we need to calculate the metrics before filtering here for logging
130+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
131+ raw_batch_with_reward = self .calculate_reward (
132+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
133+ )
134+ raw_batch_with_reward = {
135+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
136+ for k , v in raw_batch_with_reward .items ()
137+ }
138+ # [batch_size, num_generations] -> [batch_size]
139+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
140+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
141+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
142+ response_len = (
143+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
144+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
145+ + 1
146+ ).type (torch .float32 )
147+ effective_group_mask = None
148+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
149+ # filter the group based on the reward and accuracy
150+ group_ans_acc_mean = ans_acc .mean (dim = 1 )
151+ effective_group_mask = torch .logical_and (
152+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
153+ )
154+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
155+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
156+ self .buffer .append (
157+ [
158+ (
159+ group_with_reward
160+ if effective_group_mask is None or effective_group_mask [group_idx ]
161+ else None
162+ ),
163+ reward [group_idx ],
164+ format_acc [group_idx ],
165+ ans_acc [group_idx ],
166+ response_len [group_idx ],
167+ ]
168+ )
169+ if effective_group_mask is not None :
170+ print (
171+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch_with_reward )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
172+ )
173+ # mapping the effective group to the raw group for indexing
174+ effective_group_to_raw_group_mapping = {}
175+ for buffer_idx in range (len (self .buffer )):
176+ if self .buffer [buffer_idx ][0 ] is not None :
177+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
178+ buffer_idx
179+ )
180+ print (
181+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
182+ )
183+
184+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
185+ # on each dp_rank, we use minibatch_size effective samples to form a batch
186+ batches = [
187+ self .buffer [effective_group_to_raw_group_mapping [i ]]
188+ for i in range (
189+ self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
190+ )
135191 ]
136- batch = bind_batch (batches )
192+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
193+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
194+ raw_mini_batches = self .buffer [
195+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
196+ ] # include the last effective sample
197+ raw_mini_batches_metric_dict = {
198+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
199+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
200+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
201+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
202+ }
203+ batch = bind_batch ([t [0 ] for t in batches ])
137204 batch = post_recv (batch )
138- loss , excessive_prompts_idx = self .step (i , pbar , ** batch )
139-
140- if excessive_prompts_idx is not None :
141- excessive_prompts = [self .buffer [idx ] for idx in excessive_prompts_idx ]
142- self .buffer = excessive_prompts + self .buffer [self .dp_size * self .minibatch_size :]
143- else :
144- self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
205+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
206+ self .buffer = self .buffer [
207+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
208+ ]
209+ # recalculate the effective group to raw group mapping
210+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
211+ effective_group_to_raw_group_mapping = {}
212+ for buffer_idx in range (len (self .buffer )):
213+ if self .buffer [buffer_idx ][0 ] is not None :
214+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
215+ buffer_idx
216+ )
217+ assert (
218+ len (effective_group_to_raw_group_mapping )
219+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
220+ )
145221 if loss is not None :
146222 pbar .set_postfix ({"loss" : loss })
147223 i += 1
0 commit comments