-
Notifications
You must be signed in to change notification settings - Fork 24
pmpp v2 #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
pmpp v2 #48
Changes from all commits
ad7b5dd
2f8544d
f6e0876
680144e
7d15e92
f6b03a8
3fa35dc
67b1a00
ff91f9f
7c5e50c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,13 +12,6 @@ def custom_kernel(data: input_t) -> output_t: | |
Returns: | ||
Output tensor after convolution | ||
""" | ||
|
||
torch.backends.cudnn.allow_tf32 = False | ||
torch.backends.cudnn.deterministic = True | ||
input_tensor, kernel = data | ||
return F.conv2d( | ||
input_tensor, | ||
kernel, | ||
stride=1, | ||
padding=0 | ||
) | ||
input_tensor, kernel, output = data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test submission doing conf with tf32 / fp16; decide on whether we want this to pass or fail |
||
output = F.conv2d(input_tensor, kernel, stride=1, padding=0) | ||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
data, _output = data | ||
return torch.empty(size=(data.shape[0], data.shape[1]), device=data.device, dtype=data.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return _output? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,14 @@ | ||
from typing import TypedDict, TypeVar | ||
import torch | ||
|
||
input_t = TypeVar("input_t", bound=torch.Tensor) # Input will be (H, W, 3) RGB tensor | ||
output_t = TypeVar("output_t", bound=torch.Tensor) # Output will be (H, W) grayscale tensor | ||
input_t = TypeVar( | ||
"input_t", bound=tuple[torch.Tensor, torch.Tensor] | ||
) # Input is a pair of tensors (input, output) where input is (H, W, 3) RGB tensor and output is (H, W) grayscale tensor | ||
output_t = TypeVar( | ||
"output_t", bound=torch.Tensor | ||
) # Output will be (H, W) grayscale tensor | ||
|
||
|
||
class TestSpec(TypedDict): | ||
size: int # Size of the square image (H=W) | ||
seed: int | ||
seed: int |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,6 @@ | |
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
return torch.bincount(data, minlength=256) | ||
data, output = data | ||
output = torch.bincount(data, minlength=256) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out=output |
||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
data, _output = data | ||
return torch.empty(size=(256,), device=data.device, dtype=data.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return _output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import torch | ||
from task import input_t, output_t | ||
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
""" | ||
Reference implementation of histogram using PyTorch. | ||
|
@@ -9,4 +10,7 @@ def custom_kernel(data: input_t) -> output_t: | |
Returns: | ||
Tensor containing bin counts | ||
""" | ||
return torch.bincount(data, minlength=256) | ||
data, output = data | ||
# Compute histogram with 256 bins | ||
output = torch.bincount(data, minlength=256) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out= |
||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from typing import TypedDict, TypeVar | ||
import torch | ||
|
||
input_t = TypeVar("input_t", bound=torch.Tensor) | ||
input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) | ||
output_t = TypeVar("output_t", bound=torch.Tensor) | ||
|
||
|
||
class TestSpec(TypedDict): | ||
size: int | ||
seed: int | ||
contention: int | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,6 @@ | |
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
a, b = data | ||
return a @ b | ||
|
||
a, b, c = data | ||
c = a @ b | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out= |
||
return c |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,6 @@ | |
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
a, b = data | ||
return (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(a.dtype) | ||
a, b, c = data | ||
c = (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(a.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out= |
||
return c |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from task import input_t, output_t | ||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
a, b = data | ||
return a @ b | ||
a, b, c = data | ||
c = a @ b | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out=; |
||
return c |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,9 @@ def ref_kernel(data: input_t) -> output_t: | |
Returns: | ||
Tensor containing the inclusive prefix sum | ||
""" | ||
return torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64) | ||
data, output = data | ||
output = torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out=; |
||
return output | ||
|
||
|
||
def generate_input(size: int, seed: int) -> input_t: | ||
|
@@ -20,18 +22,22 @@ def generate_input(size: int, seed: int) -> input_t: | |
Returns: | ||
Tensor to compute prefix sum on | ||
""" | ||
gen = torch.Generator(device='cuda') | ||
gen = torch.Generator(device="cuda") | ||
gen.manual_seed(seed) | ||
return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() | ||
x = torch.randn( | ||
size, device="cuda", dtype=torch.float32, generator=gen | ||
).contiguous() | ||
y = torch.empty(size, device="cuda", dtype=torch.float32).contiguous() | ||
return (x, y) | ||
|
||
|
||
# This algorithm is very sensitive to the tolerance and the error is magnified by the input size | ||
# The tolerance is scaled by the square root of the input size | ||
def check_implementation(data: input_t, output: output_t) -> str: | ||
# Then get the size for scaling the tolerance | ||
n = data.numel() | ||
scale_factor = n ** 0.5 # Square root of input size | ||
|
||
scale_factor = n**0.5 # Square root of input size | ||
rtol = 1e-5 * scale_factor | ||
atol = 1e-5 * scale_factor | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return output?