Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ ft = as_folded_tensor(
[0, 1, 2],
[3],
],
pad_value=-1,
)
# FoldedTensor([[0, 1, 2],
# [3, 0, 0]])
# [3, -1, -1]])
```

`pad_value` allows changing the value used to pad the nested sequences.

You can also specify names and flattened/unflattened dimensions at the time of creation:

```python
Expand Down Expand Up @@ -87,6 +90,15 @@ print(ft.refold(("lines", "words")))
# [2, 3],
# [4, 3]])

# Use a custom value for padding when refolding
print(ft.refold(("lines", "words"), pad_value=-1))
# FoldedTensor([[ 1, -1],
# [-1, -1],
# [-1, -1],
# [-1, -1],
# [ 2, 3],
# [ 4, 3]])

# Refold on the words dim only: flatten everything
print(ft.refold(("words",)))
# FoldedTensor([1, 2, 3, 4, 3])
Expand Down
33 changes: 27 additions & 6 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
ctx,
self: "FoldedTensor",
dims: Tuple[int],
pad_value: Union[int, float, bool] = 0,
) -> "FoldedTensor":
ctx.set_materialize_grads(False)
ctx.lengths = self.lengths
Expand All @@ -82,8 +83,11 @@ def forward(
indexer = torch.from_numpy(np_new_indexer).to(device)
ctx.output_indexer = indexer
shape_suffix = data.shape[len(self.data_dims) :]
refolded_data = torch.zeros(
(*shape_prefix, *shape_suffix), dtype=data.dtype, device=device
refolded_data = torch.full(
(*shape_prefix, *shape_suffix),
pad_value,
dtype=data.dtype,
device=device,
)
refolded_data.view(-1, *shape_suffix)[indexer] = data.view(
-1, *shape_suffix
Expand Down Expand Up @@ -112,7 +116,7 @@ def backward(ctx, grad_output):
grad_input.view(-1, *shape_suffix)[ctx.input_indexer] = grad_output.reshape(
-1, *shape_suffix
).index_select(0, ctx.output_indexer)
return grad_input, None
return grad_input, None, None


type_to_dtype_dict = {
Expand Down Expand Up @@ -148,6 +152,7 @@ def as_folded_tensor(
dtype: Optional[torch.dtype] = None,
lengths: Optional[List[List[int]]] = None,
device: Optional[Union[str, torch.device]] = None,
pad_value: Union[int, float, bool] = 0,
):
"""
Converts a tensor or nested sequence into a FoldedTensor.
Expand All @@ -168,6 +173,8 @@ def as_folded_tensor(
must be provided. If `data` is a sequence, this argument must be `None`.
device: Optional[Unit[str, torch.device]]
The device of the output tensor
pad_value: Union[int, float, bool]
Value used to pad the nested sequences. Defaults to ``0``.
"""
if full_names is not None:
if data_dims is not None:
Expand Down Expand Up @@ -212,6 +219,7 @@ def as_folded_tensor(
data,
data_dims,
np.dtype(dtype),
pad_value,
)
indexer = torch.from_numpy(indexer)
padded = torch.from_numpy(padded)
Expand Down Expand Up @@ -396,7 +404,20 @@ def clone(self):
cloned._mask = self._mask.clone()
return cloned

def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]):
def refold(
self,
*dims: Union[Sequence[Union[int, str]], int, str],
pad_value: Union[int, float, bool] = 0,
):
"""Change which dimensions are padded.

Parameters
----------
*dims: Union[Sequence[Union[int, str]], int, str]
Dimensions to keep padded.
pad_value: Union[int, float, bool]
Value used to pad the folded dimensions. Defaults to ``0``.
"""
if not isinstance(dims[0], (int, str)):
assert len(dims) == 1, (
"Expected the first only argument to be a "
Expand All @@ -414,10 +435,10 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]):
f"could not be refolded with dimensions {list(dims)}"
)

if dims == self.data_dims:
if dims == self.data_dims and pad_value == 0:
return self

return Refold.apply(self, dims)
return Refold.apply(self, dims, pad_value)


def reduce_foldedtensor(self: FoldedTensor):
Expand Down
19 changes: 16 additions & 3 deletions foldedtensor/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ std::tuple<
nested_py_list_to_padded_np_array(
const py::list &nested_list,
std::vector<int> data_dims,
py::dtype &dtype) {
py::dtype &dtype,
py::object pad_value) {
// Will contain the variable lengths of the nested lists
// One sequence per dimension, containing the lengths of the lists at that dimension
std::vector<std::vector<int64_t>> lengths;
Expand Down Expand Up @@ -236,7 +237,11 @@ nested_py_list_to_padded_np_array(

// Create the padded array from the shape inferred during `flatten_py_list`
py::array padded_array = py::array(py::dtype(dtype), shape);
padded_array[py::make_tuple(py::ellipsis())] = 0;
if (PyArray_FillWithScalar(
reinterpret_cast<PyArrayObject *>(padded_array.ptr()),
pad_value.ptr()) < 0) {
throw py::error_already_set();
}

// Get the strides of the array
const py::ssize_t *array_strides = padded_array.strides();
Expand Down Expand Up @@ -311,7 +316,15 @@ PYBIND11_MODULE(_C, m) {
init_numpy();

m.def("make_refolding_indexer", &make_refolding_indexer, "Build an indexer to refold data into a different shape");
m.def("nested_py_list_to_padded_array", &nested_py_list_to_padded_np_array, "Converts a nested Python list to a padded array");
m.def(
"nested_py_list_to_padded_array",
&nested_py_list_to_padded_np_array,
py::arg("nested_list"),
py::arg("data_dims"),
py::arg("dtype"),
py::arg("pad_value") = 0,
"Converts a nested Python list to a padded array"
);
}

#pragma clang diagnostic pop
37 changes: 37 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,23 @@ def test_refold_lines(ft):
).all()


def test_refold_custom_pad_value(ft):
ft2 = ft.refold("lines", "words", pad_value=-1)
assert (
ft2.data
== torch.tensor(
[
[1, -1],
[-1, -1],
[-1, -1],
[-1, -1],
[2, 3],
[4, 3],
]
)
).all()


def test_embedding(ft):
embedder = torch.nn.Embedding(10, 16)
embedding = embedder(ft.refold("words"))
Expand Down Expand Up @@ -299,6 +316,26 @@ def test_pad_embedding():
).all()


def test_custom_pad_value():
ft = as_folded_tensor(
[
[0, 1, 2],
[3, 4],
],
pad_value=-1,
dtype=torch.long,
)
assert (
ft.data
== torch.tensor(
[
[0, 1, 2],
[3, 4, -1],
]
)
).all()


def test_empty_args():
ft = as_folded_tensor(
[
Expand Down
Loading