Skip to content

Conversation

@sky-2002
Copy link

Description

Implements the ShortGPT pruning algorithm.
Refer #418

Related Issue

Fixes #418

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

I am figuring out how to write tests for this, found the test framework a little tricky. Would appreciate some help there. I did not understand how the run_full_integration gets the dataset required to run tests, tokenizer etc.

Also, I found the SmashConfigPrefixWrapper confusing, so I tested with an alternative way:

from transformers import AutoModelForCausalLM, AutoTokenizer
from pruna.algorithms.shortgpt import ShortGPT
from datasets import load_dataset
algo = ShortGPT()

llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", 
llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype="auto")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]")

dl = DataLoader(dataset, batch_size=2, shuffle=True)

qmodel = algo._apply(llama_model, {
    "tokenizer": llama_tokenizer,
    "device": "cpu",
    "prune_ratio": 0.25,
    "angular": True,
    "train_dataloader": dl,
})

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

bis[i] += bi
counts += 1

bis /= counts
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: NaNs: The Silent Killer of Predictable Behavior

If the dataloader is empty or yields no batches, counts remains 0, causing bis /= counts to produce NaN values. These NaN values propagate through np.argsort, resulting in unpredictable pruning behavior instead of a clear error message.

Fix in Cursor Fix in Web

@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Nov 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] ShortGPT pruning algorithm

1 participant