Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions src/pruna/algorithms/shortgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

import numpy as np
import torch
import torch.nn.functional as f
from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter
from tqdm import tqdm

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.logging.logger import pruna_logger


class ShortGPT(PrunaAlgorithmBase):
"""
ShortGPT algorithm for pruning transformer layers using a block influence metric.

ShortGPT identifies and prunes less important blocks in transformer models based on their
BI scores, which uses the similarity between a layers input and output to measure its importance.
"""

algorithm_name: str = "shortgpt"
group_tags: list[str] = [tags.PRUNER]
references: dict[str, str] = {
"Paper": "https://arxiv.org/pdf/2403.03853",
}
save_fn = SAVE_FUNCTIONS.pickled
tokenizer_required: bool = True
dataset_required: bool = True
processor_required: bool = False
runs_on: list[str] = ["cuda", "cpu"]

def get_hyperparameters(self) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Returns
-------
list
The hyperparameters.
"""
return [
CategoricalHyperparameter(
"metric_type",
["BI"],
default_value="BI",
meta=dict(desc="Metric type for layer importance: Block Influence"),
),
UniformFloatHyperparameter(
"prune_ratio",
lower=0.0,
upper=0.8,
default_value=0.25,
meta=dict(desc="Fraction of layers to prune"),
),
Boolean("angular", meta=dict(desc="Use angular distance for BI computation")),
]

@staticmethod
@torch.inference_mode()
def compute_block_influence(model, tokenizer, dataloader, angular=False, device="cuda"):
"""
Compute the block influence scores for each transformer layer in the model.

The block influence score for a layer is given as 1 - the cosine similarity
between the layer's input and output activations, averaged over the dataset.
"""
model.eval().to(device)
num_layers = len(model.model.layers)
bis = torch.zeros(num_layers + 1, device=device)
counts = 0

# TODO: Discuss if we should keep clearing device cache in case of gpu,
# because model and data keep moving to device
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Computing Block Influence")):
if isinstance(batch, dict) and "text" in batch:
texts = batch["text"]
elif isinstance(batch, list):
texts = batch
else:
raise ValueError(f"Unsupported batch type: {type(batch)}")

inputs = tokenizer(
texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
).to(device)
input_ids = inputs["input_ids"]

outputs = model(input_ids=input_ids, output_hidden_states=True)
hiddens = list(outputs.hidden_states)

for i in range(len(hiddens) - 1):
in_h, out_h = hiddens[i].float(), hiddens[i + 1].float()
cos = f.cosine_similarity(
in_h.view(-1, in_h.shape[-1]),
out_h.view(-1, out_h.shape[-1]),
dim=-1,
)
if angular:
cos = cos.clamp(-1 + 1e-7, 1 - 1e-7)
bi = torch.acos(cos).mean() / np.pi
else:
bi = (1 - cos).mean()
bis[i] += bi
counts += 1

bis /= counts
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: NaNs: The Silent Killer of Predictable Behavior

If the dataloader is empty or yields no batches, counts remains 0, causing bis /= counts to produce NaN values. These NaN values propagate through np.argsort, resulting in unpredictable pruning behavior instead of a clear error message.

Fix in Cursor Fix in Web

return bis.tolist()

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
device = smash_config["device"]
model = model.to(device)
model.eval()

pruna_logger.info(f"[ShortGPT] Starting layer pruning for model on device: {device}")
pruna_logger.info(f"[ShortGPT] Model depth: {len(model.model.layers)}")
pruna_logger.info(f"[ShortGPT] Model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
tokenizer = smash_config["tokenizer"]

dataloader = smash_config["train_dataloader"]
prune_ratio = smash_config["prune_ratio"]
angular = smash_config["angular"]

pruna_logger.info(f"[ShortGPT] Running layer pruning (ratio={prune_ratio:.2f})")

scores = self.compute_block_influence(model, tokenizer, dataloader, angular=angular, device=device)

num_layers = len(model.model.layers)
n_prune = int(prune_ratio * num_layers)

# not using the final norm layer score, because paper only mentions only transformer layers # noqa
# TODO: Should we even compute the norm layer score? # noqa
layer_scores = np.array(scores[:num_layers])

prune_indices = np.argsort(layer_scores)[:n_prune].tolist()
keep_indices = [i for i in range(num_layers) if i not in prune_indices]

pruna_logger.info(f"[ShortGPT] Pruning {n_prune}/{num_layers} layers: {prune_indices}")
pruna_logger.info(f"[ShortGPT] Removing layers: {prune_indices}")

kept_layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i in keep_indices])
model.model.layers = kept_layers

pruna_logger.info(f"[ShortGPT] Pruned model depth: {len(model.model.layers)}")
pruna_logger.info(
f"[ShortGPT] Pruned model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M"
)

return model

def model_check_fn(self, model):
"""
Check if the model is a torch.nn.Module.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a torch.nn.Module, False otherwise.
"""
return isinstance(model, torch.nn.Module)