Skip to content

Commit b019d9a

Browse files
committed
feat: add indice and indice ranges mapping functions
1 parent d2c8d98 commit b019d9a

File tree

8 files changed

+1123
-103
lines changed

8 files changed

+1123
-103
lines changed

README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,59 @@ print(refolded_embedding.shape)
106106
# torch.Size([2, 5, 16]) # 2 samples, 5 words max, 16 dims
107107
```
108108

109+
### Pooling spans
110+
111+
You can pool variable length spans directly on a refolded view without padding by building flat indices and offsets and then using `embedding_bag`.
112+
113+
The helper `lengths.make_indices_ranges` expands ranges defined over contiguous variable dimensions into three arrays
114+
115+
- `indices` are the flat positions in the refolded tensor viewed as a single dimension
116+
- `offsets` are the start positions of each span within `indices`
117+
- `spans` gives the span id for every expanded position, which can be useful for functions like `torch.index_add` or `torch.index_reduce`
118+
119+
Example that sums over word spans to produce one vector per span
120+
121+
```python
122+
import torch
123+
import foldedtensor as ft
124+
125+
# Build a 4 level tensor with names: first word of the first context is split into three tokens, etc
126+
input_ids = ft.as_folded_tensor(
127+
[
128+
[
129+
[[0, 2, 3], [10], [4]],
130+
[[0, 1, 2], [2, 3], [10, 11], [100, 101]],
131+
],
132+
],
133+
full_names=("sample", "context", "word", "token"),
134+
).refold(
135+
"token"
136+
) # any refolding is fine
137+
138+
# Create embeddings from the input ids
139+
embedding = torch.nn.Embedding(2048, 16)
140+
weight = embedding(input_ids)
141+
142+
# Pool two word spans per the test
143+
# span 1 covers words 0 to 2 -> mean pool over 4 tokens [0, 2, 3, 10]
144+
# span 2 covers words 5 to 7 -> mean pool over 4 tokens [10, 11, 100, 101]
145+
indices, offsets, spans = input_ids.lengths.make_indices_ranges(
146+
begins=(torch.tensor([0, 5]),),
147+
ends=(torch.tensor([2, 7]),),
148+
indice_dims=("word",),
149+
)
150+
151+
# Sum embeddings over each span
152+
pooled = torch.nn.functional.embedding_bag(
153+
input=indices,
154+
# Flatten embeddings so rows align with flattened token positions
155+
weight=weight.view(-1, weight.size(-1)),
156+
offsets=offsets,
157+
mode="mean",
158+
)
159+
print(pooled)
160+
```
161+
109162
## Benchmarks
110163

111164
View the comparisons of `foldedtensor` against various alternatives here: [docs/benchmarks](https://github.com/aphp/foldedtensor/blob/main/docs/benchmark.md).

changelog.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## Unreleased
4+
5+
- Add `map_indices` and `make_indices_ranges` with C++ backends and expose `lengths.map_indices` and `lengths.make_indices_ranges` with contiguous `indice_dims`, boundary handling, and flat indices with offsets and span ids for pooling with `embedding_bag`
6+
- Introduce `FoldedTensorLayout` to store `full_names` and `data_dims` with named dimension resolution and helper methods and use it as the `lengths` container for `FoldedTensor`
7+
- Improve `as_folded_tensor` to better infer dims and dtype from nested data and to accept named `data_dims` and better handle names and empty structures
8+
- Benchmark script adds `--cases` to run selected cases and a new case for range based pooling and adjusts outputs
9+
310
## v0.4.0
411

512
- Fix `storage` torch warning

docs/benchmark.md

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ It compares the performance of `foldedtensor` with various alternatives for padd
88
and working with nested lists and tensors.
99

1010
Environment:
11-
- `torch.__version__ == '2.6.0'`
11+
- `torch.__version__ == '2.7.1'`
1212
- `foldedtensor.__version__ == '0.4.0'`
13-
- `python == 3.9.20`
13+
- `python == 3.11.3`
1414
- `sys.platform == 'darwin'`
1515

1616

@@ -22,79 +22,79 @@ nested_list = make_nested_list(32, (50, 100), (25, 30), value=1)
2222

2323
Comparisons:
2424
%timeit python_padding(nested_list)
25-
# 100 loops, best of 5: 15.09 ms per loop
25+
# 100 loops, best of 5: 16.96 ms per loop
2626

2727
%timeit foldedtensor.as_folded_tensor(nested_list)
28-
# 100 loops, best of 5: 0.73 ms per loop
28+
# 100 loops, best of 5: 0.88 ms per loop
2929

3030
```
31-
Speedup against best alternative: **20.67x** :rocket:
31+
Speedup against best alternative: **19.36x** :rocket:
3232

3333
## Case 2 (same lengths nested lists)
3434

3535
```python
3636
nested_list = make_nested_list(32, 100, 30, value=1)
3737

3838
%timeit torch.tensor(nested_list)
39-
# 100 loops, best of 5: 6.51 ms per loop
39+
# 100 loops, best of 5: 7.67 ms per loop
4040

4141
%timeit torch.LongTensor(nested_list)
42-
# 100 loops, best of 5: 2.78 ms per loop
42+
# 100 loops, best of 5: 3.59 ms per loop
4343

4444
%timeit python_padding(nested_list)
45-
# 100 loops, best of 5: 18.38 ms per loop
45+
# 100 loops, best of 5: 20.02 ms per loop
4646

4747
%timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0)
48-
# 100 loops, best of 5: 3.00 ms per loop
48+
# 100 loops, best of 5: 3.83 ms per loop
4949

5050
%timeit foldedtensor.as_folded_tensor(nested_list)
51-
# 100 loops, best of 5: 1.08 ms per loop
51+
# 100 loops, best of 5: 1.22 ms per loop
5252

5353
```
54-
Speedup against best alternative: **2.58x** :rocket:
54+
Speedup against best alternative: **2.94x** :rocket:
5555

5656
## Case 3 (simple list)
5757

5858
```python
5959
simple_list = make_nested_list(10000, value=1)
6060

6161
%timeit torch.tensor(simple_list)
62-
# 100 loops, best of 5: 0.63 ms per loop
62+
# 100 loops, best of 5: 0.75 ms per loop
6363

6464
%timeit torch.LongTensor(simple_list)
65-
# 100 loops, best of 5: 0.27 ms per loop
65+
# 100 loops, best of 5: 0.36 ms per loop
6666

6767
%timeit python_padding(simple_list)
68-
# 100 loops, best of 5: 0.28 ms per loop
68+
# 100 loops, best of 5: 0.36 ms per loop
6969

7070
%timeit foldedtensor.as_folded_tensor(simple_list)
71-
# 100 loops, best of 5: 0.08 ms per loop
71+
# 100 loops, best of 5: 0.10 ms per loop
7272

7373
```
74-
Speedup against best alternative: **3.32x** :rocket:
74+
Speedup against best alternative: **3.57x** :rocket:
7575

7676
## Case 4 (same lengths nested lists to flat tensor)
7777

7878
```python
7979
nested_list = make_nested_list(32, 100, 30, value=1)
8080

8181
%timeit torch.tensor(nested_list).view(-1)
82-
# 100 loops, best of 5: 6.52 ms per loop
82+
# 100 loops, best of 5: 7.65 ms per loop
8383

8484
%timeit torch.LongTensor(nested_list).view(-1)
85-
# 100 loops, best of 5: 2.76 ms per loop
85+
# 100 loops, best of 5: 3.63 ms per loop
8686

8787
%timeit python_padding(nested_list).view(-1)
88-
# 100 loops, best of 5: 18.62 ms per loop
88+
# 100 loops, best of 5: 20.36 ms per loop
8989

9090
%timeit foldedtensor.as_folded_tensor(nested_list).view(-1)
91-
# 100 loops, best of 5: 1.12 ms per loop
91+
# 100 loops, best of 5: 1.22 ms per loop
9292

9393
%timeit foldedtensor.as_folded_tensor(nested_list, data_dims=(2,))
94-
# 100 loops, best of 5: 1.08 ms per loop
94+
# 100 loops, best of 5: 1.20 ms per loop
9595

9696
```
97-
Speedup against best alternative: **2.47x** :rocket:
97+
Speedup against best alternative: **2.96x** :rocket:
9898
## Case 5 (variable lengths nested lists) to padded embeddings
9999

100100
Nested lists with different lengths (second level lists have lengths between 50 and 150). We compare `foldedtensor` with `torch.nested`.
@@ -104,41 +104,72 @@ nested_list = make_nested_list(32, (50, 150), 30, value=1)
104104
# Padding with 0
105105

106106
%timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0)
107-
# 100 loops, best of 5: 3.02 ms per loop
107+
# 100 loops, best of 5: 4.10 ms per loop
108108

109109
%timeit foldedtensor.as_folded_tensor(nested_list).as_tensor()
110-
# 100 loops, best of 5: 1.03 ms per loop
110+
# 100 loops, best of 5: 1.23 ms per loop
111111

112112
```
113-
Speedup against best alternative: **2.95x** :rocket:
113+
Speedup against best alternative: **3.33x** :rocket:
114114
```python
115115
# Padding with 1
116116

117117
%timeit torch.nested.nested_tensor([torch.FloatTensor(sub) for sub in nested_list]).to_padded_tensor(1)
118-
# 100 loops, best of 5: 3.72 ms per loop
118+
# 100 loops, best of 5: 4.42 ms per loop
119119

120120
%timeit x = foldedtensor.as_folded_tensor(nested_list); x.masked_fill_(x.mask, 1)
121-
# 100 loops, best of 5: 1.62 ms per loop
121+
# 100 loops, best of 5: 1.58 ms per loop
122122

123123
```
124-
Speedup against best alternative: **2.30x** :rocket:
124+
Speedup against best alternative: **2.80x** :rocket:
125125

126126
## Case 6 (2d padding)
127127

128128
```python
129129
nested_list = make_nested_list(160, (50, 150), value=1)
130130

131131
%timeit python_padding(nested_list)
132-
# 100 loops, best of 5: 1.33 ms per loop
132+
# 100 loops, best of 5: 1.48 ms per loop
133133

134134
%timeit torch.nested.nested_tensor([torch.LongTensor(sub) for sub in nested_list]).to_padded_tensor(0)
135-
# 100 loops, best of 5: 1.14 ms per loop
135+
# 100 loops, best of 5: 1.28 ms per loop
136136

137137
%timeit torch.nn.utils.rnn.pad_sequence([torch.LongTensor(sub) for sub in nested_list], batch_first=True, padding_value=0)
138-
# 100 loops, best of 5: 0.86 ms per loop
138+
# 100 loops, best of 5: 1.02 ms per loop
139139

140140
%timeit foldedtensor.as_folded_tensor(nested_list)
141-
# 100 loops, best of 5: 0.15 ms per loop
141+
# 100 loops, best of 5: 0.17 ms per loop
142142

143143
```
144-
Speedup against best alternative: **5.88x** :rocket:
144+
Speedup against best alternative: **6.03x** :rocket:
145+
146+
## Case 7 (summing vectors inside each differently-sized sequence, all concatenated)
147+
148+
```python
149+
def sum_all_words_per_sample(t):
150+
begins = torch.arange(len(t.lengths[1]))
151+
ends = begins + 1
152+
indices, offsets, spans = t.lengths.make_indices_ranges(
153+
begins=(begins,), ends=(ends,), indice_dims=(0,)
154+
)
155+
return torch.nn.functional.embedding_bag(
156+
input=indices,
157+
weight=t.view(-1, t.size(-1)),
158+
offsets=offsets,
159+
mode="sum",
160+
)
161+
162+
embedder = torch.nn.Embedding(500, 128)
163+
nested_list = make_nested_list(320, (150, 250), value=1)
164+
ft = foldedtensor.as_folded_tensor(nested_list).refold(1)
165+
ft = embedder(ft)
166+
167+
168+
%timeit ft.refold(0, 1).sum(-2)
169+
# 100 loops, best of 5: 3.56 ms per loop
170+
171+
%timeit sum_all_words_per_sample(ft)
172+
# 100 loops, best of 5: 1.00 ms per loop
173+
174+
```
175+
Speedup against pad-then-sum: **3.56x** :rocket:

0 commit comments

Comments
 (0)