diff --git a/csrc/debug_print.cu b/csrc/debug_print.cu index 5527ba4..40ccbc5 100644 --- a/csrc/debug_print.cu +++ b/csrc/debug_print.cu @@ -15,116 +15,176 @@ TYPE, NAME, \ AT_DISPATCH_CASE_FLOATING_AND_REDUCED_FLOATING_TYPES(__VA_ARGS__)) -template -__global__ void PrintFloatTensor1D(float_t *__restrict__ x, - const size_t stride_0, const size_t n, - const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); +__device__ void PrintCommon(void* x, const char* name_ptr, const bool print_ptr) { + if (name_ptr != nullptr) { + printf("name=%s, ", name_ptr); } - for (size_t i = 0; i < n; ++i) { - printf("%.4f, ", float(x[i * stride_0])); + if (print_ptr) { + printf("addr=%lld, ", x); } - printf("\n"); } -template -__global__ void PrintIntTensor1D(int_t *__restrict__ x, const size_t stride_0, - const size_t n, const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); - } - for (size_t i = 0; i < n; ++i) { - printf("%lld, ", int64_t(x[i * stride_0])); - } - printf("\n"); +template +struct is_my_floating_point : std::is_floating_point {}; + +template <> +struct is_my_floating_point : std::true_type {}; + +template <> +struct is_my_floating_point : std::true_type {}; + +template +struct always_false : std::false_type {}; + +template +__device__ void PrintElem(scalar_t value) { + if constexpr (is_my_floating_point::value) { + printf("%.4f, ", float(value)); + } else if constexpr (std::is_integral::value) { + printf("%lld, ", static_cast(value)); + } else { + static_assert(always_false::value, "PrintElem: unsupported scalar_t type"); + } } template -__global__ void PrintFloatTensor2D(float_t *__restrict__ x, - const size_t shape_0, const size_t stride_1, - const size_t stride_0, const size_t n, - const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); +__global__ void PrintTensor1D( + float_t *__restrict__ x, + const size_t shape_0, + const size_t stride_0, + const char* name_ptr, const bool print_ptr, const bool print_shape +) { + PrintCommon(x, name_ptr, print_ptr); + if (print_shape) { + printf("shape=(%d), stride=(%d)", (int) shape_0, (int) stride_0); } - for (size_t i = 0; i < n; ++i) { - printf("%.4f, ", - float(x[(i / shape_0) * stride_1 + (i % shape_0) * stride_0])); + printf("\n["); + for (size_t index_0 = 0; index_0 < shape_0; ++index_0) { + PrintElem(x[index_0 * stride_0]); } - printf("\n"); + printf("]\n"); } -template -__global__ void PrintIntTensor2D(int_t *__restrict__ x, const size_t shape_0, - const size_t stride_1, const size_t stride_0, - const size_t n, const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); +template +__global__ void PrintTensor2D( + float_t *__restrict__ x, + const size_t shape_0, const size_t shape_1, + const size_t stride_0, const size_t stride_1, + const char* name_ptr, const bool print_ptr, const bool print_shape +) { + PrintCommon(x, name_ptr, print_ptr); + if (print_shape) { + printf("shape=(%d, %d), stride=(%d, %d)", (int) shape_0, (int) shape_1, (int) stride_0, (int) stride_1); } - for (size_t i = 0; i < n; ++i) { - printf("%lld, ", - int64_t(x[(i / shape_0) * stride_1 + (i % shape_0) * stride_0])); + printf("\n["); + for (size_t index_0 = 0; index_0 < shape_0; ++index_0) { + printf("["); + for (size_t index_1 = 0; index_1 < shape_1; ++index_1) { + PrintElem(x[index_0 * stride_0 + index_1 * stride_1]); + } + printf("], "); } - printf("\n"); + printf("]\n"); } template -__global__ void PrintFloatTensor3D(float_t *__restrict__ x, - const size_t shape_1, const size_t shape_0, - const size_t stride_2, const size_t stride_1, - const size_t stride_0, const size_t n, - const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); +__global__ void PrintTensor3D( + float_t *__restrict__ x, + const size_t shape_0, const size_t shape_1, const size_t shape_2, + const size_t stride_0, const size_t stride_1, const size_t stride_2, + const char* name_ptr, const bool print_ptr, const bool print_shape +) { + PrintCommon(x, name_ptr, print_ptr); + if (print_shape) { + printf("shape=(%d, %d, %d), stride=(%d, %d, %d)", (int) shape_0, (int) shape_1, (int) shape_2, (int) stride_0, (int) stride_1, (int) stride_2); } - for (size_t i = 0; i < n; ++i) { - printf("%.4f, ", float(x[(i / shape_0 / shape_1) * stride_2 + - ((i / shape_0) % shape_1) * stride_1 + - (i % shape_0) * stride_0])); + printf("\n["); + for (size_t index_0 = 0; index_0 < shape_0; ++index_0) { + printf("["); + for (size_t index_1 = 0; index_1 < shape_1; ++index_1) { + printf("["); + for (size_t index_2 = 0; index_2 < shape_2; ++index_2) { + PrintElem(x[index_0 * stride_0 + index_1 * stride_1 + index_2 * stride_2]); + } + printf("], "); + } + printf("], "); } - printf("\n"); + printf("]\n"); } -template -__global__ void PrintIntTensor3D(int_t *__restrict__ x, const size_t shape_1, - const size_t shape_0, const size_t stride_2, - const size_t stride_1, const size_t stride_0, - const size_t n, const bool print_ptr) { - if (print_ptr) { - printf("addr: %lld\n", x); +template +__global__ void PrintTensor4D( + float_t *__restrict__ x, + const size_t shape_0, const size_t shape_1, const size_t shape_2, const size_t shape_3, + const size_t stride_0, const size_t stride_1, const size_t stride_2, const size_t stride_3, + const char* name_ptr, const bool print_ptr, const bool print_shape +) { + PrintCommon(x, name_ptr, print_ptr); + if (print_shape) { + printf("shape=(%d, %d, %d, %d), stride=(%d, %d, %d, %d)", (int) shape_0, (int) shape_1, (int) shape_2, (int) shape_3, (int) stride_0, (int) stride_1, (int) stride_2, (int) stride_3); } - for (size_t i = 0; i < n; ++i) { - printf("%lld, ", int64_t(x[(i / shape_0 / shape_1) * stride_2 + - ((i / shape_0) % shape_1) * stride_1 + - (i % shape_0) * stride_0])); + printf("\n["); + for (size_t index_0 = 0; index_0 < shape_0; ++index_0) { + printf("["); + for (size_t index_1 = 0; index_1 < shape_1; ++index_1) { + printf("["); + for (size_t index_2 = 0; index_2 < shape_2; ++index_2) { + printf("["); + for (size_t index_3 = 0; index_3 < shape_3; ++index_3) { + PrintElem(x[index_0 * stride_0 + index_1 * stride_1 + index_2 * stride_2 + index_3 * stride_3]); + } + printf("], "); + } + printf("], "); + } + printf("], "); } - printf("\n"); + printf("]\n"); } -void PrintTensor(torch::Tensor x, bool print_ptr) { +void PrintTensor(torch::Tensor x, std::optional name_buffer, bool print_ptr, bool print_shape) { cudaStream_t stream = c10::cuda::getCurrentCUDAStream(x.device().index()); TORCH_CHECK(x.is_cuda(), "The input tensor should be a CUDA tensor"); + + const char* name_ptr = name_buffer.has_value() ? reinterpret_cast(name_buffer->data_ptr()) : nullptr; + if (x.is_floating_point()) { if (x.dim() == 1) { AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES( - x.scalar_type(), "PrintFloatTensor1D", ([&] { - PrintFloatTensor1D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.stride(0), x.numel(), print_ptr); + x.scalar_type(), "PrintTensor1D", ([&] { + PrintTensor1D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.stride(0), + name_ptr, print_ptr, print_shape + ); })); } else if (x.dim() == 2) { AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES( - x.scalar_type(), "PrintFloatTensor2D", ([&] { - PrintFloatTensor2D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.size(1), x.stride(0), x.stride(1), - x.numel(), print_ptr); + x.scalar_type(), "PrintTensor2D", ([&] { + PrintTensor2D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.stride(0), x.stride(1), + name_ptr, print_ptr, print_shape + ); })); } else if (x.dim() == 3) { AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES( - x.scalar_type(), "PrintFloatTensor3D", ([&] { - PrintFloatTensor3D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.size(1), x.size(2), x.stride(0), - x.stride(1), x.stride(2), x.numel(), print_ptr); + x.scalar_type(), "PrintTensor3D", ([&] { + PrintTensor3D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.size(2), x.stride(0), x.stride(1), x.stride(2), + name_ptr, print_ptr, print_shape + ); + })); + } else if (x.dim() == 4) { + AT_DISPATCH_FLOATING_AND_REDUCED_FLOATING_TYPES( + x.scalar_type(), "PrintTensor4D", ([&] { + PrintTensor4D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.size(2), x.size(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), + name_ptr, print_ptr, print_shape + ); })); } else { // NOTE(Zihao): I'm just too lazy to do this, codegen for higher @@ -133,29 +193,41 @@ void PrintTensor(torch::Tensor x, bool print_ptr) { } cudaError_t status = cudaGetLastError(); TORCH_CHECK(status == cudaSuccess, - "PrintFloatTensor failed with error " + + "PrintTensor failed with error " + std::string(cudaGetErrorString(status))); } else { if (x.dim() == 1) { - AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor1D", ([&] { - PrintIntTensor1D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.stride(0), - x.numel(), print_ptr); - })); + AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor1D", ([&] { + PrintTensor1D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.stride(0), + name_ptr, print_ptr, print_shape + ); + })); } else if (x.dim() == 2) { - AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor2D", ([&] { - PrintIntTensor2D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.size(1), - x.stride(0), x.stride(1), x.numel(), - print_ptr); - })); + AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor2D", ([&] { + PrintTensor2D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.stride(0), x.stride(1), + name_ptr, print_ptr, print_shape + ); + })); } else if (x.dim() == 3) { - AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintIntTensor3D", ([&] { - PrintIntTensor3D<<<1, 1, 0, stream>>>( - x.data_ptr(), x.size(1), - x.size(2), x.stride(0), x.stride(1), - x.stride(2), x.numel(), print_ptr); - })); + AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor3D", ([&] { + PrintTensor3D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.size(2), x.stride(0), x.stride(1), x.stride(2), + name_ptr, print_ptr, print_shape + ); + })); + } else if (x.dim() == 4) { + AT_DISPATCH_INTEGRAL_TYPES(x.scalar_type(), "PrintTensor4D", ([&] { + PrintTensor4D<<<1, 1, 0, stream>>>( + x.data_ptr(), + x.size(0), x.size(1), x.size(2), x.size(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), + name_ptr, print_ptr, print_shape + ); + })); } else { // NOTE(Zihao): I'm just too lazy to do this, codegen for higher // dimensions should be a better idea @@ -163,7 +235,7 @@ void PrintTensor(torch::Tensor x, bool print_ptr) { } cudaError_t status = cudaGetLastError(); TORCH_CHECK(status == cudaSuccess, - "PrintIntTensor failed with error " + + "PrintTensor failed with error " + std::string(cudaGetErrorString(status))); } } diff --git a/debug_print/__init__.py b/debug_print/__init__.py index addae96..d6b71ed 100644 --- a/debug_print/__init__.py +++ b/debug_print/__init__.py @@ -1,6 +1,81 @@ +from dataclasses import dataclass +from typing import Dict, Optional, List + import torch -from ._kernels import print_tensor as print_tensor_kernel +from ._kernels import print_tensor as _print_tensor_kernel + + +class _Buffer: + def __init__(self, device_index: int): + self._tensor = torch.zeros((10_000_000,), dtype=torch.uint8, device=f"cuda:{device_index}") + self._used_len = 0 + + def allocate(self, size: int): + output = self._tensor[self._used_len: self._used_len + size] + self._used_len += size + assert self._used_len <= len(self._tensor) + return output + + +@dataclass +class _CopyTask: + src: torch.Tensor + dst: torch.Tensor + + def execute(self): + self.dst.copy_(self.src) + + +class _DebugPrinter: + def __init__(self, device_id: Optional[int]): + if device_id is None: + device_id = torch.cuda.current_device() + + # Can be optimized + self._buffers: Dict[int, _Buffer] = {device_id: _Buffer(device_index=device_id)} + self._pending_copy_tasks: List[_CopyTask] = [] + + def post_initialize(self): + for copy_task in self._pending_copy_tasks: + copy_task.execute() + self._pending_copy_tasks.clear() + + def __call__(self, x: torch.Tensor, name: str, print_ptr: bool, print_shape: bool): + assert x.is_cuda, f"{x.device} must be on cuda" + name_buffer_gpu = self._compute_name_buffer_gpu(name=name, device_index=x.device.index) + _print_tensor_kernel(x, name_buffer_gpu, print_ptr, print_shape) + + def _compute_name_buffer_gpu(self, name: str, device_index: int): + if len(name) == 0: + return None + + name_bytes = name.encode("utf-8") + name_buffer_gpu = self._buffers[device_index].allocate(len(name_bytes) + 1) + name_cpu = torch.tensor(list(name_bytes) + [0], dtype=torch.uint8, device="cpu") + copy_task = _CopyTask(src=name_cpu, dst=name_buffer_gpu) + + if torch.cuda.is_current_stream_capturing(): + self._pending_copy_tasks.append(copy_task) + else: + copy_task.execute() + + return name_buffer_gpu + + +_printer: Optional[_DebugPrinter] = None + + +def initialize(device_id: Optional[int] = None): + global _printer + if _printer is not None: + print("debug_print.initialize skip since already initialized") + return + _printer = _DebugPrinter(device_id=device_id) + + +def post_initialize(): + _printer.post_initialize() -def print_tensor(x: torch.Tensor, print_ptr: bool = False): - print_tensor_kernel(x, print_ptr) +def print_tensor(x: torch.Tensor, name: str = "", print_ptr: bool = True, print_shape: bool = True): + _printer(x=x, name=name, print_ptr=print_ptr, print_shape=print_shape) diff --git a/example.py b/example.py index f731d77..b24ba69 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,9 @@ import torch import debug_print +debug_print.initialize() +print("demo without cuda graph...") x = torch.rand(3, 4, 5).to(0) debug_print.print_tensor(x) debug_print.print_tensor(x[..., 0:3]) @@ -9,6 +11,27 @@ debug_print.print_tensor(x[..., 0]) debug_print.print_tensor(x[0:1, 1:3, 0:4]) +print("demo for all types...") +debug_print.print_tensor(torch.tensor([3, 4, 5], dtype=torch.int32, device="cuda:0"), name="for int32", print_shape=True, print_ptr=True) +debug_print.print_tensor(torch.tensor([3, 4, 5], dtype=torch.int64, device="cuda:0"), name="for int64", print_shape=True, print_ptr=True) +debug_print.print_tensor(torch.tensor([1.5, 2.5, 3.5], dtype=torch.float, device="cuda:0"), name="for float", print_shape=True, print_ptr=True) + +print("demo for all dims...") +debug_print.print_tensor(torch.tensor([3, 4, 5], dtype=torch.int32, device="cuda:0"), name="for 1D", print_shape=True, print_ptr=True) +debug_print.print_tensor(torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.int32, device="cuda:0"), name="for 2D", print_shape=True, print_ptr=True) +debug_print.print_tensor( + torch.tensor([[[1, 2, 3], [3, 4, 5]], [[10, 20, 30], [30, 40, 50]]], dtype=torch.int32, device="cuda:0"), + name="for 3D", print_shape=True, print_ptr=True) +debug_print.print_tensor( + torch.tensor( + [ + [[[1, 2, 3], [3, 4, 5]], [[10, 20, 30], [30, 40, 50]]], + [[[-1, -2, -3], [-3, -4, -5]], [[-10, -20, -30], [-30, -40, -50]]], + ], + dtype=torch.int32, device="cuda:0"), + name="for 4D", print_shape=True, print_ptr=True) + +print("start warmup...") s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) x = torch.empty(2, 2).half().to(0) @@ -19,7 +42,7 @@ z1 = z @ y z2 = z1 @ y - +print("start graph capture...") g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=s): debug_print.print_tensor(x) @@ -27,9 +50,11 @@ z = x @ y debug_print.print_tensor(z) z1 = z @ y - debug_print.print_tensor(z1[..., 0]) + debug_print.print_tensor(z1[..., 0], name="This is name for part of z1", print_shape=True, print_ptr=True) z2 = z1 @ y - debug_print.print_tensor(z2) + debug_print.print_tensor(z2, name="This is name for z2") + +debug_print.post_initialize() x.copy_(torch.randn(2, 2)) y.copy_(torch.ones(2, 2)) diff --git a/setup.py b/setup.py index dedb087..2ba3a10 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ setup( name="debug_print", version="0.0.2", + packages=['debug_print'], ext_modules=[ CUDAExtension( name="debug_print._kernels",