@@ -10,11 +10,9 @@ Triton CrossEntropy Executor
1010
1111The Triton CrossEntropy executor can execute ``torch.cross_entropy() `` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example::
1212
13+ import torch
1314 import thunder
14- from thunder.executors import nvfuserex, torchex
15- from thunder.executors.triton_crossentropy import deregister_triton_entropyex, register_triton_entropyex
16-
17- register_triton_entropyex(add_to_default_executors=False)
15+ from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex
1816
1917 def xentropy(logits, labels, weight, reduction, ignore_index):
2018 return thunder.torch.cross_entropy(
@@ -23,7 +21,7 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
2321
2422 jitted_xentropy = thunder.jit(
2523 xentropy,
26- executors_list=['triton_crossentropy', nvfuserex, torchex ]
24+ executors=[triton_cross_entropy_ex, ]
2725 )
2826
2927 device = 'cuda'
@@ -41,43 +39,42 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
4139
4240This prints::
4341
44- # Constructed by Delete Last Used
42+ # Constructed by Delete Last Used (took 0 milliseconds)
4543 import torch
44+ from thunder.executors.torchex import no_autocast
45+
4646 @torch.no_grad()
47- def xentropy(logits, labels, weight, reduction, ignore_index):
47+ @no_autocast()
48+ def computation(logits, labels, weight):
4849 # logits: "cuda:0 f32[2048, 50257]"
4950 # labels: "cuda:0 i64[2048]"
5051 # weight: "cuda:0 f32[50257]"
51- # "sum"
52- # ignore_index: "int 10106"
53- t22 = triton_cross_entropy(logits, labels, weight, None, ignore_index, None, "sum", 0.0) # t22: "cuda:0 f32[]"
54- del [logits, labels, weight, ignore_index]
55- return t22
52+ t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]"
53+ del logits, labels, weight
54+ return t23
5655
57- As shown in the above trace, ``triton_cross_entropy () `` is the one running the operation.
56+ As shown in the above trace, ``triton_crossentropy () `` is the one running the operation.
5857
5958Apex CrossEntropy Executor
6059==========================
6160
6261The Apex CrossEntropy executor can execute ``torch.cross_entropy() `` through an optimized kernel, like this::
6362
63+ import torch
6464 import thunder
65- from thunder.executors import nvfuserex, torchex
66- from thunder.executors.apex_entropyex import deregister_apex_entropyex, register_apex_entropyex
67-
68- register_apex_entropyex(add_to_default_executors=False)
65+ from thunder.executors.apex_entropyex import apex_ex
6966
7067 def xentropy(logits, labels):
7168 return thunder.torch.cross_entropy(
7269 logits, labels, reduction='mean', ignore_index=-1
7370 )
7471
75- jitted_xentropy = thunder.jit(xentropy, executors_list=['apex_xentropy', nvfuserex, torchex ])
72+ jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex, ])
7673
7774 device = 'cuda'
7875 dtype = torch.float32
7976
80- logits = torch.randn([2048, 50257], device=device, dtype=thunder.torch.to_torch_dtype( dtype) )
77+ logits = torch.randn([2048, 50257], device=device, dtype=dtype)
8178 labels = torch.randint(0, 50257, [2048], device=device)
8279
8380 jitted_xentropy(logits, labels)
@@ -86,14 +83,17 @@ The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an
8683
8784This prints::
8885
89- # Constructed by Delete Last Used
86+ # Constructed by Delete Last Used (took 0 milliseconds)
9087 import torch
88+ from thunder.executors.torchex import no_autocast
89+
9190 @torch.no_grad()
92- def xentropy(logits, labels):
91+ @no_autocast()
92+ def computation(logits, labels):
9393 # logits: "cuda:0 f32[2048, 50257]"
9494 # labels: "cuda:0 i64[2048]"
95- t18 = apex_cross_entropy(logits, labels, None, None, -1, None, " mean" , 0.0) # t18: "cuda:0 f32[]"
96- del [ logits, labels]
95+ ( t18, _) = apex_cross_entropy(logits, labels, ' mean' , 0.0)
96+ del logits, labels
9797 return t18
9898
9999showing that Apex is running the operation.
0 commit comments