Skip to content

Commit 84723e8

Browse files
[feat][merge] Support one-behind to reduce bubble time. Add profiling code. (#6355)
* [feat][merge] Support one-behind to reduce bubble time. Add profiling code. * [feat] Update sync model by tensor, fix tMbs problem, add qwen train benchmark. * [feat] Update consumer init to run 32B , update qwen benchmark.
1 parent 9379a89 commit 84723e8

29 files changed

+1937
-300
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,9 @@ applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169169
applications/ColossalChat/rollouts
170+
applications/ColossalChat/*.txt
171+
applications/ColossalChat/*.db
172+
applications/ColossalChat/stdin
173+
applications/ColossalChat/*.zip
174+
applications/ColossalChat/*.prof
175+
applications/ColossalChat/*.png

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
367367
}
368368

369369
# Format for RL.
370-
gt_answer = None
371-
if "messages" in chat and "gt_answer" in chat:
372-
gt_answer = chat["gt_answer"]
370+
if "messages" in chat:
371+
gt_answer = chat.get("gt_answer", None)
372+
test_cases = chat.get("test_cases", None)
373373
chat = [chat["messages"]]
374374

375375
tokens = []
@@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
402402
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
403403

404404
if gt_answer is not None:
405-
gt_answer = tokenizer.encode(
406-
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
407-
)
408-
gt_answer = gt_answer.squeeze(1)
409405
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
410-
406+
elif test_cases is not None:
407+
return {
408+
"input_ids": input_ids,
409+
"attention_mask": attention_mask,
410+
"labels": labels,
411+
"test_cases": test_cases,
412+
}
411413
return {
412414
"input_ids": input_ids,
413415
"attention_mask": attention_mask,
@@ -440,3 +442,20 @@ def __getitem__(self, index: int):
440442
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
441443
self.tokenized_texts[index] = dict(tokens)
442444
return self.tokenized_texts[index]
445+
446+
447+
def collate_fn_grpo(batch):
448+
input_ids = [item["input_ids"] for item in batch]
449+
attention_mask = [item["attention_mask"] for item in batch]
450+
labels = [item["labels"] for item in batch]
451+
# Assume input_ids, attention_mask, labels are already of the same length,
452+
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453+
input_ids = torch.stack(input_ids)
454+
attention_mask = torch.stack(attention_mask)
455+
labels = torch.stack(labels)
456+
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
457+
if "test_cases" in batch[0]:
458+
ret["test_cases"] = [item["test_cases"] for item in batch]
459+
if "gt_answer" in batch[0]:
460+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
461+
return ret

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,32 @@ def ray_broadcast_tensor_dict(
5555
if rank == src:
5656
out_dict = tensor_dict
5757
return out_dict
58+
59+
60+
def ray_broadcast_tensor_dict_and_load(
61+
producer_obj, tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
62+
):
63+
rank = cc.get_rank(group_name)
64+
if rank == src:
65+
metadata = []
66+
for k, v in tensor_dict.items():
67+
metadata.append((k, v.shape, v.dtype))
68+
else:
69+
metadata = None
70+
metadata = ray_broadcast_object(metadata, src, device, group_name)
71+
for k, shape, dtype in metadata:
72+
if "consumer_global_step" == k:
73+
continue
74+
if rank == src:
75+
tensor = tensor_dict[k]
76+
else:
77+
out_dict = {}
78+
tensor = torch.empty(shape, dtype=dtype, device=device)
79+
cc.broadcast(tensor, src, group_name)
80+
if rank != src:
81+
out_dict[k] = tensor
82+
producer_obj.load_state_dict(out_dict)
83+
del out_dict
84+
torch.npu.empty_cache()
85+
if rank == src:
86+
out_dict = tensor_dict

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 129 additions & 54 deletions
Large diffs are not rendered by default.

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 24 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from contextlib import nullcontext
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Optional
33

44
import ray
55
import torch
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
10-
from coati.distributed.reward.verifiable_reward import VerifiableReward
11-
from coati.distributed.utils import calc_action_log_probs
9+
from coati.distributed.utils import memory_efficient_logprob
1210
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1311
from transformers import AutoModelForCausalLM, AutoTokenizer
1412

@@ -40,6 +38,8 @@ def __init__(
4038
project_name: str = None,
4139
run_name: str = None,
4240
wandb_group_name: str = None,
41+
enable_profiling: bool = False,
42+
n_behind: int = 0,
4343
):
4444
print(f"Using GRPO config: {grpo_config}")
4545
if (
@@ -62,12 +62,15 @@ def __init__(
6262
batch_size,
6363
model_config,
6464
plugin_config,
65+
generate_config,
6566
minibatch_size,
6667
save_interval=save_interval,
6768
save_dir=save_dir,
69+
enable_profiling=enable_profiling,
70+
n_behind=n_behind,
6871
)
69-
path = model_config.pop("path")
70-
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
72+
self.path = model_config.pop("path")
73+
self.policy_model = AutoModelForCausalLM.from_pretrained(self.path, **model_config)
7174
self.policy_model.train()
7275
self.policy_model.gradient_checkpointing_enable()
7376
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
@@ -95,12 +98,7 @@ def __init__(
9598
loss_variation=grpo_config.get("loss_variation", "sample_level"),
9699
)
97100

98-
# Reference model is initialized from policy model.
99-
if self.policy_loss_fn.beta > 0:
100-
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
101-
self.reference_model.eval()
102-
103-
self.tokenizer = AutoTokenizer.from_pretrained(path)
101+
self.tokenizer = AutoTokenizer.from_pretrained(self.path)
104102
self.pad_token_id = self.tokenizer.pad_token_id
105103
self.num_generations = num_generations
106104
self.filter_range = grpo_config.get("filter_range", None)
@@ -119,20 +117,7 @@ def __init__(
119117
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120118
)
121119
# Initialize verifiable reward.
122-
response_format_tags = grpo_config.get("response_format_tags", None)
123-
reward_model_kwargs = {
124-
k: v
125-
for k, v in grpo_config.items()
126-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
127-
}
128-
self.reward_model = VerifiableReward(
129-
reward_fns=[
130-
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
131-
],
132-
tokenizer=self.tokenizer,
133-
tags=response_format_tags,
134-
**reward_model_kwargs,
135-
)
120+
grpo_config.get("response_format_tags", None)
136121
self.global_step = 0
137122

138123
self.lr_scheduler = CosineAnnealingWarmupLR(
@@ -158,7 +143,10 @@ def setup(self):
158143
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
159144
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
160145
)
146+
# Reference model is initialized from policy model.
161147
if self.policy_loss_fn.beta > 0:
148+
self.reference_model = AutoModelForCausalLM.from_pretrained(self.path, **self.model_config)
149+
self.reference_model.eval()
162150
self.reference_model, *_ = self.booster.boost(self.reference_model)
163151
self.plugin.logger.set_level("ERROR")
164152

@@ -295,12 +283,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295283
)
296284

297285
if self.booster.plugin.stage_manager.is_last_stage():
298-
reference_model_logits = reference_model_outputs["outputs"]["logits"]
299-
reference_action_log_probs = calc_action_log_probs(
300-
reference_model_logits / self.generate_config["temperature"],
286+
reference_action_log_probs = memory_efficient_logprob(
287+
reference_model_outputs["outputs"]["logits"],
301288
input_ids_forward_micro_batch,
302289
num_action,
303-
self.plugin.shard_config,
290+
shard_config=self.plugin.shard_config,
304291
)
305292
else:
306293
# Dummy reference logprobs for data iterator.
@@ -323,11 +310,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323310

324311
def _criterion(outputs, inputs):
325312
action_logits = outputs.logits
326-
action_log_probs = calc_action_log_probs(
327-
action_logits / self.generate_config["temperature"],
313+
action_log_probs = memory_efficient_logprob(
314+
action_logits,
328315
inputs["input_ids"],
329316
num_action,
330-
self.plugin.shard_config,
317+
shard_config=self.plugin.shard_config,
331318
)
332319
if "reference_action_log_probs" in inputs:
333320
per_token_kl = (
@@ -370,16 +357,15 @@ def _criterion(outputs, inputs):
370357
mean_kl.append(kl)
371358
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
372359
else:
373-
374360
policy_model_logits = self.policy_model(
375361
input_ids=input_ids_forward_micro_batch,
376362
attention_mask=attention_mask_forward_micro_batch,
377363
).logits
378-
action_log_probs = calc_action_log_probs(
364+
action_log_probs = memory_efficient_logprob(
379365
policy_model_logits / self.generate_config["temperature"],
380366
input_ids_forward_micro_batch,
381367
num_action,
382-
self.plugin.shard_config,
368+
shard_config=self.plugin.shard_config,
383369
)
384370

385371
if self.policy_loss_fn.beta > 0:
@@ -388,11 +374,11 @@ def _criterion(outputs, inputs):
388374
input_ids=input_ids_forward_micro_batch,
389375
attention_mask=attention_mask_forward_micro_batch,
390376
).logits
391-
reference_action_log_probs = calc_action_log_probs(
377+
reference_action_log_probs = memory_efficient_logprob(
392378
reference_model_logits / self.generate_config["temperature"],
393379
input_ids_forward_micro_batch,
394380
num_action,
395-
self.plugin.shard_config,
381+
shard_config=self.plugin.shard_config,
396382
)
397383
per_token_kl = (
398384
torch.exp(reference_action_log_probs - action_log_probs)
@@ -498,40 +484,6 @@ def _criterion(outputs, inputs):
498484
else:
499485
return None
500486

501-
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
502-
"""
503-
Calculate the group reward for the given rollout group.
504-
505-
Args:
506-
rollout_group (Dict[str, Any]):
507-
a group of samples generated by the model from the same prompt
508-
contain the following keys:
509-
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510-
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511-
"action_mask": torch.Tensor, [num_of_generation, response_length]
512-
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
513-
"response_idx": int, torch.Tensor, [num_of_generation, 2]
514-
"gt_answer": torch.Tensor, [num_of_generation, 128]
515-
"temperature": torch.Tensor, [] (scalar)
516-
517-
Returns:
518-
Dict[str, Any]: The new group data with calculated reward.
519-
"""
520-
reward_model_output = self.reward_model(
521-
rollout["input_ids"],
522-
gt_answer=rollout["gt_answer"],
523-
response_idx=rollout["response_idx"],
524-
)
525-
# [num_of_generation]
526-
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
527-
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
528-
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
529-
530-
rollout["reward"] = reward.view((-1, 1))
531-
rollout["format_acc"] = format_acc.view((-1, 1))
532-
rollout["ans_acc"] = ans_acc.view((-1, 1))
533-
return rollout
534-
535487
def state_dict(self):
536488
self.policy_model._force_wait_all_gather()
537489
model = self.policy_model.unwrap()

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
7474
micro_batch_size = input_ids.size(0)
7575
input_ids = input_ids.to(get_current_device())
7676
attention_mask = attention_mask.to(get_current_device())
77-
gt_answer = None
78-
if "gt_answer" in kwargs:
79-
gt_answer = kwargs.pop("gt_answer")
77+
gt_answer = kwargs.pop("gt_answer", None)
78+
test_cases = kwargs.pop("test_cases", None)
8079
if self.num_generations > 1:
8180
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
8281
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
@@ -116,8 +115,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
116115
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
117116

118117
if gt_answer is not None:
119-
# repeat gt_answer for each prompt.
120-
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
118+
data["gt_answer"] = gt_answer
119+
if test_cases is not None:
120+
data["test_cases"] = test_cases
121121
data = {k: v.to(get_current_device()) for k, v in data.items()}
122122
return data
123123

@@ -270,11 +270,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
270270
}
271271

272272
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
273-
274-
if "gt_answer" in kwargs:
275-
# repeat gt_answer for each prompt.
276-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
277273
data = {k: v.to(get_current_device()) for k, v in data.items()}
274+
if "gt_answer" in kwargs:
275+
data["gt_answer"] = kwargs["gt_answer"]
276+
if "test_cases" in kwargs:
277+
data["test_cases"] = kwargs["test_cases"]
278278
return data
279279

280280
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

0 commit comments

Comments
 (0)