Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/dinotxt_stage.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This version of the YAML uses processed patches
project: /home/suraj/Repositories/DINOv2_3D
project: "."
run_name: "DINOv2_pretrain_primus"
img_size: [160, 160, 160]
hidden_size: 864
Expand Down
4 changes: 2 additions & 2 deletions configs/predict.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
project: /home/suraj/Repositories/DINOv2_3D
project: "."
num_workers: 8

trainer:
Expand All @@ -16,7 +16,7 @@ trainer:
name: "get_predictions"
callbacks:
- _target_: project.callbacks.prediction_saver.SavePredictions
path: /home/suraj/Repositories/DINOv2_3D/predictions.csv
path: ./predictions.csv

lightning_module:
_target_: project.models.DINOv2_3D_LightningModule
Expand Down
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This version of the YAML uses processed patches, optimized for low memory usage
project: /home/suraj/Repositories/DINOv2_3D
project: "."
run_name: "DINOv2_pretrain_primus"
img_size: [160, 160, 160] # reduced spatial dims for lower mem
hidden_size: 864
Expand Down
15 changes: 14 additions & 1 deletion scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
from pathlib import Path
from torch.utils.data import DataLoader
from monai.bundle import ConfigParser

Expand Down Expand Up @@ -34,7 +35,19 @@ def run(mode, config_file: str, **config_overrides):
parser.update(config_overrides)

project_path = parser.get("project")
import_module_from_path("project", project_path)

# Normalize and resolve project path
# If project_path is relative, resolve it relative to the repository root
project_path = Path(project_path).expanduser()

if not project_path.is_absolute():
# Get the repository root (parent of the scripts directory)
repo_root = Path(__file__).resolve().parent.parent
project_path = (repo_root / project_path).resolve()
else:
project_path = project_path.resolve()

import_module_from_path("project", str(project_path))

trainer = parser.get_parsed_content("trainer")
lightning_module = parser.get_parsed_content("lightning_module")
Expand Down
22 changes: 18 additions & 4 deletions scripts/utility/export_ckpt_to_nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import os
import sys
from pathlib import Path
from typing import Dict, Any

from utils.imports import import_module_from_path
Expand All @@ -30,12 +31,22 @@ def modify_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:

def process_checkpoint(input_path: str, output_path: str,
remove_cls_token: bool = True,
arch_class_name: str = 'PrimusM') -> None:
arch_class_name: str = 'PrimusM',
project_path: str = '.') -> None:
"""Process checkpoint file and save in nnUNet format"""

# Import project module to handle checkpoint dependencies
project_path = "/home/suraj/Repositories/DINOv2_3D"
import_module_from_path("project", project_path)
# Normalize and resolve project path
project_path = Path(project_path).expanduser()

if not project_path.is_absolute():
# Resolve relative to the repository root
script_dir = Path(__file__).resolve().parent.parent.parent
project_path = (script_dir / project_path).resolve()
else:
project_path = project_path.resolve()

import_module_from_path("project", str(project_path))

print(f"Loading checkpoint from: {input_path}")
try:
Expand Down Expand Up @@ -86,6 +97,8 @@ def main():
help="Architecture class name (default: PrimusM)")
parser.add_argument("--keep-cls-token", action="store_true",
help="Keep CLS token in positional embeddings")
parser.add_argument("--project-path", default=".",
help="Path to the project root (default: current directory)")

args = parser.parse_args()

Expand All @@ -100,7 +113,8 @@ def main():
input_path=args.input_path,
output_path=args.output_path,
remove_cls_token=not args.keep_cls_token,
arch_class_name=args.arch_class_name
arch_class_name=args.arch_class_name,
project_path=args.project_path
)
print("Checkpoint conversion completed successfully!")
except Exception as e:
Expand Down
7 changes: 6 additions & 1 deletion utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def import_module_from_path(module_name: str, module_path: str) -> None:

module_path = Path(module_path).resolve() / "__init__.py"
if not module_path.is_file():
raise FileNotFoundError(f"No `__init__.py` in `{module_path}`.")
raise FileNotFoundError(
f"No `__init__.py` found at `{module_path}`.\n"
f"If your config contains an absolute path like '/home/suraj/...', "
f"please update it to use a relative path (e.g., 'project: \".\"') "
f"or set it to your local repository root."
)
spec = importlib.util.spec_from_file_location(module_name, str(module_path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
Expand Down