Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion slangpy/core/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
)

from slangpy.reflection import SlangFunction, SlangType
from slangpy import CommandEncoder, TypeConformance, uint3, Logger, NativeHandle, NativeHandleType
from slangpy import (
CommandEncoder,
QueryPool,
TypeConformance,
uint3,
Logger,
NativeHandle,
NativeHandleType,
)
from slangpy.slangpy import Shape
from slangpy.bindings.typeregistry import PYTHON_SIGNATURES

Expand Down Expand Up @@ -139,6 +147,12 @@ def cuda_stream(self, stream: NativeHandle) -> "FunctionNode":
"""
return FunctionNodeCUDAStream(self, stream)

def write_timestamps(self, write_timestamps: tuple[QueryPool, int, int]) -> "FunctionNode":
"""
Specify a query pool and and a before/after query index to write timestamps before/after the dispatch.
"""
return FunctionNodeWriteTimestamps(self, write_timestamps)

def constants(self, constants: dict[str, Any]):
"""
Specify link time constants that should be set when the function is compiled. These are
Expand Down Expand Up @@ -427,6 +441,21 @@ def _populate_build_info(self, info: FunctionBuildInfo):
info.options["cuda_stream"] = self.stream


class FunctionNodeWriteTimestamps(FunctionNode):
def __init__(
self, parent: NativeFunctionNode, write_timestamps: tuple[QueryPool, int, int]
) -> None:
super().__init__(parent, FunctionNodeType.write_timestamps, write_timestamps)
self.slangpy_signature = str(write_timestamps)

@property
def write_timestamps(self):
return cast(tuple[QueryPool, int, int], self._native_data)

def _populate_build_info(self, info: FunctionBuildInfo):
info.options["write_timestamps"] = self.write_timestamps


class FunctionNodeConstants(FunctionNode):
def __init__(self, parent: NativeFunctionNode, constants: dict[str, Any]) -> None:
super().__init__(parent, FunctionNodeType.kernelgen, constants)
Expand Down
18 changes: 12 additions & 6 deletions slangpy/testing/benchmark/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def __call__(
query_pool = device.create_query_pool(type=spy.QueryType.timestamp, count=iterations * 2)
for i in range(iterations):
command_encoder = device.create_command_encoder()
command_encoder.write_timestamp(query_pool, i * 2)
function(**kwargs, _append_to=command_encoder)
command_encoder.write_timestamp(query_pool, i * 2 + 1)
function.write_timestamps((query_pool, i * 2, i * 2 + 1))(
**kwargs,
_append_to=command_encoder,
)
device.submit_command_buffer(command_encoder.finish())
device.wait()
queries = np.array(query_pool.get_results(0, iterations * 2))
Expand Down Expand Up @@ -134,9 +135,14 @@ def __call__(
query_pool = device.create_query_pool(type=spy.QueryType.timestamp, count=iterations * 2)
for i in range(iterations):
command_encoder = device.create_command_encoder()
command_encoder.write_timestamp(query_pool, i * 2)
kernel.dispatch(thread_count, command_encoder=command_encoder, **kwargs)
command_encoder.write_timestamp(query_pool, i * 2 + 1)
kernel.dispatch(
thread_count,
command_encoder=command_encoder,
query_pool=query_pool,
query_index_before=i * 2,
query_index_after=i * 2 + 1,
**kwargs,
)
device.submit_command_buffer(command_encoder.finish())
device.wait()
queries = np.array(query_pool.get_results(0, iterations * 2))
Expand Down
20 changes: 18 additions & 2 deletions src/slangpy_ext/utils/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "sgl/utils/slangpy.h"
#include "sgl/device/device.h"
#include "sgl/device/kernel.h"
#include "sgl/device/query.h"
#include "sgl/device/command.h"
#include "sgl/stl/bit.h" // Replace with <bit> when available on all platforms.

Expand Down Expand Up @@ -649,10 +650,25 @@ nb::object NativeCallData::exec(

if (command_encoder == nullptr) {
// If we are not appending to a command encoder, we can dispatch directly.
m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, CommandQueueType::graphics, cuda_stream);
m_kernel->dispatch(
uint3(total_threads, 1, 1),
bind_vars,
CommandQueueType::graphics,
cuda_stream,
opts->get_query_pool(),
opts->get_query_before_index(),
opts->get_query_after_index()
);
} else {
// If we are appending to a command encoder, we need to use the command encoder.
m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, command_encoder);
m_kernel->dispatch(
uint3(total_threads, 1, 1),
bind_vars,
command_encoder,
opts->get_query_pool(),
opts->get_query_before_index(),
opts->get_query_after_index()
);
}

// If command_buffer is not null, return early.
Expand Down
15 changes: 15 additions & 0 deletions src/slangpy_ext/utils/slangpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,25 @@ class NativeCallRuntimeOptions : Object {
/// Set the CUDA stream.
void set_cuda_stream(NativeHandle cuda_stream) { m_cuda_stream = cuda_stream; }

QueryPool* get_query_pool() const { return m_query_pool.get(); }

void set_query_pool(QueryPool* query_pool) { m_query_pool = ref(query_pool); }

uint32_t get_query_before_index() const { return m_query_before_index; }

void set_query_before_index(uint32_t query_before_index) { m_query_before_index = query_before_index; }

uint32_t get_query_after_index() const { return m_query_after_index; }

void set_query_after_index(uint32_t query_after_index) { m_query_after_index = query_after_index; }

private:
nb::list m_uniforms;
nb::object m_this{nb::none()};
NativeHandle m_cuda_stream;
ref<QueryPool> m_query_pool;
uint32_t m_query_before_index{0};
uint32_t m_query_after_index{0};
};

/// Defines the common logging functions for a given log level.
Expand Down
13 changes: 11 additions & 2 deletions src/slangpy_ext/utils/slangpyfunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

#include "sgl/device/fwd.h"
#include "sgl/device/resource.h"
#include "sgl/device/query.h"

#include "utils/slangpy.h"

namespace sgl::slangpy {

enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream };
enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream, write_timestamps };
SGL_ENUM_INFO(
FunctionNodeType,
{{FunctionNodeType::unknown, "unknown"},
{FunctionNodeType::uniforms, "uniforms"},
{FunctionNodeType::kernelgen, "kernelgen"},
{FunctionNodeType::this_, "this"},
{FunctionNodeType::cuda_stream, "cuda_stream"}}
{FunctionNodeType::cuda_stream, "cuda_stream"},
{FunctionNodeType::write_timestamps, "write_timestamps"}}
);
SGL_ENUM_REGISTER(FunctionNodeType);

Expand Down Expand Up @@ -72,6 +74,13 @@ class NativeFunctionNode : NativeObject {
case sgl::slangpy::FunctionNodeType::cuda_stream:
options->set_cuda_stream(nb::cast<NativeHandle>(m_data));
break;
case sgl::slangpy::FunctionNodeType::write_timestamps: {
nb::tuple t = nb::cast<nb::tuple>(m_data);
options->set_query_pool(nb::cast<QueryPool*>(t[0]));
options->set_query_before_index(nb::cast<uint32_t>(t[1]));
options->set_query_after_index(nb::cast<uint32_t>(t[2]));
break;
}
default:
break;
}
Expand Down