Skip to content

Commit fed64ce

Browse files
committed
[example] fused_linear_jsd
1 parent 01c831e commit fed64ce

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

benchmarks/run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@
109109
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
110110
],
111111
),
112+
"fused_linear_jsd": (
113+
"tritonbench.operators.fused_linear_jsd.operator",
114+
"examples.fused_linear_jsd",
115+
"fused_linear_jsd_fwd",
116+
),
112117
}
113118

114119

@@ -407,7 +412,9 @@ def helion_method(
407412

408413
def _inner() -> Callable[..., Any] | object:
409414
# BENCHMARK HOT PATH, do not add any new logic here
415+
kfunc._self = self
410416
result = kfunc(*args, **kwargs)
417+
result = kfunc(*args)
411418
if callable(result):
412419
return result()
413420
return result

examples/fused_linear_jsd.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import helion
3+
import helion.language as hl
4+
5+
@helion.kernel()
6+
def _kernel(beta, ignore_index, temperature, student_weight, teacher_weight, student_input, teacher_input):
7+
student_logits = student_input @ student_weight.T
8+
teacher_logits = teacher_input @ teacher_weight.T
9+
loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
10+
for batch in hl.tile(student_logits.shape[0]):
11+
student_prob = torch.log_softmax(student_logits[batch,:] / temperature, dim=-1)
12+
teacher_prob = torch.log_softmax(teacher_logits[batch,:] / temperature, dim=-1)
13+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
14+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
15+
beta_ = beta
16+
m = torch.exp(student_prob) + beta_ * (torch.exp(teacher_prob) - torch.exp(student_prob))
17+
teacher_div = torch.nn.functional.kl_div(torch.log(m), teacher_prob, reduction="none", log_target=True).sum(dim=-1)
18+
student_div = torch.nn.functional.kl_div(torch.log(m), student_prob, reduction="none", log_target=True).sum(dim=-1)
19+
batch_loss = student_div + beta * (teacher_div - student_div)
20+
loss[batch] = batch_loss
21+
return (loss / student_logits.shape[0]).sum()
22+
23+
24+
def fused_linear_jsd_fwd(student_input, teacher_input, label=None):
25+
assert label is None
26+
baseline_op = fused_linear_jsd_fwd._self.baseline_op
27+
beta = baseline_op.jsd.beta
28+
ignore_index = baseline_op.jsd.ignore_index
29+
temperature = baseline_op.temperature
30+
student_weight = baseline_op.student_lin.weight
31+
teacher_weight = baseline_op.teacher_lin.weight
32+
return _kernel(beta, ignore_index, temperature, student_weight, teacher_weight, student_input, teacher_input)

helion/autotuner/base_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"misaligned address", # CUDA Error
5353
"PassManager::run failed", # Triton Error
5454
"illegal memory access", # CUDA Error
55+
"exceeds triton maximum tensor numel", # Triton Error
5556
],
5657
)
5758
)
@@ -147,7 +148,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
147148
except PTXASError:
148149
self.log.warning(f"PTXASError compiling config: {config}")
149150
except Exception as e:
150-
if not _expected_errors_regexp.search(str(e)):
151+
if not _expected_errors_regexp.search(str(e) + str(e.__cause__)):
151152
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
152153
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
153154
return inf

0 commit comments

Comments
 (0)