Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 111 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Interface to initialize and load HF models."""

import json
import os
import re
import types
from abc import abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import safetensors.torch
import torch
import torch.nn as nn
from accelerate import init_empty_weights, load_checkpoint_in_model
Expand Down Expand Up @@ -434,20 +436,120 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
# Ensure it's the first one.
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)

# reuse the load checkpoint utility from accelerate
# Choose loading method based on environment variable
# Default behavior: preload checkpoint files to CPU
# Set AD_DISABLE_PRELOAD=1 to use accelerate's load_checkpoint_in_model (no CPU preload)
disable_preload = os.environ.get("AD_DISABLE_PRELOAD", "0") == "1"
Comment on lines +439 to +442
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to keep this configurability or remove it?

Seems to me we don't need to keep it around

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure, but would it be useful for trying turning off the preloading on host machines w/ small memory?
(Though PT backend is not allowing turning this off though..)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


try:
with hf_load_state_dict_with_device(device):
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
# which collects local model params, syncs weights from checkpoint, and applies them via
# model.load_state_dict.
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
# model-transformed weights,leading to unexpected key mismatches or format issues.
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
if disable_preload:
# Load checkpoint directly to GPU using accelerate's load_checkpoint_in_model (no CPU preload)
ad_logger.info(
"AD_DISABLE_PRELOAD=1: Using accelerate's load_checkpoint_in_model (no CPU preload)"
)
with hf_load_state_dict_with_device(device):
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
else:
# Preload checkpoint files to CPU
ad_logger.info("Preloading checkpoint files to CPU")
self._load_checkpoint_with_preload(model, ckpt_file, device)
finally:
load_handle.remove()
get_handle.remove()

def _load_checkpoint_with_preload(
self, model: nn.Module, ckpt_file: str, device: DeviceLikeType
):
ad_logger.info("Preloading full checkpoint to CPU memory...")
all_weights = self._load_full_checkpoint_to_cpu(ckpt_file)

ad_logger.info(f"Loading weights into model (device: {device})...")
model.load_state_dict(all_weights, strict=False)

# Free CPU memory
del all_weights

ad_logger.info("Checkpoint loading completed")

def _load_full_checkpoint_to_cpu(self, checkpoint: str) -> dict:
"""Load the full checkpoint to CPU memory.

Args:
checkpoint: Can be:
- a path to a file containing a whole model state dict
- a path to a `.json` file containing the index to a sharded checkpoint
- a path to a folder containing a unique `.index.json` file and the shards
- a path to a folder containing a unique pytorch_model.bin or model.safetensors
"""
checkpoint_files = None
index_filename = None

# Fast path: Direct .index.json file (most common case for sharded checkpoints)
if os.path.isfile(checkpoint):
if checkpoint.endswith(".index.json"):
index_filename = checkpoint
else:
checkpoint_files = [checkpoint]
elif os.path.isdir(checkpoint):
# Check if the whole state dict is present (priority order matches accelerate)
potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
potential_state_safetensor = [
f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
]

# Case 1: pytorch_model.bin (WEIGHTS_NAME)
if len(potential_state_bin) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
# Case 2: model.safetensors (SAFE_WEIGHTS_NAME)
elif len(potential_state_safetensor) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
else:
# Case 3: Otherwise check for sharded checkpoints
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
if len(potential_index) == 0:
raise ValueError(
f"{checkpoint} is not a folder containing a `.index.json` file or a "
f"{WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
)
elif len(potential_index) == 1:
index_filename = os.path.join(checkpoint, potential_index[0])
else:
raise ValueError(
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
)
else:
raise ValueError(
f"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got "
f"{checkpoint}."
)

# Load checkpoint files from index if needed
if index_filename is not None:
checkpoint_folder = os.path.dirname(index_filename)
with open(index_filename, "r") as f:
index = json.load(f)

if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = list(set(index.values()))
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]

# Load all weights
all_weights = {}
for checkpoint_file in checkpoint_files:
ad_logger.info(f"Loading weight file: {checkpoint_file}")
if checkpoint_file.endswith(".safetensors"):
file_weights = safetensors.torch.load_file(checkpoint_file, device="cpu")
elif checkpoint_file.endswith((".bin", ".pth")):
file_weights = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
else:
raise ValueError(f"Unsupported checkpoint format: {checkpoint_file}")

all_weights.update(file_weights)

return all_weights

def _load_quantization_config(self, fetched_dir: str):
"""Load the quantization config from the model directory if not done already."""
if self._quant_config_reader is not None:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""High-level entrypoint to transform a model into an efficient inference model."""

import gc
import time
from typing import Optional

import torch
Expand All @@ -10,6 +11,7 @@
from ..distributed import common as dist_ad
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..utils.logger import ad_logger
from .interface import (
InferenceOptimizerConfig,
SharedConfig,
Expand Down Expand Up @@ -64,11 +66,14 @@ def __call__(self, cm: CachedSequenceInterface, mod: Optional[nn.Module] = None)
mod = nn.Module()

# iterate over all transforms sorted by stage in the config
start_time = time.time()
for t_name, t_config in self.config.items():
# instantiate transform
transform = TransformRegistry.get(t_name)(t_config)
# run transform
mod = transform(mod, cm, self.factory, self.shared_config)
total_time = time.time() - start_time
ad_logger.info(f"Total time for all transforms: {total_time:.2f}s")

############################################################################################
# RETURN OPTIMIZED MODEL
Expand Down