diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 2900f6bd3..96b400d63 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, Optional, Tuple import torch @@ -247,7 +248,16 @@ def calibrate_kv_cache_input_hook( kv_cache to singleton QuantizedKVParameterCache. """ kv_cache = getattr(module, "kv_cache") - kwargs["past_key_values"] = kv_cache + if not hasattr(module, "_past_kv_name"): + # Determine which past KV parameter name to use once and cache it + # TODO: Find a better place to cache this + module._past_kv_name = ( + "past_key_value" # transformers#39956 + if "past_key_value" in inspect.signature(module.forward).parameters + else "past_key_values" + ) + + kwargs[module._past_kv_name] = kv_cache kwargs["use_cache"] = False return args, kwargs