Skip to content
32 changes: 23 additions & 9 deletions trackers/core/reid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,25 @@ def _initialize_reid_model_from_timm(
return cls(model, device, transforms, model_metadata)


def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str):
state_dict, config = load_safetensors_checkpoint(checkpoint_path)
def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str, config_path: str):
state_dict, config = load_safetensors_checkpoint(checkpoint_path, config_path)
model_name = config.get("architecture")
if model_name is None:
raise ValueError(
f"The config at {config_path} is missing the 'architecture' key."
)
init_kwargs = {}
init_kwargs["pretrained"] = False
reid_model_instance = _initialize_reid_model_from_timm(
cls, **config["model_metadata"]
cls, model_name_or_checkpoint_path=model_name, device="auto", **init_kwargs
)
if config["projection_dimension"]:
if config.get("projection_dimension"):
reid_model_instance._add_projection_layer(
projection_dimension=config["projection_dimension"]
projection_dimension=config.get("projection_dimension")
)
for k, v in state_dict.items():
state_dict[k].to(reid_model_instance.device)
reid_model_instance.backbone_model.load_state_dict(state_dict)
state_dict[k] = v.to(reid_model_instance.device)
reid_model_instance.backbone_model.load_state_dict(state_dict, strict=False)
return reid_model_instance


Expand Down Expand Up @@ -122,6 +129,7 @@ def __init__(
def from_timm(
cls,
model_name_or_checkpoint_path: str,
config_path: Optional[str] = None,
device: Optional[str] = "auto",
get_pooled_features: bool = True,
**kwargs,
Expand All @@ -134,6 +142,8 @@ def from_timm(
model_name_or_checkpoint_path (str): Name of the timm model to use or
path to a safetensors checkpoint. If the exact model name is not
found, the closest match from `timm.list_models` will be used.
config_path (str): Path to the config file for the local
safetensors checkpoint.
device (str): Device to run the model on.
get_pooled_features (bool): Whether to get the pooled features from the
model or not.
Expand All @@ -143,9 +153,13 @@ def from_timm(
Returns:
ReIDModel: A new instance of `ReIDModel`.
"""
if os.path.exists(model_name_or_checkpoint_path):
if (
config_path is not None
and os.path.exists(model_name_or_checkpoint_path)
and os.path.exists(config_path)
):
return _initialize_reid_model_from_checkpoint(
cls, model_name_or_checkpoint_path
cls, model_name_or_checkpoint_path, config_path
)
else:
return _initialize_reid_model_from_timm(
Expand Down
21 changes: 12 additions & 9 deletions trackers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,31 @@ def parse_device_spec(device_spec: Union[str, torch.device]) -> torch.device:


def load_safetensors_checkpoint(
checkpoint_path: str, device: str = "cpu"
checkpoint_path: str,
config_path: str,
device: str = "cpu",
) -> Tuple[dict[str, torch.Tensor], dict[str, Any]]:
"""
Load a safetensors checkpoint into a dictionary of tensors and a dictionary
of metadata.
Load a safetensors checkpoint into a dictionary of tensors and a
separate JSON config file.

Args:
checkpoint_path (str): The path to the safetensors checkpoint.
config_path (str): The path to the JSON config file.
device (str): The device to load the checkpoint on.

Returns:
Tuple[dict[str, torch.Tensor], dict[str, Any]]: A tuple containing the
state_dict and the config.
state_dict (dict): model weights
config (dict): model config
"""
state_dict = {}
with safe_open(checkpoint_path, framework="pt", device=device) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
metadata = f.metadata()
config = json.loads(metadata["config"]) if "config" in metadata else {}
model_metadata = config.pop("model_metadata") if "model_metadata" in config else {}
if "kwargs" in model_metadata:
with open(config_path, "r") as f:
config = json.load(f)
model_metadata = config.pop("model_metadata", {})
if isinstance(model_metadata, dict) and "kwargs" in model_metadata:
kwargs = model_metadata.pop("kwargs")
model_metadata = {**kwargs, **model_metadata}
config["model_metadata"] = model_metadata
Expand Down