diff --git a/slangpy/core/function.py b/slangpy/core/function.py index c18c3876..2bd52fae 100644 --- a/slangpy/core/function.py +++ b/slangpy/core/function.py @@ -10,7 +10,15 @@ ) from slangpy.reflection import SlangFunction, SlangType -from slangpy import CommandEncoder, TypeConformance, uint3, Logger, NativeHandle, NativeHandleType +from slangpy import ( + CommandEncoder, + QueryPool, + TypeConformance, + uint3, + Logger, + NativeHandle, + NativeHandleType, +) from slangpy.slangpy import Shape from slangpy.bindings.typeregistry import PYTHON_SIGNATURES @@ -139,6 +147,12 @@ def cuda_stream(self, stream: NativeHandle) -> "FunctionNode": """ return FunctionNodeCUDAStream(self, stream) + def write_timestamps(self, write_timestamps: tuple[QueryPool, int, int]) -> "FunctionNode": + """ + Specify a query pool and and a before/after query index to write timestamps before/after the dispatch. + """ + return FunctionNodeWriteTimestamps(self, write_timestamps) + def constants(self, constants: dict[str, Any]): """ Specify link time constants that should be set when the function is compiled. These are @@ -427,6 +441,21 @@ def _populate_build_info(self, info: FunctionBuildInfo): info.options["cuda_stream"] = self.stream +class FunctionNodeWriteTimestamps(FunctionNode): + def __init__( + self, parent: NativeFunctionNode, write_timestamps: tuple[QueryPool, int, int] + ) -> None: + super().__init__(parent, FunctionNodeType.write_timestamps, write_timestamps) + self.slangpy_signature = str(write_timestamps) + + @property + def write_timestamps(self): + return cast(tuple[QueryPool, int, int], self._native_data) + + def _populate_build_info(self, info: FunctionBuildInfo): + info.options["write_timestamps"] = self.write_timestamps + + class FunctionNodeConstants(FunctionNode): def __init__(self, parent: NativeFunctionNode, constants: dict[str, Any]) -> None: super().__init__(parent, FunctionNodeType.kernelgen, constants) diff --git a/slangpy/testing/benchmark/fixtures.py b/slangpy/testing/benchmark/fixtures.py index ce666b7e..13e50b0b 100644 --- a/slangpy/testing/benchmark/fixtures.py +++ b/slangpy/testing/benchmark/fixtures.py @@ -93,9 +93,10 @@ def __call__( query_pool = device.create_query_pool(type=spy.QueryType.timestamp, count=iterations * 2) for i in range(iterations): command_encoder = device.create_command_encoder() - command_encoder.write_timestamp(query_pool, i * 2) - function(**kwargs, _append_to=command_encoder) - command_encoder.write_timestamp(query_pool, i * 2 + 1) + function.write_timestamps((query_pool, i * 2, i * 2 + 1))( + **kwargs, + _append_to=command_encoder, + ) device.submit_command_buffer(command_encoder.finish()) device.wait() queries = np.array(query_pool.get_results(0, iterations * 2)) @@ -134,9 +135,14 @@ def __call__( query_pool = device.create_query_pool(type=spy.QueryType.timestamp, count=iterations * 2) for i in range(iterations): command_encoder = device.create_command_encoder() - command_encoder.write_timestamp(query_pool, i * 2) - kernel.dispatch(thread_count, command_encoder=command_encoder, **kwargs) - command_encoder.write_timestamp(query_pool, i * 2 + 1) + kernel.dispatch( + thread_count, + command_encoder=command_encoder, + query_pool=query_pool, + query_index_before=i * 2, + query_index_after=i * 2 + 1, + **kwargs, + ) device.submit_command_buffer(command_encoder.finish()) device.wait() queries = np.array(query_pool.get_results(0, iterations * 2)) diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 1b98e4be..c01fc633 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -10,6 +10,7 @@ #include "sgl/utils/slangpy.h" #include "sgl/device/device.h" #include "sgl/device/kernel.h" +#include "sgl/device/query.h" #include "sgl/device/command.h" #include "sgl/stl/bit.h" // Replace with when available on all platforms. @@ -649,10 +650,25 @@ nb::object NativeCallData::exec( if (command_encoder == nullptr) { // If we are not appending to a command encoder, we can dispatch directly. - m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, CommandQueueType::graphics, cuda_stream); + m_kernel->dispatch( + uint3(total_threads, 1, 1), + bind_vars, + CommandQueueType::graphics, + cuda_stream, + opts->get_query_pool(), + opts->get_query_before_index(), + opts->get_query_after_index() + ); } else { // If we are appending to a command encoder, we need to use the command encoder. - m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, command_encoder); + m_kernel->dispatch( + uint3(total_threads, 1, 1), + bind_vars, + command_encoder, + opts->get_query_pool(), + opts->get_query_before_index(), + opts->get_query_after_index() + ); } // If command_buffer is not null, return early. diff --git a/src/slangpy_ext/utils/slangpy.h b/src/slangpy_ext/utils/slangpy.h index 81f75477..a96ae4b5 100644 --- a/src/slangpy_ext/utils/slangpy.h +++ b/src/slangpy_ext/utils/slangpy.h @@ -598,10 +598,25 @@ class NativeCallRuntimeOptions : Object { /// Set the CUDA stream. void set_cuda_stream(NativeHandle cuda_stream) { m_cuda_stream = cuda_stream; } + QueryPool* get_query_pool() const { return m_query_pool.get(); } + + void set_query_pool(QueryPool* query_pool) { m_query_pool = ref(query_pool); } + + uint32_t get_query_before_index() const { return m_query_before_index; } + + void set_query_before_index(uint32_t query_before_index) { m_query_before_index = query_before_index; } + + uint32_t get_query_after_index() const { return m_query_after_index; } + + void set_query_after_index(uint32_t query_after_index) { m_query_after_index = query_after_index; } + private: nb::list m_uniforms; nb::object m_this{nb::none()}; NativeHandle m_cuda_stream; + ref m_query_pool; + uint32_t m_query_before_index{0}; + uint32_t m_query_after_index{0}; }; /// Defines the common logging functions for a given log level. diff --git a/src/slangpy_ext/utils/slangpyfunction.h b/src/slangpy_ext/utils/slangpyfunction.h index 309231da..61176c36 100644 --- a/src/slangpy_ext/utils/slangpyfunction.h +++ b/src/slangpy_ext/utils/slangpyfunction.h @@ -14,19 +14,21 @@ #include "sgl/device/fwd.h" #include "sgl/device/resource.h" +#include "sgl/device/query.h" #include "utils/slangpy.h" namespace sgl::slangpy { -enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream }; +enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream, write_timestamps }; SGL_ENUM_INFO( FunctionNodeType, {{FunctionNodeType::unknown, "unknown"}, {FunctionNodeType::uniforms, "uniforms"}, {FunctionNodeType::kernelgen, "kernelgen"}, {FunctionNodeType::this_, "this"}, - {FunctionNodeType::cuda_stream, "cuda_stream"}} + {FunctionNodeType::cuda_stream, "cuda_stream"}, + {FunctionNodeType::write_timestamps, "write_timestamps"}} ); SGL_ENUM_REGISTER(FunctionNodeType); @@ -72,6 +74,13 @@ class NativeFunctionNode : NativeObject { case sgl::slangpy::FunctionNodeType::cuda_stream: options->set_cuda_stream(nb::cast(m_data)); break; + case sgl::slangpy::FunctionNodeType::write_timestamps: { + nb::tuple t = nb::cast(m_data); + options->set_query_pool(nb::cast(t[0])); + options->set_query_before_index(nb::cast(t[1])); + options->set_query_after_index(nb::cast(t[2])); + break; + } default: break; }