Skip to content

Commit 6549798

Browse files
committed
docs: update formatting and examples (PR2472)
1 parent 42f2309 commit 6549798

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

docs/source/fundamentals/installation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ Thunder can easily integrate OpenAI Triton kernels. You can install Triton using
5656
Install Thunder
5757
===============
5858

59-
You can now install Thunder
59+
You can now install Thunder::
6060

6161
pip install git+https://github.com/Lightning-AI/lightning-thunder.git
6262

63-
Alternatively you can clone the Thunder repository and install locally
63+
Alternatively you can clone the Thunder repository and install locally::
6464

6565
git clone https://github.com/Lightning-AI/lightning-thunder.git
6666
cd lightning-thunder

docs/source/intermediate/additional_executors.rst

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ Triton CrossEntropy Executor
1010

1111
The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example::
1212

13+
import torch
1314
import thunder
14-
from thunder.executors import nvfuserex, torchex
15-
from thunder.executors.triton_crossentropy import deregister_triton_entropyex, register_triton_entropyex
16-
17-
register_triton_entropyex(add_to_default_executors=False)
15+
from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex
1816

1917
def xentropy(logits, labels, weight, reduction, ignore_index):
2018
return thunder.torch.cross_entropy(
@@ -23,7 +21,7 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
2321

2422
jitted_xentropy = thunder.jit(
2523
xentropy,
26-
executors_list=['triton_crossentropy', nvfuserex, torchex]
24+
executors=[triton_cross_entropy_ex,]
2725
)
2826

2927
device = 'cuda'
@@ -41,43 +39,42 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
4139

4240
This prints::
4341

44-
# Constructed by Delete Last Used
42+
# Constructed by Delete Last Used (took 0 milliseconds)
4543
import torch
44+
from thunder.executors.torchex import no_autocast
45+
4646
@torch.no_grad()
47-
def xentropy(logits, labels, weight, reduction, ignore_index):
47+
@no_autocast()
48+
def computation(logits, labels, weight):
4849
# logits: "cuda:0 f32[2048, 50257]"
4950
# labels: "cuda:0 i64[2048]"
5051
# weight: "cuda:0 f32[50257]"
51-
# "sum"
52-
# ignore_index: "int 10106"
53-
t22 = triton_cross_entropy(logits, labels, weight, None, ignore_index, None, "sum", 0.0) # t22: "cuda:0 f32[]"
54-
del [logits, labels, weight, ignore_index]
55-
return t22
52+
t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]"
53+
del logits, labels, weight
54+
return t23
5655

57-
As shown in the above trace, ``triton_cross_entropy()`` is the one running the operation.
56+
As shown in the above trace, ``triton_crossentropy()`` is the one running the operation.
5857

5958
Apex CrossEntropy Executor
6059
==========================
6160

6261
The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this::
6362

63+
import torch
6464
import thunder
65-
from thunder.executors import nvfuserex, torchex
66-
from thunder.executors.apex_entropyex import deregister_apex_entropyex, register_apex_entropyex
67-
68-
register_apex_entropyex(add_to_default_executors=False)
65+
from thunder.executors.apex_entropyex import apex_ex
6966

7067
def xentropy(logits, labels):
7168
return thunder.torch.cross_entropy(
7269
logits, labels, reduction='mean', ignore_index=-1
7370
)
7471

75-
jitted_xentropy = thunder.jit(xentropy, executors_list=['apex_xentropy', nvfuserex, torchex])
72+
jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,])
7673

7774
device = 'cuda'
7875
dtype = torch.float32
7976

80-
logits = torch.randn([2048, 50257], device=device, dtype=thunder.torch.to_torch_dtype(dtype))
77+
logits = torch.randn([2048, 50257], device=device, dtype=dtype)
8178
labels = torch.randint(0, 50257, [2048], device=device)
8279

8380
jitted_xentropy(logits, labels)
@@ -86,14 +83,17 @@ The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an
8683

8784
This prints::
8885

89-
# Constructed by Delete Last Used
86+
# Constructed by Delete Last Used (took 0 milliseconds)
9087
import torch
88+
from thunder.executors.torchex import no_autocast
89+
9190
@torch.no_grad()
92-
def xentropy(logits, labels):
91+
@no_autocast()
92+
def computation(logits, labels):
9393
# logits: "cuda:0 f32[2048, 50257]"
9494
# labels: "cuda:0 i64[2048]"
95-
t18 = apex_cross_entropy(logits, labels, None, None, -1, None, "mean", 0.0) # t18: "cuda:0 f32[]"
96-
del [logits, labels]
95+
(t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0)
96+
del logits, labels
9797
return t18
9898

9999
showing that Apex is running the operation.

0 commit comments

Comments
 (0)