Skip to content

Commit 1acfd23

Browse files
authored
Add write function for binding (#893)
1 parent 4e475ea commit 1acfd23

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

slangpy/core/dispatchdata.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,15 @@ def dispatch(
218218
**kwargs: dict[str, Any],
219219
) -> None:
220220

221-
# Merge uniforms
221+
# Merge uniforms and collect writer tuples
222222
uniforms: dict[str, Any] = {}
223+
writers: list[tuple] = []
223224
if opts.uniforms is not None:
224225
for u in opts.uniforms:
225226
if isinstance(u, dict):
226227
uniforms.update(u)
228+
elif isinstance(u, tuple):
229+
writers.append(u)
227230
else:
228231
uniforms.update(u(self)) # type: ignore (need to work out native dispatch)
229232
uniforms.update(vars)
@@ -245,6 +248,8 @@ def dispatch(
245248
compute_pass = command_encoder.begin_compute_pass()
246249
cursor = ShaderCursor(compute_pass.bind_pipeline(self.compute_pipeline))
247250
cursor.write(uniforms)
251+
for fn, args, wkwargs in writers:
252+
fn(cursor, *args, **wkwargs)
248253
cursor.find_entry_point(0).write(call_data)
249254
compute_pass.dispatch(thread_count)
250255
compute_pass.end()

slangpy/core/function.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def set(self, *args: Any, **kwargs: Any):
156156
"Set requires either keyword arguments or 1 dictionary / hook argument"
157157
)
158158

159+
def write(self, fn: Callable, *args: Any, **kwargs: Any):
160+
"""
161+
Specify a writer function that receives a ShaderCursor and optional arguments
162+
to write uniforms directly. The function signature should be:
163+
fn(cursor: ShaderCursor, *args, **kwargs)
164+
"""
165+
if not callable(fn):
166+
raise ValueError("write() requires a callable as the first argument")
167+
return FunctionNodeSet(self, (fn, args, kwargs))
168+
159169
def cuda_stream(self, stream: NativeHandle) -> "FunctionNode":
160170
"""
161171
Specify a CUDA stream to use for the function. This is useful for synchronizing with other

slangpy/tests/slangpy_tests/test_sets_and_hooks.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
22

3+
from typing import Any
4+
35
import pytest
46
import numpy as np
57

6-
from slangpy import DeviceType, Module
8+
from slangpy import DeviceType, Module, ShaderCursor
79
from slangpy.types import Tensor
810
from slangpy.testing import helpers
911

@@ -65,5 +67,27 @@ def test_set_with_callback(device_type: DeviceType):
6567
assert np.allclose(res_data, val_data + 10)
6668

6769

70+
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
71+
def test_write(device_type: DeviceType):
72+
m = load_test_module(device_type)
73+
assert m is not None
74+
75+
add_k = m.add_k.as_func()
76+
77+
val = Tensor.empty(m.device, dtype=float, shape=(10,))
78+
val_data = np.zeros(10, dtype=np.float32) # np.random.rand(10).astype(np.float32)
79+
val.copy_from_numpy(val_data)
80+
81+
def writer(cursor: ShaderCursor, *args: Any, **kwargs: Any):
82+
cursor.write({"params": {"k": kwargs["myvalue"]}})
83+
84+
add_k = add_k.write(writer, myvalue=10)
85+
86+
res = add_k(val)
87+
88+
res_data = res.to_numpy().view(dtype=np.float32)
89+
assert np.allclose(res_data, val_data + 10)
90+
91+
6892
if __name__ == "__main__":
6993
pytest.main([__file__, "-v", "-s"])

src/slangpy_ext/utils/slangpy.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,13 @@ nb::object NativeCallData::exec(
933933
for (auto u : uniforms) {
934934
if (nb::isinstance<nb::dict>(u)) {
935935
write_shader_cursor(cursor, nb::cast<nb::dict>(u));
936+
} else if (nb::isinstance<nb::tuple>(u)) {
937+
// Writer tuple: (fn, args, kwargs)
938+
nb::tuple t = nb::cast<nb::tuple>(u);
939+
nb::object fn = t[0];
940+
nb::tuple args = nb::cast<nb::tuple>(t[1]);
941+
nb::dict kwargs = nb::cast<nb::dict>(t[2]);
942+
fn(nb::cast(cursor), *args, **kwargs);
936943
} else {
937944
write_shader_cursor(cursor, nb::cast<nb::dict>(u(this)));
938945
}

0 commit comments

Comments
 (0)