diff --git a/src/pruna/algorithms/shortgpt.py b/src/pruna/algorithms/shortgpt.py new file mode 100644 index 00000000..fd458c99 --- /dev/null +++ b/src/pruna/algorithms/shortgpt.py @@ -0,0 +1,187 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as f +from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter +from tqdm import tqdm + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +class ShortGPT(PrunaAlgorithmBase): + """ + ShortGPT algorithm for pruning transformer layers using a block influence metric. + + ShortGPT identifies and prunes less important blocks in transformer models based on their + BI scores, which uses the similarity between a layers input and output to measure its importance. + """ + + algorithm_name: str = "shortgpt" + group_tags: list[str] = [tags.PRUNER] + references: dict[str, str] = { + "Paper": "https://arxiv.org/pdf/2403.03853", + } + save_fn = SAVE_FUNCTIONS.pickled + tokenizer_required: bool = True + dataset_required: bool = True + processor_required: bool = False + runs_on: list[str] = ["cuda", "cpu"] + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + CategoricalHyperparameter( + "metric_type", + ["BI"], + default_value="BI", + meta=dict(desc="Metric type for layer importance: Block Influence"), + ), + UniformFloatHyperparameter( + "prune_ratio", + lower=0.0, + upper=0.8, + default_value=0.25, + meta=dict(desc="Fraction of layers to prune"), + ), + Boolean("angular", meta=dict(desc="Use angular distance for BI computation")), + ] + + @staticmethod + @torch.inference_mode() + def compute_block_influence(model, tokenizer, dataloader, angular=False, device="cuda"): + """ + Compute the block influence scores for each transformer layer in the model. + + The block influence score for a layer is given as 1 - the cosine similarity + between the layer's input and output activations, averaged over the dataset. + """ + model.eval().to(device) + num_layers = len(model.model.layers) + bis = torch.zeros(num_layers + 1, device=device) + counts = 0 + + # TODO: Discuss if we should keep clearing device cache in case of gpu, + # because model and data keep moving to device + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Computing Block Influence")): + if isinstance(batch, dict) and "text" in batch: + texts = batch["text"] + elif isinstance(batch, list): + texts = batch + else: + raise ValueError(f"Unsupported batch type: {type(batch)}") + + inputs = tokenizer( + texts, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True, + ).to(device) + input_ids = inputs["input_ids"] + + outputs = model(input_ids=input_ids, output_hidden_states=True) + hiddens = list(outputs.hidden_states) + + for i in range(len(hiddens) - 1): + in_h, out_h = hiddens[i].float(), hiddens[i + 1].float() + cos = f.cosine_similarity( + in_h.view(-1, in_h.shape[-1]), + out_h.view(-1, out_h.shape[-1]), + dim=-1, + ) + if angular: + cos = cos.clamp(-1 + 1e-7, 1 - 1e-7) + bi = torch.acos(cos).mean() / np.pi + else: + bi = (1 - cos).mean() + bis[i] += bi + counts += 1 + + bis /= counts + return bis.tolist() + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + device = smash_config["device"] + model = model.to(device) + model.eval() + + pruna_logger.info(f"[ShortGPT] Starting layer pruning for model on device: {device}") + pruna_logger.info(f"[ShortGPT] Model depth: {len(model.model.layers)}") + pruna_logger.info(f"[ShortGPT] Model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") + tokenizer = smash_config["tokenizer"] + + dataloader = smash_config["train_dataloader"] + prune_ratio = smash_config["prune_ratio"] + angular = smash_config["angular"] + + pruna_logger.info(f"[ShortGPT] Running layer pruning (ratio={prune_ratio:.2f})") + + scores = self.compute_block_influence(model, tokenizer, dataloader, angular=angular, device=device) + + num_layers = len(model.model.layers) + n_prune = int(prune_ratio * num_layers) + + # not using the final norm layer score, because paper only mentions only transformer layers # noqa + # TODO: Should we even compute the norm layer score? # noqa + layer_scores = np.array(scores[:num_layers]) + + prune_indices = np.argsort(layer_scores)[:n_prune].tolist() + keep_indices = [i for i in range(num_layers) if i not in prune_indices] + + pruna_logger.info(f"[ShortGPT] Pruning {n_prune}/{num_layers} layers: {prune_indices}") + pruna_logger.info(f"[ShortGPT] Removing layers: {prune_indices}") + + kept_layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i in keep_indices]) + model.model.layers = kept_layers + + pruna_logger.info(f"[ShortGPT] Pruned model depth: {len(model.model.layers)}") + pruna_logger.info( + f"[ShortGPT] Pruned model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M" + ) + + return model + + def model_check_fn(self, model): + """ + Check if the model is a torch.nn.Module. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a torch.nn.Module, False otherwise. + """ + return isinstance(model, torch.nn.Module)