Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions slangpy/builtin/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
from slangpy.builtin.value import ValueMarshall
from slangpy.reflection import SlangType, SlangProgramLayout
from slangpy.bindings import (
PYTHON_SIGNATURES,
PYTHON_TYPES,
BindContext,
BoundVariable,
BoundVariableRuntime,
CodeGenBlock,
)
from slangpy import ShaderCursor, ShaderObject
from slangpy.core.native import AccessType, CallContext, NativeValueMarshall, unpack_arg
from slangpy.core.native import AccessType, unpack_arg
import slangpy.reflection as kfr


Expand Down
65 changes: 42 additions & 23 deletions slangpy/builtin/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any, Optional, cast

from slangpy.core.native import AccessType, Shape
from slangpy.core.native import AccessType, Shape, CallMode

from slangpy.reflection.reflectiontypes import is_matching_array_type, VectorType
from slangpy.types.tensor import Tensor
Expand Down Expand Up @@ -199,18 +199,33 @@ def resolve_type(self, context: BindContext, bound_type: SlangType):
f"to tensor with element type {bound_type.dtype.full_name}"
)

# Atomic tensors are special, they must be passed as-is
if bound_type.name == "AtomicTensor":
return bound_type

return build_tensor_type(
self.layout,
bound_type.dtype,
bound_type.dims,
bound_type.writable,
self.d_in is not None,
self.d_out is not None,
)
# If binding to an interface, need to decide on tensor type based on call mode
if bound_type.name in ("ITensor", "RWTensor"):
if context.call_mode == CallMode.prim:
# In forwards pass, bind a Tensor or RWTensor depending on writability
return build_tensor_type(
self.layout,
bound_type.dtype,
bound_type.dims,
bound_type.writable,
False,
False,
)
else:
# If we are in a backward pass, no choice but to bind the full tensor
# type as we don't have context at this point to know if it should be
# GradIn/GradOut/GradInOutTensor
return build_tensor_type(
self.layout,
bound_type.dtype,
bound_type.dims,
bound_type.writable,
self.d_in is not None,
self.d_out is not None,
)

# None-interfaces have to be bound to the exact type
return bound_type

# if implicit element casts enabled, allow conversion from type to element type
if context.options["implicit_element_casts"]:
Expand Down Expand Up @@ -256,21 +271,25 @@ def resolve_dimensionality(
return self.dims + len(self.slang_element_type.shape) - len(vector_target_type.shape)

def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable):
if isinstance(binding.vector_type, ITensorType):
writable = binding.vector_type.writable
else:
writable = binding.access[0] in (AccessType.write, AccessType.readwrite)

# Atomic tensors are special, they must be passed as-is
if binding.vector_type.name == "AtomicTensor":
type_name = binding.vector_type.full_name
if isinstance(binding.vector_type, ITensorType):
# If binding to a tensor type, we need to use the same basic tensor type. However
# dimensionality may differ, we still need to generate the full tensor type name.
assert not binding.vector_type.name in ("ITensor", "RWTensor")
type_name = (
f"{binding.vector_type.name}<{self.slang_element_type.full_name}, {self.dims}>"
)
else:
# If binding to another type (eg vectorizing to a scalar), the tensor type to use
# is based on writability and existence of gradients.
type_name = build_tensor_name(
self.slang_element_type,
self.dims,
writable,
self.d_in is not None,
self.d_out is not None,
binding.access[0] in (AccessType.write, AccessType.readwrite),
self.d_in is not None
and binding.access[1] in (AccessType.read, AccessType.readwrite),
self.d_out is not None
and binding.access[1] in (AccessType.write, AccessType.readwrite),
)
cgb.type_alias(f"_t_{binding.variable_name}", type_name)

Expand Down
42 changes: 34 additions & 8 deletions src/slangpy_ext/utils/slangpytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,44 @@ void NativeTensorMarshall::write_shader_cursor_pre_dispatch(
const ref<NativeTensor>& grad_in = primal->grad_in();
const ref<NativeTensor>& grad_out = primal->grad_out();

if (!has_derivative()) {
write_shader_cursor_fields(context, binding, field, primal, read_back);
} else {
write_shader_cursor_fields(context, binding, field["primal"], primal, read_back);
ShaderCursor primal_field = field.find_field("primal");
if (primal_field.is_valid()) {
// Record these pointers for debug checks
Buffer* bound_primal_buffer = primal->storage().get();
Buffer* bound_grad_in_buffer = nullptr;
Buffer* bound_grad_out_buffer = nullptr;

// Binding to a Tensor object that contains child primal and derivative Tensors.
write_shader_cursor_fields(context, binding, primal_field, primal, read_back);

if (m_d_in) {
SGL_CHECK(grad_in, "Missing required input gradients");
write_shader_cursor_fields(context, binding, field["d_in"], grad_in.get(), read_back);
ShaderCursor d_in_field = field.find_field("d_in");
if (d_in_field.is_valid()) {
bound_grad_in_buffer = grad_in->storage().get();
write_shader_cursor_fields(context, binding, d_in_field, grad_in.get(), read_back);
}
}

if (m_d_out) {
SGL_CHECK(grad_out, "Missing required input gradients");
write_shader_cursor_fields(context, binding, field["d_out"], grad_out.get(), read_back);
ShaderCursor d_out_field = field.find_field("d_out");
if (d_out_field.is_valid()) {
bound_grad_out_buffer = grad_out->storage().get();
write_shader_cursor_fields(context, binding, d_out_field, grad_out.get(), read_back);
}
}

if (bound_primal_buffer == bound_grad_in_buffer || bound_primal_buffer == bound_grad_out_buffer) {
log_warn("Binding the same storage for primal and gradient on the same tensor. This will have serious "
"performance impacts.");
}
if (bound_grad_in_buffer != nullptr && bound_grad_in_buffer == bound_grad_out_buffer) {
log_warn("Binding the same storage for grad in and grad out on the same tensor. This will have serious "
"performance impacts.");
}

} else {
// Binding to a single Tensor object that represents the primal.
write_shader_cursor_fields(context, binding, field, primal, read_back);
}

if (context->call_mode() != CallMode::prim && grad_in && grad_in == grad_out) {
Expand Down
Loading