Skip to content

Commit db06b56

Browse files
committed
Saved computational costs of get_intermediate_layers() from unused blocks
1 parent 4731e4e commit db06b56

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/models/vision_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,13 +635,14 @@ def _intermediate_layers(
635635
) -> List[torch.Tensor]:
636636
outputs, num_blocks = [], len(self.blocks)
637637
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
638+
last_index_to_take = max(take_indices)
638639

639640
# forward pass
640641
x = self.patch_embed(x)
641642
x = self._pos_embed(x)
642643
x = self.patch_drop(x)
643644
x = self.norm_pre(x)
644-
for i, blk in enumerate(self.blocks):
645+
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
645646
x = blk(x)
646647
if i in take_indices:
647648
outputs.append(x)

0 commit comments

Comments
 (0)