Skip to content

Commit eb5e7d7

Browse files
authored
#fix qwen2 abnormal loss caused by SoftmaxCrossEntropyWithLogits on 910A/B (#2034)
1 parent 6ed1cf3 commit eb5e7d7

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

mindnlp/peft/peft_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
125125
# if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
126126
# self.base_model.config.pretraining_tp = 1
127127

128-
def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
128+
def save_pretrained(self, save_directory, safe_serialization=True, **kwargs):
129129
r"""
130130
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
131131
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]

mindnlp/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,22 @@ def forward(
826826
# Shift so that tokens < n predict n
827827
shift_logits = logits[..., :-1, :]
828828
shift_labels = labels[..., 1:]
829-
# Flatten the tokens
830-
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
831-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
832-
shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
833-
# Enable model parallelism
834-
loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
835-
loss = loss.mean()
829+
if ON_ORANGE_PI:
830+
# Flatten the tokens
831+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
832+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
833+
shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
834+
# Enable model parallelism
835+
loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
836+
loss = loss.mean()
837+
else:
838+
# Flatten the tokens
839+
loss_fct = CrossEntropyLoss()
840+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
841+
shift_labels = shift_labels.view(-1)
842+
# Enable model parallelism
843+
loss = loss_fct(shift_logits, shift_labels)
844+
836845

837846
if not return_dict:
838847
output = (logits,) + outputs[1:]
@@ -1004,10 +1013,14 @@ def forward(
10041013
else:
10051014
loss = loss_fct(pooled_logits, labels)
10061015
elif self.config.problem_type == "single_label_classification":
1007-
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
1008-
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
1009-
loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
1010-
loss = loss.mean()
1016+
if ON_ORANGE_PI:
1017+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
1018+
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
1019+
loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
1020+
loss = loss.mean()
1021+
else:
1022+
loss_fct = CrossEntropyLoss()
1023+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
10111024
elif self.config.problem_type == "multi_label_classification":
10121025
loss_fct = BCEWithLogitsLoss()
10131026
loss = loss_fct(pooled_logits, labels)
@@ -1086,10 +1099,14 @@ def forward(
10861099

10871100
loss = None
10881101
if labels is not None:
1089-
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
1090-
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
1091-
loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
1092-
loss = loss.mean()
1102+
if ON_ORANGE_PI:
1103+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
1104+
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
1105+
loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
1106+
loss = loss.mean()
1107+
else:
1108+
loss_fct = CrossEntropyLoss()
1109+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
10931110

10941111
if not return_dict:
10951112
output = (logits,) + outputs[2:]

0 commit comments

Comments
 (0)