diff --git a/README.md b/README.md index d6bf614..ceb9d3d 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,59 @@ print(refolded_embedding.shape) # torch.Size([2, 5, 16]) # 2 samples, 5 words max, 16 dims ``` +### Pooling spans + +You can pool variable length spans directly on a refolded view without padding by building flat indices and offsets and then using `embedding_bag`. + +The helper `lengths.make_indices_ranges` expands ranges defined over one or more variable dimensions. + +- `indices` are the flat positions in the refolded tensor viewed as a single dimension +- `offsets` are the start positions of each span within `indices` +- `spans` gives the span id for every expanded position, which can be useful for functions like `torch.index_add` or `torch.index_reduce` + +Example that sums over word spans to produce one vector per span + +```python +import torch +import foldedtensor as ft + +# Build a 4 level tensor with names: first word of the first context is split into three tokens, etc +input_ids = ft.as_folded_tensor( + [ + [ + [[0, 2, 3], [10], [4]], + [[0, 1, 2], [2, 3], [10, 11], [100, 101]], + ], + ], + full_names=("sample", "context", "word", "token"), +).refold( + "token" +) # any refolding is fine + +# Create embeddings from the input ids +embedding = torch.nn.Embedding(2048, 16) +weight = embedding(input_ids) + +# Pool two word spans per the test +# span 1 covers words 0 to 2 -> mean pool over 4 tokens [0, 2, 3, 10] +# span 2 covers words 5 to 7 -> mean pool over 4 tokens [10, 11, 100, 101] +indices, offsets, spans = input_ids.lengths.make_indices_ranges( + begins=(torch.tensor([0, 5]),), + ends=(torch.tensor([2, 7]),), + indice_dims=("word",), +) + +# Sum embeddings over each span +pooled = torch.nn.functional.embedding_bag( + input=indices, + # Flatten embeddings so rows align with flattened token positions + weight=weight.view(-1, weight.size(-1)), + offsets=offsets, + mode="mean", +) +print(pooled) +``` + ## Benchmarks View the comparisons of `foldedtensor` against various alternatives here: [docs/benchmarks](https://github.com/aphp/foldedtensor/blob/main/docs/benchmark.md). diff --git a/changelog.md b/changelog.md index 5450eaa..2737058 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,12 @@ # Changelog +## Unreleased + +- Add `map_indices` and `make_indices_ranges` with C++ backends and expose `lengths.map_indices` and `lengths.make_indices_ranges` with boundary handling and flat indices with offsets and span ids for pooling with `embedding_bag`. +- Introduce `FoldedTensorLayout` to store `full_names` and `data_dims` with named dimension resolution and helper methods and use it as the `lengths` container for `FoldedTensor` +- Improve `as_folded_tensor` to better infer dims and dtype from nested data and to accept named `data_dims` and better handle names and empty structures +- Benchmark script adds `--cases` to run selected cases and a new case for range based pooling and adjusts outputs + ## v0.4.0 - Fix `storage` torch warning diff --git a/docs/benchmark.md b/docs/benchmark.md index da4ec8a..903d9be 100644 --- a/docs/benchmark.md +++ b/docs/benchmark.md @@ -8,9 +8,9 @@ It compares the performance of `foldedtensor` with various alternatives for padd and working with nested lists and tensors. Environment: -- `torch.__version__ == '2.6.0'` +- `torch.__version__ == '2.8.0'` - `foldedtensor.__version__ == '0.4.0'` -- `python == 3.9.20` +- `python == 3.11.3` - `sys.platform == 'darwin'` @@ -22,13 +22,13 @@ nested_list = make_nested_list(32, (50, 100), (25, 30), value=1) Comparisons: %timeit python_padding(nested_list) -# 100 loops, best of 5: 15.09 ms per loop +# 100 loops, best of 5: 19.02 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list) -# 100 loops, best of 5: 0.73 ms per loop +# 100 loops, best of 5: 0.82 ms per loop ``` -Speedup against best alternative: **20.67x** :rocket: +Speedup against best alternative: **23.24x** :rocket: ## Case 2 (same lengths nested lists) @@ -36,22 +36,22 @@ Speedup against best alternative: **20.67x** :rocket: nested_list = make_nested_list(32, 100, 30, value=1) %timeit torch.tensor(nested_list) -# 100 loops, best of 5: 6.51 ms per loop +# 100 loops, best of 5: 7.86 ms per loop %timeit torch.LongTensor(nested_list) -# 100 loops, best of 5: 2.78 ms per loop +# 100 loops, best of 5: 3.69 ms per loop %timeit python_padding(nested_list) -# 100 loops, best of 5: 18.38 ms per loop +# 100 loops, best of 5: 23.35 ms per loop %timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0) -# 100 loops, best of 5: 3.00 ms per loop +# 100 loops, best of 5: 3.94 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list) -# 100 loops, best of 5: 1.08 ms per loop +# 100 loops, best of 5: 1.18 ms per loop ``` -Speedup against best alternative: **2.58x** :rocket: +Speedup against best alternative: **3.12x** :rocket: ## Case 3 (simple list) @@ -59,19 +59,19 @@ Speedup against best alternative: **2.58x** :rocket: simple_list = make_nested_list(10000, value=1) %timeit torch.tensor(simple_list) -# 100 loops, best of 5: 0.63 ms per loop +# 100 loops, best of 5: 0.77 ms per loop %timeit torch.LongTensor(simple_list) -# 100 loops, best of 5: 0.27 ms per loop +# 100 loops, best of 5: 0.37 ms per loop %timeit python_padding(simple_list) -# 100 loops, best of 5: 0.28 ms per loop +# 100 loops, best of 5: 0.37 ms per loop %timeit foldedtensor.as_folded_tensor(simple_list) -# 100 loops, best of 5: 0.08 ms per loop +# 100 loops, best of 5: 0.10 ms per loop ``` -Speedup against best alternative: **3.32x** :rocket: +Speedup against best alternative: **3.59x** :rocket: ## Case 4 (same lengths nested lists to flat tensor) @@ -79,22 +79,22 @@ Speedup against best alternative: **3.32x** :rocket: nested_list = make_nested_list(32, 100, 30, value=1) %timeit torch.tensor(nested_list).view(-1) -# 100 loops, best of 5: 6.52 ms per loop +# 100 loops, best of 5: 7.83 ms per loop %timeit torch.LongTensor(nested_list).view(-1) -# 100 loops, best of 5: 2.76 ms per loop +# 100 loops, best of 5: 3.68 ms per loop %timeit python_padding(nested_list).view(-1) -# 100 loops, best of 5: 18.62 ms per loop +# 100 loops, best of 5: 23.17 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list).view(-1) -# 100 loops, best of 5: 1.12 ms per loop +# 100 loops, best of 5: 1.19 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list, data_dims=(2,)) -# 100 loops, best of 5: 1.08 ms per loop +# 100 loops, best of 5: 1.16 ms per loop ``` -Speedup against best alternative: **2.47x** :rocket: +Speedup against best alternative: **3.10x** :rocket: ## Case 5 (variable lengths nested lists) to padded embeddings Nested lists with different lengths (second level lists have lengths between 50 and 150). We compare `foldedtensor` with `torch.nested`. @@ -104,24 +104,24 @@ nested_list = make_nested_list(32, (50, 150), 30, value=1) # Padding with 0 %timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0) -# 100 loops, best of 5: 3.02 ms per loop +# 100 loops, best of 5: 4.40 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list).as_tensor() -# 100 loops, best of 5: 1.03 ms per loop +# 100 loops, best of 5: 1.29 ms per loop ``` -Speedup against best alternative: **2.95x** :rocket: +Speedup against best alternative: **3.41x** :rocket: ```python # Padding with 1 %timeit torch.nested.nested_tensor([torch.FloatTensor(sub) for sub in nested_list]).to_padded_tensor(1) -# 100 loops, best of 5: 3.72 ms per loop +# 100 loops, best of 5: 4.77 ms per loop %timeit x = foldedtensor.as_folded_tensor(nested_list); x.masked_fill_(x.mask, 1) -# 100 loops, best of 5: 1.62 ms per loop +# 100 loops, best of 5: 1.65 ms per loop ``` -Speedup against best alternative: **2.30x** :rocket: +Speedup against best alternative: **2.89x** :rocket: ## Case 6 (2d padding) @@ -129,16 +129,47 @@ Speedup against best alternative: **2.30x** :rocket: nested_list = make_nested_list(160, (50, 150), value=1) %timeit python_padding(nested_list) -# 100 loops, best of 5: 1.33 ms per loop +# 100 loops, best of 5: 1.73 ms per loop %timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0) -# 100 loops, best of 5: 1.14 ms per loop +# 100 loops, best of 5: 1.48 ms per loop %timeit torch.nn.utils.rnn.pad_sequence([torch.LongTensor(sub) for sub in nested_list], batch_first=True, padding_value=0) -# 100 loops, best of 5: 0.86 ms per loop +# 100 loops, best of 5: 1.22 ms per loop %timeit foldedtensor.as_folded_tensor(nested_list) -# 100 loops, best of 5: 0.15 ms per loop +# 100 loops, best of 5: 0.18 ms per loop ``` -Speedup against best alternative: **5.88x** :rocket: +Speedup against best alternative: **6.68x** :rocket: + +## Case 7 (summing vectors inside each differently-sized sequence, all concatenated) + +```python +def sum_all_words_per_sample(t): + begins = torch.arange(len(t.lengths[1])) + ends = begins + 1 + indices, offsets, spans = t.lengths.make_indices_ranges( + begins=(begins,), ends=(ends,), indice_dims=(0,) + ) + return torch.nn.functional.embedding_bag( + input=indices, + weight=t.view(-1, t.size(-1)), + offsets=offsets, + mode="sum", + ) + +embedder = torch.nn.Embedding(500, 128) +nested_list = make_nested_list(320, (150, 250), value=1) +ft = foldedtensor.as_folded_tensor(nested_list).refold(1) +ft = embedder(ft) + + +%timeit ft.refold(0, 1).sum(-2) +# 100 loops, best of 5: 3.54 ms per loop + +%timeit sum_all_words_per_sample(ft) +# 100 loops, best of 5: 1.01 ms per loop + +``` +Speedup against pad-then-sum: **3.52x** :rocket: diff --git a/foldedtensor/__init__.py b/foldedtensor/__init__.py index 5499ab8..abe5a49 100644 --- a/foldedtensor/__init__.py +++ b/foldedtensor/__init__.py @@ -8,7 +8,183 @@ import torch from torch.autograd import Function -from . import _C +from . import _C # type: ignore[import] + +Dim = Union[int, str] + + +def map_indices( + indices: Tuple[Sequence[int], ...], + indice_dims: Tuple[int, ...], + lengths: Sequence[Sequence[int]], + data_dims: Tuple[int, ...], + *, + return_tensors: Optional[str] = None, +): + """ + Compute leaf (last-dim) flat indices given indices in other dimensions. + + Parameters + ---------- + indices: Tuple[Sequence[int], ...] + Tuple of index sequences (broadcasted together) describing positions + along `indice_dims`. + indice_dims: Tuple[int, ...] + Names or indices of the addressing dims. + lengths: Sequence[Sequence[int]] + Nested lengths describing the folded structure. + data_dims: Tuple[int, ...] + Names or indices describing the padded layout used for flattening. + return_tensors: Optional[str], optional (default=None) + Return type: "pt" for torch, "np" for numpy, "list" for python list. + + Returns + ------- + Union[List[int], np.ndarray, torch.Tensor] + Returns a list of flat indices compatible with `.view(-1)` of a tensor + refolded with `data_dims`. + """ + D = len(lengths) + if data_dims[-1] != D - 1: + raise ValueError( + "data_dims must end with the last variable dimension (e.g., 'token')" + ) + + orig_shape = None + saw_pt = False + saw_np = False + np_indices: Tuple[np.ndarray, ...] = tuple( + ( + ( + lambda a: ( + (lambda arr: arr.reshape(-1))( + a.detach().cpu().numpy() + if isinstance(a, torch.Tensor) + else (np.asarray(a)) + ) + ) + )(arr) + ) + for arr in indices + ) # type: ignore[arg-type] + + # Track types and original shape from the first array + first = indices[0] + if isinstance(first, torch.Tensor): + saw_pt = True + orig_shape = tuple(first.shape) + else: + arr0 = np.asarray(first) + if arr0.ndim > 1: + orig_shape = tuple(arr0.shape) + saw_np = isinstance(first, np.ndarray) or saw_np + + if len(indice_dims) != len(np_indices): + raise ValueError("indices and indice_dims must have the same length") + + res = _C.map_indices( + lengths, + list(data_dims), + list(indice_dims), + np_indices, + ) + out_np = np.asarray(res) + # Reshape if needed + if orig_shape is not None: + out_np = out_np.reshape(orig_shape) + + if return_tensors == "pt" or return_tensors is None and saw_pt: + return torch.from_numpy(out_np) + if return_tensors == "np" or return_tensors is None and saw_np: + return out_np + return out_np.tolist() + + +def make_indices_ranges( + *, + begins, + ends, + indice_dims, + lengths, + data_dims, + return_tensors: Union[typing.Optional[str], bool] = None, +): + """ + Expand multiple ranges specified along indice_dims into: + - flat indices (compatible with `.view(-1)` of a tensor refolded with `data_dims`), + - start offsets per span, + - and span indices (the span id for each expanded position). + + Parameters use the same conventions as map_indices. `begins` and `ends` are + tuples of 1D tensors or lists corresponding to each dimension in `indice_dims`. + Ranges are half-open: [begin, end), with boundary support when the last + coordinate equals the number of children of its parent. + """ + if not isinstance(begins, (list, tuple)) or not isinstance(ends, (list, tuple)): + raise TypeError("begins and ends must be tuples/lists of arrays") + if len(begins) != len(indice_dims) or len(ends) != len(indice_dims): + raise ValueError("begins/ends must match indice_dims length") + + saw_pt = False + saw_np = False + # Determine original shape from the first begins entry + first_b = begins[0] + if isinstance(first_b, torch.Tensor): + orig_shape = tuple(first_b.shape) + saw_pt = True + else: + arr0 = np.asarray(first_b) + orig_shape = tuple(arr0.shape) if arr0.ndim > 1 else None + saw_np = isinstance(first_b, np.ndarray) or saw_np + + def _to_np1d(x): + nonlocal saw_pt, saw_np + if isinstance(x, torch.Tensor): + saw_pt = True + return x.detach().cpu().numpy().reshape(-1) + a = np.asarray(x) + if isinstance(x, np.ndarray): + saw_np = True + return a.reshape(-1) + + begins_np = [_to_np1d(b) for b in begins] + ends_np = [_to_np1d(e) for e in ends] + + res = _C.make_indices_ranges( + lengths, + list(data_dims), + list(indice_dims), + begins_np, + ends_np, + ) + + indices, offsets, span_indices = res + indices_np = np.asarray(indices) + offsets_np = np.asarray(offsets) + span_indices_np = np.asarray(span_indices) + + # Reshape offsets to original input shape if multi-dimensional + if orig_shape is not None: + offsets_np = offsets_np.reshape(orig_shape) + + if return_tensors == "pt" or return_tensors is None and saw_pt: + return ( + torch.from_numpy(indices_np.astype(np.int64, copy=False)), + torch.from_numpy(offsets_np.astype(np.int64, copy=False)), + torch.from_numpy(span_indices_np.astype(np.int64, copy=False)), + ) + if return_tensors == "np" or return_tensors is None and saw_np: + return ( + indices_np, + offsets_np, + span_indices_np, + ) + return ( + indices_np.astype(np.int64, copy=False).tolist(), + offsets_np.astype(np.int64, copy=False).tolist(), + span_indices_np.astype(np.int64, copy=False).tolist(), + ) + np_to_torch_dtype = { torch.bool: bool, @@ -49,13 +225,115 @@ __version__ = "0.4.0" -class FoldedTensorLengths(UserList): +class FoldedTensorLayout(UserList): + """ + Folded tensor layout information. + """ + + def __init__( + self, + initlist: Optional[Sequence[Sequence[int]]] = None, + *, + data_dims: Optional[Sequence[Union[int, str]]], + full_names: Optional[Sequence[str]], + ) -> None: + super().__init__(initlist or []) + self._full_names: Optional[Tuple[str, ...]] = ( + tuple(full_names) if full_names is not None else None + ) + if self._full_names is not None: + dd = tuple( + d if isinstance(d, int) else self._full_names.index(d) + for d in data_dims + ) + else: + # Accept ints only when no names are provided + dd = tuple(int(d) for d in data_dims) + self._data_dims: Optional[Tuple[int, ...]] = dd + def __hash__(self): return id(self) + @property + def full_names(self) -> Optional[Tuple[str, ...]]: + return self._full_names + + @property + def data_dims(self) -> Optional[Tuple[int, ...]]: + return self._data_dims + + def __getitem__(self, index: Union[int, str]) -> typing.Any: + if isinstance(index, str): + if self._full_names is None: + raise ValueError( + "Cannot resolve named index without full_names in the layout" + ) + try: + index = self._full_names.index(index) + except ValueError as exc: # pragma: no cover + raise ValueError(f"Unknown dimension name {index!r}") from exc + if not isinstance(index, int): # pragma: no cover + raise TypeError("Index must be an int or a str") + return super().__getitem__(index) + + def resolve_dim(self, dim): + if isinstance(dim, tuple): + return tuple(self.resolve_dim(d) for d in dim) + if isinstance(dim, str): + if self._full_names is None: + raise ValueError( + "Cannot resolve named dim without full_names in the layout" + ) + try: + dim = self._full_names.index(dim) + except ValueError as exc: # pragma: no cover + raise ValueError(f"Unknown dimension name {dim!r}") from exc + return int(dim) + + def map_indices( + self, + indices: Tuple[Sequence[int], ...], + indice_dims: Tuple[Union[int, str], ...], + *, + data_dims: Optional[Sequence[Union[int, str]]] = None, + return_tensors: Optional[str] = None, + ): + indice_dims = self.resolve_dim(indice_dims) + data_dims = self.resolve_dim(data_dims or self.data_dims) + + return map_indices( + indices=indices, + indice_dims=indice_dims, + lengths=self, + data_dims=data_dims, + return_tensors=return_tensors, + ) + + def make_indices_ranges( + self, + *, + begins, + ends, + indice_dims, + data_dims: Optional[Sequence[Union[int, str]]] = None, + return_tensors: Optional[str] = None, + ): + # Resolve indice_dims against this layout's names if provided + indice_dims = self.resolve_dim(indice_dims) + data_dims = self.resolve_dim(data_dims or self.data_dims) + + return make_indices_ranges( + begins=begins, + ends=ends, + indice_dims=indice_dims, + lengths=self, + data_dims=data_dims, + return_tensors=return_tensors, + ) -if typing.TYPE_CHECKING: - FoldedTensorLengths = List[List[int]] # noqa: F811 + +# Backward-compatibility alias +FoldedTensorLengths = FoldedTensorLayout # noinspection PyMethodOverriding @@ -88,11 +366,12 @@ def forward( refolded_data.view(-1, *shape_suffix)[indexer] = data.view( -1, *shape_suffix ).index_select(0, self.indexer) + lengths = FoldedTensorLayout( + self.lengths, data_dims=dims, full_names=self.full_names + ) return FoldedTensor( data=refolded_data, - lengths=self.lengths, - data_dims=dims, - full_names=self.full_names, + lengths=lengths, indexer=indexer, ) @@ -146,7 +425,7 @@ def as_folded_tensor( data_dims: Optional[Sequence[Union[int, str]]] = None, full_names: Optional[Sequence[str]] = None, dtype: Optional[torch.dtype] = None, - lengths: Optional[List[List[int]]] = None, + lengths: Optional[Union[FoldedTensorLayout, List[List[int]]]] = None, device: Optional[Union[str, torch.device]] = None, ): """ @@ -169,6 +448,9 @@ def as_folded_tensor( device: Optional[Unit[str, torch.device]] The device of the output tensor """ + if isinstance(lengths, FoldedTensorLayout): + data_dims = lengths.data_dims or data_dims + full_names = lengths.full_names or full_names if full_names is not None: if data_dims is not None: data_dims = tuple( @@ -189,11 +471,10 @@ def as_folded_tensor( f"Shape inferred from lengths is not compatible with data dims: {shape}, " f"{data.shape}, {len(data_dims)}" ) + layout = FoldedTensorLayout(lengths, data_dims=data_dims, full_names=full_names) result = FoldedTensor( data=data, - lengths=FoldedTensorLengths(lengths), - data_dims=data_dims, - full_names=full_names, + lengths=layout, indexer=torch.from_numpy(np_indexer).to(data.device), ) elif isinstance(data, Sequence): @@ -217,11 +498,10 @@ def as_folded_tensor( padded = torch.from_numpy(padded) # In case of empty sequences, lengths are not computed correctly lengths = (list(lengths) + [[0]] * deepness)[:deepness] + layout = FoldedTensorLayout(lengths, data_dims=data_dims, full_names=full_names) result = FoldedTensor( data=padded, - lengths=FoldedTensorLengths(lengths), - data_dims=data_dims, - full_names=full_names, + lengths=layout, indexer=indexer, ) else: @@ -246,8 +526,6 @@ def _postprocess_func_result(result, input): return FoldedTensor( data=result, lengths=input.lengths, - data_dims=input.data_dims, - full_names=input.full_names, indexer=input.indexer, mask=input._mask, ) @@ -281,18 +559,12 @@ class FoldedTensor(torch.Tensor): def __new__( cls, data: torch.Tensor, - lengths: FoldedTensorLengths, - data_dims: Sequence[int], - full_names: Sequence[str], + lengths: FoldedTensorLayout, indexer: torch.Tensor, mask: Optional[torch.Tensor] = None, ): - data_dims = data_dims - full_names = full_names instance = data.as_subclass(cls) instance.lengths = lengths - instance.data_dims = data_dims - instance.full_names = full_names instance.indexer = indexer instance._mask = mask return instance @@ -301,12 +573,18 @@ def with_data(self, data: torch.Tensor): return FoldedTensor( data=data, lengths=self.lengths, - data_dims=self.data_dims, - full_names=self.full_names, indexer=self.indexer, mask=self._mask, ) + @property + def data_dims(self) -> Tuple[int, ...]: + return self.lengths.data_dims + + @property + def full_names(self) -> Optional[Tuple[str, ...]]: + return self.lengths.full_names + @property def mask(self): if self._mask is None: @@ -323,18 +601,14 @@ def as_tensor(self): def to(self, *args, **kwargs): with torch._C.DisableTorchFunction(): - result = super().to(*args, **kwargs) + res = super().to(*args, **kwargs) copy = kwargs.get("copy", False) - non_blocking = kwargs.get("non_blocking", False) + nb = kwargs.get("non_blocking", False) return FoldedTensor( - data=result, + data=res, lengths=self.lengths, - data_dims=self.data_dims, - full_names=self.full_names, - indexer=self.indexer.to( - result.device, copy=copy, non_blocking=non_blocking - ), - mask=self._mask.to(result.device, copy=copy, non_blocking=non_blocking) + indexer=self.indexer.to(res.device, copy=copy, non_blocking=nb), + mask=self._mask.to(res.device, copy=copy, non_blocking=nb) if self._mask is not None else None, ) @@ -361,14 +635,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if isinstance(arg, FoldedTensor): assert ( ft is None or ft.data_dims == arg.data_dims - ), "Cannot perform operation on FoldedTensors with different structures" + ), "Cannot perform operation on FoldedTensors with different layouts" ft = arg elif isinstance(arg, (list, tuple)): for item in arg: if isinstance(item, FoldedTensor): assert ft is None or ft.data_dims == item.data_dims, ( "Cannot perform operation on FoldedTensors with " - "different structures" + "different layouts" ) ft = item @@ -408,12 +682,28 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]): dim if isinstance(dim, int) else self.full_names.index(dim) for dim in dims ) - except ValueError: + except ValueError: # pragma: no cover raise ValueError( f"Folded tensor with available dimensions {self.full_names} " f"could not be refolded with dimensions {list(dims)}" ) + # Ensure the leaf (last variable) dimension is last in the refolded layout + leaf = len(self.lengths) - 1 + if dims[-1] != leaf: + leaf_name = ( + self.full_names[leaf] if self.full_names is not None else str(leaf) + ) + dim_names = tuple( + self.full_names[d] if self.full_names is not None else str(d) + for d in dims + ) + raise ValueError( + "The last dimension of data_dims must be the last variable " + f"dimension {leaf_name!r} (ie. {leaf}); got data_dims={dim_names} " + f"(ie. {tuple(dims)}" + ) + if dims == self.data_dims: return self @@ -431,8 +721,6 @@ def reduce_foldedtensor(self: FoldedTensor): ( self.data.as_tensor(), self.lengths, - self.data_dims, - self.full_names, self.indexer.clone() if self.indexer.is_shared() and self.indexer.storage().is_cuda else self.indexer, diff --git a/foldedtensor/functions.cpp b/foldedtensor/functions.cpp index 718c69b..d8730b9 100644 --- a/foldedtensor/functions.cpp +++ b/foldedtensor/functions.cpp @@ -306,12 +306,547 @@ static bool init_numpy() { return true; } +static std::vector cumsum(const std::vector &v) { + std::vector out; + out.reserve(v.size() + 1); + out.push_back(0); + int64_t total = 0; + for (auto x : v) { + total += x; + out.push_back(total); + } + return out; +} + +/** + * Compute per-dimension child start offsets (exclusive prefix sums). + * + * For every variable dimension j>0, `lengths[j]` contains, for each parent + * entity at dimension j-1, the number of children at dimension j. The + * exclusive prefix-sum of this array maps a parent global id to the global id + * of its first child at the next dimension. + * + * - starts[j].size() == lengths[j].size() + 1 + * - For a parent global id g at dimension j-1, the first child global id at + * dimension j is `starts[j][g]`, and the number of children is + * `lengths[j][g]`. + * - starts[0] is left empty (unused), since there is no dimension -1. + * + * @param lengths Variable lengths per dimension. For j>0, lengths[j][g] is the + * number of children in dim j for parent g in dim j-1. + * @return For each j, starts[j] = cumsum(lengths[j]) (exclusive prefix sum). + */ +static std::vector> child_start_offsets( + const std::vector> &lengths) { + const size_t D = lengths.size(); + std::vector> starts(D); + for (size_t j = 1; j < D; ++j) { + starts[j] = cumsum(lengths[j]); + } + return starts; +} + +static std::vector> leaf_offsets_per_dim( + const std::vector> &lengths +) { + const size_t D = lengths.size(); + if (D < 2) { + return std::vector>(); + } + // Start with one leaf per word + size_t n_words = 0; + for (auto x : lengths[D - 2]) n_words += x; + std::vector counts(n_words, 1); + + std::vector> offsets(D); + // For words (D-2): [0,1,2,...,n_words] + offsets[D - 2].resize(n_words + 1); + for (size_t i = 0; i < n_words + 1; ++i) offsets[D - 2][i] = (int64_t)i; + + for (int d = (int)D - 3; d >= 0; --d) { + std::vector new_counts; + new_counts.reserve(lengths[d + 1].size()); + auto it = counts.begin(); + for (auto n_children : lengths[d + 1]) { + int64_t s = 0; + for (int64_t k = 0; k < n_children; ++k) { + if (it == counts.end()) break; + s += *it; + ++it; + } + new_counts.push_back(s); + } + counts.swap(new_counts); + offsets[d] = cumsum(counts); + } + return offsets; +} + +/** + * Compute the flat begin index for every leaf under a refolded layout. + * + * Given the nested `lengths` description and the list of data dimensions + * `data_dims` (which must end at the leaf dimension D-1), this function + * simulates iterating leaves (tokens) while incrementing the multi-dimensional + * index over the data layout. It returns, for each leaf (global leaf id), the + * flat index at which that leaf begins in the contiguous, refolded array. + * + * The resulting flat indices are computed using strides derived from the + * maximum extents observed during the simulated iteration of `data_dims`. + * + * @param lengths Variable lengths per dimension + * @param data_dims Contiguous data dimensions in order, must end at D-1. + * @return Vector `begins[leaf_gid]` giving the flat begin offset of each leaf. + * @throws std::invalid_argument if `data_dims` does not end with D-1. + */ +static std::vector begin_idx_per_leaf( + std::vector> lengths, + const std::vector &data_dims) { + const size_t D = lengths.size(); + const size_t n_new = data_dims.size(); + if (n_new == 0) return {}; + if ((size_t)data_dims.back() != D - 1) { + throw std::invalid_argument("data_dims must end with last variable dimension"); + } + + std::vector new_dim_map(D, -1); + for (size_t i = 0; i < n_new; ++i) new_dim_map[data_dims[i]] = (int8_t)i; + + std::vector new_idx(n_new, 0); + std::vector new_shape(n_new, 0); + std::vector offsets(D - 1, 0); + + std::vector, int64_t>> ops; // (idx snapshot, leaf length) + ops.reserve(lengths.back().size()); + + for (auto leaf_len : lengths.back()) { + ops.emplace_back(new_idx, leaf_len); + + new_idx.back() += leaf_len; + if (new_idx.back() > new_shape.back()) new_shape.back() = new_idx.back(); + + int dim = (int)D - 2; + int8_t mapped = new_dim_map[dim]; + if (mapped >= 0) { + new_idx[mapped] += 1; + if (new_idx[mapped] > new_shape[mapped]) new_shape[mapped] = new_idx[mapped]; + for (size_t i = mapped + 1; i < n_new; ++i) new_idx[i] = 0; + } + + for (dim = (int)D - 2; dim >= 0; --dim) { + lengths[dim][offsets[dim]] -= 1; + if (lengths[dim][offsets[dim]] > 0) { + break; + } + offsets[dim] += 1; + if (dim == 0) break; + int next_dim = dim - 1; + int8_t next_mapped = new_dim_map[next_dim]; + if (next_mapped >= 0) { + new_idx[next_mapped] += 1; + if (new_idx[next_mapped] > new_shape[next_mapped]) new_shape[next_mapped] = new_idx[next_mapped]; + for (int8_t i = next_mapped + 1; i < (int8_t)n_new; ++i) new_idx[i] = 0; + } + } + } + + // strides + std::vector strides(n_new, 1); + for (int i = (int)n_new - 2; i >= 0; --i) { + int64_t s = new_shape[i + 1]; + if (s <= 0) s = 1; + strides[i] = strides[i + 1] * s; + } + + std::vector begins; + begins.reserve(ops.size()); + for (auto &op : ops) { + auto &idx = op.first; + int64_t base = 0; + for (size_t i = 0; i + 1 < n_new; ++i) base += idx[i] * strides[i]; + base += idx.back(); + begins.push_back(base); + } + return begins; +} + +/** + * Resolve a (possibly multi-dimensional) coordinate into a flat token index. + * + * The coordinate spans the contiguous variable dimensions given by + * `indice_dims`. Depending on the last addressed dimension, the function + * supports boundary indices (equal to the size) and maps them to the logical + * end position after the last token of the addressed entity/leaf. + * + * Single-dimension addressing rules: + * - If d == D-1 (token dimension): idx in [0, total_tokens] -> begin_of_leaf + offset. + * - If d == D-2 (leaf/word id): idx in [0, total_words] -> begin_of_leaf. + * - Else (higher level): idx in [0, leaf_offs[d].size()-1] -> first token of entity. + * In all cases, idx == size selects the end position after the last token. + * + * Multi-dimension addressing (contiguous): interpret `coord` as offsets within + * the subtree rooted at `indice_dims[0]`, descend using `starts` to compute the + * parent global id, and resolve the last coordinate either to a token offset or + * to the first token of the targeted child and boundary at the last dimension is + * supported analogously. + * + * @param lengths Variable lengths per dimension. + * @param data_dims Data dimensions (must end at D-1). + * @param indice_dims Contiguous addressed variable dimensions. + * @param starts Per-dimension child start offsets: starts[j] = cumsum(lengths[j]). + * @param leaf_offs For each dimension, offsets into the leaf (token) axis. + * @param begins_per_leaf Flat begin index per leaf (from begin_idx_per_leaf). + * @param token_starts Global token cumsum across leaves. + * @param coord Coordinate values aligned with `indice_dims`. + * @return Flat token index (or end position) in the refolded layout. + * @throws std::out_of_range on invalid coordinates beyond the allowed boundary. + */ +// Helper: memoized count of descendants at a target dimension under an entity. +// cache[target_dim][from_dim] is a vector of size = number of entities at from_dim, +// storing the count of target_dim entities under each entity at from_dim. +static int64_t count_descendants_memo( + const std::vector> &lengths, + const std::vector> &starts, + int from_dim, + int64_t gid, + int target_dim, + std::vector>> &cache) { + if (from_dim == target_dim) return 1; // the entity itself counts as 1 at its own dimension + auto &level_cache = cache[target_dim][from_dim]; + if (gid < 0 || gid >= (int64_t)level_cache.size()) return 0; + int64_t val = level_cache[gid]; + if (val >= 0) return val; + // Sum descendant counts over immediate children + int next_dim = from_dim + 1; + int64_t n_children = lengths[next_dim][gid]; + int64_t start = starts[next_dim][gid]; + int64_t total = 0; + for (int64_t i = 0; i < n_children; ++i) { + total += count_descendants_memo(lengths, starts, next_dim, start + i, target_dim, cache); + } + level_cache[gid] = total; + return total; +} + +// Map a flattened offset within descendants at target_dim to a concrete child gid at target_dim. +static int64_t descendant_gid_by_flat_offset( + const std::vector> &lengths, + const std::vector> &starts, + int from_dim, + int64_t gid, + int target_dim, + int64_t offset, + std::vector>> &cache) { + if (from_dim == target_dim) return gid; + int next_dim = from_dim + 1; + int64_t n_children = lengths[next_dim][gid]; + int64_t start = starts[next_dim][gid]; + for (int64_t i = 0; i < n_children; ++i) { + int64_t child_gid = start + i; + int64_t cnt = count_descendants_memo(lengths, starts, next_dim, child_gid, target_dim, cache); + if (offset < cnt) { + return descendant_gid_by_flat_offset(lengths, starts, next_dim, child_gid, target_dim, offset, cache); + } + offset -= cnt; + } + // Should not reach here if offset < total descendants + throw std::out_of_range("Offset beyond descendant count"); +} + +static int64_t compute_flat_index( + const std::vector> &lengths, + const std::vector &data_dims, + const std::vector &indice_dims, + const std::vector> &starts, + const std::vector> &leaf_offs, + const std::vector &begins_per_leaf, + const std::vector &token_starts, + const std::vector &coord) { + const size_t D = lengths.size(); + const int dend = indice_dims.back(); + + // Single-dimension convenience addressing + if (indice_dims.size() == 1) { + const int d = indice_dims[0]; + int64_t idx = coord[0]; + if (d == (int)D - 1) { + // Global token index (pooled across leaves), boundary allowed + int64_t total_tokens = token_starts.back(); + if (idx < 0 || idx > total_tokens) throw std::out_of_range("Token index out of bounds"); + if (idx == total_tokens) { + int64_t last_leaf = (int64_t)lengths.back().size() - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + auto it = std::upper_bound(token_starts.begin(), token_starts.end(), idx); + int64_t leaf = (int64_t)(it - token_starts.begin()) - 1; + int64_t offset = idx - token_starts[leaf]; + return begins_per_leaf[leaf] + offset; + } else if (d == (int)D - 2) { + // Global word index (leaf id), boundary allowed + int64_t total_words = 0; for (auto x : lengths[D - 2]) total_words += x; + if (idx < 0 || idx > total_words) throw std::out_of_range("Word index out of bounds"); + if (idx == total_words) { + int64_t last_leaf = total_words - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + return begins_per_leaf[idx]; + } else { + // Higher-level entity: map to first token of its first leaf, boundary allowed + // Bounds are based on the total number of entities at this level across parents, + // which corresponds to leaf_offs[d].size() - 1, not lengths[d].size(). + int64_t total_entities = (int64_t)leaf_offs[d].size() - 1; + if (idx < 0 || idx > total_entities) throw std::out_of_range("Index out of bounds"); + if (idx == total_entities) { + int64_t leaf_end = leaf_offs[d].back(); + if (leaf_end == 0) return 0; // empty + int64_t last_leaf = leaf_end - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + int64_t leaf_idx = leaf_offs[d][idx]; + return begins_per_leaf[leaf_idx]; + } + } + + // General non-contiguous multi-dimension addressing (flatten intermediate dims) + // Validate strictly increasing dims + for (size_t i = 1; i < indice_dims.size(); ++i) { + if (indice_dims[i] <= indice_dims[i - 1]) { + throw std::invalid_argument("indice_dims must be strictly increasing"); + } + } + + // Build a cache for descendant counts for all target dims that may be addressed + // Prepare cache[target_dim][from_dim][gid] = count, initialized to -1 + std::vector>> cache; + cache.resize(D); + for (size_t t = 0; t < D; ++t) { + cache[t].resize(D); + for (size_t fd = 0; fd < D; ++fd) { + size_t n_entities = 0; + if (fd == D - 1) { + n_entities = lengths.back().size(); + } else if (fd + 1 < D) { + n_entities = lengths[fd + 1].size(); + } + cache[t][fd] = std::vector(n_entities, -1); + } + } + + // Resolve the first coordinate to a global entity id at its dim + int d0 = indice_dims[0]; + int64_t gid = coord[0]; + // Bounds for first coordinate (no boundary allowed except if last dim only, already handled) + if (d0 == (int)D - 1) { + throw std::invalid_argument("First indice_dim cannot be the leaf/token dimension when multiple dims are provided"); + } else if (d0 == (int)D - 2) { + int64_t total_words = 0; for (auto x : lengths[D - 2]) total_words += x; + if (gid < 0 || gid >= total_words) throw std::out_of_range("Index out of bounds at first dimension"); + } else { + int64_t total_entities = (int64_t)leaf_offs[d0].size() - 1; + if (gid < 0 || gid >= total_entities) throw std::out_of_range("Index out of bounds at first dimension"); + } + + if (indice_dims.size() == 2 && dend == (int)D - 1) { + // Common case: [d_parent, token] with flattening across intermediates + int parent_dim = d0; + int64_t last_idx = coord[1]; + // Number of tokens under this parent entity + int64_t token_count = count_descendants_memo(lengths, starts, parent_dim, gid, (int)D - 1, cache); + if (last_idx < 0 || last_idx > token_count) throw std::out_of_range("Token index out of bounds"); + int64_t leaf_begin = leaf_offs[parent_dim][gid]; + int64_t leaf_end = leaf_offs[parent_dim][gid + 1]; + if (last_idx == token_count) { + if (leaf_end == leaf_begin) return begins_per_leaf[leaf_begin]; + int64_t last_leaf = leaf_end - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + // Map token offset to absolute token index across global token_starts + int64_t base_tokens = token_starts[leaf_begin]; + int64_t abs_token = base_tokens + last_idx; + auto it = std::upper_bound(token_starts.begin(), token_starts.end(), abs_token); + int64_t leaf = (int64_t)(it - token_starts.begin()) - 1; + int64_t offset = abs_token - token_starts[leaf]; + return begins_per_leaf[leaf] + offset; + } + + // Traverse successive addressed dims, skipping/flattening intermediates + for (size_t i = 1; i + 1 < indice_dims.size(); ++i) { + int target_dim = indice_dims[i]; + int64_t off = coord[i]; + if (off < 0) throw std::out_of_range("Negative index not allowed"); + int64_t cnt = count_descendants_memo(lengths, starts, indice_dims[i - 1], gid, target_dim, cache); + if (off >= cnt) throw std::out_of_range("Index out of bounds at intermediate dimension"); + gid = descendant_gid_by_flat_offset(lengths, starts, indice_dims[i - 1], gid, target_dim, off, cache); + } + + // Handle last dimension + int prev_dim = indice_dims[indice_dims.size() - 2]; + int64_t last_idx = coord.back(); + if (dend == (int)D - 1) { + // last is token, gid is entity at prev_dim + int64_t token_count = count_descendants_memo(lengths, starts, prev_dim, gid, (int)D - 1, cache); + if (last_idx < 0 || last_idx > token_count) throw std::out_of_range("Token index out of bounds"); + int64_t leaf_begin = leaf_offs[prev_dim][gid]; + int64_t leaf_end = leaf_offs[prev_dim][gid + 1]; + if (last_idx == token_count) { + if (leaf_end == leaf_begin) return begins_per_leaf[leaf_begin]; + int64_t last_leaf = leaf_end - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + int64_t base_tokens = token_starts[leaf_begin]; + int64_t abs_token = base_tokens + last_idx; + auto it = std::upper_bound(token_starts.begin(), token_starts.end(), abs_token); + int64_t leaf = (int64_t)(it - token_starts.begin()) - 1; + int64_t offset = abs_token - token_starts[leaf]; + return begins_per_leaf[leaf] + offset; + } else { + // last is an addressed non-leaf level, select descendant at dend with boundary allowed + int64_t cnt = count_descendants_memo(lengths, starts, prev_dim, gid, dend, cache); + if (last_idx < 0 || last_idx > cnt) throw std::out_of_range("Index out of bounds at last dimension"); + if (last_idx == cnt) { + int64_t leaf_begin = leaf_offs[prev_dim][gid]; + int64_t leaf_end = leaf_offs[prev_dim][gid + 1]; + if (leaf_end == leaf_begin) return begins_per_leaf[leaf_begin]; + int64_t last_leaf = leaf_end - 1; + return begins_per_leaf[last_leaf] + lengths.back()[last_leaf]; + } + int64_t child_gid = descendant_gid_by_flat_offset(lengths, starts, prev_dim, gid, dend, last_idx, cache); + int64_t leaf_idx = (dend == (int)D - 2) ? child_gid : leaf_offs[dend][child_gid]; + return begins_per_leaf[leaf_idx]; + } +} + +static py::array_t map_indices_cpp( + const std::vector> &lengths, + const std::vector &data_dims, + const std::vector &indice_dims, + const std::vector> &indices) { + const size_t D = lengths.size(); + if (D == 0) return py::array_t(0); + + if (data_dims.empty() || (size_t)data_dims.back() != D - 1) { + throw std::invalid_argument("data_dims must end with the last variable dimension"); + } + if (indices.size() != indice_dims.size()) { + throw std::invalid_argument("indices and indice_dims must have the same length"); + } + size_t n = indices.empty() ? 0 : indices[0].size(); + for (auto &v : indices) if (v.size() != n) throw std::invalid_argument("indices must be same length"); + + // Precompute helpers + std::vector begins = begin_idx_per_leaf(lengths, data_dims); + std::vector> leaf_offs = leaf_offsets_per_dim(lengths); + std::vector> starts = child_start_offsets(lengths); + std::vector token_starts = cumsum(lengths.back()); + + // Validate monotonic increasing dims and bounds + if (!indice_dims.empty()) { + for (size_t i = 1; i < indice_dims.size(); ++i) { + if (indice_dims[i] <= indice_dims[i - 1]) { + throw std::invalid_argument("indice_dims must be strictly increasing"); + } + } + if (indice_dims.back() > (int)D - 1) { + throw std::invalid_argument("Final indice_dim must be <= leaf dimension"); + } + } + + py::array_t out(n); + auto *out_ptr = (int64_t *) out.mutable_data(); + std::vector coord(indice_dims.size()); + for (size_t i = 0; i < n; ++i) { + for (size_t j = 0; j < indice_dims.size(); ++j) coord[j] = indices[j][i]; + out_ptr[i] = compute_flat_index(lengths, data_dims, indice_dims, starts, leaf_offs, begins, token_starts, coord); + } + return out; +} + +// Extracted from inline binding: build flat indices for spans between begins and ends +static py::tuple make_indices_ranges_cpp( + const std::vector> &lengths, + const std::vector &data_dims, + const std::vector &indice_dims, + const std::vector> &begins, + const std::vector> &ends) { + const size_t D = lengths.size(); + if (data_dims.empty() || (size_t)data_dims.back() != D - 1) { + throw std::invalid_argument("data_dims must end with the last variable dimension"); + } + if (begins.size() != ends.size() || begins.size() != indice_dims.size()) { + throw std::invalid_argument("begins/ends must match indice_dims length"); + } + size_t n = begins.empty() ? 0 : begins[0].size(); + for (auto &v : begins) if (v.size() != n) throw std::invalid_argument("begins arrays must be same length"); + for (auto &v : ends) if (v.size() != n) throw std::invalid_argument("ends arrays must be same length as begins"); + + // Precompute helpers + std::vector begins_per_leaf = begin_idx_per_leaf(lengths, data_dims); + std::vector> leaf_offs = leaf_offsets_per_dim(lengths); + std::vector> starts = child_start_offsets(lengths); + std::vector token_starts = cumsum(lengths.back()); + + // Validate monotonic increasing dims + if (!indice_dims.empty()) { + for (size_t i = 1; i < indice_dims.size(); ++i) { + if (indice_dims[i] <= indice_dims[i - 1]) { + throw std::invalid_argument("indice_dims must be strictly increasing"); + } + } + if (indice_dims.back() > (int)D - 1) { + throw std::invalid_argument("Final indice_dim must be <= leaf dimension"); + } + } + + // First pass: compute starts and total length + std::vector starts_vec; + starts_vec.reserve(n); + std::vector> be_pairs; + be_pairs.reserve(n); + int64_t total = 0; + for (size_t i = 0; i < n; ++i) { + std::vector bcoord(indice_dims.size()); + std::vector ecoord(indice_dims.size()); + for (size_t j = 0; j < indice_dims.size(); ++j) { + bcoord[j] = begins[j][i]; + ecoord[j] = ends[j][i]; + } + int64_t b = compute_flat_index(lengths, data_dims, indice_dims, starts, leaf_offs, begins_per_leaf, token_starts, bcoord); + int64_t e = compute_flat_index(lengths, data_dims, indice_dims, starts, leaf_offs, begins_per_leaf, token_starts, ecoord); + if (e < b) throw std::invalid_argument("Range end before begin"); + starts_vec.push_back(total); + be_pairs.emplace_back(b, e); + total += (e - b); + } + + // Build outputs + py::array_t indices(total); + auto *ind_ptr = (int64_t *) indices.mutable_data(); + // Also build span indices: the span number for each expanded position + py::array_t span_indices(total); + auto *span_ptr = (int64_t *) span_indices.mutable_data(); + for (size_t i = 0; i < be_pairs.size(); ++i) { + auto &p = be_pairs[i]; + for (int64_t x = p.first; x < p.second; ++x) { + *ind_ptr++ = x; + *span_ptr++ = (int64_t) i; + } + } + py::array_t offsets(starts_vec.size()); + auto *off_ptr = (int64_t *) offsets.mutable_data(); + for (auto s : starts_vec) *off_ptr++ = s; + return py::make_tuple(indices, offsets, span_indices); +} + PYBIND11_MODULE(_C, m) { // Initialize the NumPy API. init_numpy(); m.def("make_refolding_indexer", &make_refolding_indexer, "Build an indexer to refold data into a different shape"); m.def("nested_py_list_to_padded_array", &nested_py_list_to_padded_np_array, "Converts a nested Python list to a padded array"); + m.def("map_indices", &map_indices_cpp, "Maps indices to flat leaf starts with boundary support"); + m.def("make_indices_ranges", &make_indices_ranges_cpp, "Expand ranges between begins and ends into flat indices, start offsets, and span indices"); } +// PARTS TO SIMPLIFY -- END + #pragma clang diagnostic pop diff --git a/scripts/benchmark.py b/scripts/benchmark.py index c03564f..d160dbe 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -1,4 +1,5 @@ # ruff: noqa: F401, E501 +import argparse import contextlib import random import subprocess @@ -132,7 +133,16 @@ def format_time(dt): if __name__ == "__main__": # fmt: off - cases = [1, 2, 3, 4, 5, 6] + parser = argparse.ArgumentParser(description="Run foldedtensor benchmarks.") + parser.add_argument( + "-c", + "--cases", + type=int, + nargs="*", + help="Space-separated case IDs to run (1-8). Default: all.", + ) + args = parser.parse_args() + cases = args.cases or list(range(1, 9)) if 1 in cases: print("\n## Case 1 (pad variable lengths nested list)\n") @@ -228,44 +238,50 @@ def format_time(dt): print(f"Speedup against best alternative: **{min(alt) / ft_time:.2f}x** :rocket:") if 7 in cases: - # Test case not working yet - - def sum_all_words_per_sample(ft): - lengths = ft.lengths - ids = torch.arange(lengths[0][0]) - for i in range(1, len(lengths)): - ids = torch.repeat_interleave( - ids, - lengths[i], - output_size=len(lengths[i + 1]) - if i < len(lengths) - 1 - else ft.size(len(ft.data_dims) - 1), - ) - - out = torch.zeros(lengths[0][0], ft.shape[-1]) - out.index_add_(source=ft.as_tensor(), dim=0, index=ids) - - return out - - - print("\n## Case 7 (flat sums)\n") + print("\n## Case 7 (summing vectors inside each differently-sized sequence, all concatenated)\n") with block_code(): exec_and_print( + 'def sum_all_words_per_sample(t):\n' + ' begins = torch.arange(len(t.lengths[1]))\n' + ' ends = begins + 1\n' + ' indices, offsets, spans = t.lengths.make_indices_ranges(\n' + ' begins=(begins,), ends=(ends,), indice_dims=(0,)\n' + ' )\n' + ' return torch.nn.functional.embedding_bag(\n' + ' input=indices,\n' + ' weight=t.view(-1, t.size(-1)),\n' + ' offsets=offsets,\n' + ' mode="sum",\n' + ' )\n\n' "embedder = torch.nn.Embedding(500, 128)\n" "nested_list = make_nested_list(320, (150, 250), value=1)\n" - "ft = foldedtensor.as_folded_tensor(nested_list).refold(2)\n" - "nt = torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list])\n" + "ft = foldedtensor.as_folded_tensor(nested_list).refold(1)\n" + #"nt = torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list])\n" "ft = embedder(ft)\n" - "nt = embedder(nt)\n" + #"nt = embedder(nt)\n" ) - - nt_time = timeit("nt.sum(dim=1)") + pd_time = timeit("ft.refold(0, 1).sum(-2)") ft_time = timeit("sum_all_words_per_sample(ft)") - print(f"Speedup against best alternative: **{nt_time / ft_time:.2f}x** :rocket:") + print(f"Speedup against pad-then-sum: **{pd_time / ft_time:.2f}x** :rocket:") + + if 8 in cases: + print("\n## Case 8 (CamemBERT tokenization with padding)\n") + + with block_code(): + exec_and_print( + "from transformers import AutoTokenizer\n" + "tokenizer = AutoTokenizer.from_pretrained('camembert-base')\n" + "texts = [\n" + " ('Le chat est sur le tapis. ' * random.randint(50, 150)).strip()\n" + " for _ in range(64)\n" + "]" + ) + hf_time = timeit("tokenizer(texts, return_tensors='pt', padding=True)") + ft_time = timeit("foldedtensor.as_folded_tensor(tokenizer(texts)['input_ids'])") - # timeit("embedder(ft)") - # timeit("embedder(ft).refold(0, 1)") - # timeit("embedder(nt)") + print( + f"Speedup against baseline: **{hf_time / ft_time:.2f}x** :rocket:" + ) # fmt: on diff --git a/tests/test_folded_tensor.py b/tests/test_folded_tensor.py index a4d2e24..67139a4 100644 --- a/tests/test_folded_tensor.py +++ b/tests/test_folded_tensor.py @@ -446,3 +446,48 @@ def test_missing_dims(): tensor.refold("line", "token") assert "line" in str(e.value) + + +def test_get_lengths(): + tensor = as_folded_tensor( + [ + [0, 1, 2], + [3, 4], + ], + full_names=("sample", "token"), + dtype=torch.long, + ) + assert tensor.lengths == [[2], [3, 2]] + assert tensor.lengths["token"] == [3, 2] + + +def test_recreate_folded_tensor_manually(): + tensor = as_folded_tensor( + [ + [0, 1, 2], + [3, 4], + ], + full_names=("sample", "token"), + dtype=torch.long, + ) + as_folded_tensor( + data=tensor.data, + lengths=tensor.lengths, + data_dims=tensor.data_dims, + full_names=("sample_bis", "token_bis"), + ) + + +def test_fail_on_refold_missing_last_dim(): + tensor = as_folded_tensor( + [ + [[0], [1, 2]], + [[3, 4], [8, 9, 10, 11]], + ], + full_names=("sample", "sent", "word"), + dtype=torch.long, + ) + with pytest.raises(ValueError) as e: + tensor.refold("sent") + + assert "The last dimension" in str(e.value) diff --git a/tests/test_indices.py b/tests/test_indices.py new file mode 100644 index 0000000..94749f8 --- /dev/null +++ b/tests/test_indices.py @@ -0,0 +1,249 @@ +import numpy as np +import torch + +import foldedtensor as ft + + +def build_tensor(): + # Two samples total. First sample mirrors the example in the prompt + # and totals 14 tokens (contexts: 5 and 9). Second sample has one + # small context to ensure strides are unchanged for the (context, token) view. + data = [ + [ + [ + [0, 2, 3], + [10], + [4], + ], + [ + [0, 1, 2], + [2, 3], + [10, 11], + [100, 101], + ], + ], + [ + [ + [7], + [8, 9], + ], + ], + ] + return ft.as_folded_tensor(data, full_names=("sample", "context", "word", "token")) + + +def build_tensor_single_sample(): + # Single sample as in the prompt example + data = [ + [ + [ + [0, 2, 3], + [10], + [4], + ], + [ + [0, 1, 2], + [2, 3], + [10, 11], + [100, 101], + ], + ], + ] + return ft.as_folded_tensor(data, full_names=("sample", "context", "word", "token")) + + +def test_map_indices_flat_unpadded_tokens_by_token(): + t = build_tensor() + + assert t.refold("token").lengths.map_indices( + indices=([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],), + indice_dims=("token",), + ) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + + +def test_map_indices_flat_unpadded_tokens_by_word(): + t = build_tensor() + + assert t.refold("token").lengths.map_indices( + indices=([0, 1, 2, 3, 4, 5, 6, 7],), + indice_dims=("word",), + ) == [0, 3, 4, 5, 8, 10, 12, 14] + + +def test_map_indices_flat_unpadded_tokens_by_context(): + t = build_tensor() + + assert t.refold("token").lengths.map_indices( + indices=([0, 1, 2],), + indice_dims=("context",), + ) == [0, 5, 14] + + +def test_map_indices_flat_unpadded_tokens_by_sample(): + t = build_tensor() + + assert t.refold("token").lengths.map_indices( + indices=([0, 1],), + indice_dims=("sample",), + ) == [0, 14] + + +def test_map_indices_subset_words(): + t = build_tensor() + + assert t.refold("token").lengths.map_indices( + indices=([0, 1, 2, 4, 6],), + indice_dims=("word",), + ) == [0, 3, 4, 8, 12] + + +def test_map_indices_context_word_to_padded_context_token(): + t = build_tensor() + + assert t.refold("context", "token").lengths.map_indices( + indices=([0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 0, 1, 2, 3]), + indice_dims=("context", "word"), + ) == [0, 3, 4, 5, 9, 12, 14, 16] + + +def test_make_indices_ranges_flat_tokens(): + t = build_tensor_single_sample() + + indices, offsets, spans = t.refold("token").lengths.make_indices_ranges( + begins=(torch.as_tensor([0, 0, 1]), torch.as_tensor([0, 1, 2])), + ends=(torch.as_tensor([0, 1, 1]), torch.as_tensor([1, 3, 4])), + indice_dims=("context", "word"), + return_tensors=False, + ) + + assert indices == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 10, 11, 12, 13] + assert offsets == [0, 3, 12] + assert spans == [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2] + + +def test_make_indices_ranges_one_dim_token(): + t = build_tensor_single_sample() + indices, offsets, spans = t.refold("token").lengths.make_indices_ranges( + begins=(torch.tensor([0, 3, 12]),), + ends=(torch.tensor([3, 12, 14]),), + indice_dims=("token",), + return_tensors=False, + ) + assert indices == list(range(0, 3)) + list(range(3, 12)) + list(range(12, 14)) + assert offsets == [0, 3, 12] + assert spans == [0] * 3 + [1] * 9 + [2] * 2 + + +def test_make_indices_ranges_one_dim_word(): + t = build_tensor_single_sample() + indices, offsets, spans = t.refold("token").lengths.make_indices_ranges( + begins=(torch.tensor([0, 1, 3]),), + ends=(torch.tensor([1, 3, 7]),), + indice_dims=("word",), + return_tensors=False, + ) + assert indices == list(range(0, 3)) + list(range(3, 5)) + list(range(5, 14)) + assert offsets == [0, 3, 5] + assert spans == [0] * 3 + [1] * 2 + [2] * 9 + + +def test_word_span_mean_pooler_with_embedding_bag_flat_indices(): + t = build_tensor().refold("context", "token") + # 0 -> 2: [[0, 2, 3], [10]] + # 5 -> 7: [[10, 11], [100, 101]] + indices, offsets, spans = t.lengths.make_indices_ranges( + begins=(torch.tensor([0, 5]),), + ends=(torch.tensor([2, 7]),), + indice_dims=("word",), + ) + embeds = t.unsqueeze(-1).expand(-1, -1, 2).float() + res = torch.nn.functional.embedding_bag( + input=indices, + weight=embeds.view(-1, 2), + offsets=offsets, + mode="mean", + ) + assert res.tolist() == [[3.75, 3.75], [55.5, 55.5]] + + +def test_word_span_mean_pooler_with_embedding_bag_multidim_indices(): + t = build_tensor().refold("context", "token") + # 0 -> 2: [[0, 2, 3], [10]] + # 5 -> 7: [[10, 11], [100, 101]] + indices, offsets, spans = t.lengths.make_indices_ranges( + begins=( + torch.tensor([0, 1]), + torch.tensor([0, 2]), + ), + ends=( + torch.tensor([0, 1]), + torch.tensor([2, 4]), + ), + indice_dims=( + "context", + "word", + ), + ) + embeds = t.unsqueeze(-1).expand(-1, -1, 2).float() + res = torch.nn.functional.embedding_bag( + input=indices, + weight=embeds.view(-1, 2), + offsets=offsets, + mode="mean", + ) + assert res.tolist() == [[3.75, 3.75], [55.5, 55.5]] + + +def test_map_indices_format_torch_multidimensional(): + t = build_tensor() + + assert torch.allclose( + t.lengths.map_indices( + indices=( + torch.as_tensor([[0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12, 13]]), + ), + indice_dims=("token",), + ), + torch.tensor([[0, 1, 2, 3, 6, 12, 13], [14, 15, 16, 18, 19, 21, 22]]), + ) + + +def test_map_indices_format_numpy_multidimensional(): + t = build_tensor() + + assert np.allclose( + t.lengths.map_indices( + indices=(np.asarray([[0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12, 13]]),), + indice_dims=("token",), + ), + np.asarray([[0, 1, 2, 3, 6, 12, 13], [14, 15, 16, 18, 19, 21, 22]]), + ) + + +def test_make_indices_ranges_format_torch_multidimensional(): + t = build_tensor_single_sample() + + indices, offsets, spans = t.lengths.make_indices_ranges( + begins=(torch.as_tensor([[0, 0, 1]]), torch.as_tensor([[0, 1, 2]])), + ends=(torch.as_tensor([[0, 1, 1]]), torch.as_tensor([[1, 3, 4]])), + indice_dims=("context", "word"), + ) + + assert isinstance(indices, torch.Tensor) + assert isinstance(offsets, torch.Tensor) + assert isinstance(spans, torch.Tensor) + assert offsets.shape == (1, 3) + + +def test_make_indices_ranges_format_numpy_multidimensional(): + t = build_tensor_single_sample() + + indices, offsets, spans = t.lengths.make_indices_ranges( + begins=(np.asarray([[0, 0, 1]]), np.asarray([[0, 1, 2]])), + ends=(np.asarray([[0, 1, 1]]), np.asarray([[1, 3, 4]])), + indice_dims=("context", "word"), + ) + assert isinstance(indices, np.ndarray) + assert isinstance(offsets, np.ndarray) + assert isinstance(spans, np.ndarray) + assert offsets.shape == (1, 3)