@@ -826,13 +826,22 @@ def forward(
826
826
# Shift so that tokens < n predict n
827
827
shift_logits = logits [..., :- 1 , :]
828
828
shift_labels = labels [..., 1 :]
829
- # Flatten the tokens
830
- loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
831
- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
832
- shift_labels = nn .functional .one_hot (shift_labels .view (- 1 ), self .config .vocab_size )
833
- # Enable model parallelism
834
- loss , _ = loss_fct (shift_logits , shift_labels .to (shift_logits .dtype ))
835
- loss = loss .mean ()
829
+ if ON_ORANGE_PI :
830
+ # Flatten the tokens
831
+ loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
832
+ shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
833
+ shift_labels = nn .functional .one_hot (shift_labels .view (- 1 ), self .config .vocab_size )
834
+ # Enable model parallelism
835
+ loss , _ = loss_fct (shift_logits , shift_labels .to (shift_logits .dtype ))
836
+ loss = loss .mean ()
837
+ else :
838
+ # Flatten the tokens
839
+ loss_fct = CrossEntropyLoss ()
840
+ shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
841
+ shift_labels = shift_labels .view (- 1 )
842
+ # Enable model parallelism
843
+ loss = loss_fct (shift_logits , shift_labels )
844
+
836
845
837
846
if not return_dict :
838
847
output = (logits ,) + outputs [1 :]
@@ -1004,10 +1013,14 @@ def forward(
1004
1013
else :
1005
1014
loss = loss_fct (pooled_logits , labels )
1006
1015
elif self .config .problem_type == "single_label_classification" :
1007
- loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
1008
- labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
1009
- loss , _ = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .to (pooled_logits .dtype ))
1010
- loss = loss .mean ()
1016
+ if ON_ORANGE_PI :
1017
+ loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
1018
+ labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
1019
+ loss , _ = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .to (pooled_logits .dtype ))
1020
+ loss = loss .mean ()
1021
+ else :
1022
+ loss_fct = CrossEntropyLoss ()
1023
+ loss = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
1011
1024
elif self .config .problem_type == "multi_label_classification" :
1012
1025
loss_fct = BCEWithLogitsLoss ()
1013
1026
loss = loss_fct (pooled_logits , labels )
@@ -1086,10 +1099,14 @@ def forward(
1086
1099
1087
1100
loss = None
1088
1101
if labels is not None :
1089
- loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
1090
- labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
1091
- loss , _ = loss_fct (logits .view (- 1 , self .num_labels ), labels .to (logits .dtype ))
1092
- loss = loss .mean ()
1102
+ if ON_ORANGE_PI :
1103
+ loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
1104
+ labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
1105
+ loss , _ = loss_fct (logits .view (- 1 , self .num_labels ), labels .to (logits .dtype ))
1106
+ loss = loss .mean ()
1107
+ else :
1108
+ loss_fct = CrossEntropyLoss ()
1109
+ loss = loss_fct (logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
1093
1110
1094
1111
if not return_dict :
1095
1112
output = (logits ,) + outputs [2 :]
0 commit comments