-
Notifications
You must be signed in to change notification settings - Fork 641
Open
Description
Does to_executorch()
support slicing tensors based on dynamic (bool) indices?
Here's an example to demonstrate the scenario that currently fails:
Click to Expand
import torch
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from torch.export import export_for_training
from torch.export.experimental import _export_forward_backward
class DynamicTestModel(torch.nn.Module):
def __init__(self,):
super(DynamicTestModel, self).__init__()
def forward(self, tensor_to_slice: torch.Tensor,
slice_conditional_tensor: torch.Tensor,
) -> torch.Tensor:
mask_tensor_bool = (slice_conditional_tensor.squeeze() != -100).bool()
tensor_to_slice = tensor_to_slice[:, mask_tensor_bool, :]
return tensor_to_slice.mean()
def _export(model, tensor_to_slice, slice_conditional_tensor):
exp = export_for_training(model, (tensor_to_slice, slice_conditional_tensor))
exp = _export_forward_backward(exp)
edge = to_edge(
exp,
compile_config=EdgeCompileConfig()
)
edge = edge.to_backend(
XnnpackPartitioner(force_fp32_dynamic_linear=True)
)
edge = edge.to_executorch(config=ExecutorchBackendConfig(
external_mutable_weights=False
))
if __name__ == "__main__":
tensor_to_slice = torch.randn((1, 10, 20))
slice_conditional_tensor = torch.randint(0, 100, (1, 10))
slice_conditional_tensor[:, : 5] = -100
tensor_to_slice.requires_grad_(True)
_export(DynamicTestModel(), tensor_to_slice, slice_conditional_tensor.contiguous())
Error: Click to Expand
`Traceback (most recent call last):`
...
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(20*u21, 0) (unhinted: Eq(20*u21, 0)). (Size-like symbols: u21)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Caused by: (executorch/exir/tensor.py:71 in dim_order_from_stride)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u21"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
While executing %aten_index_tensor : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.index.Tensor](args = (%tensor_to_slice, [None, %aten_ne_scalar]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
...
This seems to be a special case of data dependent control flow involving data dependent dynamic shapes. However, as mentioned in the (later) link, it seems like such scenarios are expected to cause graph breaks, and thus cannot be exported to_executorch
. Is my understanding correct?
Or is there any way to export this?
Metadata
Metadata
Assignees
Labels
No labels