|
| 1 | +import torch |
| 2 | +import helion |
| 3 | +import helion.language as hl |
| 4 | + |
| 5 | +@helion.kernel() |
| 6 | +def _kernel(beta, ignore_index, temperature, student_weight, teacher_weight, student_input, teacher_input): |
| 7 | + student_logits = student_input @ student_weight.T |
| 8 | + teacher_logits = teacher_input @ teacher_weight.T |
| 9 | + loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float) |
| 10 | + for batch in hl.tile(student_logits.shape[0]): |
| 11 | + student_prob = torch.log_softmax(student_logits[batch,:] / temperature, dim=-1) |
| 12 | + teacher_prob = torch.log_softmax(teacher_logits[batch,:] / temperature, dim=-1) |
| 13 | + student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1)) |
| 14 | + teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1)) |
| 15 | + beta_ = beta |
| 16 | + m = torch.exp(student_prob) + beta_ * (torch.exp(teacher_prob) - torch.exp(student_prob)) |
| 17 | + teacher_div = torch.nn.functional.kl_div(torch.log(m), teacher_prob, reduction="none", log_target=True).sum(dim=-1) |
| 18 | + student_div = torch.nn.functional.kl_div(torch.log(m), student_prob, reduction="none", log_target=True).sum(dim=-1) |
| 19 | + batch_loss = student_div + beta * (teacher_div - student_div) |
| 20 | + loss[batch] = batch_loss |
| 21 | + return (loss / student_logits.shape[0]).sum() |
| 22 | + |
| 23 | + |
| 24 | +def fused_linear_jsd_fwd(student_input, teacher_input, label=None): |
| 25 | + assert label is None |
| 26 | + baseline_op = fused_linear_jsd_fwd._self.baseline_op |
| 27 | + beta = baseline_op.jsd.beta |
| 28 | + ignore_index = baseline_op.jsd.ignore_index |
| 29 | + temperature = baseline_op.temperature |
| 30 | + student_weight = baseline_op.student_lin.weight |
| 31 | + teacher_weight = baseline_op.teacher_lin.weight |
| 32 | + return _kernel(beta, ignore_index, temperature, student_weight, teacher_weight, student_input, teacher_input) |
0 commit comments