Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
248b58f
refactor: move data files from models/ to data/ and update imports
samrat-rm Mar 19, 2026
221c0b6
refactor: move all the loss functions from common_metrics.py file in…
samrat-rm Mar 19, 2026
7516d55
refactor: common_metrics.py is split into metrics.py and validation.p…
samrat-rm Mar 19, 2026
1a64693
refactor: config/saved_models_run.json moved into parent dir. Reasoni…
samrat-rm Mar 19, 2026
7c4b20f
cleanup: loaders was moved to the data module but original file remin…
samrat-rm Mar 19, 2026
d8d08cd
refactor: moved the mode_test.py to evaluation and models_tcloud to m…
samrat-rm Mar 19, 2026
5e3c60c
refactor: plot_landsat.py moved to utils module
samrat-rm Mar 19, 2026
fd201f6
chore: add requirements.txt with project dependencies
samrat-rm Mar 19, 2026
d4d2b22
Add __init__.py to make directories importable as packages
samrat-rm Mar 25, 2026
588932e
fix import statement for calculate_metrics function
samrat-rm Mar 25, 2026
8c4e43d
refactor: add BaseModel abstract class for model interface contract
samrat-rm Mar 26, 2026
e5b2f86
refactor: add MODEL_REGISTRY and migrate SiameseUNet to BaseModel
samrat-rm Mar 26, 2026
8e2e092
refactor: migrated CDnetV2 and HRCloudNet to BaseModel abstract class
samrat-rm Mar 26, 2026
99f436f
refactor: migrated SwinCloud model to BaseModel abstract class
samrat-rm Mar 26, 2026
dfb1015
fix: img_size for SwinCloud
samrat-rm Mar 26, 2026
3af7aab
refactor : migrated BAM_CD model to BaseModel
samrat-rm Mar 26, 2026
57809a2
fix: removed duplicate import for smp
samrat-rm Mar 26, 2026
325a548
refactor: add UnetModel wrapper and migrate to BaseModel/MODEL_REGISTRY
samrat-rm Mar 26, 2026
0c97e72
refactor: add SegFormerModel wrapper and migrate to BaseModel/MODEL_R…
samrat-rm Mar 26, 2026
6f6cd12
refactor: move SiameseUNet to models/siamese_unet.py
samrat-rm Mar 26, 2026
a043505
refactor: add DeepLabV3Model and SwinUnetModel wrappers, migrate to M…
samrat-rm Mar 26, 2026
43228e6
fix: super method called without () in SegFormerModel
samrat-rm Mar 26, 2026
a25a411
refactor: replace MODEL_REGISTRY dict with decorator-based registry
samrat-rm Mar 26, 2026
e262ed9
Add: model imports to module init for centralized registration
samrat-rm Mar 26, 2026
1cc012c
Add: loss_registry for loss functions
samrat-rm Mar 26, 2026
6fab6c2
refactor: add loss registry with decorator pattern, split builders fr…
samrat-rm Mar 26, 2026
bbb6c7f
feat: add hook system to training loop for extensible loss/logging
samrat-rm Mar 28, 2026
21ded36
chore: fix comment
samrat-rm Mar 30, 2026
247d1ae
feat: introduce hook-based trainer architecture for extensible traini…
samrat-rm Mar 30, 2026
6091f34
feat: add pipeline/runner.py entry point and smoke test
samrat-rm Mar 30, 2026
0e0b654
feat: add PipelineConfig dataclass with validation and from_json()
samrat-rm Mar 30, 2026
b7fd8d8
chore: import and error handling for model testing
samrat-rm Mar 30, 2026
3e0e120
chore: remove comments
samrat-rm Mar 30, 2026
e6c74fb
fix: remap pre-refactor checkpoint keys for BaseModel wrapper compati…
samrat-rm Mar 30, 2026
25e0b29
feat: implement mean_uncertainty in validate_all() — UncertaintyHook …
samrat-rm Mar 30, 2026
4c1e25f
chore: update requirements.txt with missing packages and Python versi…
samrat-rm Mar 30, 2026
66d229c
chore: removing and updating comments and imports
samrat-rm Mar 30, 2026
d9925ab
chore: formatting imports and removing unused imports
samrat-rm Mar 30, 2026
0bcfc75
chore: adding error handling for traintest
samrat-rm Mar 30, 2026
04ae51d
fix: refactor parent_script_dir, remove duplicate sys.path and remove…
samrat-rm Mar 30, 2026
0a948ab
Fix : delete the List and Optional import and use the modern syntax.
samrat-rm Mar 30, 2026
35e099e
feat: replace UncertaintyHook placeholder with real entropy-based unc…
samrat-rm Mar 30, 2026
6b95b66
chore: adding doc string for Uncertainty hook
samrat-rm Mar 31, 2026
4e59a44
Fix :Making the BaseTrainer train_epoch() abstract because CloudTrai…
samrat-rm Mar 31, 2026
4cc01fd
chore: replace legacy typing imports with modern syntax in pipeline_c…
samrat-rm Mar 31, 2026
b390d0e
Fix: Divide by zero error handling for mean uncertainty calculation
samrat-rm Mar 31, 2026
3ff65ec
chore: update the doc string comment and remove unnecessary comments
samrat-rm Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions configs/pipeline_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

import json
from dataclasses import dataclass, field
from typing import Any

VALID_LOSSES = {
"CrossEntropy", "CrossEntropyWeights",
"DiceCECombined", "Dice", "Focal", "CDnetV2Loss",
}
VALID_OPTIMIZERS = {"adam", "adamw"}


@dataclass
class PipelineConfig:

model_type: str = "Unet"
features: list[str] = field(default_factory=lambda: ["tir"])
num_classes: int = 2
traintest: str = "train"

dataset: str = ""
dataset_folder: str = ""
dataset_dir: str | None = None
target_band: str = "cloud_mask"
batch_size: int = 64
cpuworkers: int = 4
thin_cloud_class: int = -1
yshift: int = 0
transform: str | None = None

loss: str = "CrossEntropy"
class_counts: list[int] | None = None

optimizer: str = "adam"
lr: float = 1e-4
weight_decay: float = 0.0

lambda_reg: float = 0.1

patience: int = 20
max_epochs: int = 200
target_metric: str = "iou_avg"
seed: int | None = None

device: str = "cpu"
results_csv: str | None = None
model_file: str | None = None
save_model: bool = False

def __post_init__(self):
"""Validate fields immediately after construction."""
if not self.features:
raise ValueError("features must be a non-empty list of band names.")
if self.num_classes < 2:
raise ValueError(f"num_classes must be >= 2, got {self.num_classes}.")
if self.loss not in VALID_LOSSES:
raise ValueError(f"Unknown loss '{self.loss}'. Valid: {sorted(VALID_LOSSES)}")
if self.optimizer not in VALID_OPTIMIZERS:
raise ValueError(f"Unknown optimizer '{self.optimizer}'. Valid: {sorted(VALID_OPTIMIZERS)}")
if self.lr <= 0:
raise ValueError(f"lr must be > 0, got {self.lr}.")
if self.patience <= 0:
raise ValueError(f"patience must be > 0, got {self.patience}.")
if self.max_epochs <= 0:
raise ValueError(f"max_epochs must be > 0, got {self.max_epochs}.")
if self.batch_size <= 0:
raise ValueError(f"batch_size must be > 0, got {self.batch_size}.")
if not 0.0 <= self.lambda_reg <= 1.0:
raise ValueError(f"lambda_reg must be in [0, 1], got {self.lambda_reg}.")
if self.traintest not in {"train", "val", "test"}:
raise ValueError(
f"traintest must be 'train', 'val', or 'test', got '{self.traintest}'"
)

@classmethod
def from_json(cls, path: str, key: str = None) -> PipelineConfig:
"""
Load a PipelineConfig from a JSON file.

Args:
path: Path to the JSON file.
key: Top-level key to read from (e.g. "viirs_unet").
If None, the file must be a flat dict of fields.
"""
with open(path) as f:
data = json.load(f)

if key is not None:
if key not in data:
raise KeyError(f"Key '{key}' not found in {path}. Available: {list(data.keys())}")
data = data[key]

known = cls.__dataclass_fields__.keys()
unknown = set(data.keys()) - set(known)
if unknown:
raise ValueError(f"Unknown fields in config: {unknown}. Check for typos.")

return cls(**data)

def to_dict(self) -> dict[str, Any]:
d = {
"model_type": self.model_type,
"features": self.features,
"num_classes": self.num_classes,
"dataset": self.dataset,
"dataset_folder": self.dataset_folder,
"target_band": self.target_band,
"batch_size": self.batch_size,
"cpuworkers": self.cpuworkers,
"thin_cloud_class": self.thin_cloud_class,
"yshift": self.yshift,
"transform": self.transform,
"loss": self.loss,
"optimizer": self.optimizer,
"lr": self.lr,
"weight_decay": self.weight_decay,
"lambda_reg": self.lambda_reg,
"patience": self.patience,
"max_epochs": self.max_epochs,
"target_metric": self.target_metric,
"device": self.device,
"results_csv": self.results_csv,
"traintest": self.traintest,
"save_model": self.save_model,
}
if self.dataset_dir is not None:
d["dataset_dir"] = self.dataset_dir
if self.class_counts is not None:
d["class_counts"] = self.class_counts
if self.seed is not None:
d["seed"] = self.seed
if self.model_file is not None:
d["model_file"] = self.model_file
return d
File renamed without changes.
Empty file added data/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
Empty file added evaluation/__init__.py
Empty file.
100 changes: 100 additions & 0 deletions evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np

def iou_score(preds, targets, num_classes=2):
preds = preds.view(-1)
targets = targets.view(-1)
ious = []
for cls in range(num_classes):
pred_inds = preds == cls
target_inds = targets == cls
intersection = (pred_inds & target_inds).sum().item()
union = (pred_inds | target_inds).sum().item()
if union == 0:
ious.append(float('nan'))
else:
ious.append(intersection / union)
return np.nanmean(ious)

def iou_per_class(preds, labels, num_classes=3):
preds = preds.detach().cpu()
labels = labels.detach().cpu()
ious = []

for cls in range(num_classes):
pred_inds = (preds == cls)
label_inds = (labels == cls)

intersection = (pred_inds & label_inds).sum().item()
union = (pred_inds | label_inds).sum().item()

if union == 0:
iou = float('nan')
else:
iou = intersection / union
ious.append(iou)
return ious

def calculate_metrics(all_preds, all_targets, num_classes, total_pixels, correct_pixels, mean_uncertainty=None):
"""
Calculate validation metrics from predictions and targets.

Args:
all_preds: List of all predictions (flattened)
all_targets: List of all targets (flattened)
num_classes: Number of classes
total_pixels: Total number of pixels
correct_pixels: Number of correctly predicted pixels

Returns:
metrics: Dictionary of all computed metrics (pixel_accuracy, iou per class,
precision, recall, f1, iou_avg, and optionally mean_uncertainty).
"""

all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

tp = np.zeros(num_classes, dtype=np.uint64)
fp = np.zeros(num_classes, dtype=np.uint64)
fn = np.zeros(num_classes, dtype=np.uint64)
tn = np.zeros(num_classes, dtype=np.uint64)

for cls in range(num_classes):
cls_pred = (all_preds == cls)
cls_true = (all_targets == cls)
tp[cls] = np.logical_and(cls_pred, cls_true).sum()
fp[cls] = np.logical_and(cls_pred, ~cls_true).sum()
fn[cls] = np.logical_and(~cls_pred, cls_true).sum()
tn[cls] = np.logical_and(~cls_pred, ~cls_true).sum()

epsilon = 1e-7
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2 * precision * recall / (precision + recall + epsilon)
iou = tp / (tp + fp + fn + epsilon)
pixel_accuracy = correct_pixels / total_pixels

metrics = {
"pixel_accuracy": pixel_accuracy,
}
avg_iou = 0
for cls in range(num_classes):
metrics[f"iou_{cls}"] = iou[cls]
metrics[f"precision_{cls}"] = precision[cls]
metrics[f"recall_{cls}"] = recall[cls]
metrics[f"f1_{cls}"] = f1[cls]
metrics[f"tp_{cls}"] = tp[cls]
metrics[f"fn_{cls}"] = fn[cls]
metrics[f"tn_{cls}"] = tn[cls]
metrics[f"fp_{cls}"] = fp[cls]
avg_iou += iou[cls]

avg_iou = avg_iou / num_classes
metrics[f"iou_avg"]=avg_iou

uncertainty_str = f", mean_uncertainty: {mean_uncertainty:.4f}" if mean_uncertainty is not None else ""
print(f"\nPixel Accuracy: {pixel_accuracy:.4f}, mIoU: {avg_iou}{uncertainty_str}")
print(f"{'Class':<6} {'IoU':>6} {'Precision':>10} {'Recall':>8} {'F1':>6}")
for cls in range(num_classes):
print(f"{cls:<6} {iou[cls]:>6.3f} {precision[cls]:>10.3f} {recall[cls]:>8.3f} {f1[cls]:>6.3f}")

return metrics
55 changes: 25 additions & 30 deletions models/model_test.py → evaluation/model_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import torch
import numpy as np
import argparse
import gc
import json
import os
import pandas as pd
import sys

import numpy as np
import pandas as pd
import torch
from common_metrics import validate_all, record_validation_metrics_to_csv
import sys
from tqdm import tqdm

parent_script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(parent_script_dir)
import gc
from models_tcloud import init_model_and_loaders
sys.path.insert(0, parent_script_dir)

from evaluation.validate import validate_all, record_validation_metrics_to_csv
from libraries.utils import save_geotiff, get_preds_multi_encoders
from libraries.wandb_retrieve import get_filtered_wandb_runs, wandinit
import json
import argparse
from tqdm import tqdm

from model_builder.models_tcloud import init_model_and_loaders

def save_inference_images(ibatch, save_inference_dir, results, inputs, outputs, preds, targets, batch_size, test_df, save_logits, num_classes):
if isinstance(inputs, list):
Expand All @@ -30,18 +29,15 @@ def save_inference_images(ibatch, save_inference_dir, results, inputs, outputs,

mask_path = logits_path = None

# Save class mask
class_mask = preds[j].numpy().astype(np.uint8)[np.newaxis, ...]
mask_path = os.path.join(save_inference_dir, f"{image_id}_mask.tif")
save_geotiff(class_mask, mask_path, ref_tif, dtype="uint8", count=1)

# Save logits
if save_logits:
logits = outputs[j].cpu().numpy().astype(np.float32)
logits_path = os.path.join(save_inference_dir, f"{image_id}_logits.tif")
save_geotiff(logits, logits_path, ref_tif, dtype="float32", count=num_classes)

# Compute per-class IoU
pred_np = preds[j].numpy().flatten()
target_np = targets[j].numpy().flatten()
ious = []
Expand All @@ -67,8 +63,6 @@ def evaluate_on_test_set(

save_inference=params_dict.get("save_inference", False)
save_logits=params_dict.get("save_logits", False)

# Load test set and data

df = pd.read_csv(os.path.join(params_dict["dataset_folder"],params_dict["dataset"]))
if params_dict["traintest"]!="test":
Expand All @@ -82,7 +76,14 @@ def evaluate_on_test_set(
num_classes=params_dict["num_classes"]

model,_,test_loader=init_model_and_loaders(params_dict)
loaded_state_dict = torch.load(model_path, weights_only=True)
loaded_state_dict = torch.load(model_path, weights_only=True, map_location=device)

# Models saved before the BaseModel refactor have keys without the "_model." prefix.
model_keys = set(model.state_dict().keys())
ckpt_keys = set(loaded_state_dict.keys())
if not ckpt_keys.issubset(model_keys) and all(f"_model.{k}" in model_keys for k in ckpt_keys):
loaded_state_dict = {f"_model.{k}": v for k, v in loaded_state_dict.items()}

model.load_state_dict(loaded_state_dict)
model.eval()

Expand All @@ -109,10 +110,7 @@ def evaluate_on_test_set(
for i, (inputs, targets) in enumerate(tqdm(test_loader, desc="Inference Progress")):

outputs = get_preds_multi_encoders(model, inputs, device)
'''
inputs = inputs.to(device)
outputs = model(inputs)
'''

if isinstance(outputs, tuple):
preds = torch.argmax(outputs[0], dim=1).cpu()
else:
Expand All @@ -127,7 +125,6 @@ def evaluate_on_test_set(
batch_size, test_df, save_logits, num_classes)

if save_inference:
# Save DataFrame to CSV
results_df = pd.DataFrame(results)
csv_path = os.path.join(save_inference_dir, "inference_results.csv")
results_df.to_csv(csv_path, index=False)
Expand All @@ -139,8 +136,6 @@ def evaluate_on_test_set(
if wandbrun:
wandbrun.log(metrics)

#record_validation_metrics_to_csv(os.path.expanduser("~/shared_storage/tcloudDS/benchmarks/test_results_v2.csv"), metrics, params_dict)

del model
del test_loader
torch.cuda.empty_cache()
Expand Down Expand Up @@ -173,8 +168,11 @@ def main():
args = parser.parse_args()
test_set = args.test_set
list_only = args.list_only

if test_set is None:
parser.error("--test_set or -t is required. Choose from: viirs, landsat, landsatMA")

configfile=os.path.join(script_dir,"configs/saved_models_run.json")
configfile = os.path.join(parent_script_dir, "configs/saved_models_run.json")
with open(configfile, 'r') as file:
configdict = json.load(file)
if not test_set in configdict:
Expand Down Expand Up @@ -210,19 +208,16 @@ def main():
paramsdict={}
#model_path=find_file_recursive(os.path.basename(row["model_file"]), os.path.dirname(row["model_file"]))

#first pass config params from wandb
for k in [p for p in dfmodels if p.startswith("config_")]:
paramsdict[k[7:]]=wandb_row[k] # without config prefix

#second pass parameters from config files to override
for k in configparams:
paramsdict[k]=configparams[k]
if "config_dataset" in wandb_row and "config_trained" not in wandb_row:
paramsdict["trained"]=wandb_row["config_dataset"]

if not "device" in paramsdict:
paramsdict["device"]="cuda:0"
# force loading 'test' dataset in case not otherwise configured
if not "traintest" in configparams:
paramsdict["traintest"]="test"

Expand Down
Loading