diff --git a/8to2.py b/8to2.py
new file mode 100644
index 0000000000..4540a9524f
--- /dev/null
+++ b/8to2.py
@@ -0,0 +1,170 @@
+import os
+import numpy as np
+import rasterio
+from shutil import copy2
+from tqdm import tqdm
+
+
+def keep_last_two_bands(folder_path, backup=True, backup_suffix='_8band_backup'):
+ """
+ 将所有TIF文件只保留最后两个波段,覆盖原文件
+
+ 参数:
+ folder_path: 要处理的根目录路径
+ backup: 是否备份原文件
+ backup_suffix: 备份文件后缀
+ """
+ total_files = 0
+ files_processed = 0
+ files_skipped = 0
+ tif_extensions = ['.tif', '.tiff', '.TIF', '.TIFF']
+
+ # 收集所有tif文件
+ print(f"扫描路径: {folder_path}")
+ tif_files = []
+ for root, dirs, files in os.walk(folder_path):
+ for file in files:
+ if any(file.endswith(ext) for ext in tif_extensions):
+ tif_files.append(os.path.join(root, file))
+
+ total_files = len(tif_files)
+ print(f"找到 {total_files} 个TIF文件\n")
+
+ # 第一步:检查文件波段数
+ print("步骤1: 检查文件波段数...")
+ print("=" * 80)
+
+ band_info = {}
+ files_to_process = []
+
+ try:
+ iterator = tqdm(tif_files, desc="扫描文件")
+ except:
+ iterator = tif_files
+
+ for file_path in iterator:
+ try:
+ with rasterio.open(file_path) as src:
+ num_bands = src.count
+ rel_path = os.path.relpath(file_path, folder_path)
+
+ if num_bands not in band_info:
+ band_info[num_bands] = []
+ band_info[num_bands].append(rel_path)
+
+ if num_bands >= 2: # 只处理波段数>=2的文件
+ files_to_process.append((file_path, num_bands))
+ else:
+ files_skipped += 1
+
+ except Exception as e:
+ print(f"\n读取错误: {os.path.basename(file_path)}")
+ print(f" 错误: {e}")
+
+ # 显示波段数统计
+ print("\n波段数统计:")
+ for num_bands in sorted(band_info.keys()):
+ print(f" {num_bands}波段: {len(band_info[num_bands])} 个文件")
+
+ print(f"\n将处理 {len(files_to_process)} 个文件(波段数>=2)")
+ print(f"跳过 {files_skipped} 个文件(波段数<2)")
+
+ if len(files_to_process) == 0:
+ print("没有需要处理的文件!")
+ return
+
+ # 第二步:确认并处理
+ print("\n" + "=" * 80)
+ print("步骤2: 处理文件")
+ print("操作: 只保留最后两个波段,覆盖原文件")
+
+ if backup:
+ print(f"将创建备份文件(后缀: {backup_suffix})")
+ else:
+ print("警告: 不会创建备份!")
+
+ user_input = input("\n是否继续?(y/n): ")
+ if user_input.lower() != 'y':
+ print("操作已取消")
+ return
+
+ print("\n开始处理...")
+
+ try:
+ iterator = tqdm(files_to_process, desc="处理文件")
+ except:
+ iterator = files_to_process
+
+ for file_path, original_bands in iterator:
+ try:
+ # 备份原文件
+ if backup:
+ backup_path = file_path + backup_suffix
+ copy2(file_path, backup_path)
+
+ # 读取文件
+ with rasterio.open(file_path) as src:
+ # 保存元数据并修改波段数
+ meta = src.meta.copy()
+ meta['count'] = 2 # 修改为2波段
+
+ # 读取最后两个波段
+ band_n_1 = src.read(original_bands - 1) # 倒数第二个波段
+ band_n = src.read(original_bands) # 最后一个波段
+
+ # 创建临时文件
+ temp_path = file_path + '.tmp'
+
+ # 写入新文件(只有2个波段)
+ with rasterio.open(temp_path, 'w', **meta) as dst:
+ dst.write(band_n_1, 1) # 写入第1个波段
+ dst.write(band_n, 2) # 写入第2个波段
+
+ # 替换原文件
+ os.replace(temp_path, file_path)
+ files_processed += 1
+
+ except Exception as e:
+ print(f"\n处理错误: {os.path.basename(file_path)}")
+ print(f" 原波段数: {original_bands}")
+ print(f" 错误: {e}")
+ # 清理临时文件
+ temp_path = file_path + '.tmp'
+ if os.path.exists(temp_path):
+ os.remove(temp_path)
+
+ # 输出结果
+ print("\n" + "=" * 80)
+ print("处理完成!")
+ print(f"总文件数: {total_files}")
+ print(f"成功处理的文件: {files_processed}")
+ print(f"跳过的文件: {files_skipped}")
+
+ if backup:
+ print(f"\n原始文件已备份(后缀: {backup_suffix})")
+ print("如果确认无误,可以使用以下命令删除备份文件:")
+ print(f" find {folder_path} -name '*{backup_suffix}' -delete")
+
+ # 验证处理结果
+ print("\n步骤3: 验证处理结果...")
+ print("检查前3个文件的波段数...")
+
+ for i, (file_path, _) in enumerate(files_to_process[:3]):
+ try:
+ with rasterio.open(file_path) as src:
+ rel_path = os.path.relpath(file_path, folder_path)
+ print(f" ✓ {rel_path}: {src.count} 波段")
+ except Exception as e:
+ print(f" ✗ 验证失败: {e}")
+
+ if i >= 2: # 只检查前3个
+ break
+
+
+if __name__ == "__main__":
+ folder_path = "/mnt/d/Project/Code/Floodnet/data/mixed_dataset/SAR/"
+
+ # 执行处理
+ # backup=True 会备份原文件(推荐)
+ # backup=False 直接覆盖(不推荐)
+ keep_last_two_bands(folder_path, backup=False, backup_suffix='_8band_backup')
\ No newline at end of file
diff --git a/command b/command
new file mode 100644
index 0000000000..f774dc5546
--- /dev/null
+++ b/command
@@ -0,0 +1,148 @@
+git clone https://VIncentmuyi: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@github.com/VIncentmuyi/Floodnet.git
+
+/解决python不搜索本文件夹包的问题
+ModuleNotFoundError: No module named 'mmseg'
+export PYTHONPATH=.:$PYTHONPATH
+
+ps -ef | grep python
+pkill -9 python
+git fetch origin
+git reset --hard origin/main
+git clean -fd # 删除未跟踪的文件
+
+python tools/train.py ./configs/deeplabv3plus/Deeplabv3+UAVflood.py --work-dir work_dirs/SAR/Deeplabv3+
+python tools/test_full_metrics.py ./configs/deeplabv3plus/Deeplabv3+UAVflood.py work_dirs/SAR/Deeplabv3+/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/Deeplabv3+ --show-dir ./Result/SAR/Deeplabv3+/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \
+ work_dirs/UAVflood/Deeplabv3+/best_val_mIoU_iter_20000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \
+ --shape 8 256 256
+
+python tools/train.py ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py --work-dir work_dirs/SAR/segformer
+python tools/test_full_metrics.py ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py work_dirs/SAR/segformer/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/segformer/ --show-dir ./Result/SAR/segformer/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py\
+ work_dirs/UAVflood/segformer/best_val_mIoU_iter_40000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py \
+ --shape 5 256 256
+
+python tools/train.py ./configs/unet/Unet-Uavflood.py --work-dir work_dirs/SARflood/unet
+python tools/test_full_metrics.py ./configs/unet/Unet-Uavflood.py work_dirs/SAR/unet/best_mIoU_epoch_90.pth --work-dir ./Result/SAR/unet/ --show-dir ./Result/SAR/unet/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/unet/Unet-Uavflood.py \
+ work_dirs/UAVflood/unet/best_val_mIoU_iter_40000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/unet/Unet-Uavflood.py\
+ --shape 8 256 256
+
+python tools/train.py ./configs/mae/mae-base-Uavflood.py --work-dir work_dirs/UAVflood/mae
+python tools/test_full_metrics.py ./configs/mae/mae-base-Uavflood.py work_dirs/UAVflood/mae/best_val_mIoU_iter_20000.pth --work-dir ./Result/UAV/mae/ --show-dir ./Result/UAV/mae/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/mae/mae-base-Uavflood.py\
+ work_dirs/UAVflood/mae/best_val_mIoU_iter_28000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/mae/mae-base-Uavflood.py\
+ --shape 5 256 256
+
+python tools/train.py ./configs/vit/vit-Uavflood.py --work-dir work_dirs/SAR/vit
+python tools/test_full_metrics.py ./configs/vit/vit-Uavflood.py work_dirs/SAR/vit/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/vit/ --show-dir ./Result/SAR/vit/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/vit/vit-Uavflood.py\
+ work_dirs/UAVflood/vit/best_val_mIoU_iter_36000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/vit/vit-Uavflood.py\
+ --shape 3 256 256
+
+python tools/train.py ./configs/beit/beit-Uavflood.py --work-dir work_dirs/SARflood/beit
+python tools/test_full_metrics.py ./configs/beit/beit-Uavflood.py work_dirs/UAVflood/beit/best_val_mIoU_iter_16000.pth --work-dir ./Result/UAV/beit/ --show-dir ./Result/UAV/beit/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/beit/beit-Uavflood.py\
+ work_dirs/UAVflood/beit/best_val_mIoU_iter_40000.pth\
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/beit/beit-Uavflood.py\
+ --shape 3 256 256
+
+python tools/train.py ./configs/convnext/convnext-base-uavflood.py --work-dir work_dirs/SAR/convnext
+python tools/test_full_metrics.py ./configs/convnext/convnext-base-uavflood.py work_dirs/SAR/convnext/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/convnext/ --show-dir ./Result/SAR/convnext/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/convnext/convnext-base-uavflood.py\
+ work_dirs/GFflood/convnext/best_val_mIoU_iter_20000.pth \
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/convnext/convnext-base-uavflood.py\
+ --shape 5 256 256
+
+python tools/train.py ./configs/swin/Swin-uavflood-256x256.py --work-dir work_dirs/SAR/Swin --cfg-options seed=42
+python tools/test_full_metrics.py ./configs/swin/Swin-uavflood-256x256.py work_dirs/GFflood/Swin/best_val_mIoU_iter_16000.pth --work-dir ./Result/GF/swin/ --show-dir ./Result/GF/swin/vis --cfg-options visualizer.alpha=1.0
+python tools/analysis_tools/benchmark.py \
+ ./configs/swin/Swin-uavflood-256x256.py\
+ work_dirs/GFflood/Swin/best_val_mIoU_iter_16000.pth\
+ --repeat-times 3
+python tools/analysis_tools/get_flops.py \
+ ./configs/swin/Swin-uavflood-256x256.py\
+ --shape 5 256 256
+
+python tools/train.py ./configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py --work-dir work_dirs/floodnet/SwinmoeB/655 --cfg-options seed=42
+python tools/test_full_metrics.py ./configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py work_dirs/floodnet/SwinmoeB/best_mIoU_epoch_100.pth --work-dir ./Result/Floodnet/SAR/ --show-dir ./Result/Floodnet/SAR/vis --cfg-options visualizer.alpha=1.0
+
+python tools/test_full_metrics.py \
+ configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \
+ work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \
+ --cfg-options test_dataloader.dataset.filter_modality=sar \
+ --work-dir ./Result/Floodnet/SAR/ \
+ --show-dir ./Result/Floodnet/SAR/vis --cfg-options visualizer.alpha=1.0
+
+
+python tools/train.py \
+ configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \
+ --work-dir work_dirs/floodnet/SwinmoeB_sar_only \
+ --cfg-options seed=42
+
+python tools/test_full_metrics.py \
+ configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \
+ work_dirs/floodnet/SwinmoeB_sar_only/best_mIoU_epoch_100.pth \
+
+
+ python tools/train.py configs/floodnet/continue_train_150ep.py --work-dir work_dirs/floodnet/SwinmoeB/655 --resume --cfg-options load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth"
+
+ python tools/analysis_tools/visualize_expert_routing.py \
+ configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \
+ work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \
+ --output-dir work_dirs/figures/expert_routing \
+ --num-samples 50
+
+python tools/train.py configs/floodnet/finetune_single_modal.py \
+ --work-dir work_dirs/generalization/LY-train-station/ \
+ --cfg-options \
+ train_dataloader.dataset.data_root="data/LY-train-station/" \
+ val_dataloader.dataset.data_root="data/LY-train-station/" \
+ test_dataloader.dataset.data_root="data/LY-train-station/"
+
+python tools/test.py \
+ configs/floodnet/finetune_single_modal.py \
+ work_dirs/generalization/LY-train-station/best_mIoU_epoch_50.pth \
+ --work-dir work_dirs/generalization/LY-train-station/test_results \
+ --cfg-options \
+ test_dataloader.dataset.data_root="data/LY-train-station/" \
+ "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \
+ --show-dir work_dirs/generalization/LY-train-station/test_results/vis \
+ --out work_dirs/generalization/LY-train-station/test_results/predictions
+
+python tools/predict_large_tif.py \
+ configs/floodnet/finetune_single_modal.py \
+ work_dirs/generalization/LY-train-station/best_mIoU_epoch_50.pth \
+ --input data/luoyuan/result.tif \
+ --output data/luoyuan/prediction.tif \
+ --tile-size 512 \
+ --overlap 64 \
+ --modal rgb \
+ --bands 0 1 2 \
+ --batch-size 16
\ No newline at end of file
diff --git a/configs/_base_/datasets/GFflood.py b/configs/_base_/datasets/GFflood.py
new file mode 100644
index 0000000000..7478691b16
--- /dev/null
+++ b/configs/_base_/datasets/GFflood.py
@@ -0,0 +1,83 @@
+# GF 5-channel dataset settings
+dataset_type = 'UAVfloodDataset'
+data_root = '../Floodnet/data/mixed_dataset/GF/'
+crop_size = (256, 256)
+
+# GF 5-channel normalization parameters
+train_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ # PhotoMetricDistortion removed - not suitable for multispectral imagery
+ dict(type='PackSegInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackSegInputs')
+]
+
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ],
+ [dict(type='LoadAnnotations')],
+ [dict(type='PackSegInputs')]
+ ])
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline))
+
+test_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
\ No newline at end of file
diff --git a/configs/_base_/datasets/SARflood.py b/configs/_base_/datasets/SARflood.py
new file mode 100644
index 0000000000..c4e4ac9742
--- /dev/null
+++ b/configs/_base_/datasets/SARflood.py
@@ -0,0 +1,83 @@
+# SAR 8-channel dataset settings
+dataset_type = 'UAVfloodDataset'
+data_root = '../Floodnet/data/mixed_dataset/SAR/'
+crop_size = (256, 256)
+
+# SAR 8-channel normalization parameters
+train_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ # PhotoMetricDistortion removed - not suitable for SAR imagery
+ dict(type='PackSegInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackSegInputs')
+]
+
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ],
+ [dict(type='LoadAnnotations')],
+ [dict(type='PackSegInputs')]
+ ])
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline))
+
+test_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
\ No newline at end of file
diff --git a/configs/_base_/datasets/UAVflood.py b/configs/_base_/datasets/UAVflood.py
new file mode 100644
index 0000000000..4f8b700fb6
--- /dev/null
+++ b/configs/_base_/datasets/UAVflood.py
@@ -0,0 +1,77 @@
+# dataset settings
+dataset_type = 'UAVfloodDataset'
+data_root = '../Floodnet/data/mixed_dataset/SAR/'
+crop_size = (256, 256)
+train_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ # PhotoMetricDistortion removed - not compatible with multi-band (8-channel) images
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadMultiBandTiffFromFile'),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline))
+test_dataloader = dict(
+ batch_size=8,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
diff --git a/configs/_base_/models/deeplabv3_unet_s5-d16.py b/configs/_base_/models/deeplabv3_unet_s5-d16.py
index 92df52c35d..f373110aa6 100644
--- a/configs/_base_/models/deeplabv3_unet_s5-d16.py
+++ b/configs/_base_/models/deeplabv3_unet_s5-d16.py
@@ -1,19 +1,24 @@
-# model settings
+# UNet for GF 5-channel imagery
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False, # Not RGB imagery
pad_val=0,
seg_pad_val=255)
+
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='UNet',
- in_channels=3,
+ in_channels=8, # 5-channel input
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
@@ -53,6 +58,5 @@
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
- # model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/configs/_base_/models/deeplabv3plus_r50-d8.py b/configs/_base_/models/deeplabv3plus_r50-d8.py
index 74dbed5593..19af44fe17 100644
--- a/configs/_base_/models/deeplabv3plus_r50-d8.py
+++ b/configs/_base_/models/deeplabv3plus_r50-d8.py
@@ -2,17 +2,22 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean= [117.926186, 117.568402, 97.217239],
+ #std= [53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
- pretrained='open-mmlab://resnet50_v1c',
+ #pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
+ in_channels=8,
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
@@ -31,7 +36,7 @@
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
@@ -44,7 +49,7 @@
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
diff --git a/configs/_base_/models/segformer_mit-b0.py b/configs/_base_/models/segformer_mit-b0.py
index 46841adc07..0c9dbafd83 100644
--- a/configs/_base_/models/segformer_mit-b0.py
+++ b/configs/_base_/models/segformer_mit-b0.py
@@ -1,10 +1,14 @@
-# model settings
+# model settings 这实际上是b5的参数量
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
@@ -13,13 +17,13 @@
pretrained=None,
backbone=dict(
type='MixVisionTransformer',
- in_channels=3,
- embed_dims=32,
+ in_channels=8,
+ embed_dims=64, # B0: 32 -> B5: 64
num_stages=4,
- num_layers=[2, 2, 2, 2],
- num_heads=[1, 2, 5, 8],
- patch_sizes=[7, 3, 3, 3],
- sr_ratios=[8, 4, 2, 1],
+ num_layers=[3, 6, 40, 3], # B0: [2, 2, 2, 2] -> B5: [3, 6, 40, 3]
+ num_heads=[1, 2, 5, 8], # 保持不变
+ patch_sizes=[7, 3, 3, 3], # 保持不变
+ sr_ratios=[8, 4, 2, 1], # 保持不变
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
@@ -28,11 +32,11 @@
drop_path_rate=0.1),
decode_head=dict(
type='SegformerHead',
- in_channels=[32, 64, 160, 256],
+ in_channels=[64, 128, 320, 512], # B0: [32, 64, 160, 256] -> B5: [64, 128, 320, 512]
in_index=[0, 1, 2, 3],
- channels=256,
+ channels=768, # B0: 256 -> B5: 768
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
diff --git a/configs/_base_/models/upernet_beit.py b/configs/_base_/models/upernet_beit.py
index 691e288dbf..d1ae45e11e 100644
--- a/configs/_base_/models/upernet_beit.py
+++ b/configs/_base_/models/upernet_beit.py
@@ -1,9 +1,13 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean=[432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std=[97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
@@ -12,9 +16,9 @@
pretrained=None,
backbone=dict(
type='BEiT',
- img_size=(640, 640),
+ img_size=(256, 256),
patch_size=16,
- in_channels=3,
+ in_channels=8,
embed_dims=768,
num_layers=12,
num_heads=12,
@@ -35,7 +39,7 @@
pool_scales=(1, 2, 3, 6),
channels=768,
dropout_ratio=0.1,
- num_classes=150,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
@@ -48,7 +52,7 @@
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
- num_classes=150,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
diff --git a/configs/_base_/models/upernet_convnext.py b/configs/_base_/models/upernet_convnext.py
index 958994c91e..39d59b4963 100644
--- a/configs/_base_/models/upernet_convnext.py
+++ b/configs/_base_/models/upernet_convnext.py
@@ -3,9 +3,13 @@
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth' # noqa
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
@@ -15,13 +19,15 @@
backbone=dict(
type='mmpretrain.ConvNeXt',
arch='base',
+ in_channels=8,
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
- init_cfg=dict(
- type='Pretrained', checkpoint=checkpoint_file,
- prefix='backbone.')),
+ #init_cfg=dict(
+ #type='Pretrained', checkpoint=checkpoint_file,
+ #prefix='backbone.')
+ ),
decode_head=dict(
type='UPerHead',
in_channels=[128, 256, 512, 1024],
diff --git a/configs/_base_/models/upernet_mae.py b/configs/_base_/models/upernet_mae.py
index b833b67645..a7d6e6b242 100644
--- a/configs/_base_/models/upernet_mae.py
+++ b/configs/_base_/models/upernet_mae.py
@@ -1,9 +1,13 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ mean=[117.926186, 117.568402, 97.217239],
+ std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ # mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ # std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
@@ -12,7 +16,7 @@
pretrained=None,
backbone=dict(
type='MAE',
- img_size=(640, 640),
+ img_size=(256, 256),
patch_size=16,
in_channels=3,
embed_dims=768,
@@ -47,7 +51,7 @@
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
diff --git a/configs/_base_/models/upernet_vit-b16_ln_mln.py b/configs/_base_/models/upernet_vit-b16_ln_mln.py
index 776525ad98..0d43c1ffae 100644
--- a/configs/_base_/models/upernet_vit-b16_ln_mln.py
+++ b/configs/_base_/models/upernet_vit-b16_ln_mln.py
@@ -2,20 +2,24 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
- pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
+ #pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
backbone=dict(
type='VisionTransformer',
- img_size=(512, 512),
+ img_size=(256, 256),
patch_size=16,
- in_channels=3,
+ in_channels=8,
embed_dims=768,
num_layers=12,
num_heads=12,
@@ -42,7 +46,7 @@
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
@@ -55,7 +59,7 @@
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
- num_classes=19,
+ num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
diff --git a/configs/_base_/schedules/schedule_20k.py b/configs/_base_/schedules/schedule_20k.py
index e809e3e880..45d8942e0b 100644
--- a/configs/_base_/schedules/schedule_20k.py
+++ b/configs/_base_/schedules/schedule_20k.py
@@ -19,6 +19,6 @@
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
- checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000, max_keep_ckpts=1, save_best='val/mIoU'),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
diff --git a/configs/_base_/schedules/schedule_40k.py b/configs/_base_/schedules/schedule_40k.py
index 4b823339a2..a2b7269844 100644
--- a/configs/_base_/schedules/schedule_40k.py
+++ b/configs/_base_/schedules/schedule_40k.py
@@ -17,8 +17,8 @@
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
- logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
- checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000, max_keep_ckpts=1, save_best='val/mIoU'),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
diff --git a/configs/beit/beit-UAVflood.py b/configs/beit/beit-UAVflood.py
new file mode 100644
index 0000000000..16cb337835
--- /dev/null
+++ b/configs/beit/beit-UAVflood.py
@@ -0,0 +1,43 @@
+_base_ = [
+ '../_base_/models/upernet_beit.py', '../_base_/datasets/SARflood.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
+]
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ img_size=(256, 256)),
+ decode_head=dict(
+ num_classes=2),
+ auxiliary_head=dict(
+ num_classes=2),
+ test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170)))
+
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
+ constructor='LayerDecayOptimizerConstructor',
+ paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
+ dict(
+ type='PolyLR',
+ power=1.0,
+ begin=1500,
+ end=20000,
+ eta_min=0.0,
+ by_epoch=False,
+ )
+]
+
+
+train_dataloader = dict(batch_size=8, num_workers=8)
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
+
+
diff --git a/configs/convnext/convnext-base-uavflood.py b/configs/convnext/convnext-base-uavflood.py
new file mode 100644
index 0000000000..24c4924644
--- /dev/null
+++ b/configs/convnext/convnext-base-uavflood.py
@@ -0,0 +1,91 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py',
+ '../_base_/datasets/SARflood.py',
+ '../_base_/default_runtime.py'
+]
+
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ in_channels=[128, 256, 512, 1024],
+ num_classes=2 # UAVflood二分类:background, flood
+ ),
+ auxiliary_head=dict(
+ in_channels=512,
+ num_classes=2
+ ),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+# 使用AMP优化器包装器以提高训练效率
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05
+ ),
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ },
+ constructor='LearningRateDecayOptimizerConstructor',
+ loss_scale='dynamic'
+)
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5 # warmup前5个epoch
+ ),
+ dict(
+ type='PolyLR',
+ power=1.0,
+ begin=5,
+ end=100,
+ eta_min=0.0,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
+
+# 数据加载器配置,适配UAVflood数据集
+train_dataloader = dict(batch_size=8, num_workers=8)
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
+
+# 可选:设置随机种子以提高可复现性
+randomness = dict(
+ seed=42,
+ deterministic=False,
+)
\ No newline at end of file
diff --git a/configs/deeplabv3plus/Deeplabv3+UAVflood.py b/configs/deeplabv3plus/Deeplabv3+UAVflood.py
new file mode 100644
index 0000000000..d22cc9fc3b
--- /dev/null
+++ b/configs/deeplabv3plus/Deeplabv3+UAVflood.py
@@ -0,0 +1,66 @@
+_base_ = [
+ '../_base_/models/deeplabv3plus_r50-d8.py',
+ '../_base_/datasets/UAVflood.py',
+ '../_base_/default_runtime.py'
+]
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=dict(num_classes=2))
+
+# 优化器配置
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005),
+ clip_grad=None)
+
+# 修改为epoch训练,训练100个epoch
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5), # warmup前5个epoch
+ dict(
+ type='PolyLR',
+ eta_min=1e-4,
+ power=0.9,
+ begin=5,
+ end=100,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+randomness = dict(
+ seed=42,
+ deterministic=False, # 如需完全可复现,设为True
+)
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_e6_k1.py b/configs/floodnet/ablations/ablation_e6_k1.py
new file mode 100644
index 0000000000..dad0e3d267
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_e6_k1.py
@@ -0,0 +1,180 @@
+"""
+MoE Hyperparameter Study: 6 experts, top_k=1
+Table 3
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_e6_k1.py \
+ --work-dir work_dirs/moe_hyper/e6_k1/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 6
+top_k = 1
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True, use_modal_bias=True,
+ moe_balance_weight=1.0, moe_diversity_weight=0.1,
+ multi_tasks_reweight=None, decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_e6_k2.py b/configs/floodnet/ablations/ablation_e6_k2.py
new file mode 100644
index 0000000000..ae0dcb467d
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_e6_k2.py
@@ -0,0 +1,180 @@
+"""
+MoE Hyperparameter Study: 6 experts, top_k=2
+Table 3
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_e6_k2.py \
+ --work-dir work_dirs/moe_hyper/e6_k2/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 6
+top_k = 2
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True, use_modal_bias=True,
+ moe_balance_weight=1.0, moe_diversity_weight=0.1,
+ multi_tasks_reweight=None, decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_e6_k3.py b/configs/floodnet/ablations/ablation_e6_k3.py
new file mode 100644
index 0000000000..8527debdc2
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_e6_k3.py
@@ -0,0 +1,180 @@
+"""
+MoE Hyperparameter Study: 6 experts, top_k=3
+Table 3
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_e6_k3.py \
+ --work-dir work_dirs/moe_hyper/e6_k3/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 6
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True, use_modal_bias=True,
+ moe_balance_weight=1.0, moe_diversity_weight=0.1,
+ multi_tasks_reweight=None, decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_e8_k1.py b/configs/floodnet/ablations/ablation_e8_k1.py
new file mode 100644
index 0000000000..f084f5f53b
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_e8_k1.py
@@ -0,0 +1,180 @@
+"""
+MoE Hyperparameter Study: 8 experts, top_k=1
+Table 3
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_e8_k1.py \
+ --work-dir work_dirs/moe_hyper/e8_k1/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 1
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True, use_modal_bias=True,
+ moe_balance_weight=1.0, moe_diversity_weight=0.1,
+ multi_tasks_reweight=None, decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_e8_k2.py b/configs/floodnet/ablations/ablation_e8_k2.py
new file mode 100644
index 0000000000..1b90ffae68
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_e8_k2.py
@@ -0,0 +1,180 @@
+"""
+MoE Hyperparameter Study: 8 experts, top_k=2
+Table 3
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_e8_k2.py \
+ --work-dir work_dirs/moe_hyper/e8_k2/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 2
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True, use_modal_bias=True,
+ moe_balance_weight=1.0, moe_diversity_weight=0.1,
+ multi_tasks_reweight=None, decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_experts_16.py b/configs/floodnet/ablations/ablation_experts_16.py
new file mode 100644
index 0000000000..30c3f39ae5
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_experts_16.py
@@ -0,0 +1,186 @@
+"""
+Ablation Study: 16 Experts (vs default 8)
+Increases MoE capacity to study effect of expert count.
+
+Table 3 Row: num_experts=16, top_k=4
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_experts_16.py \
+ --work-dir work_dirs/ablations/experts_16/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 16 # <<< 16 instead of 8
+top_k = 4 # <<< 4 instead of 3 (keep ~25% ratio)
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_experts_4.py b/configs/floodnet/ablations/ablation_experts_4.py
new file mode 100644
index 0000000000..f9ce12078c
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_experts_4.py
@@ -0,0 +1,186 @@
+"""
+Ablation Study: 4 Experts (vs default 8)
+Reduces MoE capacity to study effect of expert count.
+
+Table 3 Row: num_experts=4, top_k=2
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_experts_4.py \
+ --work-dir work_dirs/ablations/experts_4/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 4 # <<< 4 instead of 8
+top_k = 2 # <<< 2 instead of 3 (keep ~50% ratio)
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_gf_only.py b/configs/floodnet/ablations/ablation_gf_only.py
new file mode 100644
index 0000000000..38c4e5fc4e
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_gf_only.py
@@ -0,0 +1,187 @@
+"""
+Ablation Study: GaoFen-Only Training
+Single-modal baseline using only GaoFen satellite data.
+
+Table 4 Row: GF-only
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_gf_only.py \
+ --work-dir work_dirs/ablations/gf_only/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# GF-only: filter_modality='GF'
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='GF',
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='GF',
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='GF',
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_no_diversity_loss.py b/configs/floodnet/ablations/ablation_no_diversity_loss.py
new file mode 100644
index 0000000000..e28bca1a9a
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_no_diversity_loss.py
@@ -0,0 +1,187 @@
+"""
+Ablation Study: No Expert Diversity Loss
+MoE enabled with modal bias but expert diversity loss disabled.
+
+Table 2 Row (e): w/o Expert Diversity Loss
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_no_diversity_loss.py \
+ --work-dir work_dirs/ablations/no_diversity_loss/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.0, # <<< DISABLED
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds,
+ use_expert_diversity_loss=False, # <<< DISABLED
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_no_modal_bias.py b/configs/floodnet/ablations/ablation_no_modal_bias.py
new file mode 100644
index 0000000000..dffd5e8032
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_no_modal_bias.py
@@ -0,0 +1,185 @@
+"""
+Ablation Study: No Modal Bias
+MoE is enabled but modal-specific routing bias is disabled.
+
+Table 2 Row (c): w/o Modal Bias
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_no_modal_bias.py \
+ --work-dir work_dirs/ablations/no_modal_bias/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=False, # <<< DISABLED
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_no_modal_specific_stem.py b/configs/floodnet/ablations/ablation_no_modal_specific_stem.py
new file mode 100644
index 0000000000..6f16ab855b
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_no_modal_specific_stem.py
@@ -0,0 +1,187 @@
+"""
+Ablation Study: No ModalSpecificStem (Unified Patch Embedding)
+Uses a single shared Conv2d with zero-padding instead of per-modal patch embedding.
+
+Table 2 Row: w/o ModalSpecificStem
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_no_modal_specific_stem.py \
+ --work-dir work_dirs/ablations/no_modal_specific_stem/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_modal_specific_stem=False, # <<< DISABLED: use UnifiedPatchEmbed
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_no_moe.py b/configs/floodnet/ablations/ablation_no_moe.py
new file mode 100644
index 0000000000..c18e7bf0a3
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_no_moe.py
@@ -0,0 +1,257 @@
+"""
+Ablation Study: No MoE (Standard Swin-Base + UPerNet baseline)
+All MoE components disabled - uses standard FFN instead of MoE FFN.
+
+Table 2 Row (b): w/o MoE
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_no_moe.py \
+ --work-dir work_dirs/ablations/no_moe/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+# ==================== Dataset Definition ====================
+DATASETS_CONFIG = dict(
+ names=['sar', 'rgb', 'GF'],
+)
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar',
+ 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb',
+ 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF',
+ 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+# ==================== MoE DISABLED ====================
+depths = [2, 2, 18, 2]
+
+# ==================== Model Config ====================
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor',
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=False, # <<< DISABLED
+ use_modal_bias=False, # No MoE means no modal bias
+ moe_balance_weight=0.0,
+ moe_diversity_weight=0.0,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+
+ pretrain_img_size=224,
+ patch_size=4,
+ embed_dims=128,
+ depths=depths,
+ num_heads=[4, 8, 16, 32],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+
+ # ---- MoE DISABLED ----
+ use_moe=False,
+ num_experts=8,
+ num_shared_experts_config={0: 0, 1: 0, 2: 0, 3: 0},
+ top_k=3,
+ noisy_gating=False,
+ MoE_Block_inds=[[], [], [], []],
+ use_expert_diversity_loss=False,
+
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0
+ )
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=512,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=0.4
+ )
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+# ==================== Dataset Config ====================
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF',
+ batch_size=16,
+ ),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images',
+ seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='val/images',
+ seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='test/images',
+ seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+# ==================== Optimizer Config ====================
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'patch_embed': dict(lr_mult=2.0),
+ 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.),
+ }
+ )
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_no_shared_experts.py b/configs/floodnet/ablations/ablation_no_shared_experts.py
new file mode 100644
index 0000000000..60c0fa7090
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_no_shared_experts.py
@@ -0,0 +1,186 @@
+"""
+Ablation Study: No Shared Experts
+MoE enabled but shared experts removed from all stages.
+
+Table 2 Row (d): w/o Shared Experts
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_no_shared_experts.py \
+ --work-dir work_dirs/ablations/no_shared_experts/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 0, 3: 0} # <<< ALL ZEROS
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_rgb_only.py b/configs/floodnet/ablations/ablation_rgb_only.py
new file mode 100644
index 0000000000..070b6114a0
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_rgb_only.py
@@ -0,0 +1,187 @@
+"""
+Ablation Study: RGB-Only Training
+Single-modal baseline using only RGB optical data.
+
+Table 4 Row: RGB-only
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_rgb_only.py \
+ --work-dir work_dirs/ablations/rgb_only/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# RGB-only: filter_modality='rgb'
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='rgb',
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='rgb',
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='rgb',
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_sar_only.py b/configs/floodnet/ablations/ablation_sar_only.py
new file mode 100644
index 0000000000..84f2ffd2bd
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_sar_only.py
@@ -0,0 +1,187 @@
+"""
+Ablation Study: SAR-Only Training
+Single-modal baseline using only Synthetic Aperture Radar (SAR) data.
+
+Table 4 Row: SAR-only
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_sar_only.py \
+ --work-dir work_dirs/ablations/sar_only/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# SAR-only: filter_modality='sar'
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_shared_decoder.py b/configs/floodnet/ablations/ablation_shared_decoder.py
new file mode 100644
index 0000000000..4b5646146b
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_shared_decoder.py
@@ -0,0 +1,186 @@
+"""
+Ablation Study: Shared Decoder
+Full model but with a single shared decode head for all modalities.
+
+Table 2 Row (g): Shared Decoder (vs Separate)
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_shared_decoder.py \
+ --work-dir work_dirs/ablations/shared_decoder/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='shared', # <<< SHARED instead of separate
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_topk_1.py b/configs/floodnet/ablations/ablation_topk_1.py
new file mode 100644
index 0000000000..4386f53fda
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_topk_1.py
@@ -0,0 +1,186 @@
+"""
+Ablation Study: Top-k=1 (vs default k=3)
+Only one expert activated per token.
+
+Table 3 Row: num_experts=8, top_k=1
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_topk_1.py \
+ --work-dir work_dirs/ablations/topk_1/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 1 # <<< 1 instead of 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF', batch_size=16),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/ablations/ablation_uniform_sampling.py b/configs/floodnet/ablations/ablation_uniform_sampling.py
new file mode 100644
index 0000000000..e7d8fe27ab
--- /dev/null
+++ b/configs/floodnet/ablations/ablation_uniform_sampling.py
@@ -0,0 +1,184 @@
+"""
+Ablation Study: Uniform Sampling (No SAR Boost)
+Full model but with uniform random sampling instead of SAR-boosted ratio.
+
+Table 2 Row (f): w/o SAR Boost Sampling
+
+Usage:
+ python tools/train.py configs/floodnet/ablations/ablation_uniform_sampling.py \
+ --work-dir work_dirs/ablations/uniform_sampling/ --seed 42
+"""
+
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF'])
+
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]]
+num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ pretrain_img_size=224, patch_size=4,
+ embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
+ patch_norm=True, out_indices=[0, 1, 2, 3],
+ use_moe=True, num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k, noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True,
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead', in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512,
+ dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead', in_channels=512, in_index=2, channels=256,
+ num_convs=1, concat_input=False, dropout_ratio=0.1,
+ num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# <<< KEY CHANGE: DefaultSampler instead of FixedRatioModalSampler >>>
+train_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='train/images', seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='val/images', seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16, num_workers=8, persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(type=dataset_type, data_root=data_root,
+ data_prefix=dict(img_path='test/images', seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(custom_keys={
+ 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ })
+)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5),
+ dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True),
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10,
+ max_keep_ckpts=1, save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/continue_train_150ep.py b/configs/floodnet/continue_train_150ep.py
new file mode 100644
index 0000000000..d269d17019
--- /dev/null
+++ b/configs/floodnet/continue_train_150ep.py
@@ -0,0 +1,28 @@
+"""
+Continue training Full Model for 50 more epochs (101-150).
+"""
+
+_base_ = ['./multimodal_floodnet_sar_boost_swinbase_moe_config.py']
+
+# Extend to 150 epochs
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=150,
+ val_interval=10)
+
+# Extend PolyLR end to 150
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=150,
+ by_epoch=True),
+]
diff --git a/configs/floodnet/finetune_sen1floods11_s1.py b/configs/floodnet/finetune_sen1floods11_s1.py
new file mode 100644
index 0000000000..bec6510938
--- /dev/null
+++ b/configs/floodnet/finetune_sen1floods11_s1.py
@@ -0,0 +1,187 @@
+"""
+Sen1Floods11 fine-tune: S1Hand (2-band VV/VH SAR).
+
+Expected on-disk layout (see tools/setup_sen1floods11.py)::
+
+ data/Sen1Floods11/
+ S1Hand/_S1Hand.tif # 2-band SAR (VV, VH in dB)
+ S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI (unused here)
+ LabelHand/_LabelHand.tif # 1-band label: -1=nodata, 0=bg, 1=flood
+ splits/train.txt # one per line, generated by
+ splits/val.txt # tools/setup_sen1floods11.py
+ splits/test.txt
+
+All tiles are 512x512. Labels are signed TIFFs; the ``-1`` nodata
+value is remapped to ``ignore_index=255`` by
+:class:`LoadSen1Floods11Annotation`.
+
+Setup (run once, before the first training run):
+
+ # 1. Generate train/val/test splits + compute mean/std for s1 & s2.
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11
+
+ # 2. Paste the printed NORM_CONFIGS block into the 's1' entry in
+ # mmseg/datasets/transforms/multimodal_pipelines.py
+ # (skip this if you're happy with the shipped defaults).
+
+Training:
+
+ python tools/train.py configs/floodnet/finetune_sen1floods11_s1.py \\
+ --work-dir work_dirs/generalization/sen1floods11_s1/ \\
+ --cfg-options \\
+ load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth"
+
+The pretrained Swin-Base + MoE checkpoint has modal-specific patch
+embeds / decode heads for ``sar / rgb / GF``. Here we redefine the
+single trainable modal as ``s1`` (2-channel) and the single decode-head
+key as ``s1``; the mismatched pretrained keys are ignored on load
+(``strict=False``), so the new patch-embed and decode head train from
+scratch while the frozen Swin stages reuse their pretrained weights.
+"""
+
+_base_ = ['./finetune_stem_decoder.py']
+
+# ==================== Modal / dataset identifiers ====================
+MODAL_NAME = 's1'
+MODAL_CHANNELS = 2
+
+ALL_KNOWN_MODALS = {
+ # _delete_ forces a full replacement so the base config's
+ # sar/rgb/GF entries don't leak through.
+ '_delete_': True,
+ MODAL_NAME: {
+ 'channels': MODAL_CHANNELS,
+ 'pattern': 's1hand',
+ 'description': 'Sen1Floods11 S1Hand (VV/VH dB)',
+ },
+}
+TRAINING_MODALS = [MODAL_NAME]
+DATASET_NAMES = [MODAL_NAME]
+
+# 512x512 input -> crop to 256x256 for training (gives ~4x position
+# diversity per tile and keeps each training step cheap).
+crop_size = (256, 256)
+
+# ==================== Model override ====================
+model = dict(
+ dataset_names=DATASET_NAMES,
+ backbone=dict(
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ # frozen_stages / freeze_patch_embed inherit from base
+ ),
+)
+
+# ==================== Dataset / pipelines ====================
+dataset_type = 'Sen1Floods11Dataset'
+data_root = 'data/Sen1Floods11/'
+
+# All tiles are already 512x512, so no RandomResize is needed.
+# Pipeline order:
+# 1. Load image (sets 'img')
+# 2. Load annotation (sets 'gt_seg_map', 'seg_fields')
+# - must come BEFORE RandomCrop so cat_max_ratio can check label
+# distribution and avoid degenerate all-background crops.
+# 3. RandomCrop 256x256
+# 4. RandomFlip
+# 5. MultiModalNormalize (NaN/Inf-safe, modality-aware mean/std)
+# 6. MultiModalPad to crop_size (no-op if RandomCrop matched exactly,
+# but handles any edge case where the tile is SegDataSample
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadSen1Floods11Annotation'),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# Test / val: feed the full 512x512 tile; sliding-window inference
+# (configured in the base config as mode='slide', crop_size=256,
+# stride=170) handles the cropping internally.
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadSen1Floods11Annotation'),
+ dict(type='MultiModalNormalize'),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# ==================== Dataloaders ====================
+# _delete_=True on sampler forces full replacement so the base
+# FixedRatioModalSampler config is discarded - Sen1Floods11 is single
+# modal, so we just use the standard shuffling sampler.
+#
+# ann_file='splits/train.txt' is resolved relative to data_root by
+# BaseSegDataset._join_prefix, so the actual path read is
+# data/Sen1Floods11/splits/train.txt . Generate these files with
+# tools/setup_sen1floods11.py before training.
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ _delete_=True,
+ type='DefaultSampler',
+ shuffle=True,
+ ),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/train.txt',
+ data_prefix=dict(
+ img_path='S1Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=train_pipeline,
+ ),
+)
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/val.txt',
+ data_prefix=dict(
+ img_path='S1Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=test_pipeline,
+ ),
+)
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/test.txt',
+ data_prefix=dict(
+ img_path='S1Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=test_pipeline,
+ ),
+)
diff --git a/configs/floodnet/finetune_sen1floods11_s2.py b/configs/floodnet/finetune_sen1floods11_s2.py
new file mode 100644
index 0000000000..a2e4fe1a41
--- /dev/null
+++ b/configs/floodnet/finetune_sen1floods11_s2.py
@@ -0,0 +1,168 @@
+"""
+Sen1Floods11 fine-tune: S2Hand (13-band Sentinel-2 MSI).
+
+Sibling of ``finetune_sen1floods11_s1.py`` - same freeze / LR / schedule
+strategy, but swaps the trainable modal to ``s2`` (13 ch) and points
+the dataset at the ``S2Hand`` subdirectory. LabelHand files are shared
+between S1 and S2, so the exact same splits/train.txt / splits/val.txt
+/ splits/test.txt files are used - this means any S1 vs S2 comparison
+is done on the same tiles.
+
+Expected on-disk layout (see tools/setup_sen1floods11.py)::
+
+ data/Sen1Floods11/
+ S1Hand/_S1Hand.tif # 2-band SAR (unused here)
+ S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI
+ LabelHand/_LabelHand.tif # -1=nodata, 0=bg, 1=flood
+ splits/{train,val,test}.txt
+
+All tiles are 512x512.
+
+Setup (run once, before the first training run):
+
+ # 1. Generate splits + compute mean/std (covers both s1 and s2).
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11
+
+ # 2. Paste the printed NORM_CONFIGS block into the 's2' entry in
+ # mmseg/datasets/transforms/multimodal_pipelines.py
+ # (skip this if you're happy with the shipped defaults).
+
+Training:
+
+ python tools/train.py configs/floodnet/finetune_sen1floods11_s2.py \\
+ --work-dir work_dirs/generalization/sen1floods11_s2/ \\
+ --cfg-options \\
+ load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth"
+"""
+
+_base_ = ['./finetune_stem_decoder.py']
+
+# ==================== Modal / dataset identifiers ====================
+MODAL_NAME = 's2'
+MODAL_CHANNELS = 13
+
+ALL_KNOWN_MODALS = {
+ # _delete_ forces a full replacement so the base config's
+ # sar/rgb/GF entries don't leak through.
+ '_delete_': True,
+ MODAL_NAME: {
+ 'channels': MODAL_CHANNELS,
+ 'pattern': 's2hand',
+ 'description': 'Sen1Floods11 S2Hand (13-band Sentinel-2 MSI)',
+ },
+}
+TRAINING_MODALS = [MODAL_NAME]
+DATASET_NAMES = [MODAL_NAME]
+
+# 512x512 input -> crop to 256x256 for training.
+crop_size = (256, 256)
+
+# ==================== Model override ====================
+# See finetune_sen1floods11_s1.py for the rationale - the single
+# trainable modal here is `s2` with a 13-channel stem conv. The
+# s2 mean/std entry in MultiModalNormalize.NORM_CONFIGS is used
+# by default; re-run tools/setup_sen1floods11.py to refresh it
+# for your local data.
+model = dict(
+ dataset_names=DATASET_NAMES,
+ backbone=dict(
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+ ),
+)
+
+# ==================== Dataset / pipelines ====================
+dataset_type = 'Sen1Floods11Dataset'
+data_root = 'data/Sen1Floods11/'
+
+# Pipeline order matches finetune_sen1floods11_s1.py. See that file
+# for the rationale behind each step.
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadSen1Floods11Annotation'),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadSen1Floods11Annotation'),
+ dict(type='MultiModalNormalize'),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+# ==================== Dataloaders ====================
+# Shares splits/train.txt / splits/val.txt / splits/test.txt with
+# finetune_sen1floods11_s1.py so the S1 / S2 results are directly
+# comparable (same tiles in each split, just different sensors).
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ _delete_=True,
+ type='DefaultSampler',
+ shuffle=True,
+ ),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/train.txt',
+ data_prefix=dict(
+ img_path='S2Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=train_pipeline,
+ ),
+)
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/val.txt',
+ data_prefix=dict(
+ img_path='S2Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=test_pipeline,
+ ),
+)
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ _delete_=True,
+ type=dataset_type,
+ data_root=data_root,
+ modality=MODAL_NAME,
+ ann_file='splits/test.txt',
+ data_prefix=dict(
+ img_path='S2Hand',
+ seg_map_path='LabelHand',
+ ),
+ pipeline=test_pipeline,
+ ),
+)
diff --git a/configs/floodnet/finetune_single_modal.py b/configs/floodnet/finetune_single_modal.py
new file mode 100644
index 0000000000..b756f317a8
--- /dev/null
+++ b/configs/floodnet/finetune_single_modal.py
@@ -0,0 +1,48 @@
+"""
+Generalization Fine-tuning Config: Single-Modal (RGB only)
+
+Freeze backbone, retrain stem + decoder on a new RGB-only flood event.
+Based on finetune_stem_decoder.py but adapted for single-modality data:
+ - filter_modality='rgb' to only load RGB images
+ - DefaultSampler instead of FixedRatioModalSampler
+ - Uses the rgb decode_head from the pretrained separate-decoder model
+
+Usage:
+ python tools/train.py configs/floodnet/finetune_single_modal.py \
+ --work-dir work_dirs/generalization/LY-train-station/ \
+ --cfg-options \
+ train_dataloader.dataset.data_root="data/LY-train-station/" \
+ val_dataloader.dataset.data_root="data/LY-train-station/" \
+ test_dataloader.dataset.data_root="data/LY-train-station/"
+"""
+
+_base_ = ['./finetune_stem_decoder.py']
+
+# ==================== Override sampler to DefaultSampler ====================
+# _delete_=True forces full replacement instead of recursive merge,
+# otherwise FixedRatioModalSampler's modal_ratios etc. leak through
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ _delete_=True,
+ type='DefaultSampler',
+ shuffle=True,
+ ),
+ dataset=dict(
+ filter_modality='rgb',
+ ),
+)
+
+val_dataloader = dict(
+ dataset=dict(
+ filter_modality='rgb',
+ ),
+)
+
+test_dataloader = dict(
+ dataset=dict(
+ filter_modality='rgb',
+ ),
+)
diff --git a/configs/floodnet/finetune_stem_decoder.py b/configs/floodnet/finetune_stem_decoder.py
new file mode 100644
index 0000000000..ba149f71f5
--- /dev/null
+++ b/configs/floodnet/finetune_stem_decoder.py
@@ -0,0 +1,87 @@
+"""
+Generalization Fine-tuning Config: Freeze Backbone, Retrain Stem + Decoder
+
+Loads pretrained Full Model checkpoint, freezes all 4 Swin Transformer stages
+(attention, MoE, patch merging) via backbone.frozen_stages=3, and only retrains:
+ 1. ModalSpecificPatchEmbed (backbone.patch_embed) - adapts to new sensor inputs
+ 2. UPerHead decode_heads - adapts segmentation output to new domain
+ 3. FCNHead auxiliary_heads - auxiliary segmentation head
+
+Frozen stages use requires_grad=False + eval() mode for proper BN/Dropout behavior.
+
+Usage:
+ python tools/train.py configs/floodnet/finetune_stem_decoder.py \
+ --work-dir work_dirs/generalization/event_name/ \
+ --cfg-options \
+ load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth" \
+ train_dataloader.dataset.data_root="../floodnet/data/new_event/" \
+ val_dataloader.dataset.data_root="../floodnet/data/new_event/" \
+ test_dataloader.dataset.data_root="../floodnet/data/new_event/"
+"""
+
+_base_ = ['./multimodal_floodnet_sar_boost_swinbase_moe_config.py']
+
+# ==================== Load pretrained checkpoint ====================
+# Override this via --cfg-options load_from="path/to/checkpoint.pth"
+load_from = 'work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth'
+
+# ==================== Freeze all 4 backbone stages ====================
+# frozen_stages=3 freezes stages 0-3 (all Swin blocks, MoE, patch merging, norms)
+# freeze_patch_embed=False keeps ModalSpecificPatchEmbed trainable
+model = dict(
+ backbone=dict(
+ frozen_stages=3, # Freeze all 4 stages (0,1,2,3)
+ freeze_patch_embed=False # Keep stem trainable
+ ),
+)
+
+# ==================== Optimizer for fine-tuning ====================
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.0001, # Higher base LR for fine-tuning fewer params
+ betas=(0.9, 0.999),
+ weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ # Stem: moderate LR boost
+ 'backbone.patch_embed': dict(lr_mult=2.0),
+ # Decode heads: high LR for fast adaptation
+ 'decode_heads': dict(lr_mult=10.0),
+ 'auxiliary_heads': dict(lr_mult=10.0),
+ }
+ )
+)
+
+# ==================== Shorter schedule for fine-tuning ====================
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-4,
+ by_epoch=True,
+ begin=0,
+ end=10), # Short warmup
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=2,
+ end=30,
+ by_epoch=True),
+]
+
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100, # Much shorter than full training (100 epochs)
+ val_interval=10)
+
+# ==================== Checkpoint hook ====================
+default_hooks = dict(
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10,
+ max_keep_ckpts=1,
+ save_best='mIoU'),
+)
diff --git a/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py
new file mode 100644
index 0000000000..936e4eaed6
--- /dev/null
+++ b/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py
@@ -0,0 +1,298 @@
+"""
+SwinMoE: Multi-Dataset FloodNet Training - SAR-Boosted Configuration
+MMSeg 1.x Version
+
+Usage:
+ python tools/train.py configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py \
+ --work-dir work_dirs/swin_moe_upernet/SAR_Boost/ --seed 42
+"""
+
+_base_ = [
+ '../_base_/default_runtime.py',
+]
+
+# ==================== Dataset Definition ====================
+DATASETS_CONFIG = dict(
+ names=['sar', 'rgb', 'GF'],
+)
+
+# ==================== Modal Definition ====================
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar',
+ 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb',
+ 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF',
+ 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+# ==================== MoE Config ====================
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [
+ [],
+ [1],
+ [1, 3, 5, 7, 9, 11, 13, 15, 17],
+ [0, 1]
+]
+
+num_shared_experts_config = {
+ 0: 0,
+ 1: 0,
+ 2: 2,
+ 3: 1
+}
+
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+# ==================== Model Config ====================
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor',
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+
+ pretrain_img_size=224,
+ patch_size=4,
+ embed_dims=96,
+ depths=depths,
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+
+ use_moe=True,
+ num_experts=8,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=3,
+ noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds,
+ use_expert_diversity_loss=True,
+
+ pretrained=None,
+ ),
+
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[96, 192, 384, 768],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0
+ )
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=384,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=0.4
+ )
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole')
+)
+
+# ==================== Dataset Config ====================
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/SAR'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='Resize', scale=(256, 256), keep_ratio=False),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(256, 256), keep_ratio=False),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='sar',
+ batch_size=16,
+ ),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images',
+ seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='val/images',
+ seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='test/images',
+ seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+# ==================== Optimizer Config ====================
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'patch_embed': dict(lr_mult=2.0),
+ 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.),
+ 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5),
+ 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ }
+ )
+)
+
+# ==================== Learning Rate Config ====================
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=200,
+ by_epoch=True),
+]
+
+# ==================== Training Loop ====================
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# ==================== Hooks ====================
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10,
+ max_keep_ckpts=1,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+# ==================== Runtime Overrides ====================
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py
new file mode 100644
index 0000000000..0109a0a30d
--- /dev/null
+++ b/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py
@@ -0,0 +1,305 @@
+"""
+Swin-Base + MoE: Multi-Dataset FloodNet Training - SAR-Boosted Configuration
+MMSeg 1.x Version
+
+Based on Swin-uavflood-256x256 (Swin-Base) upgraded with MoE.
+Backbone: Swin-Base (embed_dims=128) + MoE (8 experts, top_k=3)
+Estimated params: ~456M (vs Swin-T+MoE ~278M, vs Swin-B ~109M)
+
+Usage:
+ python tools/train.py configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \
+ --work-dir work_dirs/swinbase_moe_upernet/SAR_Boost/ --seed 42
+"""
+
+_base_ = [
+ '../_base_/default_runtime.py',
+]
+
+# ==================== Dataset Definition ====================
+DATASETS_CONFIG = dict(
+ names=['sar', 'rgb', 'GF'],
+)
+
+# ==================== Modal Definition ====================
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar',
+ 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb',
+ 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF',
+ 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+# ==================== MoE Config ====================
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [
+ [],
+ [1],
+ [1, 3, 5, 7, 9, 11, 13, 15, 17],
+ [0, 1]
+]
+
+num_shared_experts_config = {
+ 0: 0,
+ 1: 0,
+ 2: 2,
+ 3: 1
+}
+
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+# ==================== Model Config ====================
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor',
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+
+ pretrain_img_size=224,
+ patch_size=4,
+ # ---- Swin-Base dimensions ----
+ embed_dims=128,
+ depths=depths,
+ num_heads=[4, 8, 16, 32],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3, # Swin-B uses 0.3 (vs Swin-T 0.2)
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+
+ # ---- MoE ----
+ use_moe=True,
+ num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k,
+ noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds,
+ use_expert_diversity_loss=True,
+
+ pretrained=None,
+ ),
+
+ # ---- Swin-Base output channels: [128, 256, 512, 1024] ----
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0
+ )
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=512, # Swin-B stage2 output
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=0.4
+ )
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+# ==================== Dataset Config ====================
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(
+ type='FixedRatioModalSampler',
+ modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5},
+ modal_order=['sar', 'rgb', 'GF'],
+ reference_modal='GF',
+ batch_size=16,
+ ),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='train/images',
+ seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='val/images',
+ seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='test/images',
+ seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+# ==================== Optimizer Config ====================
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'patch_embed': dict(lr_mult=2.0),
+ 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.),
+ 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5),
+ 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ }
+ )
+)
+
+# ==================== Learning Rate Config ====================
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=100,
+ by_epoch=True),
+]
+
+# ==================== Training Loop ====================
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# ==================== Hooks ====================
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10,
+ max_keep_ckpts=1,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+# ==================== Runtime Overrides ====================
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py
new file mode 100644
index 0000000000..d1db87be7c
--- /dev/null
+++ b/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py
@@ -0,0 +1,300 @@
+"""
+Swin-Base + MoE: SAR-Only FloodNet Training Configuration
+MMSeg 1.x Version
+
+Only trains on SAR modality (100 epochs).
+
+Usage:
+ python tools/train.py configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \
+ --work-dir work_dirs/floodnet/SwinmoeB_sar_only/ --cfg-options seed=42
+"""
+
+_base_ = [
+ '../_base_/default_runtime.py',
+]
+
+# ==================== Dataset Definition ====================
+DATASETS_CONFIG = dict(
+ names=['sar', 'rgb', 'GF'],
+)
+
+# ==================== Modal Definition ====================
+ALL_KNOWN_MODALS = {
+ 'sar': {'channels': 8, 'pattern': 'sar',
+ 'description': 'Synthetic Aperture Radar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb',
+ 'description': 'RGB Optical'},
+ 'GF': {'channels': 5, 'pattern': 'GF',
+ 'description': 'GaoFen Satellite'},
+}
+
+TRAINING_MODALS = ['sar', 'rgb', 'GF']
+num_classes = 2
+
+# ==================== MoE Config ====================
+depths = [2, 2, 18, 2]
+MoE_Block_inds = [
+ [],
+ [1],
+ [1, 3, 5, 7, 9, 11, 13, 15, 17],
+ [0, 1]
+]
+
+num_shared_experts_config = {
+ 0: 0,
+ 1: 0,
+ 2: 2,
+ 3: 1
+}
+
+num_experts = 8
+top_k = 3
+noisy_gating = True
+
+# ==================== Model Config ====================
+norm_cfg = dict(type='BN', requires_grad=True)
+crop_size = (256, 256)
+
+data_preprocessor = dict(
+ type='MultiModalDataPreProcessor',
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+)
+
+model = dict(
+ type='MultiModalEncoderDecoderV2',
+ data_preprocessor=data_preprocessor,
+ use_moe=True,
+ use_modal_bias=True,
+ moe_balance_weight=1.0,
+ moe_diversity_weight=0.1,
+ multi_tasks_reweight=None,
+ decoder_mode='separate',
+ dataset_names=DATASETS_CONFIG['names'],
+
+ backbone=dict(
+ type='MultiModalSwinMoE',
+ modal_configs=ALL_KNOWN_MODALS,
+ training_modals=TRAINING_MODALS,
+
+ pretrain_img_size=224,
+ patch_size=4,
+ # ---- Swin-Base dimensions ----
+ embed_dims=128,
+ depths=depths,
+ num_heads=[4, 8, 16, 32],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+
+ # ---- MoE ----
+ use_moe=True,
+ num_experts=num_experts,
+ num_shared_experts_config=num_shared_experts_config,
+ top_k=top_k,
+ noisy_gating=noisy_gating,
+ MoE_Block_inds=MoE_Block_inds,
+ use_expert_diversity_loss=True,
+
+ pretrained=None,
+ ),
+
+ # ---- Swin-Base output channels: [128, 256, 512, 1024] ----
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0
+ )
+ ),
+
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=512,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=0.4
+ )
+ ),
+
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+# ==================== Dataset Config ====================
+dataset_type = 'MultiModalDeepflood'
+data_root = '../floodnet/data/mixed_dataset/'
+
+train_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='GenerateBoundary', thickness=3),
+ dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='MultiModalNormalize'),
+ dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+test_pipeline = [
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
+ dict(type='MultiModalNormalize'),
+ dict(type='LoadAnnotations', reduce_zero_label=False),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+]
+
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(
+ img_path='train/images',
+ seg_map_path='train/labels'),
+ pipeline=train_pipeline),
+)
+
+val_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(
+ img_path='val/images',
+ seg_map_path='val/labels'),
+ pipeline=test_pipeline),
+)
+
+test_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ filter_modality='sar',
+ data_prefix=dict(
+ img_path='test/images',
+ seg_map_path='test/labels'),
+ pipeline=test_pipeline),
+)
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
+
+# ==================== Optimizer Config ====================
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'patch_embed': dict(lr_mult=2.0),
+ 'modal_patch_embeds': dict(lr_mult=2.0),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.),
+ 'gating': dict(lr_mult=2.0),
+ 'experts': dict(lr_mult=1.5),
+ 'modal_bias': dict(lr_mult=3.0),
+ 'shared_experts': dict(lr_mult=2.0),
+ }
+ )
+)
+
+# ==================== Learning Rate Config ====================
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=100,
+ by_epoch=True),
+]
+
+# ==================== Training Loop ====================
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# ==================== Hooks ====================
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10,
+ max_keep_ckpts=1,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'),
+)
+
+# ==================== Runtime Overrides ====================
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+log_processor = dict(by_epoch=True)
diff --git a/configs/mae/mae-base-Uavflood.py b/configs/mae/mae-base-Uavflood.py
new file mode 100644
index 0000000000..828fd929f1
--- /dev/null
+++ b/configs/mae/mae-base-Uavflood.py
@@ -0,0 +1,53 @@
+_base_ = [
+ '../_base_/models/upernet_mae.py', '../_base_/datasets/UAVflood.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
+]
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='MAE',
+ img_size=(256, 256),
+ patch_size=16,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ init_values=1.0,
+ drop_path_rate=0.1,
+ out_indices=[3, 5, 7, 11]),
+ neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]),
+ decode_head=dict(
+ in_channels=[768, 768, 768, 768], num_classes=2, channels=768),
+ auxiliary_head=dict(in_channels=768, num_classes=2),
+ test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170)))
+
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05),
+ paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65),
+ constructor='LayerDecayOptimizerConstructor')
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=1500,
+ end=20000,
+ by_epoch=False,
+ )
+]
+
+# mixed precision
+fp16 = dict(loss_scale='dynamic')
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+train_dataloader = dict(batch_size=8)
+val_dataloader = dict(batch_size=8)
+test_dataloader = val_dataloader
diff --git a/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py b/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py
new file mode 100644
index 0000000000..0d42f170be
--- /dev/null
+++ b/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py
@@ -0,0 +1,73 @@
+_base_ = [
+ '../_base_/models/segformer_mit-b0.py',
+ '../_base_/datasets/UAVflood.py',
+ '../_base_/default_runtime.py'
+]
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2 # 修改为你需要的类别数
+ )
+)
+
+randomness = dict(
+ seed=42,
+ deterministic=False, # 如需完全可复现,设为True
+)
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.)
+ }))
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=100,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
+
+train_dataloader = dict(batch_size=8, num_workers=8)
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
diff --git a/configs/swin/Swin-uavflood-256x256.py b/configs/swin/Swin-uavflood-256x256.py
new file mode 100644
index 0000000000..2826cfc2e0
--- /dev/null
+++ b/configs/swin/Swin-uavflood-256x256.py
@@ -0,0 +1,154 @@
+_base_ = [
+ '../_base_/datasets/GFflood.py',
+ '../_base_/default_runtime.py',
+ '../_base_/models/upernet_swin.py'
+]
+
+# ===== 基础配置 =====
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+backbone_norm_cfg = dict(type='LN', requires_grad=True)
+crop_size = (256, 256)
+
+# ===== 数据预处理器 =====
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ #mean=[117.926186, 117.568402, 97.217239],
+ #std=[53.542876104049824, 50.084170325219176, 50.49331035114637],
+ mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789],
+ std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978],
+ #mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ #std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ bgr_to_rgb=False,
+ pad_val=0,
+ seg_pad_val=255
+)
+
+# ===== 模型配置 - Swin-Base =====
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ backbone=dict(
+ type='SwinTransformer',
+ pretrain_img_size=224,
+ embed_dims=128, # Base: 128
+ in_channels=5,
+ patch_size=4,
+ window_size=7,
+ mlp_ratio=4,
+ depths=[2, 2, 18, 2], # Base: [2, 2, 18, 2]
+ num_heads=[4, 8, 16, 32], # Base: [4, 8, 16, 32]
+ strides=(4, 2, 2, 2),
+ out_indices=(0, 1, 2, 3),
+ qkv_bias=True,
+ qk_scale=None,
+ patch_norm=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ use_abs_pos_embed=False,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=backbone_norm_cfg
+ ),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[128, 256, 512, 1024], # Base输出通道
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
+ ),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=512, # Base stage2输出
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
+ ),
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170))
+)
+
+# ===== 优化器配置 =====
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.01
+ ),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)
+ }
+ )
+)
+
+# ===== 学习率调度器 =====
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-6,
+ by_epoch=True,
+ begin=0,
+ end=5 # warmup前5个epoch
+ ),
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=100,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
+
+# ===== 数据加载器配置 =====
+train_dataloader = dict(batch_size=8, num_workers=8) # Base模型较大,减小batch_size
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
+
+# ===== 随机种子 =====
+randomness = dict(seed=42, deterministic=False)
diff --git a/configs/unet/Unet-Uavflood.py b/configs/unet/Unet-Uavflood.py
new file mode 100644
index 0000000000..2b7d902c71
--- /dev/null
+++ b/configs/unet/Unet-Uavflood.py
@@ -0,0 +1,66 @@
+_base_ = [
+ '../_base_/models/deeplabv3_unet_s5-d16.py',
+ '../_base_/datasets/SARflood.py',
+ '../_base_/default_runtime.py'
+]
+
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=dict(num_classes=2),
+ test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170)))
+
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005))
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch
+ dict(
+ type='PolyLR',
+ eta_min=1e-4,
+ power=0.9,
+ begin=5,
+ end=100,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+randomness = dict(
+ seed=42,
+ deterministic=False, # 如需完全可复现,设为True
+)
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
+
+train_dataloader = dict(batch_size=8, num_workers=8)
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
\ No newline at end of file
diff --git a/configs/vit/vit-Uavflood.py b/configs/vit/vit-Uavflood.py
new file mode 100644
index 0000000000..24da97f435
--- /dev/null
+++ b/configs/vit/vit-Uavflood.py
@@ -0,0 +1,77 @@
+_base_ = [
+ '../_base_/models/upernet_vit-b16_ln_mln.py',
+ '../_base_/datasets/SARflood.py',
+ '../_base_/default_runtime.py'
+]
+
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ img_size=(256, 256),
+ drop_path_rate=0.1,
+ final_norm=True),
+ decode_head=dict(num_classes=2),
+ auxiliary_head=dict(num_classes=2))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm
+# in backbone
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'pos_embed': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)
+ }))
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch
+ dict(
+ type='PolyLR',
+ eta_min=0.0,
+ power=1.0,
+ begin=5,
+ end=100,
+ by_epoch=True,
+ )
+]
+
+# 使用EpochBasedTrainLoop训练100个epoch
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=100,
+ val_interval=10) # 每10个epoch验证一次
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# 修改hooks为基于epoch
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook',
+ by_epoch=True,
+ interval=10, # 每10个epoch保存一次
+ max_keep_ckpts=3,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+randomness = dict(
+ seed=42,
+ deterministic=False, # 如需完全可复现,设为True
+)
+
+# 设置日志按epoch显示
+log_processor = dict(by_epoch=True)
+
+train_dataloader = dict(batch_size=8, num_workers=8)
+val_dataloader = dict(batch_size=8, num_workers=8)
+test_dataloader = val_dataloader
\ No newline at end of file
diff --git a/cropSAR.py b/cropSAR.py
new file mode 100644
index 0000000000..a8f4c788ea
--- /dev/null
+++ b/cropSAR.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+from PIL import Image
+import rasterio
+from rasterio.transform import from_bounds
+import random
+
+
+
+# 设置路径
+gt_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/GT' # 标签文件夹
+sar_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/SAR' # 影像文件夹
+output_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/output' # 输出文件夹
+
+# 创建输出目录结构
+for subset in ['train', 'val', 'test']:
+ os.makedirs(os.path.join(output_folder, subset, 'images'), exist_ok=True)
+ os.makedirs(os.path.join(output_folder, subset, 'labels'), exist_ok=True)
+
+
+# 裁剪函数 - 适配多波段
+def crop_image(img_array, crop_size=256):
+ crops = []
+ if len(img_array.shape) == 3: # 多波段 (bands, height, width)
+ bands, h, w = img_array.shape
+ for i in range(0, h, crop_size):
+ for j in range(0, w, crop_size):
+ crop = img_array[:, i:i + crop_size, j:j + crop_size]
+ if crop.shape[1] == crop_size and crop.shape[2] == crop_size:
+ crops.append(crop)
+ else: # 单波段 (height, width)
+ h, w = img_array.shape
+ for i in range(0, h, crop_size):
+ for j in range(0, w, crop_size):
+ crop = img_array[i:i + crop_size, j:j + crop_size]
+ if crop.shape[0] == crop_size and crop.shape[1] == crop_size:
+ crops.append(crop)
+ return crops
+
+
+# 获取所有文件
+gt_files = sorted([f for f in os.listdir(gt_folder) if f.endswith('.tif')])
+sar_files = sorted([f for f in os.listdir(sar_folder) if f.endswith('.tif')])
+
+# 生成所有裁剪块
+all_crops = []
+for gt_file, sar_file in zip(gt_files, sar_files):
+ gt_path = os.path.join(gt_folder, gt_file)
+ sar_path = os.path.join(sar_folder, sar_file)
+
+ # 读取标签(单波段)
+ with Image.open(gt_path) as gt_img:
+ gt_array = np.array(gt_img)
+
+ # 读取SAR影像(多波段)
+ with rasterio.open(sar_path) as sar_src:
+ sar_array = sar_src.read() # 读取所有波段 (bands, height, width)
+ sar_profile = sar_src.profile
+
+ # 裁剪
+ gt_crops = crop_image(gt_array)
+ sar_crops = crop_image(sar_array)
+
+ # 保存裁剪信息
+ base_name = os.path.splitext(gt_file)[0]
+ for idx, (gt_crop, sar_crop) in enumerate(zip(gt_crops, sar_crops)):
+ all_crops.append({
+ 'gt': gt_crop,
+ 'sar': sar_crop,
+ 'name': f'{base_name}_{idx}',
+ 'profile': sar_profile
+ })
+
+print(f'共生成 {len(all_crops)} 个裁剪块')
+
+# 打乱数据
+random.shuffle(all_crops)
+
+# 按6:2:2划分
+total = len(all_crops)
+train_end = int(total * 0.6)
+val_end = int(total * 0.8)
+
+train_crops = all_crops[:train_end]
+val_crops = all_crops[train_end:val_end]
+test_crops = all_crops[val_end:]
+
+
+# 保存函数
+def save_crops(crops, subset):
+ for crop_data in crops:
+ # 保存多波段SAR影像为TIF
+ sar_path = os.path.join(output_folder, subset, 'images', f"sar_{crop_data['name']}.tif")
+
+ # 更新profile
+ profile = crop_data['profile'].copy()
+ profile.update({
+ 'height': 256,
+ 'width': 256,
+ 'transform': from_bounds(0, 0, 256, 256, 256, 256)
+ })
+
+ with rasterio.open(sar_path, 'w', **profile) as dst:
+ dst.write(crop_data['sar'])
+
+ # 处理标签:将值为2的像元改为1
+ gt_array = crop_data['gt'].copy()
+ gt_array[gt_array == 2] = 1
+
+ # 保存标签为PNG
+ gt_img = Image.fromarray(gt_array.astype(np.uint8))
+ gt_path = os.path.join(output_folder, subset, 'labels', f"sar_{crop_data['name']}.png")
+ gt_img.save(gt_path)
+
+ print(f'{subset}: {len(crops)} 张图像')
+
+
+# 保存所有数据集
+save_crops(train_crops, 'train')
+save_crops(val_crops, 'val')
+save_crops(test_crops, 'test')
+
+print(f'\n总共处理了 {total} 张裁剪图像')
+print(f'train: {len(train_crops)}, val: {len(val_crops)}, test: {len(test_crops)}')
\ No newline at end of file
diff --git a/dealNaN.py b/dealNaN.py
new file mode 100644
index 0000000000..0d83c44dc7
--- /dev/null
+++ b/dealNaN.py
@@ -0,0 +1,143 @@
+import os
+import numpy as np
+import rasterio
+from shutil import copy2
+from tqdm import tqdm
+
+
+def replace_nan_with_zero(folder_path, backup=False, backup_suffix='_backup'):
+ """
+ 遍历文件夹,将所有TIF文件中的NaN值替换为0
+
+ 参数:
+ folder_path: 要处理的根目录路径
+ backup: 是否备份原文件
+ backup_suffix: 备份文件后缀
+ """
+ total_files = 0
+ files_with_nan = 0
+ files_processed = 0
+ tif_extensions = ['.tif', '.tiff', '.TIF', '.TIFF']
+
+ # 收集所有tif文件
+ tif_files = []
+ for root, dirs, files in os.walk(folder_path):
+ for file in files:
+ if any(file.endswith(ext) for ext in tif_extensions):
+ tif_files.append(os.path.join(root, file))
+
+ total_files = len(tif_files)
+ print(f"找到 {total_files} 个TIF文件")
+ print("=" * 80)
+
+ # 第一步:检查哪些文件包含NaN
+ print("\n步骤1: 检查NaN值...")
+ files_to_process = []
+
+ try:
+ iterator = tqdm(tif_files, desc="扫描文件")
+ except:
+ iterator = tif_files
+
+ for file_path in iterator:
+ try:
+ with rasterio.open(file_path) as src:
+ has_nan = False
+ for band_idx in range(1, src.count + 1):
+ band_data = src.read(band_idx)
+ if np.isnan(band_data).any():
+ has_nan = True
+ break
+
+ if has_nan:
+ files_with_nan += 1
+ files_to_process.append(file_path)
+
+ except Exception as e:
+ print(f"\n读取错误: {file_path}")
+ print(f" 错误: {e}")
+
+ print(f"\n检查完成: {files_with_nan} 个文件包含NaN值")
+
+ if files_with_nan == 0:
+ print("没有需要处理的文件!")
+ return
+
+ # 第二步:处理包含NaN的文件
+ print(f"\n步骤2: 替换NaN值为0...")
+ print(f"需要处理 {len(files_to_process)} 个文件")
+
+ if backup:
+ print(f"将创建备份文件(后缀: {backup_suffix})")
+
+ user_input = input("\n是否继续?(y/n): ")
+ if user_input.lower() != 'y':
+ print("操作已取消")
+ return
+
+ print("\n开始处理...")
+
+ try:
+ iterator = tqdm(files_to_process, desc="处理文件")
+ except:
+ iterator = files_to_process
+
+ for file_path in iterator:
+ try:
+ # 备份原文件
+ if backup:
+ backup_path = file_path + backup_suffix
+ copy2(file_path, backup_path)
+
+ # 读取文件
+ with rasterio.open(file_path) as src:
+ # 保存元数据
+ meta = src.meta.copy()
+ num_bands = src.count
+
+ # 创建临时文件路径
+ temp_path = file_path + '.tmp'
+
+ # 写入处理后的数据
+ with rasterio.open(temp_path, 'w', **meta) as dst:
+ for band_idx in range(1, num_bands + 1):
+ # 读取波段数据
+ band_data = src.read(band_idx)
+
+ # 将NaN替换为0
+ band_data = np.nan_to_num(band_data, nan=0.0)
+
+ # 写入波段
+ dst.write(band_data, band_idx)
+
+ # 替换原文件
+ os.replace(temp_path, file_path)
+ files_processed += 1
+
+ except Exception as e:
+ print(f"\n处理错误: {file_path}")
+ print(f" 错误: {e}")
+ # 如果临时文件存在,删除它
+ temp_path = file_path + '.tmp'
+ if os.path.exists(temp_path):
+ os.remove(temp_path)
+
+ # 输出结果
+ print("\n" + "=" * 80)
+ print("处理完成!")
+ print(f"总文件数: {total_files}")
+ print(f"包含NaN的文件: {files_with_nan}")
+ print(f"成功处理的文件: {files_processed}")
+
+ if backup:
+ print(f"\n原始文件已备份(后缀: {backup_suffix})")
+ print("如果确认无误,可以删除备份文件")
+
+
+if __name__ == "__main__":
+ folder_path = "/mnt/d/Project/Code/Floodnet/data/mixed_dataset/val/"
+
+ # 执行替换
+ # backup=True 会备份原文件
+ # backup=False 直接覆盖原文件(不推荐)
+ replace_nan_with_zero(folder_path, backup=False, backup_suffix='_backup')
\ No newline at end of file
diff --git a/delete8X256.py b/delete8X256.py
new file mode 100644
index 0000000000..25041c4fcd
--- /dev/null
+++ b/delete8X256.py
@@ -0,0 +1,45 @@
+import os
+from PIL import Image
+
+# 设置基础路径
+base_folder = '../Floodnet/data/mixed_dataset/SAR/' # 修改为你的SAR文件夹路径
+subsets = ['train', 'val', 'test']
+
+total_deleted = 0
+
+for subset in subsets:
+ image_folder = os.path.join(base_folder, subset, 'images')
+ label_folder = os.path.join(base_folder, subset, 'labels')
+
+ if not os.path.exists(image_folder):
+ print(f'{image_folder} 不存在,跳过')
+ continue
+
+ deleted_count = 0
+
+ for filename in os.listdir(image_folder):
+ if filename.endswith('.tif'):
+ img_path = os.path.join(image_folder, filename)
+
+ # 读取图像尺寸
+ with Image.open(img_path) as img:
+ width, height = img.size
+
+ # 如果不是256x256,删除图像和对应标签
+ if width != 256 or height != 256:
+ # 删除图像
+ os.remove(img_path)
+
+ # 删除对应标签
+ label_filename = filename.replace('.tif', '.png')
+ label_path = os.path.join(label_folder, label_filename)
+ if os.path.exists(label_path):
+ os.remove(label_path)
+
+ deleted_count += 1
+ print(f'[{subset}] 已删除: {filename} (尺寸: {width}x{height})')
+
+ print(f'{subset} 删除了 {deleted_count} 对文件\n')
+ total_deleted += deleted_count
+
+print(f'总共删除了 {total_deleted} 对文件')
\ No newline at end of file
diff --git a/docs/experiment_plan.md b/docs/experiment_plan.md
new file mode 100644
index 0000000000..0364909821
--- /dev/null
+++ b/docs/experiment_plan.md
@@ -0,0 +1,188 @@
+# FloodNet Paper Experiment Plan
+## Multi-Modal Flood Segmentation via Mixture-of-Experts with Modal-Aware Routing
+
+---
+
+## 1. Model Overview (Final Model)
+
+**Config**: `multimodal_floodnet_sar_boost_swinbase_moe_config.py`
+
+| Component | Configuration |
+|-----------|--------------|
+| Backbone | Swin-Base (embed_dims=128, depths=[2,2,18,2]) |
+| Patch Embed | ModalSpecificStem (独立卷积 per modality) |
+| MoE | 8 experts, top_k=3, noisy gating |
+| Shared Experts | Stage 2: 2, Stage 3: 1 |
+| Modal Bias | Learnable per-modal routing bias |
+| Diversity Loss | weight=0.1; Balance Loss weight=1.0 |
+| Decoder | Separate UPerHead per modality |
+| Sampling | SAR Boost (6:5:5) |
+| Modalities | SAR (8ch), RGB (3ch), GaoFen (5ch) |
+
+---
+
+## 2. Experiment Design
+
+### Table 2: Component Ablation Study (组件消融)
+
+**目的**: 隔离每个组件的贡献,验证设计合理性。
+
+**评估**: 每个变体训练后,分别在 SAR / RGB / GF 三个模态上独立测试 mIoU。
+
+| Row | 配置 | 变更 | SAR | RGB | GF | Avg |
+|-----|------|------|-----|-----|-----|-----|
+| (a) | **Full Model** | — | - | - | - | - |
+| (b) | w/o MoE | 标准 FFN, 无专家路由 | - | - | - | - |
+| (c) | w/o ModalSpecificStem | 统一卷积+零填充 替代 模态独立卷积 | - | - | - | - |
+| (d) | w/o Modal Bias | MoE gating 无模态偏置 | - | - | - | - |
+| (e) | w/o Shared Experts | 所有专家均为路由专家,无共享 | - | - | - | - |
+| (f) | w/o Separate Decoder | 一个 UPerHead 共享所有模态 | - | - | - | - |
+
+**Config Files**:
+- (a) `multimodal_floodnet_sar_boost_swinbase_moe_config.py`
+- (b) `ablations/ablation_no_moe.py`
+- (c) `ablations/ablation_no_modal_specific_stem.py`
+- (d) `ablations/ablation_no_modal_bias.py`
+- (e) `ablations/ablation_no_shared_experts.py`
+- (f) `ablations/ablation_shared_decoder.py`
+
+**运行**:
+```bash
+bash scripts/run_all_experiments.sh table2
+```
+
+**讨论要点**:
+
+**(b) w/o MoE**: 预期最大下降。标准 FFN 无法适配不同模态的特征分布差异。所有 token 共享相同权重,无论来自 SAR 后向散射还是 RGB 反射率。
+
+**(c) w/o ModalSpecificStem**: 统一 patch embedding 将所有模态零填充到 max_channels(8) 再卷积,丢失了模态特异性的早期特征提取。SAR 的 8 通道数据被原样保留,但 RGB(3ch) 和 GF(5ch) 被大量零填充,引入噪声。
+
+**(d) w/o Modal Bias**: 去除模态路由偏置后,gating 网络对所有模态一视同仁。模态偏置允许模型学习到某些专家专门处理 SAR 的散斑噪声和后向散射强度 vs RGB 的颜色纹理。
+
+**(e) w/o Shared Experts**: 共享专家捕获模态不变特征(如水体空间结构、边界形态)。移除后,路由专家需冗余学习通用特征,降低专门化能力。
+
+**(f) w/o Separate Decoder**: 单一解码器须在不同模态特征分布间妥协。独立解码器允许每个模态有定制化的上采样和分类边界。
+
+---
+
+### Table 3: MoE Hyperparameter Study (MoE超参数研究)
+
+**目的**: 研究专家数量和 top-k 路由对性能的影响。
+
+**评估**: 每个变体训练后,分别在 SAR / RGB / GF 三个模态上独立测试 mIoU。
+
+| num_experts | top_k | SAR | RGB | GF | Avg |
+|-------------|-------|-----|-----|-----|-----|
+| 6 | 1 | - | - | - | - |
+| 6 | 2 | - | - | - | - |
+| 6 | 3 | - | - | - | - |
+| 8 | 1 | - | - | - | - |
+| 8 | 2 | - | - | - | - |
+| **8** | **3** | **-** | **-** | **-** | **-** |
+
+**Config Files**:
+- `ablations/ablation_e6_k1.py`
+- `ablations/ablation_e6_k2.py`
+- `ablations/ablation_e6_k3.py`
+- `ablations/ablation_e8_k1.py`
+- `ablations/ablation_e8_k2.py`
+- Full Model (E=8, K=3)
+
+**运行**:
+```bash
+bash scripts/run_all_experiments.sh table3
+```
+
+**讨论要点**:
+- **top_k=1**: 过于稀疏,每个 token 仅使用 1 个专家,限制了多专家集成效果
+- **top_k=2**: 中等稀疏度,基本满足 3 模态路由需求
+- **top_k=3 (ours)**: 允许 token 同时受益于多个专门化专家
+- **6 experts**: 对 3 种模态来说容量偏小,平均每模态仅 2 个专家
+- **8 experts (ours)**: 最佳平衡点,每模态约 2-3 个专家并允许跨模态共享
+
+---
+
+### Table 4: Single-Modal vs Multi-Modal Training (单模态 vs 多模态)
+
+**目的**: 证明多模态联合训练通过跨模态知识迁移提升了每个模态的性能。
+
+**评估**: 单模态训练只测试本模态;多模态训练测试所有三个模态。
+
+| 训练数据 | 测试模态 | mIoU |
+|----------|----------|------|
+| SAR-only | SAR | - |
+| RGB-only | RGB | - |
+| GF-only | GF | - |
+| Multi-modal (Ours) | SAR | - |
+| Multi-modal (Ours) | RGB | - |
+| Multi-modal (Ours) | GF | - |
+
+**Config Files**:
+- SAR-only → `multimodal_floodnet_sar_only_swinbase_moe_config.py`
+- RGB-only → `ablations/ablation_rgb_only.py`
+- GF-only → `ablations/ablation_gf_only.py`
+- Multi-modal → Full Model
+
+**运行**:
+```bash
+bash scripts/run_all_experiments.sh table4
+```
+
+**讨论要点**:
+- 多模态训练预期在**每个**模态上都优于单模态训练
+- SAR 受益最大:有限的 SAR 数据通过 RGB/GF 知识迁移得到增强
+- MoE 架构防止负迁移:模态专用专家避免了不同传感器类型间的干扰
+- 共享专家捕获通用洪水模式(水体几何、边界结构),在所有模态间迁移
+
+---
+
+## 3. Running Experiments
+
+```bash
+# 分表运行(推荐)
+bash scripts/run_all_experiments.sh table2 # 组件消融 (6 实验 × 3 模态测试)
+bash scripts/run_all_experiments.sh table3 # MoE 超参数 (6 实验 × 3 模态测试)
+bash scripts/run_all_experiments.sh table4 # 单/多模态 (4 实验)
+
+# 全部运行
+bash scripts/run_all_experiments.sh all
+
+# 指定 GPU
+GPU_IDS=0,1 bash scripts/run_all_experiments.sh table2
+```
+
+**结果目录结构**:
+```
+work_dirs/paper_experiments/
+├── results_summary.txt # 汇总日志
+├── table2/
+│ ├── full_model/
+│ │ ├── best_mIoU_*.pth
+│ │ ├── test_sar/test_log.txt
+│ │ ├── test_rgb/test_log.txt
+│ │ └── test_GF/test_log.txt
+│ ├── no_moe/
+│ ├── no_modal_specific_stem/
+│ ├── no_modal_bias/
+│ ├── no_shared_experts/
+│ └── shared_decoder/
+├── table3/
+│ ├── e6_k1/
+│ ├── e6_k2/
+│ ├── e6_k3/
+│ ├── e8_k1/
+│ └── e8_k2/
+└── table4/
+ ├── sar_only/
+ │ └── test_sar/test_log.txt
+ ├── rgb_only/
+ │ └── test_rgb/test_log.txt
+ ├── gf_only/
+ │ └── test_GF/test_log.txt
+ └── (multi_modal reuses table2/full_model)
+```
+
+**实验总量**: 15 个训练 + 48 次测试
+- Table 2: 6 训练 × 3 模态测试 = 6 + 18
+- Table 3: 5 训练 × 3 模态测试 + 1 复用 = 5 + 18
+- Table 4: 3 训练 × 1 模态测试 + 1 复用 × 3 测试 = 3 + 6
diff --git a/docs/method_section_details.md b/docs/method_section_details.md
new file mode 100644
index 0000000000..b97fb905d6
--- /dev/null
+++ b/docs/method_section_details.md
@@ -0,0 +1,370 @@
+# 论文 Method 部分 — 技术内容详细提纲与草稿素材
+
+> 以下内容基于代码实现逐一提取,所有数字、公式、维度均与代码一致。
+> 供写作时参考,可直接改写为论文正文。
+
+---
+
+## 3.1 Overall Framework(整体框架)
+
+### 需要写的内容
+
+一段话概述整体流水线,配合 Fig.1(架构图)。
+
+### 技术事实
+
+整体流水线为:
+
+```
+多模态输入(SAR/RGB/GF,通道数各异)
+ → Modal-Specific Stem(模态独立 Patch Embedding)
+ → 4-Stage Swin Transformer(选定 block 中 FFN 替换为 Sparse MoE)
+ → Modality-Separate UPerNet Decoder(每种模态独立解码头)
+ → 二分类分割输出(flood / non-flood)
+```
+
+模型类:`MultiModalEncoderDecoderV2`。
+
+**输入处理**:`MultiModalDataPreProcessor` 不对多模态输入做通道堆叠(因为各模态通道数不同),而是保持为 `List[Tensor]` 传入骨干网络。每个 tensor 仅做空间维度 padding 到 crop_size (256×256)。
+
+**训练采样**:采用 `FixedRatioModalSampler` 按固定比例(SAR:RGB:GF = 6:5:5)组成每个 batch,以 GF 为参考模态确定 epoch 长度。这确保 SAR 数据(样本量最少)在每个 batch 中占 37.5%(6/16),高于其在数据集中的自然比例,缓解模态不均衡问题。
+
+---
+
+## 3.2 Modal-Specific Stem(模态特定嵌入层)
+
+### 需要写的内容
+
+解释为什么不能直接用标准 Patch Embedding,以及你的解决方案。
+
+### 技术事实
+
+**问题**:三种模态通道数不同(SAR=8, RGB=3, GF=5),标准 Swin Transformer 的 Patch Embedding 是一个固定输入通道的 Conv2d,无法处理可变通道输入。
+
+**朴素方案(UnifiedPatchEmbed,用于消融对比)**:将所有模态零填充到最大通道数(8),然后共享一个 `Conv2d(8, embed_dims, kernel_size=4, stride=4)`。缺点是 RGB(3ch→8ch)和 GF(5ch→8ch)引入大量零值,稀释有效信息。
+
+**本文方案(ModalSpecificPatchEmbed)**:为每种模态配置独立的投影卷积:
+
+```
+对于模态 m(通道数为 C_m):
+ Conv2d(C_m, embed_dims, kernel_size=patch_size, stride=patch_size) + LayerNorm(embed_dims)
+```
+
+具体实例化:
+- SAR: `Conv2d(8, 128, kernel_size=4, stride=4)` + `LayerNorm(128)`
+- RGB: `Conv2d(3, 128, kernel_size=4, stride=4)` + `LayerNorm(128)`
+- GF: `Conv2d(5, 128, kernel_size=4, stride=4)` + `LayerNorm(128)`
+
+输入图像经 patch embedding 后,空间分辨率缩小为 H/4 × W/4,特征维度统一为 embed_dims=128。之后所有模态在同一特征空间中处理,共享 Transformer 的注意力层。
+
+### 公式
+
+$$
+\mathbf{z}_0^{(m)} = \text{LN}\left(\text{Conv}_{m}(\mathbf{x}^{(m)})\right), \quad \mathbf{z}_0^{(m)} \in \mathbb{R}^{\frac{HW}{P^2} \times D}
+$$
+
+其中 $\text{Conv}_{m}: \mathbb{R}^{C_m \times H \times W} \to \mathbb{R}^{D \times \frac{H}{P} \times \frac{W}{P}}$ 为模态 $m$ 专属的卷积投影,$P=4$ 为 patch 大小,$D=128$ 为嵌入维度。
+
+---
+
+## 3.3 Swin Transformer Backbone with Sparse MoE(骨干网络)
+
+### 需要写的内容
+
+先简要介绍 Swin Transformer 基本结构(审稿人可能不熟悉),然后重点描述你在哪些位置插入了 MoE,以及为什么是稀疏放置。
+
+### 技术事实
+
+**Swin-Base 配置**:
+
+| 参数 | 值 |
+|------|-----|
+| embed_dims | 128 |
+| depths | [2, 2, 18, 2] |
+| num_heads | [4, 8, 16, 32] |
+| window_size | 7 |
+| mlp_ratio | 4.0 |
+| drop_path_rate | 0.3 |
+| 各 stage 输出通道 | [128, 256, 512, 1024] |
+
+**每个 Swin Block 的结构**(标准部分,简要提及即可):
+
+```
+x = x + DropPath(W-MSA(LN(x))) // 窗口多头自注意力
+x = x + DropPath(FFN(LN(x))) // 前馈网络(在选定 block 中替换为 MoE)
+```
+
+偶数 block 用 W-MSA,奇数 block 用 SW-MSA(shifted window)。
+
+**MoE 替换位置(Sparse Placement)**:
+
+并非所有 block 的 FFN 都替换为 MoE,而是选择性稀疏放置:
+
+| Stage | 总 block 数 | MoE block 索引 | MoE 数量 | 说明 |
+|-------|------------|---------------|---------|------|
+| 0 | 2 | [] | 0 | 浅层特征通用性强,无需 MoE |
+| 1 | 2 | [1] | 1 | 仅最后一个 block |
+| 2 | 18 | [1,3,5,7,9,11,13,15,17] | 9 | 每隔一个 block 放置(交替) |
+| 3 | 2 | [0, 1] | 2 | 全部 block |
+
+**共享专家配置**:
+
+| Stage | 共享专家数 |
+|-------|-----------|
+| 0 | 0 |
+| 1 | 0 |
+| 2 | 2 |
+| 3 | 1 |
+
+Stage 2 是主力计算阶段(18 个 block),交替放置 MoE 允许普通 FFN block 做特征整合,MoE block 做模态专门化。
+
+**Stage 间下采样**:使用 PatchMerging,将空间分辨率减半、通道数加倍。
+
+---
+
+## 3.4 Sparse Mixture-of-Experts Module(稀疏 MoE 模块)
+
+### 需要写的内容
+
+这是方法的核心。需要分三个子部分详细描述:(1) 余弦相似度门控 (2) 模态偏置 (3) 稀疏路由与专家计算。
+
+### 3.4.1 Cosine Similarity Gating(余弦相似度门控)
+
+**技术事实**:
+
+门控网络 `CosineTopKGate` 的完整计算流程:
+
+1. **特征池化**:将输入 $\mathbf{X} \in \mathbb{R}^{B \times N \times C}$ 沿 token 维度均值池化得到 $\bar{\mathbf{x}} \in \mathbb{R}^{B \times C}$
+
+2. **余弦相似度计算**:
+ $$\mathbf{l} = \frac{\mathbf{W}_p \bar{\mathbf{x}}}{\|\mathbf{W}_p \bar{\mathbf{x}}\|_2} \cdot \frac{\mathbf{S}}{\|\mathbf{S}\|_2}$$
+ 其中 $\mathbf{W}_p \in \mathbb{R}^{C \times d}$ 为投影矩阵(`cosine_projector`),$d = \min(C/2, 256)$;$\mathbf{S} \in \mathbb{R}^{d \times E}$ 为相似度矩阵(`sim_matrix`),$E$ 为专家数。
+
+3. **温度缩放**:
+ $$\mathbf{l} = \mathbf{l} \cdot \exp\left(\text{clamp}(\tau, \max=\ln 100)\right)$$
+ 其中 $\tau$ 为可学习温度参数,初始化为 $\ln(1/0.5) \approx 0.693$。
+
+4. **模态偏置注入**(见 3.4.2)
+
+5. **噪声注入**(仅训练阶段):
+ $$\tilde{\mathbf{l}} = \mathbf{l} + \epsilon \cdot \text{Softplus}(\bar{\mathbf{x}} \mathbf{W}_{\text{noise}}), \quad \epsilon \sim \mathcal{N}(0, 1)$$
+ 其中 $\mathbf{W}_{\text{noise}} \in \mathbb{R}^{C \times E}$。噪声鼓励探索,防止路由僵化。
+
+6. **Top-K 选择与归一化**:
+ $$\text{TopK}(\tilde{\mathbf{l}}, k) \to (\mathbf{l}_{\text{top}}, \mathbf{I}_{\text{top}})$$
+ $$\mathbf{g}_{\text{top}} = \text{Softmax}(\mathbf{l}_{\text{top}})$$
+ 将 $\mathbf{g}_{\text{top}}$ scatter 回完整的 $E$ 维向量,未被选中的专家权重为 0。
+
+本文配置:$E=8$, $k=3$。
+
+### 公式(汇总版,适合论文)
+
+$$
+\mathbf{g} = \text{TopK-Softmax}\Big(\underbrace{\text{CosSim}(\mathbf{W}_p \bar{\mathbf{x}},\ \mathbf{S})}_{\text{content-based routing}} \cdot e^{\tau} + \underbrace{\mathbf{b}_{m}}_{\text{modal bias}},\ k\Big)
+$$
+
+### 3.4.2 Learnable Modal Bias(可学习模态偏置)
+
+**技术事实**:
+
+参数:$\mathbf{B}_{\text{modal}} \in \mathbb{R}^{M \times E}$,其中 $M=3$(三种模态),$E=8$(专家数)。初始化为零矩阵。
+
+应用方式:对于模态 $m$ 的输入样本,在门控 logits 上叠加偏置:
+$$\mathbf{l}_i = \mathbf{l}_i + \mathbf{B}_{\text{modal}}[m, :]$$
+
+每个 MoE 层有独立的 modal_bias 参数(不共享)。
+
+**作用**:即使两个不同模态的 token 在特征空间中相近(余弦相似度门控给出相似路由),modal bias 也能将它们导向不同的专家。这为路由网络提供了显式的模态先验。
+
+**学习率**:modal_bias 使用 3× 的学习率倍率(`lr_mult=3.0`),加速模态偏好的学习。
+
+### 3.4.3 Expert Computation & Shared Experts(专家计算与共享专家)
+
+**路由专家(Routed Experts)**:
+
+每个 MoE 层包含 $E=8$ 个结构相同但参数独立的 FFN:
+$$\text{FFN}_i(\mathbf{x}) = \mathbf{W}_2^{(i)} \cdot \text{GELU}(\mathbf{W}_1^{(i)} \mathbf{x}) + \mathbf{bias}_2^{(i)}$$
+其中 $\mathbf{W}_1^{(i)} \in \mathbb{R}^{C \times 4C}$, $\mathbf{W}_2^{(i)} \in \mathbb{R}^{4C \times C}$(mlp_ratio=4)。
+
+**稀疏派发(Sparse Dispatch)**:
+
+通过 `SparseDispatcher` 实现,仅将每个样本发送到被选中的 top-k 专家:
+1. 根据 gate 矩阵 $\mathbf{G} \in \mathbb{R}^{B \times E}$ 的非零位置确定路由
+2. 按专家分组派发输入
+3. 各专家独立计算
+4. 按 gate 权重加权合并:$\mathbf{y}_{\text{routed}} = \sum_{i \in \text{TopK}} g_i \cdot \text{FFN}_i(\mathbf{x})$
+
+**共享专家(Shared Experts)**:
+
+共享专家不经过门控路由,对所有输入无条件执行:
+$$\mathbf{y}_{\text{shared}} = \text{FFN}_{\text{shared}}(\mathbf{x})$$
+
+当 Stage 2 配置 2 个共享专家时,其 FFN hidden_dim 为 $2 \times 4C$(等效于将两个专家的 FFN 拼接为一个更宽的 FFN)。
+
+**最终输出**:
+$$\mathbf{y} = \mathbf{y}_{\text{routed}} + \mathbf{y}_{\text{shared}}$$
+
+---
+
+## 3.5 Modality-Separate Decoder(模态独立解码头)
+
+### 需要写的内容
+
+解释为什么不用共享解码头,以及独立解码头的实现方式。
+
+### 技术事实
+
+**结构**:`decoder_mode='separate'` 时,为每种模态(sar/rgb/GF)各创建一个独立的 UPerHead + FCNHead(辅助头)。
+
+**主解码头(UPerHead)配置**:
+
+| 参数 | 值 |
+|------|-----|
+| in_channels | [128, 256, 512, 1024](对应 4 个 stage 输出) |
+| pool_scales | (1, 2, 3, 6)(PPM 模块) |
+| channels | 512 |
+| dropout_ratio | 0.1 |
+| num_classes | 2(flood / non-flood) |
+| loss | CrossEntropyLoss (weight=1.0) |
+
+**辅助解码头(FCNHead)配置**:
+
+| 参数 | 值 |
+|------|-----|
+| in_channels | 512(取 Stage 2 输出) |
+| channels | 256 |
+| num_convs | 1 |
+| loss | CrossEntropyLoss (weight=0.4) |
+
+**路由逻辑**:训练时,根据每个样本 metainfo 中的 `dataset_name` 字段,将同一 batch 内的样本按模态分组,分别送入对应的解码头计算 loss。推理时,根据输入样本的模态类型选择对应的解码头。
+
+**与共享解码头对比**(消融实验 Table 2(f)):共享模式下所有模态使用同一组 UPerHead 参数。
+
+---
+
+## 3.6 Training Objectives(训练目标)
+
+### 需要写的内容
+
+详细描述总损失函数的组成。
+
+### 技术事实
+
+总损失由三部分组成:
+
+$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{seg}} + \lambda_{\text{bal}} \mathcal{L}_{\text{balance}} + \lambda_{\text{div}} \mathcal{L}_{\text{diversity}}$$
+
+**1. 分割损失 $\mathcal{L}_{\text{seg}}$**:
+
+各模态独立计算 CrossEntropy loss,主头权重 1.0 + 辅助头权重 0.4:
+$$\mathcal{L}_{\text{seg}} = \sum_{m \in \{sar, rgb, GF\}} \left(\mathcal{L}_{\text{CE}}^{m} + 0.4 \cdot \mathcal{L}_{\text{CE,aux}}^{m}\right)$$
+
+**2. 负载均衡损失 $\mathcal{L}_{\text{balance}}$**($\lambda_{\text{bal}}=1.0$):
+
+鼓励专家负载均匀,防止少数专家被过度使用:
+$$\mathcal{L}_{\text{balance}} = \text{CV}^2(\text{importance}) + \text{CV}^2(\text{load})$$
+
+其中:
+- $\text{importance}_i = \sum_{b=1}^{B} g_{b,i}$(专家 $i$ 的总 gate 权重)
+- $\text{load}_i = \sum_{b=1}^{B} \mathbb{1}[g_{b,i} > 0]$(专家 $i$ 被选中的次数)
+- $\text{CV}^2(\mathbf{x}) = \frac{\text{Var}(\mathbf{x})}{(\text{Mean}(\mathbf{x}))^2 + \epsilon}$(变异系数的平方)
+
+该损失在所有 MoE 层上求平均。
+
+**3. 专家多样性损失 $\mathcal{L}_{\text{diversity}}$**($\lambda_{\text{div}}=0.1$):
+
+防止不同专家学到相似的表征(专家坍缩):
+
+$$\mathcal{L}_{\text{diversity}} = \text{ReLU}\left(\frac{1}{E(E-1)/2} \sum_{i None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py
index f8ad750d76..d9c562805e 100644
--- a/mmseg/datasets/__init__.py
+++ b/mmseg/datasets/__init__.py
@@ -26,6 +26,10 @@
from .refuge import REFUGEDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
+from .UAVflood import UAVfloodDataset
+from .multimodal_deepflood import MultiModalDeepflood
+from .sen1floods11 import Sen1Floods11Dataset
+from .fixed_ratio_modal_sampler import FixedRatioModalSampler
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
@@ -61,5 +65,6 @@
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
- 'NYUDataset', 'HSIDrive20Dataset'
+ 'NYUDataset', 'HSIDrive20Dataset', 'UAVfloodDataset',
+ 'MultiModalDeepflood', 'Sen1Floods11Dataset', 'FixedRatioModalSampler'
]
diff --git a/mmseg/datasets/fixed_ratio_modal_sampler.py b/mmseg/datasets/fixed_ratio_modal_sampler.py
new file mode 100644
index 0000000000..9f1ab3c61f
--- /dev/null
+++ b/mmseg/datasets/fixed_ratio_modal_sampler.py
@@ -0,0 +1,193 @@
+"""
+Fixed Ratio Modal Sampler for Multi-Dataset Training - MMSeg 1.x Version
+
+Registered as DATA_SAMPLERS for mmengine compatibility.
+"""
+from typing import Dict, Iterator, List, Optional, Sequence, Sized, Union
+
+import torch
+from torch.utils.data import Sampler
+
+from mmseg.registry import DATA_SAMPLERS
+
+
+@DATA_SAMPLERS.register_module()
+class FixedRatioModalSampler(Sampler):
+ """Fixed ratio modal sampler.
+
+ Ensures each batch has a fixed modal composition.
+
+ Args:
+ dataset: Dataset object, must have data_list with modal_type.
+ modal_ratios: Modal sampling ratios.
+ modal_order: Modal ordering.
+ reference_modal: Reference modal for epoch length calculation.
+ seed: Random seed.
+ batch_size: Batch size (replaces samples_per_gpu).
+ """
+
+ def __init__(
+ self,
+ dataset: Sized,
+ modal_ratios: Optional[Union[Sequence[int], Dict[str, int]]] = None,
+ modal_order: Optional[Sequence[str]] = None,
+ reference_modal: Optional[str] = None,
+ seed: Optional[int] = None,
+ batch_size: int = 1,
+ **kwargs,
+ ):
+ if modal_ratios is None:
+ raise ValueError(
+ "modal_ratios must be provided for FixedRatioModalSampler.")
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.seed = 0 if seed is None else seed
+ self.epoch = 0
+ self.reference_modal = reference_modal
+
+ if isinstance(modal_ratios, dict):
+ if modal_order is None:
+ modal_order = list(modal_ratios.keys())
+ self.modal_order = list(modal_order)
+ self.modal_ratios = [
+ modal_ratios[modal] for modal in self.modal_order]
+ else:
+ if modal_order is None:
+ raise ValueError(
+ "modal_order must be provided when "
+ "modal_ratios is a list.")
+ self.modal_order = list(modal_order)
+ self.modal_ratios = list(modal_ratios)
+
+ self._validate_ratios()
+ self.modal_indices = self._group_by_modal()
+ self.num_samples = self._calculate_num_samples()
+
+ self._print_statistics()
+
+ def _validate_ratios(self):
+ if len(self.modal_order) != len(self.modal_ratios):
+ raise ValueError(
+ "modal_order and modal_ratios must have the same length.")
+ if any(ratio <= 0 for ratio in self.modal_ratios):
+ raise ValueError("All modal ratios must be positive.")
+ if sum(self.modal_ratios) > self.batch_size:
+ raise ValueError(
+ "Sum of modal ratios must be <= batch_size "
+ f"(got {sum(self.modal_ratios)} > {self.batch_size}).")
+ if self.batch_size % sum(self.modal_ratios) != 0:
+ raise ValueError(
+ "batch_size must be divisible by sum(modal_ratios) "
+ f"(got {self.batch_size} % {sum(self.modal_ratios)} != 0).")
+
+ def _group_by_modal(self) -> Dict[str, List[int]]:
+ modal_indices = {modal: [] for modal in self.modal_order}
+ # In 1.x, dataset.data_list contains the data info dicts
+ for idx in range(len(self.dataset)):
+ data_info = self.dataset.get_data_info(idx)
+ modal_type = data_info.get('modal_type', 'unknown')
+ if modal_type in modal_indices:
+ modal_indices[modal_type].append(idx)
+ return modal_indices
+
+ def _calculate_num_samples(self) -> int:
+ total_ratio = sum(self.modal_ratios)
+ batch_repeats = self.batch_size // total_ratio
+
+ if self.reference_modal is None:
+ total_original = sum(
+ len(v) for v in self.modal_indices.values())
+ full_batches = total_original // self.batch_size
+ return full_batches * self.batch_size
+
+ if self.reference_modal not in self.modal_order:
+ raise ValueError(
+ f"reference_modal '{self.reference_modal}' "
+ f"is not in modal_order.")
+
+ reference_count = len(self.modal_indices[self.reference_modal])
+ if reference_count == 0:
+ raise ValueError(
+ f"No samples found for reference_modal "
+ f"'{self.reference_modal}'.")
+
+ reference_ratio = self.modal_ratios[
+ self.modal_order.index(self.reference_modal)]
+ total_groups = reference_count // reference_ratio
+ full_batches = total_groups // batch_repeats
+ return full_batches * self.batch_size
+
+ def _print_statistics(self):
+ print("\n" + "=" * 60)
+ print("Fixed Ratio Modal Sampler Statistics")
+ print("=" * 60)
+ print(f"Batch size: {self.batch_size}")
+ print("Modal Ratios:")
+ for modal, ratio in zip(self.modal_order, self.modal_ratios):
+ print(f" {modal}: {ratio}")
+ print("Modal Distribution (Original):")
+ total_original = sum(len(v) for v in self.modal_indices.values())
+ for modal in self.modal_order:
+ count = len(self.modal_indices[modal])
+ percentage = (
+ (count / total_original * 100)
+ if total_original > 0 else 0.0)
+ print(f" {modal}: {count:5d} samples ({percentage:5.2f}%)")
+ if self.reference_modal is not None:
+ reference_count = len(
+ self.modal_indices[self.reference_modal])
+ reference_ratio = self.modal_ratios[
+ self.modal_order.index(self.reference_modal)]
+ print(
+ f"Reference modal: {self.reference_modal} "
+ f"(count={reference_count}, ratio={reference_ratio})")
+ print(f"Total samples per epoch: {self.num_samples}")
+ print(f"Iterations per epoch: {len(self)}")
+ print("=" * 60 + "\n")
+
+ def __iter__(self) -> Iterator[int]:
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+
+ per_modal_pools = {}
+ per_modal_pos = {}
+ for modal, indices in self.modal_indices.items():
+ if len(indices) == 0:
+ raise ValueError(
+ f"No samples found for modal '{modal}'.")
+ perm = torch.randperm(len(indices), generator=g)
+ per_modal_pools[modal] = [indices[i] for i in perm.tolist()]
+ per_modal_pos[modal] = 0
+
+ indices = []
+ batch_repeats = self.batch_size // sum(self.modal_ratios)
+ num_batches = len(self)
+
+ for _ in range(num_batches):
+ batch_indices = []
+ for _ in range(batch_repeats):
+ for modal, ratio in zip(
+ self.modal_order, self.modal_ratios):
+ pool = per_modal_pools[modal]
+ pos = per_modal_pos[modal]
+ for _ in range(ratio):
+ if pos >= len(pool):
+ perm = torch.randperm(
+ len(pool), generator=g)
+ pool = [pool[i] for i in perm.tolist()]
+ per_modal_pools[modal] = pool
+ pos = 0
+ batch_indices.append(pool[pos])
+ pos += 1
+ per_modal_pos[modal] = pos
+
+ indices.extend(batch_indices)
+
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples // self.batch_size
+
+ def set_epoch(self, epoch: int) -> None:
+ self.epoch = epoch
diff --git a/mmseg/datasets/multimodal_deepflood.py b/mmseg/datasets/multimodal_deepflood.py
new file mode 100644
index 0000000000..520be4e03a
--- /dev/null
+++ b/mmseg/datasets/multimodal_deepflood.py
@@ -0,0 +1,185 @@
+"""
+Multi-Modal Deepflood Dataset - MMSeg 1.x Version
+
+Key changes from 0.x:
+- Base class: CustomDataset -> BaseSegDataset
+- Registry: from builder import DATASETS -> from mmseg.registry import DATASETS
+- load_annotations() -> load_data_list()
+- img_dir/ann_dir -> data_prefix dict
+- img_infos -> data_list
+- evaluate() -> IoUMetric (external evaluator)
+"""
+import os.path as osp
+from collections import OrderedDict
+from typing import List
+
+import mmengine
+import mmengine.fileio as fileio
+import numpy as np
+from mmengine.logging import print_log
+
+from mmseg.registry import DATASETS
+from .basesegdataset import BaseSegDataset
+
+
+@DATASETS.register_module()
+class MultiModalDeepflood(BaseSegDataset):
+ """Multi-Modal Deepflood dataset for flood segmentation.
+
+ Supports SAR, RGB, and GaoFen (GF) modalities with automatic
+ modality identification from filenames.
+ """
+
+ METAINFO = dict(
+ classes=('Background', 'Flood'),
+ palette=[[0, 0, 0], [255, 0, 0]]
+ )
+
+ # Modal configuration
+ MODAL_CONFIGS = {
+ 'sar': {'channels': 8, 'pattern': 'sar'},
+ 'rgb': {'channels': 3, 'pattern': 'rgb'},
+ 'GF': {'channels': 5, 'pattern': 'GF'},
+ }
+
+ MODAL_TO_DATASET = {
+ 'sar': {'dataset_source': 0, 'dataset_name': 'sar'},
+ 'rgb': {'dataset_source': 1, 'dataset_name': 'rgb'},
+ 'GF': {'dataset_source': 2, 'dataset_name': 'GF'},
+ }
+
+ def __init__(self, filter_modality=None, **kwargs):
+ # Set default suffixes for flood data
+ kwargs.setdefault('img_suffix', '.tif')
+ kwargs.setdefault('seg_map_suffix', '.png')
+ kwargs.setdefault('reduce_zero_label', False)
+ # filter_modality: str or None, e.g. 'sar', 'rgb', 'GF'
+ # When set, only samples of this modality are kept.
+ self.filter_modality = filter_modality
+ super().__init__(**kwargs)
+
+ def load_data_list(self) -> List[dict]:
+ """Load annotation from directory.
+
+ Returns:
+ list[dict]: All data info of dataset, each dict contains:
+ - img_path (str)
+ - seg_map_path (str)
+ - modal_type (str)
+ - actual_channels (int)
+ - dataset_source (int)
+ - dataset_name (str)
+ - label_map (dict or None)
+ - reduce_zero_label (bool)
+ - seg_fields (list)
+ """
+ data_list = []
+ img_dir = self.data_prefix.get('img_path', None)
+ ann_dir = self.data_prefix.get('seg_map_path', None)
+
+ if not osp.isdir(self.ann_file) and self.ann_file:
+ # Load from annotation file
+ assert osp.isfile(self.ann_file), \
+ f'Failed to load `ann_file` {self.ann_file}'
+ lines = mmengine.list_from_file(
+ self.ann_file, backend_args=self.backend_args)
+ for line in lines:
+ img_name = line.strip()
+ data_info = dict(
+ img_path=osp.join(img_dir,
+ img_name + self.img_suffix))
+
+ # Identify modality
+ modal_info = self._identify_modality(img_name)
+ data_info.update(modal_info)
+
+ if ann_dir is not None:
+ seg_map = img_name + self.seg_map_suffix
+ data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
+
+ data_info['label_map'] = self.label_map
+ data_info['reduce_zero_label'] = self.reduce_zero_label
+ data_info['seg_fields'] = []
+ data_list.append(data_info)
+ else:
+ # Scan directory
+ for img in fileio.list_dir_or_file(
+ dir_path=img_dir,
+ list_dir=False,
+ suffix=self.img_suffix,
+ recursive=True,
+ backend_args=self.backend_args):
+ data_info = dict(img_path=osp.join(img_dir, img))
+
+ # Identify modality
+ modal_info = self._identify_modality(img)
+ data_info.update(modal_info)
+
+ if ann_dir is not None:
+ seg_map = img[:-len(self.img_suffix)] + self.seg_map_suffix
+ data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
+
+ data_info['label_map'] = self.label_map
+ data_info['reduce_zero_label'] = self.reduce_zero_label
+ data_info['seg_fields'] = []
+ data_list.append(data_info)
+
+ data_list = sorted(data_list, key=lambda x: x['img_path'])
+
+ # Filter by modality if specified
+ if self.filter_modality is not None:
+ data_list = [
+ d for d in data_list
+ if d['modal_type'] == self.filter_modality
+ ]
+ print_log(
+ f'Filtered to modality "{self.filter_modality}": '
+ f'{len(data_list)} images',
+ logger='current')
+
+ print_log(f'Loaded {len(data_list)} images', logger='current')
+ self._print_modal_statistics(data_list)
+
+ return data_list
+
+ def _identify_modality(self, img_name):
+ """Identify modality from filename."""
+ img_name_lower = img_name.lower()
+
+ for modal_name, config in self.MODAL_CONFIGS.items():
+ if config['pattern'].lower() in img_name_lower:
+ dataset_info = self.MODAL_TO_DATASET.get(modal_name, {
+ 'dataset_source': 1,
+ 'dataset_name': 'rgb'
+ })
+
+ return {
+ 'modal_type': modal_name,
+ 'actual_channels': config['channels'],
+ 'dataset_source': dataset_info['dataset_source'],
+ 'dataset_name': dataset_info['dataset_name'],
+ }
+
+ # Default: RGB
+ return {
+ 'modal_type': 'rgb',
+ 'actual_channels': 3,
+ 'dataset_source': 1,
+ 'dataset_name': 'rgb',
+ }
+
+ def _print_modal_statistics(self, data_list):
+ """Print dataset modal statistics."""
+ modal_counts = {}
+ for info in data_list:
+ modal = info['modal_type']
+ modal_counts[modal] = modal_counts.get(modal, 0) + 1
+
+ print_log("\n=== Dataset Modal Statistics ===", logger='current')
+ for modal, count in sorted(modal_counts.items()):
+ channels = self.MODAL_CONFIGS.get(
+ modal, {}).get('channels', 'unknown')
+ print_log(
+ f" {modal}: {count} images ({channels} channels)",
+ logger='current')
+ print_log("================================\n", logger='current')
diff --git a/mmseg/datasets/sen1floods11.py b/mmseg/datasets/sen1floods11.py
new file mode 100644
index 0000000000..5f31a4cccc
--- /dev/null
+++ b/mmseg/datasets/sen1floods11.py
@@ -0,0 +1,169 @@
+"""
+Sen1Floods11 single-modal flood segmentation dataset - MMSeg 1.x.
+
+Expected on-disk layout (the default from the official Sen1Floods11 release)::
+
+ data/Sen1Floods11/
+ S1Hand/
+ __S1Hand.tif # 2-band SAR (VV, VH in dB)
+ S2Hand/
+ __S2Hand.tif # 13-band Sentinel-2 MSI
+ LabelHand/
+ __LabelHand.tif # 1-band label (-1=nodata, 0=bg, 1=flood)
+
+All images are 512x512. Labels are signed TIFFs; the ``-1`` nodata value
+is mapped to ``ignore_index=255`` by the companion
+``LoadSen1Floods11Annotation`` transform.
+
+This dataset plugs into the existing multi-modal Swin+MoE pipeline by
+exposing ``modal_type`` / ``actual_channels`` / ``dataset_name`` on each
+sample, so the only things a config needs to do to fine-tune on
+Sen1Floods11 is to:
+
+ 1. set ``dataset_type='Sen1Floods11Dataset'`` and pick
+ ``modality='s1'`` or ``'s2'``;
+ 2. register a new modal in ``model.backbone.modal_configs`` /
+ ``training_modals`` with the matching channel count;
+ 3. use ``model.dataset_names=['s1']`` (or ``['s2']``) so the
+ per-dataset decode head key matches.
+"""
+import os.path as osp
+from typing import List
+
+import mmengine
+import mmengine.fileio as fileio
+from mmengine.logging import print_log
+
+from mmseg.registry import DATASETS
+from .basesegdataset import BaseSegDataset
+
+
+@DATASETS.register_module()
+class Sen1Floods11Dataset(BaseSegDataset):
+ """Sen1Floods11 dataset for single-modal fine-tuning.
+
+ Args:
+ modality (str): ``'s1'`` (2-band SAR) or ``'s2'`` (13-band MSI).
+ ann_file (str, optional): Path to a split file listing one
+ sample base-name per line (e.g. ``Bolivia_23014``). The
+ path may be relative to ``data_root``. If empty or absent,
+ the whole ``data_prefix['img_path']`` directory is scanned.
+ **kwargs: forwarded to :class:`BaseSegDataset`. ``img_suffix``,
+ ``seg_map_suffix``, ``data_prefix`` all default to values
+ that match the Sen1Floods11 layout described above.
+ """
+
+ METAINFO = dict(
+ classes=('Background', 'Flood'),
+ palette=[[0, 0, 0], [255, 0, 0]],
+ )
+
+ # Per-modality layout / shape info.
+ MODAL_CONFIG = {
+ 's1': {
+ 'channels': 2,
+ 'img_subdir': 'S1Hand',
+ 'img_suffix': '_S1Hand.tif',
+ 'dataset_source': 0,
+ 'dataset_name': 's1',
+ },
+ 's2': {
+ 'channels': 13,
+ 'img_subdir': 'S2Hand',
+ 'img_suffix': '_S2Hand.tif',
+ 'dataset_source': 1,
+ 'dataset_name': 's2',
+ },
+ }
+
+ SEG_SUBDIR = 'LabelHand'
+ SEG_SUFFIX = '_LabelHand.tif'
+
+ def __init__(self, modality: str = 's1', **kwargs):
+ if modality not in self.MODAL_CONFIG:
+ raise ValueError(
+ f'Unknown modality "{modality}". Expected one of '
+ f'{list(self.MODAL_CONFIG.keys())}')
+ self.modality = modality
+
+ modal_cfg = self.MODAL_CONFIG[modality]
+
+ kwargs.setdefault('img_suffix', modal_cfg['img_suffix'])
+ kwargs.setdefault('seg_map_suffix', self.SEG_SUFFIX)
+ kwargs.setdefault('reduce_zero_label', False)
+
+ data_prefix = kwargs.pop('data_prefix', None)
+ if not data_prefix:
+ data_prefix = dict(
+ img_path=modal_cfg['img_subdir'],
+ seg_map_path=self.SEG_SUBDIR,
+ )
+
+ super().__init__(data_prefix=data_prefix, **kwargs)
+
+ # ------------------------------------------------------------------
+ # core list building
+ # ------------------------------------------------------------------
+ def load_data_list(self) -> List[dict]:
+ modal_cfg = self.MODAL_CONFIG[self.modality]
+ img_dir = self.data_prefix.get('img_path', None)
+ ann_dir = self.data_prefix.get('seg_map_path', None)
+ assert img_dir is not None, \
+ 'Sen1Floods11Dataset requires data_prefix["img_path"]'
+
+ data_list = []
+
+ ann_file = self.ann_file
+ use_ann_file = bool(ann_file) and osp.isfile(ann_file)
+
+ if use_ann_file:
+ lines = mmengine.list_from_file(
+ ann_file, backend_args=self.backend_args)
+ for line in lines:
+ base = line.strip()
+ if not base:
+ continue
+ # Accept entries that include any of the known suffixes.
+ for suf in (modal_cfg['img_suffix'], self.SEG_SUFFIX):
+ if base.endswith(suf):
+ base = base[:-len(suf)]
+ break
+ img_name = base + modal_cfg['img_suffix']
+ data_list.append(
+ self._build_info(img_name, img_dir, ann_dir))
+ else:
+ for img in fileio.list_dir_or_file(
+ dir_path=img_dir,
+ list_dir=False,
+ suffix=modal_cfg['img_suffix'],
+ recursive=True,
+ backend_args=self.backend_args):
+ data_list.append(
+ self._build_info(img, img_dir, ann_dir))
+
+ data_list = sorted(data_list, key=lambda x: x['img_path'])
+
+ print_log(
+ f'[Sen1Floods11] modality="{self.modality}" '
+ f'loaded {len(data_list)} images from "{img_dir}"',
+ logger='current')
+
+ return data_list
+
+ def _build_info(self, img_name: str, img_dir: str,
+ ann_dir: str) -> dict:
+ modal_cfg = self.MODAL_CONFIG[self.modality]
+ info = dict(
+ img_path=osp.join(img_dir, img_name),
+ modal_type=self.modality,
+ actual_channels=modal_cfg['channels'],
+ dataset_source=modal_cfg['dataset_source'],
+ dataset_name=modal_cfg['dataset_name'],
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label,
+ seg_fields=[],
+ )
+ if ann_dir is not None:
+ base = img_name[:-len(modal_cfg['img_suffix'])]
+ info['seg_map_path'] = osp.join(ann_dir, base + self.SEG_SUFFIX)
+ return info
diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py
index 125f070818..33d8f5c242 100644
--- a/mmseg/datasets/transforms/__init__.py
+++ b/mmseg/datasets/transforms/__init__.py
@@ -1,8 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackSegInputs
+from .multimodal_pipelines import (LoadMultiModalImageFromFile,
+ MultiModalNormalize,
+ GenerateBoundary,
+ MultiModalPad,
+ PackMultiModalSegInputs)
+from .sen1floods11_pipelines import LoadSen1Floods11Annotation
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadDepthAnnotation, LoadImageFromNDArray,
+ LoadMultiBandTiffFromFile,
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
@@ -25,6 +32,9 @@
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
- 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
- 'RandomFlip', 'Resize'
+ 'LoadMultiBandTiffFromFile', 'LoadMultipleRSImageFromFile',
+ 'LoadDepthAnnotation', 'RandomDepthMix', 'RandomFlip', 'Resize',
+ 'LoadMultiModalImageFromFile', 'MultiModalNormalize',
+ 'GenerateBoundary', 'MultiModalPad', 'PackMultiModalSegInputs',
+ 'LoadSen1Floods11Annotation'
]
diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py
index c28937e55e..a846f6b5f9 100644
--- a/mmseg/datasets/transforms/loading.py
+++ b/mmseg/datasets/transforms/loading.py
@@ -18,6 +18,11 @@
except ImportError:
gdal = None
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
@TRANSFORMS.register_module()
class LoadAnnotations(MMCV_LoadAnnotations):
@@ -769,3 +774,72 @@ def transform(self, results: dict) -> Optional[dict]:
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
+
+
+@TRANSFORMS.register_module()
+class LoadMultiBandTiffFromFile(BaseTransform):
+ """Load a multi-band TIFF image from file using tifffile.
+
+ This loader supports TIFF images with more than 4 channels, which are
+ not supported by OpenCV. It uses the tifffile library to read the image.
+
+ Required Keys:
+
+ - img_path
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - ori_shape
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image keeps its original
+ data type. Defaults to True.
+ """
+
+ def __init__(self, to_float32: bool = True):
+ self.to_float32 = to_float32
+
+ if tifffile is None:
+ raise RuntimeError('tifffile is not installed. Please install it '
+ 'using: pip install tifffile')
+
+ def transform(self, results: Dict) -> Dict:
+ """Functions to load image.
+
+ Args:
+ results (dict): Result dict from :obj:``mmcv.BaseDataset``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ filename = results['img_path']
+ try:
+ # Read multi-band TIFF using tifffile
+ img = tifffile.imread(filename)
+
+ # tifffile returns (height, width, channels) for multi-channel images
+ # or (channels, height, width) depending on the file structure
+ # We need to ensure the format is (height, width, channels)
+ if img.ndim == 3 and img.shape[0] < img.shape[2]:
+ # If first dimension is smaller, it's likely (channels, height, width)
+ # Convert to (height, width, channels)
+ img = np.transpose(img, (1, 2, 0))
+
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ return results
+ except Exception as e:
+ raise RuntimeError(f'Failed to load image from {filename}: {str(e)}')
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32})')
+ return repr_str
diff --git a/mmseg/datasets/transforms/multimodal_pipelines.py b/mmseg/datasets/transforms/multimodal_pipelines.py
new file mode 100644
index 0000000000..947bd47250
--- /dev/null
+++ b/mmseg/datasets/transforms/multimodal_pipelines.py
@@ -0,0 +1,435 @@
+"""
+Multi-Modal Data Loading Pipeline - MMSeg 1.x Version
+
+Key changes from 0.x:
+- Registry: PIPELINES -> TRANSFORMS
+- __call__ -> transform (BaseTransform)
+- Normalize moved to SegDataPreProcessor
+- Collect + FormatBundle -> PackMultiModalSegInputs (SegDataSample)
+"""
+import copy
+
+import numpy as np
+from mmcv.transforms import BaseTransform, to_tensor
+from mmengine.structures import PixelData
+
+from mmseg.registry import TRANSFORMS
+from mmseg.structures import SegDataSample
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+try:
+ import mmcv
+except ImportError:
+ mmcv = None
+
+import os.path as osp
+
+
+@TRANSFORMS.register_module()
+class LoadMultiModalImageFromFile(BaseTransform):
+ """Load multi-modal image - no zero padding version.
+
+ Required Keys:
+ - img_path (str)
+ - modal_type (str)
+ - actual_channels (int)
+
+ Added Keys:
+ - img (np.ndarray)
+ - img_shape (tuple)
+ - ori_shape (tuple)
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='unchanged',
+ imdecode_backend='cv2'):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.imdecode_backend = imdecode_backend
+
+ def transform(self, results: dict) -> dict:
+ filename = results['img_path']
+
+ # Support TIFF format
+ if filename.endswith('.tif') or filename.endswith('.tiff'):
+ if tifffile is not None:
+ img = tifffile.imread(filename)
+ if len(img.shape) == 3 and img.shape[0] < img.shape[2]:
+ img = np.transpose(img, (1, 2, 0))
+ elif mmcv is not None:
+ img_bytes = mmcv.FileClient.infer_client(
+ None, filename).get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes,
+ flag=self.color_type,
+ backend=self.imdecode_backend)
+ else:
+ raise ImportError(
+ 'tifffile or mmcv required for TIFF loading')
+ else:
+ if mmcv is not None:
+ img_bytes = mmcv.FileClient.infer_client(
+ None, filename).get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes,
+ flag=self.color_type,
+ backend=self.imdecode_backend)
+ else:
+ from PIL import Image
+ img = np.array(Image.open(filename))
+
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ actual_channels = results.get('actual_channels', 3)
+
+ if len(img.shape) == 2:
+ img = img[:, :, np.newaxis]
+
+ current_channels = img.shape[2]
+
+ if current_channels != actual_channels:
+ actual_channels = current_channels
+ results['actual_channels'] = actual_channels
+
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ results['ori_filename'] = osp.basename(filename)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(to_float32={self.to_float32}, '
+ repr_str += f"color_type='{self.color_type}')"
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class MultiModalNormalize(BaseTransform):
+ """Multi-modal normalization - supports dynamic channel count.
+
+ Required Keys:
+ - img (np.ndarray)
+ - modal_type (str)
+ - actual_channels (int)
+
+ Modified Keys:
+ - img (np.ndarray)
+
+ Added Keys:
+ - img_norm_cfg (dict)
+ """
+
+ NORM_CONFIGS = {
+ 'rgb': {
+ 'mean': [123.675, 116.28, 103.53],
+ 'std': [58.395, 57.12, 57.375],
+ },
+ 'sar': {
+ 'mean': [0.23651549, 0.31761484, 0.18514981, 0.26901252,
+ -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ 'std': [0.16280619, 0.20849304, 0.14008107, 0.19767644,
+ 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ },
+ 'multispectral': {
+ 'mean': [1353., 1329., 1627., 1935., 2268., 2723., 3154.,
+ 3541., 3652., 3416., 1112., 2619., 2060.],
+ 'std': [1108., 942., 976., 1164., 1196., 1351., 1500.,
+ 1605., 1611., 1288., 770., 1325., 1186.],
+ },
+ 'GF': {
+ 'mean': [432.02181, 315.92948, 246.468659,
+ 310.61462, 360.267789],
+ 'std': [97.73313111900238, 85.78646917160748,
+ 95.78015824658593,
+ 124.84677067613467, 251.73965882246978],
+ },
+ # ---- Sen1Floods11: S1Hand (2-band VV/VH SAR, dB) ----
+ # Defaults computed from the Sen1Floods11 train split with
+ # the nodata pixels (-1) masked out. Re-run
+ # tools/compute_sen1floods11_stats.py on your own split to
+ # refresh these numbers if needed.
+ 's1': {
+ 'mean': [-10.483032437170175, -17.362463068117055],
+ 'std': [4.178513068178825, 4.863193681650141],
+ },
+ # ---- Sen1Floods11: S2Hand (13-band Sentinel-2 MSI, TOA*10000) ----
+ 's2': {
+ 'mean': [1483.1443242989628, 1234.2590152666212, 1204.8650733526135, 1034.8886830055903, 1305.9293935728826, 2257.5830489084565, 2723.9229150471333, 2515.8173451011917, 2957.957558849772, 447.04848745605636, 57.01156794081842, 1893.1678433033005, 1040.2051810757566],
+ 'std': [314.8831761693824, 341.1134613706103, 367.08026898219816, 524.9257541131836, 446.1063090649358, 697.2651115379473, 897.1816646156293, 868.071173572936, 1031.6451147331716, 287.1789646791449, 130.50110493994782, 830.2058546614454, 610.0878654966048],
+ },
+ }
+
+ def __init__(self, to_rgb=True):
+ self.to_rgb = to_rgb
+
+ def transform(self, results: dict) -> dict:
+ img = results['img']
+ modal_type = results.get('modal_type', 'rgb')
+ actual_channels = results['actual_channels']
+
+ if modal_type in self.NORM_CONFIGS:
+ config = self.NORM_CONFIGS[modal_type]
+ mean = np.array(
+ config['mean'][:actual_channels], dtype=np.float32)
+ std = np.array(
+ config['std'][:actual_channels], dtype=np.float32)
+ else:
+ mean = np.array([128.0] * actual_channels, dtype=np.float32)
+ std = np.array([50.0] * actual_channels, dtype=np.float32)
+
+ mean_b = mean.reshape(1, 1, -1)
+ std_b = std.reshape(1, 1, -1)
+
+ # Sen1Floods11 S1Hand (and many other SAR/MSI products) encode
+ # nodata pixels as NaN / ±Inf inside the TIFF. Left alone, the
+ # NaN propagates through `(img - mean) / std` and poisons every
+ # downstream feature map, so both the CE loss and the MoE
+ # balance loss become NaN from the very first training step.
+ # Replace non-finite pixels with the per-channel mean so they
+ # normalize to 0 (a neutral value the network will learn to
+ # ignore alongside the 255-ignored label pixels).
+ if not np.all(np.isfinite(img)):
+ img = np.where(
+ np.isfinite(img),
+ img,
+ np.broadcast_to(mean_b, img.shape),
+ ).astype(np.float32)
+
+ img = (img - mean_b) / std_b
+
+ # Final safety net: some SAR products use a finite sentinel
+ # (e.g. -9999) instead of NaN for nodata, which would otherwise
+ # survive normalization as a huge outlier. Clipping to ±10σ is
+ # well outside any legitimate value for the modalities
+ # configured above, so real data is untouched.
+ img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
+ np.clip(img, -10.0, 10.0, out=img)
+
+ results['img'] = img
+ results['img_norm_cfg'] = dict(
+ mean=mean.tolist(),
+ std=std.tolist(),
+ to_rgb=self.to_rgb,
+ )
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class GenerateBoundary(BaseTransform):
+ """Generate boundary map from segmentation label.
+
+ Required Keys:
+ - gt_seg_map (np.ndarray)
+
+ Added Keys:
+ - gt_boundary_map (np.ndarray)
+ """
+
+ def __init__(self, thickness=3, ignore_index=255):
+ self.thickness = thickness
+ self.ignore_index = ignore_index
+
+ def transform(self, results: dict) -> dict:
+ if 'gt_seg_map' not in results:
+ return results
+
+ seg = results['gt_seg_map']
+
+ if len(seg.shape) == 3:
+ seg = seg.squeeze(-1)
+
+ boundary = self._generate_boundary(seg)
+
+ results['gt_boundary_map'] = boundary
+ if 'seg_fields' in results:
+ results['seg_fields'].append('gt_boundary_map')
+
+ return results
+
+ def _generate_boundary(self, seg_mask):
+ import cv2
+
+ boundary = np.zeros_like(seg_mask, dtype=np.uint8)
+
+ unique_labels = np.unique(seg_mask)
+ unique_labels = unique_labels[unique_labels != self.ignore_index]
+
+ kernel = np.ones((self.thickness, self.thickness), np.uint8)
+
+ for label in unique_labels:
+ class_mask = (seg_mask == label).astype(np.uint8)
+ dilated = cv2.dilate(class_mask, kernel, iterations=1)
+ eroded = cv2.erode(class_mask, kernel, iterations=1)
+ class_boundary = dilated - eroded
+ boundary = np.maximum(boundary, class_boundary)
+
+ boundary[seg_mask == self.ignore_index] = self.ignore_index
+
+ return boundary
+
+
+@TRANSFORMS.register_module()
+class MultiModalPad(BaseTransform):
+ """Pad images with arbitrary channel counts using numpy.
+
+ mmcv.transforms.Pad uses cv2.copyMakeBorder which only supports up to
+ 4 channels. This transform uses numpy padding instead, so it works
+ with SAR (8ch), GF (5ch), etc.
+
+ Required Keys:
+ - img (np.ndarray)
+
+ Modified Keys:
+ - img (np.ndarray)
+ - img_shape (tuple)
+
+ Added Keys:
+ - pad_shape (tuple)
+ - padding_size (tuple)
+
+ Args:
+ size (tuple): Target (H, W).
+ pad_val (float): Padding value for images. Default: 0.
+ seg_pad_val (float): Padding value for seg maps. Default: 255.
+ """
+
+ def __init__(self, size, pad_val=0, seg_pad_val=255):
+ self.size = size
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+
+ def transform(self, results: dict) -> dict:
+ img = results['img']
+ h, w = img.shape[:2]
+ target_h, target_w = self.size
+
+ pad_h = max(target_h - h, 0)
+ pad_w = max(target_w - w, 0)
+
+ if pad_h > 0 or pad_w > 0:
+ if len(img.shape) == 3:
+ pad_width = ((0, pad_h), (0, pad_w), (0, 0))
+ else:
+ pad_width = ((0, pad_h), (0, pad_w))
+
+ img = np.pad(img, pad_width, mode='constant',
+ constant_values=self.pad_val)
+ results['img'] = img
+
+ # Pad seg maps
+ for key in results.get('seg_fields', []):
+ if key in results:
+ seg = results[key]
+ if len(seg.shape) == 3:
+ seg_pad = ((0, pad_h), (0, pad_w), (0, 0))
+ else:
+ seg_pad = ((0, pad_h), (0, pad_w))
+ results[key] = np.pad(
+ seg, seg_pad, mode='constant',
+ constant_values=self.seg_pad_val)
+
+ if 'gt_seg_map' in results and 'gt_seg_map' not in results.get(
+ 'seg_fields', []):
+ seg = results['gt_seg_map']
+ if len(seg.shape) == 3:
+ seg_pad = ((0, pad_h), (0, pad_w), (0, 0))
+ else:
+ seg_pad = ((0, pad_h), (0, pad_w))
+ results['gt_seg_map'] = np.pad(
+ seg, seg_pad, mode='constant',
+ constant_values=self.seg_pad_val)
+
+ results['img_shape'] = img.shape[:2]
+ results['pad_shape'] = img.shape[:2]
+ results['padding_size'] = (0, pad_w, 0, pad_h)
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class PackMultiModalSegInputs(BaseTransform):
+ """Pack multi-modal data into SegDataSample format.
+
+ This replaces CollectMultiModalData + MultiModalFormatBundle from 0.x.
+
+ Required Keys:
+ - img (np.ndarray)
+
+ Optional Keys:
+ - gt_seg_map (np.ndarray)
+ - gt_boundary_map (np.ndarray)
+ - dataset_name (str)
+ - modal_type (str)
+ - actual_channels (int)
+
+ Added Keys:
+ - inputs (torch.Tensor)
+ - data_samples (SegDataSample)
+ """
+
+ def __init__(self,
+ meta_keys=('img_path', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor',
+ 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels',
+ 'dataset_name', 'img_norm_cfg',
+ 'reduce_zero_label')):
+ self.meta_keys = meta_keys
+
+ def transform(self, results: dict) -> dict:
+ packed_results = dict()
+
+ if 'img' in results:
+ img = results['img']
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ if not img.flags.c_contiguous:
+ img = to_tensor(
+ np.ascontiguousarray(img.transpose(2, 0, 1)))
+ else:
+ img = img.transpose(2, 0, 1)
+ img = to_tensor(img).contiguous()
+ packed_results['inputs'] = img
+
+ data_sample = SegDataSample()
+
+ if 'gt_seg_map' in results:
+ gt_seg_map = results['gt_seg_map']
+ if len(gt_seg_map.shape) == 2:
+ data = to_tensor(gt_seg_map[None, ...].astype(np.int64))
+ else:
+ data = to_tensor(gt_seg_map.astype(np.int64))
+ gt_sem_seg_data = dict(data=data)
+ data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
+
+ if 'gt_boundary_map' in results:
+ gt_boundary_data = dict(
+ data=to_tensor(
+ results['gt_boundary_map'][None, ...].astype(np.int64)))
+ data_sample.set_data(
+ dict(gt_boundary_map=PixelData(**gt_boundary_data)))
+
+ # Set meta info
+ img_meta = {}
+ for key in self.meta_keys:
+ if key in results:
+ img_meta[key] = results[key]
+ data_sample.set_metainfo(img_meta)
+
+ packed_results['data_samples'] = data_sample
+
+ return packed_results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(meta_keys={self.meta_keys})'
+ return repr_str
diff --git a/mmseg/datasets/transforms/sen1floods11_pipelines.py b/mmseg/datasets/transforms/sen1floods11_pipelines.py
new file mode 100644
index 0000000000..a8efe54a8e
--- /dev/null
+++ b/mmseg/datasets/transforms/sen1floods11_pipelines.py
@@ -0,0 +1,86 @@
+"""
+Sen1Floods11 loading transforms.
+
+The ``LabelHand`` TIFFs store signed values in ``{-1, 0, 1}``:
+
+ -1: nodata -> mapped to ``ignore_index`` (255 by default)
+ 0: non-flood / background
+ 1: flood
+
+Default :class:`LoadAnnotations` uses ``mmcv.imfrombytes`` which is not
+reliable for signed-int TIFF, so we decode with :mod:`tifffile` and
+explicitly remap nodata. The returned ``gt_seg_map`` is ``np.uint8``
+which is what the rest of the mmseg 1.x pipeline expects.
+"""
+import numpy as np
+from mmcv.transforms import BaseTransform
+
+from mmseg.registry import TRANSFORMS
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+
+@TRANSFORMS.register_module()
+class LoadSen1Floods11Annotation(BaseTransform):
+ """Load a Sen1Floods11 LabelHand TIFF segmentation map.
+
+ Required Keys:
+ - seg_map_path
+
+ Added Keys:
+ - gt_seg_map (np.uint8, shape (H, W))
+ - seg_fields (list)
+
+ Args:
+ ignore_index (int): Value that replaces ``nodata_value`` in the
+ output mask. Default: 255.
+ nodata_value (int): Raw label value that indicates "no data".
+ Default: -1.
+ """
+
+ def __init__(self, ignore_index: int = 255, nodata_value: int = -1):
+ if tifffile is None:
+ raise ImportError(
+ 'tifffile is required to load Sen1Floods11 labels. '
+ 'Install with `pip install tifffile`.')
+ self.ignore_index = int(ignore_index)
+ self.nodata_value = int(nodata_value)
+
+ def transform(self, results: dict) -> dict:
+ seg_path = results['seg_map_path']
+ raw = tifffile.imread(seg_path)
+
+ # Some TIFFs come as (1, H, W) / (H, W, 1)
+ raw = np.squeeze(raw)
+ if raw.ndim != 2:
+ raise RuntimeError(
+ f'Expected a 2D label for {seg_path}, got shape {raw.shape}')
+
+ raw = raw.astype(np.int32)
+
+ gt = np.full(raw.shape, self.ignore_index, dtype=np.uint8)
+ gt[raw == 0] = 0
+ gt[raw == 1] = 1
+ # Explicit nodata remap - catches anything that isn't {0, 1}.
+ gt[raw == self.nodata_value] = self.ignore_index
+
+ # Optional label remapping (preserves parity with LoadAnnotations)
+ if results.get('label_map', None):
+ gt_copy = gt.copy()
+ for old_id, new_id in results['label_map'].items():
+ gt[gt_copy == old_id] = new_id
+
+ results['gt_seg_map'] = gt
+ results.setdefault('seg_fields', [])
+ if 'gt_seg_map' not in results['seg_fields']:
+ results['seg_fields'].append('gt_seg_map')
+
+ return results
+
+ def __repr__(self) -> str:
+ return (f'{self.__class__.__name__}('
+ f'ignore_index={self.ignore_index}, '
+ f'nodata_value={self.nodata_value})')
diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py
index 21cddde89d..ac9282a38c 100644
--- a/mmseg/engine/hooks/visualization_hook.py
+++ b/mmseg/engine/hooks/visualization_hook.py
@@ -4,9 +4,12 @@
from typing import Optional, Sequence
import mmcv
+import numpy as np
+import torch
from mmengine.fileio import get
from mmengine.hooks import Hook
from mmengine.runner import Runner
+from mmengine.utils import mkdir_or_exist
from mmengine.visualization import Visualizer
from mmseg.registry import HOOKS
@@ -15,14 +18,8 @@
@HOOKS.register_module()
class SegVisualizationHook(Hook):
- """Segmentation Visualization Hook. Used to visualize validation and
- testing process prediction results.
-
- In the testing phase:
-
- 1. If ``show`` is True, it means that only the prediction results are
- visualized without storing data, so ``vis_backends`` needs to
- be excluded.
+ """Segmentation Visualization Hook. This hook visualize the prediction
+ results during validation and testing.
Args:
draw (bool): whether to draw prediction results. If it is False,
@@ -30,10 +27,8 @@ class SegVisualizationHook(Hook):
interval (int): The interval of visualization. Defaults to 50.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
- backend_args (dict, Optional): Arguments to instantiate a file backend.
- See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
- for details. Defaults to None.
- Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
+ backend_args (dict, Optional): Arguments to instantiate a file client.
+ Defaults to None.
"""
def __init__(self,
@@ -52,16 +47,47 @@ def __init__(self,
'the prediction results are visualized '
'without storing data, so vis_backends '
'needs to be excluded.')
-
- self.wait_time = wait_time
- self.backend_args = backend_args.copy() if backend_args else None
+ self.wait_time = wait_time
+ else:
+ self.wait_time = 0.
self.draw = draw
- if not self.draw:
- warnings.warn('The draw is False, it means that the '
- 'hook for visualization will not take '
- 'effect. The results will NOT be '
- 'visualized or stored.')
- self._test_index = 0
+ self.backend_args = backend_args
+
+ IGNORE_INDEX = 255
+
+ def _create_dummy_image(self, height: int, width: int) -> np.ndarray:
+ """Create a dummy black image for visualization.
+
+ Args:
+ height: Image height
+ width: Image width
+
+ Returns:
+ Black RGB image of shape (H, W, 3)
+ """
+ return np.zeros((height, width, 3), dtype=np.uint8)
+
+ @staticmethod
+ def _mask_nodata(output: SegDataSample,
+ ignore_index: int = 255) -> None:
+ """Set prediction to ignore_index where GT is nodata.
+
+ This prevents nodata pixels from being visualized as a real
+ class (e.g. Flood). Called AFTER evaluator.process() so
+ metrics are unaffected.
+ """
+ if not (hasattr(output, 'gt_sem_seg')
+ and hasattr(output, 'pred_sem_seg')):
+ return
+ gt = output.gt_sem_seg.data # (1, H, W) or (H, W)
+ pred = output.pred_sem_seg.data # same shape
+ nodata_mask = (gt == ignore_index)
+ if nodata_mask.any():
+ output.pred_sem_seg.data = torch.where(
+ nodata_mask,
+ torch.tensor(ignore_index, dtype=pred.dtype,
+ device=pred.device),
+ pred)
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
@@ -71,59 +97,84 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
- outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
- that contain annotations and predictions.
+ outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
"""
if self.draw is False:
return
- # There is no guarantee that the same batch of images
- # is visualized for each evaluation.
- total_curr_iter = runner.iter + batch_idx
-
- # Visualize only the first data
- img_path = outputs[0].img_path
- img_bytes = get(img_path, backend_args=self.backend_args)
- img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
- window_name = f'val_{osp.basename(img_path)}'
-
- if total_curr_iter % self.interval == 0:
- self._visualizer.add_datasample(
- window_name,
- img,
- data_sample=outputs[0],
- show=self.show,
- wait_time=self.wait_time,
- step=total_curr_iter)
+ if self.every_n_inner_iters(batch_idx, self.interval):
+ for output in outputs:
+ self._mask_nodata(output, self.IGNORE_INDEX)
+
+ img_path = output.img_path
+ img_name = osp.basename(img_path)
+
+ # Get image size from prediction
+ if hasattr(output, 'pred_sem_seg'):
+ pred_shape = output.pred_sem_seg.data.shape
+ h, w = pred_shape[-2], pred_shape[-1]
+ elif hasattr(output, 'gt_sem_seg'):
+ gt_shape = output.gt_sem_seg.data.shape
+ h, w = gt_shape[-2], gt_shape[-1]
+ else:
+ # Fallback to default size
+ h, w = 512, 512
+
+ # Create dummy image instead of loading original
+ img = self._create_dummy_image(h, w)
+
+ # Only draw prediction, not ground truth
+ self._visualizer.add_datasample(
+ img_name,
+ img,
+ data_sample=output,
+ draw_gt=False, # Don't draw GT
+ draw_pred=True, # Only draw prediction
+ show=self.show,
+ wait_time=self.wait_time,
+ step=runner.iter,
+ with_labels=False)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
- """Run after every testing iterations.
+ """Run after every testing iteration.
Args:
runner (:obj:`Runner`): The runner of the testing process.
- batch_idx (int): The index of the current batch in the val loop.
+ batch_idx (int): The index of the current batch in the test loop.
data_batch (dict): Data from dataloader.
- outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
- that contain annotations and predictions.
+ outputs (Sequence[:obj:`SegDataSample`]]): Outputs from model.
"""
if self.draw is False:
return
- for data_sample in outputs:
- self._test_index += 1
+ for output in outputs:
+ self._mask_nodata(output, self.IGNORE_INDEX)
+
+ img_path = output.img_path
+ img_name = osp.basename(img_path)
- img_path = data_sample.img_path
- window_name = f'test_{osp.basename(img_path)}'
+ # Get image size from prediction
+ if hasattr(output, 'pred_sem_seg'):
+ pred_shape = output.pred_sem_seg.data.shape
+ h, w = pred_shape[-2], pred_shape[-1]
+ elif hasattr(output, 'gt_sem_seg'):
+ gt_shape = output.gt_sem_seg.data.shape
+ h, w = gt_shape[-2], gt_shape[-1]
+ else:
+ # Fallback to default size
+ h, w = 512, 512
- img_path = data_sample.img_path
- img_bytes = get(img_path, backend_args=self.backend_args)
- img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
+ # Create dummy image instead of loading original
+ img = self._create_dummy_image(h, w)
self._visualizer.add_datasample(
- window_name,
+ img_name,
img,
- data_sample=data_sample,
+ data_sample=output,
+ draw_gt=False, # Don't draw GT
+ draw_pred=True, # Only draw prediction
show=self.show,
wait_time=self.wait_time,
- step=self._test_index)
+ step=runner.iter,
+ with_labels=False) # 移除了 out_file 参数
\ No newline at end of file
diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py
index a98951283c..e993e600c3 100644
--- a/mmseg/models/__init__.py
+++ b/mmseg/models/__init__.py
@@ -4,6 +4,7 @@
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
from .data_preprocessor import SegDataPreProcessor
+from .multimodal_data_preprocessor import MultiModalDataPreProcessor
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
@@ -12,5 +13,6 @@
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
- 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
+ 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor',
+ 'MultiModalDataPreProcessor'
]
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index 784d3dfdb7..3838e1a584 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -24,6 +24,7 @@
from .unet import UNet
from .vit import VisionTransformer
from .vpd import VPD
+from .multimodal_swin_moe_backbone import MultiModalSwinMoE
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
@@ -31,5 +32,5 @@
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
- 'DDRNet', 'VPD'
+ 'DDRNet', 'VPD', 'MultiModalSwinMoE'
]
diff --git a/mmseg/models/backbones/multimodal_swin_moe_backbone.py b/mmseg/models/backbones/multimodal_swin_moe_backbone.py
new file mode 100644
index 0000000000..51b738c2bb
--- /dev/null
+++ b/mmseg/models/backbones/multimodal_swin_moe_backbone.py
@@ -0,0 +1,1092 @@
+"""
+Multi-Modal Swin Transformer with MoE - MMSeg 1.x Version
+
+Migrated from mmseg 0.x to 1.x:
+- Registry: BACKBONES -> MODELS
+- BaseModule: mmcv.runner -> mmengine.model
+- load_checkpoint: mmcv.runner -> mmengine.runner
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from collections import OrderedDict
+import warnings
+
+from mmengine.model import BaseModule
+from mmengine.runner import load_checkpoint
+from mmcv.cnn import build_norm_layer
+from mmengine.utils import to_2tuple
+from timm.models.layers import DropPath
+
+from mmseg.registry import MODELS
+
+
+# ==================== Modal-Specific Patch Embedding ====================
+class ModalSpecificPatchEmbed(nn.Module):
+ """模态专用Patch Embedding - 无零填充版本"""
+
+ def __init__(self,
+ modal_configs,
+ training_modals,
+ embed_dims=96,
+ patch_size=4,
+ norm_cfg=dict(type='LN')):
+ super().__init__()
+ self.modal_configs = modal_configs
+ self.training_modals = set(training_modals) if training_modals else set()
+ self.embed_dims = embed_dims
+ self.patch_size = patch_size
+
+ self.modal_patch_embeds = nn.ModuleDict()
+ for modal_name in self.training_modals:
+ if modal_name in modal_configs:
+ in_ch = modal_configs[modal_name]['channels']
+ self.modal_patch_embeds[modal_name] = nn.Sequential(
+ nn.Conv2d(in_ch, embed_dims,
+ kernel_size=patch_size,
+ stride=patch_size),
+ nn.LayerNorm(embed_dims, eps=1e-6)
+ )
+
+ self.register_buffer('forward_count', torch.zeros(1, dtype=torch.long))
+ self._print_init_info()
+
+ def _print_init_info(self):
+ print(f"\n{'=' * 80}")
+ print(f"Modal-Specific Patch Embedding (NO ZERO PADDING)")
+ print(f"{'=' * 80}")
+ print(f"Embed dims: {self.embed_dims}")
+ print(f"Patch size: {self.patch_size}")
+ print(f"Training modals: {sorted(self.training_modals)}")
+
+ total_params = 0
+ for modal_name in sorted(self.modal_patch_embeds.keys()):
+ in_ch = self.modal_configs[modal_name]['channels']
+ params = in_ch * self.embed_dims * self.patch_size * self.patch_size
+ total_params += params
+ print(f" {modal_name:>10s}: {in_ch:>2d}ch -> {self.embed_dims:>3d}d "
+ f"({params:>6,d} params)")
+
+ print(f"\nTotal parameters: {total_params:,}")
+ print(f"Strategy: Direct processing (NO zero-padding!)")
+ print(f"{'=' * 80}\n")
+
+ def forward(self, imgs, modal_types):
+ """
+ Args:
+ imgs: List[Tensor] - each [C_i, H, W]
+ modal_types: List[str]
+ Returns:
+ x: [B, H_p*W_p, embed_dims]
+ hw_shape: (H_p, W_p)
+ """
+ self.forward_count += 1
+
+ if self.forward_count == 1 and self.training:
+ self._print_first_forward(imgs, modal_types)
+
+ outputs = []
+ hw_shape = None
+
+ for img, modal in zip(imgs, modal_types):
+ x_i = img.unsqueeze(0)
+
+ if modal in self.modal_patch_embeds:
+ conv = self.modal_patch_embeds[modal][0]
+ norm = self.modal_patch_embeds[modal][1]
+
+ out = conv(x_i)
+ _, C, H, W = out.shape
+ if hw_shape is None:
+ hw_shape = (H, W)
+
+ out = out.flatten(2).transpose(1, 2)
+ out = norm(out)
+ else:
+ if self.training:
+ raise ValueError(
+ f"Modal '{modal}' not in training modals: "
+ f"{self.training_modals}"
+ )
+ else:
+ in_ch = x_i.shape[1]
+ temp_conv = nn.Conv2d(
+ in_ch, self.embed_dims,
+ kernel_size=self.patch_size,
+ stride=self.patch_size).to(x_i.device)
+ temp_norm = nn.LayerNorm(self.embed_dims).to(x_i.device)
+
+ nn.init.trunc_normal_(temp_conv.weight, std=0.02)
+ nn.init.constant_(temp_conv.bias, 0)
+
+ out = temp_conv(x_i)
+ _, C, H, W = out.shape
+ if hw_shape is None:
+ hw_shape = (H, W)
+ out = out.flatten(2).transpose(1, 2)
+ out = temp_norm(out)
+
+ outputs.append(out)
+
+ x = torch.cat(outputs, dim=0)
+ return x, hw_shape
+
+ def _print_first_forward(self, imgs, modal_types):
+ from collections import Counter
+ print(f"\n{'=' * 80}")
+ print(f"Modal-Specific Patch Embedding - First Forward (NO PADDING)")
+ print(f"{'=' * 80}")
+ print(f"Batch size: {len(imgs)}")
+
+ modal_counts = Counter(modal_types)
+ print(f"\nBatch composition:")
+ for modal, count in sorted(modal_counts.items()):
+ idx = modal_types.index(modal)
+ ch = imgs[idx].shape[0]
+ print(f" {modal}: {count} samples x {ch}ch (direct, no padding!)")
+ print(f"{'=' * 80}\n")
+
+
+class UnifiedPatchEmbed(nn.Module):
+ """统一Patch Embedding - 所有模态共享一个卷积核(消融baseline)"""
+
+ def __init__(self,
+ modal_configs,
+ training_modals,
+ embed_dims=96,
+ patch_size=4,
+ norm_cfg=dict(type='LN')):
+ super().__init__()
+ self.modal_configs = modal_configs
+ self.training_modals = set(training_modals) if training_modals else set()
+ self.embed_dims = embed_dims
+ self.patch_size = patch_size
+
+ self.max_channels = max(
+ modal_configs[m]['channels'] for m in training_modals
+ ) if training_modals and modal_configs else 3
+
+ self.modal_channels = {}
+ for modal_name in (training_modals or []):
+ if modal_name in modal_configs:
+ self.modal_channels[modal_name] = modal_configs[modal_name]['channels']
+
+ self.unified_patch_embed = nn.Sequential(
+ nn.Conv2d(self.max_channels, embed_dims,
+ kernel_size=patch_size, stride=patch_size),
+ nn.LayerNorm(embed_dims, eps=1e-6)
+ )
+
+ self.register_buffer('forward_count', torch.zeros(1, dtype=torch.long))
+
+ def forward(self, imgs, modal_types):
+ self.forward_count += 1
+ outputs = []
+ hw_shape = None
+
+ for img, modal in zip(imgs, modal_types):
+ x_i = img.unsqueeze(0)
+ actual_ch = x_i.shape[1]
+
+ if actual_ch < self.max_channels:
+ pad_size = self.max_channels - actual_ch
+ padding = torch.zeros(
+ 1, pad_size, x_i.shape[2], x_i.shape[3],
+ device=x_i.device, dtype=x_i.dtype
+ )
+ x_i = torch.cat([x_i, padding], dim=1)
+
+ conv = self.unified_patch_embed[0]
+ norm = self.unified_patch_embed[1]
+
+ out = conv(x_i)
+ _, C, H, W = out.shape
+ if hw_shape is None:
+ hw_shape = (H, W)
+
+ out = out.flatten(2).transpose(1, 2)
+ out = norm(out)
+ outputs.append(out)
+
+ x = torch.cat(outputs, dim=0)
+ return x, hw_shape
+
+
+# ==================== MoE Components ====================
+class CosineTopKGate(nn.Module):
+ """Cosine similarity Gating with modal bias"""
+
+ def __init__(self, model_dim, num_experts,
+ modal_configs=None, training_modals=None, init_t=0.5,
+ use_modal_bias=True):
+ super().__init__()
+ proj_dim = min(model_dim // 2, 256)
+
+ self.temperature = nn.Parameter(
+ torch.log(torch.full([1], 1.0 / init_t)),
+ requires_grad=True
+ )
+ self.cosine_projector = nn.Linear(model_dim, proj_dim)
+ self.sim_matrix = nn.Parameter(
+ torch.randn(size=(proj_dim, num_experts)),
+ requires_grad=True
+ )
+ self.clamp_max = torch.log(torch.tensor(1. / 0.01)).item()
+
+ self.modal_configs = modal_configs
+ self.training_modals = set(training_modals) if training_modals else set()
+ self.use_modal_bias = use_modal_bias
+
+ if modal_configs is not None and use_modal_bias:
+ self.modal_bias = nn.Parameter(
+ torch.zeros(len(modal_configs), num_experts),
+ requires_grad=True
+ )
+ self.modal_name_to_idx = {
+ name: i for i, name in enumerate(modal_configs.keys())
+ }
+
+ nn.init.normal_(self.sim_matrix, 0, 0.01)
+
+ def forward(self, x, modal_types=None):
+ if len(x.shape) == 4:
+ B, H, W, C = x.shape
+ x = x.reshape(B, -1, C).mean(dim=1)
+ elif len(x.shape) == 3:
+ x = x.mean(dim=1)
+
+ # eps guards against NaN from 0/0 when a row of the projection
+ # (or a column of sim_matrix) has zero L2 norm. This matters during
+ # fine-tuning on a new modal whose patch embed is freshly
+ # initialized and can emit all-zero pooled features.
+ logits = torch.matmul(
+ F.normalize(self.cosine_projector(x), dim=1, eps=1e-6),
+ F.normalize(self.sim_matrix, dim=0, eps=1e-6)
+ )
+
+ logit_scale = torch.clamp(self.temperature, max=self.clamp_max).exp()
+ logits = logits * logit_scale
+
+ if (self.use_modal_bias and modal_types is not None
+ and self.modal_configs is not None):
+ modal_bias = torch.zeros_like(logits)
+ for i, modal in enumerate(modal_types):
+ if (modal in self.modal_name_to_idx
+ and modal in self.training_modals):
+ modal_idx = self.modal_name_to_idx[modal]
+ modal_bias[i] = self.modal_bias[modal_idx]
+ logits = logits + modal_bias
+
+ return logits
+
+
+class SparseDispatcher:
+ """Sparse dispatcher for MoE"""
+
+ def __init__(self, num_experts, gates):
+ # Sanitize gates: NaN/Inf can leak in from F.normalize on zero-norm
+ # vectors inside CosineTopKGate when a freshly-initialized modal
+ # patch embed produces degenerate features early in fine-tuning.
+ # Replace them with 0 so they are simply not routed anywhere.
+ if not torch.isfinite(gates).all():
+ gates = torch.nan_to_num(gates, nan=0.0, posinf=0.0, neginf=0.0)
+
+ self._gates = gates
+ self._num_experts = num_experts
+
+ # IMPORTANT: use the SAME positive-mask for both `_batch_index`
+ # and `_part_sizes`. Previously line 287 used `torch.nonzero(gates)`
+ # (which includes NaN because `NaN != 0`) while line 290 used
+ # `gates > 0` (which excludes NaN), making `split_sizes` disagree
+ # with `_batch_index.numel()` and crashing `torch.split`.
+ positive_mask = gates > 0
+ sorted_experts, index_sorted_experts = torch.nonzero(positive_mask).sort(0)
+ _, self._expert_index = sorted_experts.split(1, dim=1)
+ self._batch_index = sorted_experts[index_sorted_experts[:, 1], 0]
+ self._part_sizes = list(positive_mask.sum(0).cpu().numpy())
+
+ gates_exp = gates[self._batch_index.flatten()]
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
+
+ def dispatch(self, inp):
+ inp_exp = inp[self._batch_index].squeeze(1)
+ return torch.split(inp_exp, self._part_sizes, dim=0)
+
+ def combine(self, expert_out, multiply_by_gates=True):
+ stitched = torch.cat(expert_out, 0)
+
+ if multiply_by_gates:
+ stitched = stitched.mul(self._nonzero_gates)
+
+ zeros = torch.zeros(
+ self._gates.size(0), expert_out[-1].size(1),
+ requires_grad=True, device=stitched.device
+ )
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
+ return combined
+
+
+class SwinFFN(nn.Module):
+ """Swin-style FFN (single expert)"""
+
+ def __init__(self, in_features, hidden_features,
+ act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, in_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwinMoELayer(nn.Module):
+ """MoE Layer for Swin Transformer"""
+
+ def __init__(self,
+ in_features,
+ hidden_features,
+ num_experts=8,
+ num_shared_experts=0,
+ top_k=2,
+ noisy_gating=True,
+ modal_configs=None,
+ training_modals=None,
+ use_modal_bias=True,
+ act_layer=nn.GELU,
+ drop=0.,
+ return_expert_outputs=False):
+ super().__init__()
+ self.num_experts = num_experts
+ self.num_shared_experts = num_shared_experts
+ self.top_k = top_k
+ self.noisy_gating = noisy_gating
+ self.return_expert_outputs = return_expert_outputs
+
+ self.experts = nn.ModuleList([
+ SwinFFN(in_features, hidden_features, act_layer, drop)
+ for _ in range(num_experts)
+ ])
+
+ self.gating = CosineTopKGate(
+ in_features, num_experts,
+ modal_configs, training_modals,
+ use_modal_bias=use_modal_bias
+ )
+
+ if noisy_gating:
+ self.w_noise = nn.Parameter(
+ torch.zeros(in_features, num_experts),
+ requires_grad=True
+ )
+
+ if num_shared_experts > 0:
+ shared_hidden = hidden_features * num_shared_experts
+ self.shared_experts = SwinFFN(
+ in_features, shared_hidden, act_layer, drop
+ )
+ else:
+ self.shared_experts = None
+
+ self.softplus = nn.Softplus()
+ self.softmax = nn.Softmax(-1)
+
+ def cv_squared(self, x):
+ eps = 1e-10
+ if x.shape[0] == 1:
+ return torch.tensor([0.0], device=x.device)
+ return x.float().var() / (x.float().mean() ** 2 + eps)
+
+ def noisy_top_k_gating(self, x, modal_types=None, noise_epsilon=1e-2):
+ x_pooled = x.mean(dim=1)
+ clean_logits = self.gating(x_pooled, modal_types)
+
+ if self.noisy_gating and self.training:
+ raw_noise_stddev = x_pooled @ self.w_noise
+ noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
+ noisy_logits = clean_logits + (
+ torch.randn_like(clean_logits) * noise_stddev)
+ logits = noisy_logits
+ else:
+ logits = clean_logits
+
+ top_logits, top_indices = logits.topk(
+ min(self.top_k, self.num_experts), dim=-1
+ )
+ top_k_gates = self.softmax(top_logits)
+
+ zeros = torch.zeros_like(logits, requires_grad=True)
+ gates = zeros.scatter(-1, top_indices, top_k_gates)
+
+ load = (gates > 0).sum(0)
+ return gates, load
+
+ def forward(self, x, modal_types=None, loss_coef=1e-2):
+ identity = x
+ B, N, C = x.shape
+
+ gates, load = self.noisy_top_k_gating(x, modal_types)
+ importance = gates.sum(0)
+
+ balance_loss = (self.cv_squared(importance)
+ + self.cv_squared(load.float()))
+ balance_loss = balance_loss * loss_coef
+
+ dispatcher = SparseDispatcher(self.num_experts, gates)
+
+ x_for_dispatch = x.reshape(B, -1)
+ expert_inputs = dispatcher.dispatch(x_for_dispatch)
+
+ expert_outputs = []
+ expert_features_list = []
+
+ for i, expert_input in enumerate(expert_inputs):
+ if expert_input.size(0) > 0:
+ n_samples = expert_input.size(0)
+ expert_input_reshaped = expert_input.reshape(n_samples, N, C)
+
+ expert_out = self.experts[i](expert_input_reshaped)
+
+ if self.return_expert_outputs and self.training:
+ expert_feat_pooled = expert_out.mean(dim=1)
+ expert_feat_4d = expert_feat_pooled.unsqueeze(-1).unsqueeze(-1)
+ expert_features_list.append(expert_feat_4d)
+
+ expert_out = expert_out.reshape(n_samples, -1)
+ expert_outputs.append(expert_out)
+ else:
+ expert_outputs.append(
+ torch.empty(0, N * C, device=x.device))
+
+ y_routed = dispatcher.combine(expert_outputs)
+ y_routed = y_routed.reshape(B, N, C)
+
+ if self.shared_experts is not None:
+ y_shared = self.shared_experts(identity)
+ else:
+ y_shared = 0
+
+ y = y_routed + y_shared
+
+ if (self.return_expert_outputs and self.training
+ and len(expert_features_list) > 0):
+ if self.shared_experts is not None:
+ y_shared_pooled = y_shared.mean(dim=1)
+ y_shared_4d = y_shared_pooled.unsqueeze(-1).unsqueeze(-1)
+ expert_features_list.append(y_shared_4d)
+
+ return y, balance_loss, expert_features_list
+ else:
+ return y, balance_loss, None
+
+
+# ==================== Swin Transformer Components ====================
+class WindowMSA(nn.Module):
+ """Window-based Multi-head Self Attention"""
+
+ def __init__(self, dim, window_size, num_heads,
+ qkv_bias=True, qk_scale=None,
+ attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(
+ (2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)
+ )
+
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = (coords_flatten[:, :, None]
+ - coords_flatten[:, None, :])
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer(
+ "relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(
+ B_, N, 3, self.num_heads, C // self.num_heads)
+ qkv = qkv.permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1], -1
+ )
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous()
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ if B_ % nW == 0:
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
+ attn = attn + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size,
+ W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size,
+ window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SwinBlockWithMoE(nn.Module):
+ """Swin Transformer Block with optional MoE"""
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ use_moe=False,
+ num_experts=8,
+ num_shared_experts=0,
+ top_k=2,
+ noisy_gating=True,
+ modal_configs=None,
+ training_modals=None,
+ use_modal_bias=True,
+ return_expert_outputs=False):
+ super().__init__()
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.use_moe = use_moe
+
+ assert 0 <= self.shift_size < self.window_size, \
+ "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowMSA(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop
+ )
+
+ self.drop_path = (DropPath(drop_path)
+ if drop_path > 0. else nn.Identity())
+ self.norm2 = norm_layer(dim)
+
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if use_moe:
+ self.mlp = SwinMoELayer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ num_experts=num_experts,
+ num_shared_experts=num_shared_experts,
+ top_k=top_k,
+ noisy_gating=noisy_gating,
+ modal_configs=modal_configs,
+ training_modals=training_modals,
+ use_modal_bias=use_modal_bias,
+ act_layer=act_layer,
+ drop=drop,
+ return_expert_outputs=return_expert_outputs
+ )
+ else:
+ self.mlp = SwinFFN(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop
+ )
+
+ self.register_buffer("attn_mask", None)
+
+ def forward(self, x, hw_shape, modal_types=None):
+ H, W = hw_shape
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ if self.attn_mask is None:
+ self.attn_mask = self._calculate_mask(Hp, Wp).to(x.device)
+ attn_mask = self.attn_mask
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ x_windows = window_partition(shifted_x, self.window_size)
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C)
+
+ attn_windows = self.attn(x_windows, mask=attn_mask)
+
+ attn_windows = attn_windows.view(
+ -1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
+
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + self.drop_path(x)
+
+ moe_loss = None
+ expert_features = None
+
+ if self.use_moe:
+ mlp_out = self.mlp(self.norm2(x), modal_types)
+ if len(mlp_out) == 3:
+ mlp_x, moe_loss, expert_features = mlp_out
+ else:
+ mlp_x, moe_loss = mlp_out
+ x = x + self.drop_path(mlp_x)
+ else:
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x, moe_loss, expert_features
+
+ def _calculate_mask(self, H, W):
+ img_mask = torch.zeros((1, H, W, 1))
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0))
+ attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))
+ return attn_mask
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer"""
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, hw_shape):
+ H, W = hw_shape
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, \
+ f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :]
+ x1 = x[:, 1::2, 0::2, :]
+ x2 = x[:, 0::2, 1::2, :]
+ x3 = x[:, 1::2, 1::2, :]
+ x = torch.cat([x0, x1, x2, x3], -1)
+ x = x.view(B, -1, 4 * C)
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x, (H // 2, W // 2)
+
+
+# ==================== Main Backbone ====================
+@MODELS.register_module()
+class MultiModalSwinMoE(BaseModule):
+ """Multi-Modal Swin Transformer with MoE - MMSeg 1.x"""
+
+ def __init__(self,
+ modal_configs=None,
+ training_modals=None,
+ pretrain_img_size=224,
+ patch_size=4,
+ embed_dims=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+ use_moe=True,
+ use_modal_bias=True,
+ num_experts=8,
+ num_shared_experts_config=None,
+ top_k=2,
+ noisy_gating=True,
+ MoE_Block_inds=None,
+ use_expert_diversity_loss=False,
+ use_modal_specific_stem=True,
+ frozen_stages=-1,
+ freeze_patch_embed=False,
+ pretrained=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ self.frozen_stages = frozen_stages
+ self.freeze_patch_embed = freeze_patch_embed
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dims = embed_dims
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.depths = depths
+ self.num_heads = num_heads
+ self.use_moe = use_moe
+ self.use_expert_diversity_loss = use_expert_diversity_loss
+
+ self.modal_configs = modal_configs
+ if modal_configs is not None:
+ if training_modals is None:
+ self.training_modals = list(modal_configs.keys())
+ else:
+ self.training_modals = training_modals
+ else:
+ self.training_modals = []
+
+ if MoE_Block_inds is None:
+ MoE_Block_inds = []
+ for depth in depths:
+ start_idx = depth // 2
+ MoE_Block_inds.append(list(range(start_idx, depth)))
+ self.MoE_Block_inds = MoE_Block_inds
+
+ if num_shared_experts_config is None:
+ num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1}
+ self.num_shared_experts_config = num_shared_experts_config
+
+ self.use_modal_specific_stem = use_modal_specific_stem
+ if use_modal_specific_stem:
+ self.patch_embed = ModalSpecificPatchEmbed(
+ modal_configs=modal_configs if modal_configs else {},
+ training_modals=self.training_modals,
+ embed_dims=embed_dims,
+ patch_size=patch_size,
+ norm_cfg=dict(type='LN') if patch_norm else None
+ )
+ else:
+ self.patch_embed = UnifiedPatchEmbed(
+ modal_configs=modal_configs if modal_configs else {},
+ training_modals=self.training_modals,
+ embed_dims=embed_dims,
+ patch_size=patch_size,
+ norm_cfg=dict(type='LN') if patch_norm else None
+ )
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(
+ 0, drop_path_rate, sum(depths))]
+
+ self.stages = nn.ModuleList()
+ num_features = [int(embed_dims * 2 ** i)
+ for i in range(self.num_layers)]
+
+ for i_stage in range(self.num_layers):
+ stage_blocks = []
+ dim = num_features[i_stage]
+ num_shared = num_shared_experts_config.get(i_stage, 0)
+
+ for i_block in range(depths[i_stage]):
+ use_moe_block = (use_moe
+ and (i_block in MoE_Block_inds[i_stage]))
+
+ block = SwinBlockWithMoE(
+ dim=dim,
+ num_heads=num_heads[i_stage],
+ window_size=window_size,
+ shift_size=(0 if (i_block % 2 == 0)
+ else window_size // 2),
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_stage]) + i_block],
+ norm_layer=norm_layer,
+ use_moe=use_moe_block,
+ num_experts=num_experts,
+ num_shared_experts=(num_shared
+ if use_moe_block else 0),
+ top_k=top_k,
+ noisy_gating=noisy_gating,
+ modal_configs=(modal_configs
+ if use_moe_block else None),
+ training_modals=(self.training_modals
+ if use_moe_block else None),
+ use_modal_bias=use_modal_bias,
+ return_expert_outputs=use_expert_diversity_loss
+ )
+ stage_blocks.append(block)
+
+ if i_stage < self.num_layers - 1:
+ downsample = PatchMerging(dim=dim, norm_layer=norm_layer)
+ else:
+ downsample = None
+
+ self.stages.append(nn.ModuleDict({
+ 'blocks': nn.ModuleList(stage_blocks),
+ 'downsample': downsample
+ }))
+
+ for i in out_indices:
+ layer = norm_layer(num_features[i])
+ layer_name = f'norm{i}'
+ self.add_module(layer_name, layer)
+
+ self.num_features = num_features
+ self.pretrained = pretrained
+ self._init_weights()
+ self._freeze_stages()
+ self._print_architecture_info()
+
+ def _freeze_stages(self):
+ """Freeze stages to stop gradient and set eval mode.
+
+ frozen_stages (int): Stages to freeze. -1 means no freezing.
+ 0 freezes stage 0, 1 freezes stages 0-1, etc.
+ freeze_patch_embed (bool): If True, also freeze patch_embed (stem).
+ Default False (stem remains trainable for domain adaptation).
+ """
+ if self.freeze_patch_embed:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(0, self.frozen_stages + 1):
+ if i >= self.num_layers:
+ break
+ # Freeze stage blocks
+ stage = self.stages[i]
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
+
+ # Freeze corresponding output norm layer
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 0:
+ frozen_params = sum(
+ 1 for p in self.parameters() if not p.requires_grad)
+ total_params = sum(1 for _ in self.parameters())
+ print(f'\n[Freeze] frozen_stages={self.frozen_stages}, '
+ f'freeze_patch_embed={self.freeze_patch_embed}')
+ print(f'[Freeze] {frozen_params}/{total_params} params frozen\n')
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def init_weights(self):
+ if self.pretrained is not None:
+ print(f'Loading pretrained Swin weights from {self.pretrained}')
+ load_checkpoint(
+ self, self.pretrained,
+ map_location='cpu',
+ strict=False
+ )
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ state_dict = super().state_dict(destination, prefix, keep_vars)
+ keys_to_remove = [k for k in state_dict.keys() if 'attn_mask' in k]
+ for k in keys_to_remove:
+ state_dict.pop(k)
+ return state_dict
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k]
+ for k in attn_mask_keys:
+ state_dict.pop(k)
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs
+ )
+
+ def _print_architecture_info(self):
+ print(f"\n{'=' * 80}")
+ print(f"Multi-Modal Swin Transformer with MoE")
+ print(f"{'=' * 80}")
+ print(f"Training modals: {self.training_modals}")
+ print(f"Embed dims: {self.embed_dims}")
+ print(f"Depths: {self.depths}")
+ print(f"Use MoE: {self.use_moe}")
+
+ print(f"\nShared Experts:")
+ for i in range(self.num_layers):
+ num_moe = len(self.MoE_Block_inds[i])
+ num_shared = self.num_shared_experts_config.get(i, 0)
+ print(f" Stage {i}: {self.depths[i]} blocks, "
+ f"{num_moe} MoE blocks, {num_shared} shared experts")
+
+ print(f"\nOutput indices: {self.out_indices}")
+ print(f"Output features: "
+ f"{[self.num_features[i] for i in self.out_indices]}")
+ print(f"{'=' * 80}\n")
+
+ def train(self, mode=True):
+ """Override train to keep frozen stages in eval mode."""
+ super().train(mode)
+ self._freeze_stages()
+
+ def forward(self, imgs, modal_types=None, **kwargs):
+ """
+ Args:
+ imgs: List[Tensor] - each [C_i, H, W]
+ modal_types: List[str]
+ Returns:
+ tuple of features, moe_balance_loss, expert_features
+ """
+ B = len(imgs)
+
+ if modal_types is None:
+ modal_types = ['rgb'] * B
+
+ x, hw_shape = self.patch_embed(imgs, modal_types)
+ x = self.pos_drop(x)
+
+ outs = []
+ moe_balance_losses = []
+ all_expert_features = []
+
+ for i_stage, stage_dict in enumerate(self.stages):
+ blocks = stage_dict['blocks']
+ downsample = stage_dict['downsample']
+
+ for block in blocks:
+ x, moe_loss, expert_features = block(
+ x, hw_shape, modal_types)
+
+ if moe_loss is not None:
+ moe_balance_losses.append(moe_loss)
+
+ if (expert_features is not None
+ and self.use_expert_diversity_loss
+ and i_stage == self.num_layers - 1):
+ all_expert_features.extend(expert_features)
+
+ if i_stage in self.out_indices:
+ norm_layer = getattr(self, f'norm{i_stage}')
+ x_out = norm_layer(x)
+
+ H, W = hw_shape
+ x_out = x_out.view(B, H, W, -1).permute(
+ 0, 3, 1, 2).contiguous()
+ outs.append(x_out)
+
+ if downsample is not None:
+ x, hw_shape = downsample(x, hw_shape)
+
+ avg_balance_loss = None
+ if len(moe_balance_losses) > 0:
+ avg_balance_loss = (sum(moe_balance_losses)
+ / len(moe_balance_losses))
+
+ return tuple(outs), avg_balance_loss, all_expert_features
diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py
index 0467cb3ad8..564c95039c 100644
--- a/mmseg/models/losses/__init__.py
+++ b/mmseg/models/losses/__init__.py
@@ -10,6 +10,7 @@
from .ohem_cross_entropy_loss import OhemCrossEntropy
from .silog_loss import SiLogLoss
from .tversky_loss import TverskyLoss
+from .moe_losses import CombinedMoELoss, ExpertDiversityLoss, MoEBalanceLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
@@ -17,5 +18,6 @@
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
- 'HuasdorffDisstanceLoss', 'SiLogLoss'
+ 'HuasdorffDisstanceLoss', 'SiLogLoss',
+ 'MoEBalanceLoss', 'ExpertDiversityLoss', 'CombinedMoELoss'
]
diff --git a/mmseg/models/losses/moe_losses.py b/mmseg/models/losses/moe_losses.py
new file mode 100644
index 0000000000..6901e3aec8
--- /dev/null
+++ b/mmseg/models/losses/moe_losses.py
@@ -0,0 +1,110 @@
+"""
+MoE-related loss functions - MMSeg 1.x Version
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MoEBalanceLoss(nn.Module):
+ """MoE load balancing loss wrapper"""
+
+ def __init__(self, loss_weight=1.0):
+ super().__init__()
+ self.loss_weight = loss_weight
+
+ def forward(self, balance_loss):
+ if balance_loss is None:
+ return torch.tensor(0.0)
+ return self.loss_weight * balance_loss
+
+
+class ExpertDiversityLoss(nn.Module):
+ """Expert diversity loss - encourages different expert representations"""
+
+ def __init__(self,
+ loss_weight=0.01,
+ similarity_type='cosine',
+ normalize=True):
+ super().__init__()
+ self.loss_weight = loss_weight
+ self.similarity_type = similarity_type
+ self.normalize = normalize
+
+ def forward(self, expert_features_list):
+ if expert_features_list is None or len(expert_features_list) < 2:
+ return torch.tensor(
+ 0.0,
+ device=(expert_features_list[0].device
+ if expert_features_list else None))
+
+ expert_vectors = []
+ for expert_feat in expert_features_list:
+ pooled = F.adaptive_avg_pool2d(
+ expert_feat, 1).squeeze(-1).squeeze(-1)
+
+ if self.normalize:
+ pooled = F.normalize(pooled, dim=1)
+
+ n_samples = pooled.shape[0]
+ expert_vector = pooled.sum(dim=0) / max(n_samples, 1)
+ expert_vectors.append(expert_vector)
+
+ num_experts = len(expert_vectors)
+ total_sim = 0.0
+ count = 0
+
+ for i in range(num_experts):
+ for j in range(i + 1, num_experts):
+ if self.similarity_type == 'cosine':
+ sim = F.cosine_similarity(
+ expert_vectors[i].unsqueeze(0),
+ expert_vectors[j].unsqueeze(0),
+ dim=1
+ )
+ elif self.similarity_type == 'l2':
+ dist = torch.norm(
+ expert_vectors[i] - expert_vectors[j], p=2)
+ sim = torch.clamp(1.0 - dist / 2.0, min=0.0, max=1.0)
+ else:
+ raise ValueError(
+ f"Unknown similarity type: {self.similarity_type}")
+
+ total_sim += sim
+ count += 1
+
+ if count > 0:
+ avg_sim = total_sim / count
+ else:
+ avg_sim = torch.tensor(0.0, device=expert_vectors[0].device)
+
+ diversity_loss = F.relu(avg_sim)
+ return self.loss_weight * diversity_loss
+
+
+class CombinedMoELoss(nn.Module):
+ """Combined MoE loss: balance + diversity"""
+
+ def __init__(self,
+ balance_weight=1.0,
+ diversity_weight=0.01,
+ diversity_similarity='cosine'):
+ super().__init__()
+ self.balance_loss = MoEBalanceLoss(loss_weight=balance_weight)
+ self.diversity_loss = ExpertDiversityLoss(
+ loss_weight=diversity_weight,
+ similarity_type=diversity_similarity
+ )
+
+ def forward(self, balance_loss, expert_features_list):
+ loss_balance = self.balance_loss(balance_loss)
+ loss_diversity = self.diversity_loss(expert_features_list)
+
+ total_loss = loss_balance + loss_diversity
+
+ loss_dict = {
+ 'loss_moe_balance': loss_balance,
+ 'loss_moe_diversity': loss_diversity
+ }
+
+ return total_loss, loss_dict
diff --git a/mmseg/models/multimodal_data_preprocessor.py b/mmseg/models/multimodal_data_preprocessor.py
new file mode 100644
index 0000000000..798efda31c
--- /dev/null
+++ b/mmseg/models/multimodal_data_preprocessor.py
@@ -0,0 +1,83 @@
+"""
+Multi-Modal Data PreProcessor - handles images with different channel counts.
+
+Unlike SegDataPreProcessor which stacks all images into a [B,C,H,W] tensor,
+this preprocessor keeps them as a list of [C_i,H,W] tensors since different
+modalities have different channel counts (e.g., SAR:8ch, RGB:3ch, GF:5ch).
+"""
+from numbers import Number
+from typing import Any, Dict, List, Optional, Sequence
+
+import torch
+import torch.nn.functional as F
+from mmengine.model import BaseDataPreprocessor
+
+from mmseg.registry import MODELS
+
+
+@MODELS.register_module()
+class MultiModalDataPreProcessor(BaseDataPreprocessor):
+ """Data pre-processor for multi-modal segmentation.
+
+ Unlike SegDataPreProcessor, this does NOT stack inputs into a single
+ tensor because different modalities may have different channel counts.
+ Instead, it returns inputs as a list of tensors.
+
+ Args:
+ size (tuple, optional): Fixed padding size (H, W).
+ pad_val (float): Padding value for images. Default: 0.
+ seg_pad_val (float): Padding value for segmentation maps. Default: 255.
+ """
+
+ def __init__(
+ self,
+ size: Optional[tuple] = None,
+ pad_val: Number = 0,
+ seg_pad_val: Number = 255,
+ ):
+ super().__init__()
+ self.size = size
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+
+ def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
+ data = self.cast_data(data)
+ inputs = data['inputs']
+ data_samples = data.get('data_samples', None)
+
+ inputs = [_input.float() for _input in inputs]
+
+ # Pad spatial dimensions only (no stacking across channels)
+ padded_inputs = []
+ for i, tensor in enumerate(inputs):
+ if self.size is not None:
+ width = max(self.size[-1] - tensor.shape[-1], 0)
+ height = max(self.size[-2] - tensor.shape[-2], 0)
+ padding_size = (0, width, 0, height)
+ else:
+ padding_size = (0, 0, 0, 0)
+
+ pad_img = F.pad(tensor, padding_size, value=self.pad_val)
+ padded_inputs.append(pad_img)
+
+ if data_samples is not None:
+ ds = data_samples[i]
+ if 'gt_sem_seg' in ds:
+ gt = ds.gt_sem_seg.data
+ del ds.gt_sem_seg.data
+ ds.gt_sem_seg.data = F.pad(
+ gt, padding_size, value=self.seg_pad_val)
+ if hasattr(ds, 'gt_boundary_map') and 'gt_boundary_map' in ds:
+ gt_b = ds.gt_boundary_map.data
+ del ds.gt_boundary_map.data
+ ds.gt_boundary_map.data = F.pad(
+ gt_b, padding_size, value=self.seg_pad_val)
+
+ ds.set_metainfo({
+ 'img_shape': tensor.shape[-2:],
+ 'pad_shape': pad_img.shape[-2:],
+ 'padding_size': padding_size
+ })
+
+ # Return list (NOT stacked tensor) — the model handles the list
+ return dict(inputs=padded_inputs, data_samples=data_samples)
diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py
index 59b012f417..bb253edf27 100644
--- a/mmseg/models/segmentors/__init__.py
+++ b/mmseg/models/segmentors/__init__.py
@@ -5,8 +5,9 @@
from .encoder_decoder import EncoderDecoder
from .multimodal_encoder_decoder import MultimodalEncoderDecoder
from .seg_tta import SegTTAModel
+from .multimodal_encoder_decoder_v2 import MultiModalEncoderDecoderV2
__all__ = [
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel',
- 'MultimodalEncoderDecoder', 'DepthEstimator'
+ 'MultimodalEncoderDecoder', 'DepthEstimator', 'MultiModalEncoderDecoderV2'
]
diff --git a/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py b/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py
new file mode 100644
index 0000000000..d6a67176ef
--- /dev/null
+++ b/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py
@@ -0,0 +1,588 @@
+"""
+Multi-Modal EncoderDecoder - MMSeg 1.x Version
+
+Key changes from 0.x:
+- Registry: SEGMENTORS -> MODELS
+- Base class: 0.x EncoderDecoder -> 1.x EncoderDecoder
+- forward_train(img, img_metas, gt) -> loss(inputs, data_samples)
+- simple_test/forward_test -> predict(inputs, data_samples)
+- img_metas dict -> SegDataSample.metainfo
+- gt_semantic_seg -> data_samples[i].gt_sem_seg.data
+- train_step removed (handled by mmengine Runner)
+- _parse_losses removed (handled by mmengine)
+- DataContainer removed
+"""
+from collections import OrderedDict
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from mmseg.registry import MODELS
+from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
+ OptSampleList, SampleList, add_prefix)
+from .encoder_decoder import EncoderDecoder
+from ..losses.moe_losses import CombinedMoELoss
+
+
+@MODELS.register_module()
+class MultiModalEncoderDecoderV2(EncoderDecoder):
+ """Multi-modal encoder-decoder segmentor with MoE support.
+
+ Supports shared/separate decode heads per modality.
+
+ Args:
+ use_moe: Enable MoE in backbone.
+ use_modal_bias: Enable modal bias in MoE gating.
+ moe_balance_weight: Weight for MoE balance loss.
+ moe_diversity_weight: Weight for MoE diversity loss.
+ multi_tasks_reweight: Reweighting strategy
+ ('equal', 'uncertainty', None).
+ dataset_names: List of dataset/modality names.
+ decoder_mode: 'shared' or 'separate'.
+ """
+
+ def __init__(self,
+ backbone: ConfigType,
+ decode_head: ConfigType,
+ neck: OptConfigType = None,
+ auxiliary_head: OptConfigType = None,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ pretrained: Optional[str] = None,
+ init_cfg: OptMultiConfig = None,
+ # Custom args
+ use_moe: bool = False,
+ use_modal_bias: bool = True,
+ moe_balance_weight: float = 1.0,
+ moe_diversity_weight: float = 0.01,
+ moe_diversity_similarity: str = 'cosine',
+ multi_tasks_reweight: Optional[str] = None,
+ mtl_sigma_init: float = 1.0,
+ dataset_names: list = None,
+ decoder_mode: str = 'separate'):
+ self.dataset_names = dataset_names or ['sar', 'rgb', 'GF']
+ self.decoder_mode = decoder_mode
+ self._use_moe = use_moe
+ self._use_modal_bias = use_modal_bias
+ self._multi_tasks_reweight = multi_tasks_reweight
+ self._mtl_sigma_init = mtl_sigma_init
+
+ if decoder_mode not in ['shared', 'separate']:
+ raise ValueError(
+ f"decoder_mode must be 'shared' or 'separate', "
+ f"got '{decoder_mode}'")
+
+ # Call grandparent __init__ to skip EncoderDecoder's
+ # _init_decode_head / _init_auxiliary_head
+ # We need custom initialization for multi-head support
+ from .base import BaseSegmentor
+ BaseSegmentor.__init__(
+ self,
+ data_preprocessor=data_preprocessor,
+ init_cfg=init_cfg)
+
+ if pretrained is not None:
+ assert backbone.get('pretrained') is None
+ backbone.pretrained = pretrained
+
+ self.backbone = MODELS.build(backbone)
+ if neck is not None:
+ self.neck = MODELS.build(neck)
+
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ # MoE loss
+ if use_moe:
+ self.moe_loss_fn = CombinedMoELoss(
+ balance_weight=moe_balance_weight,
+ diversity_weight=moe_diversity_weight,
+ diversity_similarity=moe_diversity_similarity
+ )
+
+ # Multi-task reweighting
+ self.multi_tasks_reweight = multi_tasks_reweight
+ self.mtl_sigma_eps = 1e-6
+ if multi_tasks_reweight == 'uncertainty':
+ self.mtl_sigma = nn.Parameter(
+ torch.full((len(self.dataset_names),),
+ float(mtl_sigma_init)))
+
+ def _init_decode_head(self, decode_head):
+ if decode_head is None:
+ return
+
+ if self.decoder_mode == 'shared':
+ if isinstance(decode_head, dict) and 'type' not in decode_head:
+ decode_head = decode_head[self.dataset_names[0]]
+ self._shared_decode_head = MODELS.build(decode_head)
+ self.align_corners = self._shared_decode_head.align_corners
+ self.num_classes = self._shared_decode_head.num_classes
+ self.out_channels = self._shared_decode_head.out_channels
+ else:
+ self.decode_heads = nn.ModuleDict()
+ if isinstance(decode_head, dict) and 'type' not in decode_head:
+ for name in self.dataset_names:
+ self.decode_heads[name] = MODELS.build(
+ decode_head[name])
+ else:
+ for name in self.dataset_names:
+ self.decode_heads[name] = MODELS.build(decode_head)
+
+ first_head = self.decode_heads[self.dataset_names[0]]
+ self.align_corners = first_head.align_corners
+ self.num_classes = first_head.num_classes
+ self.out_channels = first_head.out_channels
+
+ def _init_auxiliary_head(self, auxiliary_head):
+ if auxiliary_head is None:
+ return
+
+ if self.decoder_mode == 'shared':
+ if isinstance(auxiliary_head, dict) and 'type' not in auxiliary_head:
+ auxiliary_head = auxiliary_head[self.dataset_names[0]]
+ self._shared_auxiliary_head = MODELS.build(auxiliary_head)
+ else:
+ self.auxiliary_heads = nn.ModuleDict()
+ if isinstance(auxiliary_head, dict) and 'type' not in auxiliary_head:
+ for name in self.dataset_names:
+ self.auxiliary_heads[name] = MODELS.build(
+ auxiliary_head[name])
+ else:
+ for name in self.dataset_names:
+ self.auxiliary_heads[name] = MODELS.build(
+ auxiliary_head)
+
+ @property
+ def decode_head(self):
+ if self.decoder_mode == 'shared':
+ return getattr(self, '_shared_decode_head', None)
+ else:
+ if hasattr(self, 'decode_heads') and len(self.decode_heads) > 0:
+ return self.decode_heads[self.dataset_names[0]]
+ return None
+
+ @property
+ def auxiliary_head(self):
+ if self.decoder_mode == 'shared':
+ return getattr(self, '_shared_auxiliary_head', None)
+ else:
+ if (hasattr(self, 'auxiliary_heads')
+ and len(self.auxiliary_heads) > 0):
+ return self.auxiliary_heads[self.dataset_names[0]]
+ return None
+
+ @property
+ def with_decode_head(self):
+ if self.decoder_mode == 'shared':
+ return hasattr(self, '_shared_decode_head')
+ else:
+ return (hasattr(self, 'decode_heads')
+ and len(self.decode_heads) > 0)
+
+ @property
+ def with_auxiliary_head(self):
+ if self.decoder_mode == 'shared':
+ return hasattr(self, '_shared_auxiliary_head')
+ else:
+ return (hasattr(self, 'auxiliary_heads')
+ and len(self.auxiliary_heads) > 0)
+
+ def loss(self, inputs, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ inputs: Input images (Tensor or List[Tensor]).
+ data_samples: list[SegDataSample] with metainfo and gt_sem_seg.
+
+ Returns:
+ dict[str, Tensor]: Loss components.
+ """
+ # Extract modal types from data_samples
+ modal_types = [
+ ds.metainfo.get('modal_type', 'rgb') for ds in data_samples
+ ]
+
+ # Extract dataset names
+ ds_names = [
+ ds.metainfo.get('dataset_name',
+ ds.metainfo.get('modal_type', 'rgb'))
+ for ds in data_samples
+ ]
+
+ # For multi-modal backbone: pass List[Tensor] + modal_types
+ if isinstance(inputs, torch.Tensor):
+ # Standard tensor - pass directly
+ # But backbone expects List[Tensor] for multi-modal
+ imgs_list = list(inputs)
+ else:
+ imgs_list = inputs
+
+ x, moe_balance_loss, expert_features = self.extract_feat(
+ imgs_list, modal_types=modal_types)
+
+ losses = dict()
+
+ # MoE losses
+ if self._use_moe and moe_balance_loss is not None:
+ if expert_features is None:
+ losses['loss_moe'] = moe_balance_loss
+ else:
+ total_moe_loss, moe_loss_dict = self.moe_loss_fn(
+ moe_balance_loss, expert_features)
+ losses.update(moe_loss_dict)
+
+ modal_losses = {}
+
+ if self.decoder_mode == 'shared':
+ self._loss_shared_mode(
+ x, data_samples, ds_names, losses, modal_losses)
+ else:
+ self._loss_separate_mode(
+ x, data_samples, ds_names, losses, modal_losses)
+
+ # Apply multi-task reweighting
+ if self.multi_tasks_reweight == 'equal' and modal_losses:
+ losses.update(self._apply_equal_reweight(modal_losses))
+ elif self.multi_tasks_reweight == 'uncertainty' and modal_losses:
+ losses.update(
+ self._apply_uncertainty_reweight(modal_losses))
+
+ return losses
+
+ def _loss_shared_mode(self, x, data_samples, ds_names,
+ losses, modal_losses):
+ """Compute loss in shared decode head mode."""
+ if self.multi_tasks_reweight is not None:
+ for dataset_name in sorted(set(ds_names)):
+ mask = [n == dataset_name for n in ds_names]
+ if not any(mask):
+ continue
+
+ # Filter data_samples for this modality
+ modal_ds = [ds for ds, m in zip(data_samples, mask) if m]
+ # Filter features
+ modal_x = tuple(
+ feat[torch.tensor(mask)] for feat in x)
+
+ loss_decode = self._shared_decode_head.loss(
+ modal_x, modal_ds, self.train_cfg)
+ losses.update(add_prefix(
+ loss_decode, f'decode_{dataset_name}'))
+
+ modal_loss = self._sum_loss_dict(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._shared_auxiliary_head.loss(
+ modal_x, modal_ds, self.train_cfg)
+ losses.update(add_prefix(
+ loss_aux, f'aux_{dataset_name}'))
+ modal_loss = modal_loss + self._sum_loss_dict(loss_aux)
+
+ modal_losses[dataset_name] = modal_loss
+ else:
+ loss_decode = self._shared_decode_head.loss(
+ x, data_samples, self.train_cfg)
+ losses.update(add_prefix(loss_decode, 'decode'))
+
+ if self.with_auxiliary_head:
+ loss_aux = self._shared_auxiliary_head.loss(
+ x, data_samples, self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+
+ def _loss_separate_mode(self, x, data_samples, ds_names,
+ losses, modal_losses):
+ """Compute loss in separate decode head mode."""
+ for dataset_name in sorted(set(ds_names)):
+ if dataset_name not in self.decode_heads:
+ continue
+
+ mask = [n == dataset_name for n in ds_names]
+ if not any(mask):
+ continue
+
+ modal_ds = [ds for ds, m in zip(data_samples, mask) if m]
+ modal_x = tuple(
+ feat[torch.tensor(mask)] for feat in x)
+
+ decode_head = self.decode_heads[dataset_name]
+ loss_decode = decode_head.loss(
+ modal_x, modal_ds, self.train_cfg)
+ losses.update(add_prefix(
+ loss_decode, f'decode_{dataset_name}'))
+
+ modal_loss = self._sum_loss_dict(loss_decode)
+
+ if self.with_auxiliary_head:
+ aux_head = self.auxiliary_heads[dataset_name]
+ loss_aux = aux_head.loss(
+ modal_x, modal_ds, self.train_cfg)
+ losses.update(add_prefix(
+ loss_aux, f'aux_{dataset_name}'))
+ modal_loss = modal_loss + self._sum_loss_dict(loss_aux)
+
+ modal_losses[dataset_name] = modal_loss
+
+ def extract_feat(self, inputs, modal_types=None):
+ """Extract features, supporting multi-modal input.
+
+ Args:
+ inputs: List[Tensor] or Tensor
+ modal_types: List[str] or None
+
+ Returns:
+ x, moe_balance_loss, expert_features
+ """
+ # Check if backbone supports modal_types
+ if modal_types is not None and hasattr(self.backbone, 'forward'):
+ import inspect
+ sig = inspect.signature(self.backbone.forward)
+ if 'modal_types' in sig.parameters:
+ backbone_output = self.backbone(
+ inputs, modal_types=modal_types)
+ else:
+ backbone_output = self.backbone(inputs)
+ else:
+ backbone_output = self.backbone(inputs)
+
+ moe_balance_loss = None
+ expert_features = None
+
+ if isinstance(backbone_output, tuple):
+ if len(backbone_output) == 3:
+ x, moe_balance_loss, expert_features = backbone_output
+ elif len(backbone_output) == 2:
+ if (isinstance(backbone_output[1], torch.Tensor)
+ and backbone_output[1].dim() == 0):
+ x, moe_balance_loss = backbone_output
+ else:
+ x = backbone_output
+ else:
+ x = backbone_output
+ else:
+ x = backbone_output
+ if not isinstance(x, tuple):
+ x = (x,)
+
+ if self.with_neck:
+ x = self.neck(x)
+
+ return x, moe_balance_loss, expert_features
+
+ def encode_decode(self, inputs, batch_img_metas: List[dict]):
+ """Encode images and decode into segmentation map."""
+ modal_types = [
+ meta.get('modal_type', 'rgb') for meta in batch_img_metas
+ ]
+
+ if isinstance(inputs, torch.Tensor):
+ imgs_list = list(inputs)
+ else:
+ imgs_list = inputs
+
+ x, _, _ = self.extract_feat(imgs_list, modal_types=modal_types)
+
+ # Select appropriate decode head
+ if self.decoder_mode == 'shared':
+ decode_head = self._shared_decode_head
+ else:
+ dataset_name = self._get_dataset_name_from_metas(
+ batch_img_metas)
+ if dataset_name in self.decode_heads:
+ decode_head = self.decode_heads[dataset_name]
+ else:
+ decode_head = self.decode_head
+
+ seg_logits = decode_head.predict(
+ x, batch_img_metas, self.test_cfg)
+
+ return seg_logits
+
+ def slide_inference(self, inputs, batch_img_metas: List[dict]):
+ """Inference by sliding-window with overlap.
+
+ Overrides base class to handle list inputs from multimodal pipeline.
+ When inputs have different channel sizes (e.g. mixed modalities),
+ each sample is processed individually to avoid stack errors.
+ """
+ # Handle list inputs - check if channels are uniform
+ if isinstance(inputs, (list, tuple)):
+ channel_sizes = [t.shape[0] for t in inputs]
+ if len(set(channel_sizes)) > 1:
+ # Mixed channels: process each sample individually
+ return self._slide_inference_per_sample(
+ inputs, batch_img_metas)
+ else:
+ inputs = torch.stack(inputs, dim=0)
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = inputs.size()
+ out_channels = self.out_channels
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = inputs[:, :, y1:y2, x1:x2]
+ batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
+ crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ seg_logits = preds / count_mat
+ return seg_logits
+
+ def _slide_inference_per_sample(self, inputs, batch_img_metas):
+ """Slide inference processing each sample individually.
+
+ Used when batch contains mixed modalities with different channels.
+ """
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ out_channels = self.out_channels
+
+ seg_logits_list = []
+ for i, (img, img_meta) in enumerate(zip(inputs, batch_img_metas)):
+ # img shape: (C, H, W) -> (1, C, H, W)
+ img = img.unsqueeze(0)
+ _, _, h_img, w_img = img.size()
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((1, out_channels, h_img, w_img))
+ count_mat = img.new_zeros((1, 1, h_img, w_img))
+ meta_list = [img_meta]
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ meta_list[0]['img_shape'] = crop_img.shape[2:]
+ crop_seg_logit = self.encode_decode(
+ crop_img, meta_list)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2),
+ int(y1), int(preds.shape[2] - y2)))
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ seg_logits_list.append(preds / count_mat)
+
+ return torch.cat(seg_logits_list, dim=0)
+
+ def predict(self, inputs, data_samples: OptSampleList = None):
+ """Predict results from inputs and data samples."""
+ if data_samples is not None:
+ batch_img_metas = [
+ data_sample.metainfo for data_sample in data_samples
+ ]
+ else:
+ batch_img_metas = [
+ dict(
+ ori_shape=inputs.shape[2:],
+ img_shape=inputs.shape[2:],
+ pad_shape=inputs.shape[2:],
+ padding_size=[0, 0, 0, 0])
+ ] * inputs.shape[0]
+
+ seg_logits = self.inference(inputs, batch_img_metas)
+
+ return self.postprocess_result(seg_logits, data_samples)
+
+ def _forward(self, inputs, data_samples: OptSampleList = None):
+ """Network forward process (tensor mode)."""
+ modal_types = None
+ if data_samples is not None:
+ modal_types = [
+ ds.metainfo.get('modal_type', 'rgb')
+ for ds in data_samples
+ ]
+
+ if isinstance(inputs, torch.Tensor):
+ imgs_list = list(inputs)
+ else:
+ imgs_list = inputs
+
+ x, _, _ = self.extract_feat(imgs_list, modal_types=modal_types)
+ return self.decode_head.forward(x)
+
+ def _get_dataset_name_from_metas(self, batch_img_metas):
+ if batch_img_metas and len(batch_img_metas) > 0:
+ first_meta = batch_img_metas[0]
+ if 'dataset_name' in first_meta:
+ return first_meta['dataset_name']
+ elif 'modal_type' in first_meta:
+ return first_meta['modal_type']
+ return self.dataset_names[0]
+
+ @staticmethod
+ def _sum_loss_dict(loss_dict):
+ total_loss = None
+ for loss_name, loss_value in loss_dict.items():
+ if 'loss' not in loss_name:
+ continue
+ if isinstance(loss_value, torch.Tensor):
+ current = loss_value
+ elif isinstance(loss_value, list):
+ current = sum(loss for loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+ total_loss = (current if total_loss is None
+ else total_loss + current)
+ if total_loss is None:
+ total_loss = torch.tensor(0.0)
+ return total_loss
+
+ def _apply_equal_reweight(self, modal_losses):
+ if not modal_losses:
+ return {}
+ n_modals = len(modal_losses)
+ total = sum(modal_losses.values()) / n_modals
+ reweight_losses = {
+ 'equal_reweighted_total_loss': total,
+ }
+ for name in modal_losses:
+ reweight_losses[f'equal_weight_{name}'] = torch.tensor(
+ 1.0 / n_modals, device=total.device)
+ return reweight_losses
+
+ def _apply_uncertainty_reweight(self, modal_losses):
+ loss_sum = None
+ reweight_losses = {}
+
+ for idx, dataset_name in enumerate(self.dataset_names):
+ if dataset_name not in modal_losses:
+ continue
+ loss = modal_losses[dataset_name]
+ sigma_sq = self.mtl_sigma[idx] ** 2 + self.mtl_sigma_eps
+ weighted = 0.5 / sigma_sq * loss + torch.log1p(sigma_sq)
+ loss_sum = (weighted if loss_sum is None
+ else loss_sum + weighted)
+ reweight_losses[f'mtl_sigma_{dataset_name}'] = (
+ sigma_sq.detach())
+
+ if loss_sum is None:
+ return {}
+
+ reweight_losses['reweighted_total_losses'] = loss_sum
+ return reweight_losses
diff --git a/scripts/collect_results_to_excel.py b/scripts/collect_results_to_excel.py
new file mode 100644
index 0000000000..aae23d666f
--- /dev/null
+++ b/scripts/collect_results_to_excel.py
@@ -0,0 +1,536 @@
+"""
+Collect experiment results from test logs and export to Excel.
+
+Parses metrics_log.txt / test_log.txt from all Table 2/3/4 experiments
+and the Full Model, then generates one Excel file per table with
+per-modality metrics (aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, mRecall)
+plus per-class breakdowns.
+
+Usage:
+ python scripts/collect_results_to_excel.py
+ python scripts/collect_results_to_excel.py --work-root work_dirs/paper_experiments
+ python scripts/collect_results_to_excel.py --tables table2 table3 table4
+
+Output:
+ work_dirs/paper_experiments/Table2_results.xlsx
+ work_dirs/paper_experiments/Table3_results.xlsx
+ work_dirs/paper_experiments/Table4_results.xlsx
+"""
+
+import argparse
+import os
+import re
+import glob
+from collections import OrderedDict
+
+try:
+ import openpyxl
+ from openpyxl.styles import Font, Alignment, PatternFill, Border, Side
+ from openpyxl.utils import get_column_letter
+ HAS_OPENPYXL = True
+except ImportError:
+ HAS_OPENPYXL = False
+
+try:
+ import pandas as pd
+ HAS_PANDAS = True
+except ImportError:
+ HAS_PANDAS = False
+
+
+# ============================================================================
+# Log Parsing
+# ============================================================================
+
+def parse_summary_metrics(log_text):
+ """Parse summary metrics from mmengine log output.
+
+ Looks for lines like:
+ 04/02 02:10:00 - mmengine - INFO - Iter(test) ... aAcc: 82.45 mIoU: 65.32 ...
+ or the final dict output.
+ """
+ metrics = {}
+
+ # Pattern 1: mmengine log line with metrics
+ # e.g., "aAcc: 82.45 mIoU: 65.32 mAcc: 72.18 mDice: 75.00 mFscore: 73.50 mPrecision: 71.20 mRecall: 75.80"
+ metric_pattern = r'(aAcc|mIoU|mAcc|mDice|mFscore|mPrecision|mRecall)\s*:\s*([\d.]+)'
+ matches = re.findall(metric_pattern, log_text)
+ for key, val in matches:
+ # Take the last occurrence (in case of multiple test runs)
+ metrics[key] = float(val)
+
+ return metrics
+
+
+def parse_per_class_table(log_text):
+ """Parse per-class PrettyTable from log.
+
+ Looks for:
+ per class results:
+ +-------+-------+-------+...
+ | Class | IoU | Acc |...
+ +-------+-------+-------+...
+ | cls1 | 45.23 | 67.89 |...
+ | cls2 | 52.11 | 71.34 |...
+ +-------+-------+-------+...
+ """
+ per_class = {}
+
+ # Find the table after "per class results:"
+ marker = 'per class results:'
+ idx = log_text.rfind(marker) # Use last occurrence
+ if idx == -1:
+ return per_class
+
+ table_text = log_text[idx:]
+ lines = table_text.split('\n')
+
+ header = None
+ for line in lines:
+ line = line.strip()
+ if not line or line.startswith('+'):
+ continue
+ if line.startswith('|'):
+ cells = [c.strip() for c in line.split('|') if c.strip()]
+ if header is None:
+ header = cells
+ else:
+ if len(cells) == len(header):
+ class_name = cells[0]
+ per_class[class_name] = {}
+ for i, h in enumerate(header[1:], 1):
+ try:
+ per_class[class_name][h] = float(cells[i])
+ except ValueError:
+ per_class[class_name][h] = cells[i]
+
+ return per_class
+
+
+def parse_log_file(log_path):
+ """Parse a single log file and return metrics dict."""
+ if not os.path.exists(log_path):
+ return None, None
+
+ with open(log_path, 'r') as f:
+ text = f.read()
+
+ summary = parse_summary_metrics(text)
+ per_class = parse_per_class_table(text)
+
+ return summary, per_class
+
+
+def find_log_file(work_dir, modal, prefer_metrics=True):
+ """Find the log file for a given experiment and modality."""
+ candidates = []
+ if prefer_metrics:
+ candidates.append(os.path.join(work_dir, f'metrics_{modal}', 'metrics_log.txt'))
+ candidates.append(os.path.join(work_dir, f'test_{modal}', 'test_log.txt'))
+ # Also try metrics log as fallback
+ candidates.append(os.path.join(work_dir, f'metrics_{modal}', 'metrics_log.txt'))
+
+ for path in candidates:
+ if os.path.exists(path):
+ return path
+
+ # Last resort: search for any log in the directory
+ for subdir_pattern in [f'metrics_{modal}', f'test_{modal}']:
+ subdir = os.path.join(work_dir, subdir_pattern)
+ if os.path.isdir(subdir):
+ # Try to find mmengine log files
+ log_files = glob.glob(os.path.join(subdir, '*.log'))
+ if log_files:
+ return sorted(log_files)[-1]
+
+ return None
+
+
+# ============================================================================
+# Experiment Definitions
+# ============================================================================
+
+TABLE2_EXPERIMENTS = OrderedDict([
+ ('Full Model', {
+ 'work_dir': 'work_dirs/floodnet/SwinmoeB/655',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'Swin-B + MoE (E=8, K=3)',
+ }),
+ ('w/o MoE', {
+ 'work_dir': 'work_dirs/paper_experiments/table2/no_moe',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'Remove MoE, use standard FFN',
+ }),
+ ('w/o ModalSpecificStem', {
+ 'work_dir': 'work_dirs/paper_experiments/table2/no_modal_specific_stem',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'Shared patch embedding',
+ }),
+ ('w/o Modal Bias', {
+ 'work_dir': 'work_dirs/paper_experiments/table2/no_modal_bias',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'No modal bias in gating',
+ }),
+ ('w/o Shared Experts', {
+ 'work_dir': 'work_dirs/paper_experiments/table2/no_shared_experts',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'No shared experts',
+ }),
+ ('w/o Separate Decoder', {
+ 'work_dir': 'work_dirs/paper_experiments/table2/shared_decoder',
+ 'modals': ['sar', 'rgb', 'GF'],
+ 'desc': 'Shared decoder head',
+ }),
+])
+
+TABLE3_EXPERIMENTS = OrderedDict([
+ ('E=6, K=1', {
+ 'work_dir': 'work_dirs/paper_experiments/table3/e6_k1',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+ ('E=6, K=2', {
+ 'work_dir': 'work_dirs/paper_experiments/table3/e6_k2',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+ ('E=6, K=3', {
+ 'work_dir': 'work_dirs/paper_experiments/table3/e6_k3',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+ ('E=8, K=1', {
+ 'work_dir': 'work_dirs/paper_experiments/table3/e8_k1',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+ ('E=8, K=2', {
+ 'work_dir': 'work_dirs/paper_experiments/table3/e8_k2',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+ ('E=8, K=3 (Full)', {
+ 'work_dir': 'work_dirs/floodnet/SwinmoeB/655',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+])
+
+TABLE4_EXPERIMENTS = OrderedDict([
+ ('SAR-only', {
+ 'work_dir': 'work_dirs/paper_experiments/table4/sar_only',
+ 'modals': ['sar'],
+ }),
+ ('RGB-only', {
+ 'work_dir': 'work_dirs/paper_experiments/table4/rgb_only',
+ 'modals': ['rgb'],
+ }),
+ ('GF-only', {
+ 'work_dir': 'work_dirs/paper_experiments/table4/gf_only',
+ 'modals': ['GF'],
+ }),
+ ('Multi-Modal (Full)', {
+ 'work_dir': 'work_dirs/floodnet/SwinmoeB/655',
+ 'modals': ['sar', 'rgb', 'GF'],
+ }),
+])
+
+SUMMARY_METRICS = ['aAcc', 'mIoU', 'mAcc', 'mDice', 'mFscore', 'mPrecision', 'mRecall']
+MODAL_DISPLAY = {'sar': 'SAR', 'rgb': 'RGB', 'GF': 'GaoFen'}
+
+
+# ============================================================================
+# Excel Generation
+# ============================================================================
+
+def collect_table_data(experiments, work_root=''):
+ """Collect all metrics for a table's experiments."""
+ results = OrderedDict()
+
+ for exp_name, exp_info in experiments.items():
+ work_dir = exp_info['work_dir']
+ if work_root and not os.path.isabs(work_dir):
+ # Don't prepend work_root if work_dir is already an absolute-like path
+ # that doesn't start with work_root
+ if not work_dir.startswith('work_dirs/paper_experiments'):
+ pass # Keep original path (e.g., Full Model)
+ # work_dir stays as is
+
+ results[exp_name] = {}
+ for modal in exp_info['modals']:
+ log_path = find_log_file(work_dir, modal)
+ if log_path:
+ summary, per_class = parse_log_file(log_path)
+ results[exp_name][modal] = {
+ 'summary': summary or {},
+ 'per_class': per_class or {},
+ 'log_path': log_path,
+ }
+ found_metrics = list((summary or {}).keys())
+ print(f" [OK] {exp_name} / {MODAL_DISPLAY.get(modal, modal)}: "
+ f"{log_path} ({len(found_metrics)} metrics)")
+ else:
+ results[exp_name][modal] = {
+ 'summary': {},
+ 'per_class': {},
+ 'log_path': None,
+ }
+ print(f" [--] {exp_name} / {MODAL_DISPLAY.get(modal, modal)}: "
+ f"no log found in {work_dir}")
+
+ return results
+
+
+def write_excel(results, experiments, output_path, table_name):
+ """Write results to Excel with openpyxl for better formatting."""
+ wb = openpyxl.Workbook()
+
+ # ---- Style definitions ----
+ header_font = Font(bold=True, size=11)
+ header_fill = PatternFill(start_color='4472C4', end_color='4472C4', fill_type='solid')
+ header_font_white = Font(bold=True, size=11, color='FFFFFF')
+ modal_fill = {
+ 'sar': PatternFill(start_color='E2EFDA', end_color='E2EFDA', fill_type='solid'),
+ 'rgb': PatternFill(start_color='D6E4F0', end_color='D6E4F0', fill_type='solid'),
+ 'GF': PatternFill(start_color='FCE4D6', end_color='FCE4D6', fill_type='solid'),
+ }
+ exp_fill = PatternFill(start_color='F2F2F2', end_color='F2F2F2', fill_type='solid')
+ best_font = Font(bold=True, color='C00000')
+ thin_border = Border(
+ left=Side(style='thin'), right=Side(style='thin'),
+ top=Side(style='thin'), bottom=Side(style='thin'))
+ center_align = Alignment(horizontal='center', vertical='center')
+
+ # ==================== Sheet 1: Summary Table ====================
+ ws = wb.active
+ ws.title = 'Summary'
+
+ # Determine all modals used
+ all_modals = []
+ for exp_info in experiments.values():
+ for m in exp_info['modals']:
+ if m not in all_modals:
+ all_modals.append(m)
+
+ # Header row 1: Experiment | SAR (spanning) | RGB (spanning) | GF (spanning)
+ row = 1
+ ws.cell(row=row, column=1, value='Experiment').font = header_font_white
+ ws.cell(row=row, column=1).fill = header_fill
+ ws.cell(row=row, column=1).border = thin_border
+
+ col = 2
+ for modal in all_modals:
+ modal_name = MODAL_DISPLAY.get(modal, modal)
+ start_col = col
+ for metric in SUMMARY_METRICS:
+ ws.cell(row=row + 1, column=col, value=metric).font = header_font
+ ws.cell(row=row + 1, column=col).fill = modal_fill.get(modal, exp_fill)
+ ws.cell(row=row + 1, column=col).border = thin_border
+ ws.cell(row=row + 1, column=col).alignment = center_align
+ col += 1
+ end_col = col - 1
+ # Merge header for modal name
+ ws.merge_cells(start_row=row, start_column=start_col,
+ end_row=row, end_column=end_col)
+ cell = ws.cell(row=row, column=start_col, value=modal_name)
+ cell.font = header_font_white
+ cell.fill = header_fill
+ cell.alignment = center_align
+ cell.border = thin_border
+
+ ws.merge_cells(start_row=row, start_column=1, end_row=row + 1, end_column=1)
+
+ # Data rows
+ data_row = row + 2
+ # Track best values per (modal, metric) for highlighting
+ metric_values = {(m, met): [] for m in all_modals for met in SUMMARY_METRICS}
+
+ for exp_name in results:
+ for modal in all_modals:
+ if modal in results[exp_name]:
+ summary = results[exp_name][modal]['summary']
+ for met in SUMMARY_METRICS:
+ val = summary.get(met)
+ if val is not None:
+ metric_values[(modal, met)].append((exp_name, val))
+
+ for exp_name, exp_data in results.items():
+ ws.cell(row=data_row, column=1, value=exp_name).font = Font(bold=True)
+ ws.cell(row=data_row, column=1).border = thin_border
+
+ col = 2
+ for modal in all_modals:
+ if modal in exp_data and exp_data[modal]['summary']:
+ summary = exp_data[modal]['summary']
+ for metric in SUMMARY_METRICS:
+ val = summary.get(metric)
+ cell = ws.cell(row=data_row, column=col)
+ if val is not None:
+ cell.value = val
+ cell.number_format = '0.00'
+ # Bold best value
+ best_vals = metric_values.get((modal, metric), [])
+ if best_vals:
+ best_val = max(v for _, v in best_vals)
+ if val == best_val and len(best_vals) > 1:
+ cell.font = best_font
+ else:
+ cell.value = '-'
+ cell.border = thin_border
+ cell.alignment = center_align
+ col += 1
+ else:
+ for _ in SUMMARY_METRICS:
+ cell = ws.cell(row=data_row, column=col, value='-')
+ cell.border = thin_border
+ cell.alignment = center_align
+ col += 1
+
+ data_row += 1
+
+ # Auto-width
+ for col_idx in range(1, col + 1):
+ ws.column_dimensions[get_column_letter(col_idx)].width = 13
+ ws.column_dimensions['A'].width = 25
+
+ # ==================== Sheet 2: Per-Class Details ====================
+ ws2 = wb.create_sheet('Per-Class Details')
+ row = 1
+
+ for exp_name, exp_data in results.items():
+ for modal in all_modals:
+ if modal not in exp_data or not exp_data[modal]['per_class']:
+ continue
+
+ per_class = exp_data[modal]['per_class']
+ modal_name = MODAL_DISPLAY.get(modal, modal)
+
+ # Section header
+ ws2.cell(row=row, column=1,
+ value=f'{exp_name} - {modal_name}').font = Font(bold=True, size=12)
+ ws2.cell(row=row, column=1).fill = modal_fill.get(modal, exp_fill)
+ row += 1
+
+ # Get metric columns from first class
+ first_class = list(per_class.keys())[0]
+ metric_cols = list(per_class[first_class].keys())
+
+ # Header
+ ws2.cell(row=row, column=1, value='Class').font = header_font
+ ws2.cell(row=row, column=1).border = thin_border
+ for j, met in enumerate(metric_cols, 2):
+ ws2.cell(row=row, column=j, value=met).font = header_font
+ ws2.cell(row=row, column=j).border = thin_border
+ ws2.cell(row=row, column=j).alignment = center_align
+ row += 1
+
+ # Data
+ for cls_name, cls_metrics in per_class.items():
+ ws2.cell(row=row, column=1, value=cls_name).border = thin_border
+ for j, met in enumerate(metric_cols, 2):
+ cell = ws2.cell(row=row, column=j)
+ cell.value = cls_metrics.get(met, '-')
+ cell.border = thin_border
+ cell.alignment = center_align
+ if isinstance(cell.value, float):
+ cell.number_format = '0.00'
+ row += 1
+
+ row += 1 # Blank row between sections
+
+ # Auto-width sheet 2
+ ws2.column_dimensions['A'].width = 20
+ for col_idx in range(2, 15):
+ ws2.column_dimensions[get_column_letter(col_idx)].width = 13
+
+ # Save
+ wb.save(output_path)
+ print(f'\n => Saved: {output_path}')
+
+
+def write_csv_fallback(results, experiments, output_path, table_name):
+ """Fallback: write CSV if openpyxl not available."""
+ import csv
+
+ all_modals = []
+ for exp_info in experiments.values():
+ for m in exp_info['modals']:
+ if m not in all_modals:
+ all_modals.append(m)
+
+ csv_path = output_path.replace('.xlsx', '.csv')
+
+ with open(csv_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+
+ # Header
+ header = ['Experiment']
+ for modal in all_modals:
+ modal_name = MODAL_DISPLAY.get(modal, modal)
+ for metric in SUMMARY_METRICS:
+ header.append(f'{modal_name}_{metric}')
+ writer.writerow(header)
+
+ # Data
+ for exp_name, exp_data in results.items():
+ row = [exp_name]
+ for modal in all_modals:
+ if modal in exp_data and exp_data[modal]['summary']:
+ summary = exp_data[modal]['summary']
+ for metric in SUMMARY_METRICS:
+ val = summary.get(metric)
+ row.append(f'{val:.2f}' if val is not None else '-')
+ else:
+ row.extend(['-'] * len(SUMMARY_METRICS))
+ writer.writerow(row)
+
+ print(f'\n => Saved (CSV fallback): {csv_path}')
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Collect experiment results into Excel tables')
+ parser.add_argument('--work-root', default='.',
+ help='Root directory (default: current dir)')
+ parser.add_argument('--output-dir', default=None,
+ help='Output directory (default: work_dirs/paper_experiments/)')
+ parser.add_argument('--tables', nargs='+',
+ default=['table2', 'table3', 'table4'],
+ help='Which tables to collect (default: all)')
+ args = parser.parse_args()
+
+ os.chdir(args.work_root)
+
+ output_dir = args.output_dir or 'work_dirs/paper_experiments'
+ os.makedirs(output_dir, exist_ok=True)
+
+ write_fn = write_excel if HAS_OPENPYXL else write_csv_fallback
+ if not HAS_OPENPYXL:
+ print('[WARN] openpyxl not installed. Will output CSV instead of Excel.')
+ print(' Install with: pip install openpyxl')
+
+ table_configs = {
+ 'table2': ('Table2_ComponentAblation', TABLE2_EXPERIMENTS),
+ 'table3': ('Table3_MoE_Hyperparameters', TABLE3_EXPERIMENTS),
+ 'table4': ('Table4_SingleModal_vs_MultiModal', TABLE4_EXPERIMENTS),
+ }
+
+ for table_key in args.tables:
+ if table_key not in table_configs:
+ print(f'[WARN] Unknown table: {table_key}, skipping')
+ continue
+
+ table_name, experiments = table_configs[table_key]
+ print(f'\n{"="*60}')
+ print(f'Collecting {table_name}')
+ print(f'{"="*60}')
+
+ results = collect_table_data(experiments)
+
+ ext = '.xlsx' if HAS_OPENPYXL else '.csv'
+ output_path = os.path.join(output_dir, f'{table_name}{ext}')
+ write_fn(results, experiments, output_path, table_name)
+
+ print(f'\nDone! Results saved to {output_dir}/')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/run_all_experiments.sh b/scripts/run_all_experiments.sh
new file mode 100755
index 0000000000..70aba534f5
--- /dev/null
+++ b/scripts/run_all_experiments.sh
@@ -0,0 +1,321 @@
+#!/bin/bash
+# ============================================================================
+# FloodNet Paper Experiments - Complete Run Script
+# Multi-Modal Flood Segmentation with Swin-Base + MoE
+# ============================================================================
+#
+# NOTE: Full Model (Swin-B + MoE, E=8 K=3) is already trained and tested.
+# This script only runs ablation/hyperparameter/single-modal experiments.
+#
+# Usage:
+# bash scripts/run_all_experiments.sh table2 # Component Ablation
+# bash scripts/run_all_experiments.sh table3 # MoE Hyperparameter Study
+# bash scripts/run_all_experiments.sh table4 # Single-Modal vs Multi-Modal
+# bash scripts/run_all_experiments.sh all # Run all tables
+#
+# Environment Variables:
+# GPU_IDS GPU device IDs (default: 0)
+#
+# Example:
+# GPU_IDS=0 bash scripts/run_all_experiments.sh table2
+#
+# Each experiment uses --seed 42 for reproducibility.
+# After training, each experiment runs per-modality testing automatically.
+# ============================================================================
+
+set -e
+
+SEED=42
+GPU_IDS=${GPU_IDS:-0}
+CONFIG_DIR="configs/floodnet"
+ABLATION_DIR="${CONFIG_DIR}/ablations"
+WORK_ROOT="work_dirs/paper_experiments"
+RESULTS_LOG="${WORK_ROOT}/results_summary.txt"
+
+GROUP=${1:-"all"}
+
+mkdir -p "${WORK_ROOT}"
+
+# ============================================================================
+# Helper Functions
+# ============================================================================
+
+run_train() {
+ local config=$1
+ local work_dir=$2
+ local desc=$3
+
+ echo "============================================================"
+ echo "[TRAIN] ${desc}"
+ echo "[CFG] ${config}"
+ echo "[DIR] ${work_dir}"
+ echo "============================================================"
+
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/train.py \
+ ${config} \
+ --work-dir ${work_dir} \
+ --cfg-options randomness.seed=${SEED}
+
+ echo "[TRAIN DONE] ${desc}"
+ echo ""
+}
+
+find_best_ckpt() {
+ local work_dir=$1
+ local best_ckpt=$(ls ${work_dir}/best_mIoU_*.pth 2>/dev/null | head -1)
+ if [ -z "$best_ckpt" ]; then
+ best_ckpt=$(ls ${work_dir}/epoch_*.pth 2>/dev/null | sort -V | tail -1)
+ fi
+ echo "$best_ckpt"
+}
+
+run_test_modal() {
+ local config=$1
+ local checkpoint=$2
+ local work_dir=$3
+ local modal=$4
+ local desc=$5
+
+ local test_work_dir="${work_dir}/test_${modal}"
+ mkdir -p "${test_work_dir}"
+
+ echo "------------------------------------------------------------"
+ echo "[TEST] ${desc} | Modality: ${modal}"
+ echo "[CKPT] ${checkpoint}"
+ echo "------------------------------------------------------------"
+
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \
+ ${config} \
+ ${checkpoint} \
+ --work-dir ${test_work_dir} \
+ --cfg-options \
+ test_dataloader.dataset.filter_modality="${modal}" \
+ 2>&1 | tee "${test_work_dir}/test_log.txt"
+
+ echo "[${desc}] modal=${modal} -> see ${test_work_dir}/test_log.txt" >> "${RESULTS_LOG}"
+ echo ""
+}
+
+run_test_all_modals() {
+ local config=$1
+ local checkpoint=$2
+ local work_dir=$3
+ local desc=$4
+
+ echo "============================================================"
+ echo "[TEST ALL MODALS] ${desc}"
+ echo "============================================================"
+
+ run_test_modal "${config}" "${checkpoint}" "${work_dir}" "sar" "${desc}"
+ run_test_modal "${config}" "${checkpoint}" "${work_dir}" "rgb" "${desc}"
+ run_test_modal "${config}" "${checkpoint}" "${work_dir}" "GF" "${desc}"
+}
+
+train_and_test_all_modals() {
+ local config=$1
+ local work_dir=$2
+ local desc=$3
+
+ run_train "${config}" "${work_dir}" "${desc}"
+
+ local ckpt=$(find_best_ckpt "${work_dir}")
+ if [ -z "$ckpt" ]; then
+ echo "[ERROR] No checkpoint found in ${work_dir} after training"
+ return 1
+ fi
+
+ run_test_all_modals "${config}" "${ckpt}" "${work_dir}" "${desc}"
+}
+
+train_and_test_single_modal() {
+ local config=$1
+ local work_dir=$2
+ local modal=$3
+ local desc=$4
+
+ run_train "${config}" "${work_dir}" "${desc}"
+
+ local ckpt=$(find_best_ckpt "${work_dir}")
+ if [ -z "$ckpt" ]; then
+ echo "[ERROR] No checkpoint found in ${work_dir} after training"
+ return 1
+ fi
+
+ run_test_modal "${config}" "${ckpt}" "${work_dir}" "${modal}" "${desc}"
+}
+
+# ============================================================================
+# TABLE 2: Component Ablation Study
+# ============================================================================
+# Full Model (a) is already trained and tested — NOT included here.
+# Only the ablation variants (b)-(f) are trained and tested.
+#
+# (b) w/o MoE — train + test SAR/RGB/GF
+# (c) w/o ModalSpecificStem— train + test SAR/RGB/GF
+# (d) w/o Modal Bias — train + test SAR/RGB/GF
+# (e) w/o Shared Experts — train + test SAR/RGB/GF
+# (f) w/o Separate Decoder — train + test SAR/RGB/GF
+# ============================================================================
+if [[ "$GROUP" == "all" || "$GROUP" == "table2" ]]; then
+ echo ""
+ echo "################################################################"
+ echo "# TABLE 2: Component Ablation Study #"
+ echo "# (Full Model result is already available) #"
+ echo "################################################################"
+ echo ""
+
+ # (b) w/o MoE
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_no_moe.py" \
+ "${WORK_ROOT}/table2/no_moe" \
+ "Table2(b) w/o MoE"
+
+ # (c) w/o ModalSpecificStem
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_no_modal_specific_stem.py" \
+ "${WORK_ROOT}/table2/no_modal_specific_stem" \
+ "Table2(c) w/o ModalSpecificStem"
+
+ # (d) w/o Modal Bias
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_no_modal_bias.py" \
+ "${WORK_ROOT}/table2/no_modal_bias" \
+ "Table2(d) w/o Modal Bias"
+
+ # (e) w/o Shared Experts
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_no_shared_experts.py" \
+ "${WORK_ROOT}/table2/no_shared_experts" \
+ "Table2(e) w/o Shared Experts"
+
+ # (f) w/o Separate Decoder (use shared decoder)
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_shared_decoder.py" \
+ "${WORK_ROOT}/table2/shared_decoder" \
+ "Table2(f) w/o Separate Decoder"
+
+ echo ""
+ echo "[TABLE 2 COMPLETE] Results in ${WORK_ROOT}/table2/"
+ echo ""
+fi
+
+# ============================================================================
+# TABLE 3: MoE Hyperparameter Study
+# ============================================================================
+# Grid: num_experts={6, 8} x top_k={1, 2, 3}
+# (8, 3) = Full Model — already done, NOT included here.
+#
+# Each variant tested on SAR / RGB / GF separately.
+# ============================================================================
+if [[ "$GROUP" == "all" || "$GROUP" == "table3" ]]; then
+ echo ""
+ echo "################################################################"
+ echo "# TABLE 3: MoE Hyperparameter Study #"
+ echo "# (E=8 K=3 Full Model result is already available) #"
+ echo "################################################################"
+ echo ""
+
+ # E=6, K=1
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_e6_k1.py" \
+ "${WORK_ROOT}/table3/e6_k1" \
+ "Table3 E=6 K=1"
+
+ # E=6, K=2
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_e6_k2.py" \
+ "${WORK_ROOT}/table3/e6_k2" \
+ "Table3 E=6 K=2"
+
+ # E=6, K=3
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_e6_k3.py" \
+ "${WORK_ROOT}/table3/e6_k3" \
+ "Table3 E=6 K=3"
+
+ # E=8, K=1
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_e8_k1.py" \
+ "${WORK_ROOT}/table3/e8_k1" \
+ "Table3 E=8 K=1"
+
+ # E=8, K=2
+ train_and_test_all_modals \
+ "${ABLATION_DIR}/ablation_e8_k2.py" \
+ "${WORK_ROOT}/table3/e8_k2" \
+ "Table3 E=8 K=2"
+
+ # E=8, K=3 = Full Model — SKIP (already trained and tested)
+ echo "[SKIP] E=8 K=3 = Full Model (already trained and tested)"
+
+ echo ""
+ echo "[TABLE 3 COMPLETE] Results in ${WORK_ROOT}/table3/"
+ echo ""
+fi
+
+# ============================================================================
+# TABLE 4: Single-Modal vs Multi-Modal Training
+# ============================================================================
+# SAR-only → train + test SAR
+# RGB-only → train + test RGB
+# GF-only → train + test GF
+# Multi-modal (Full Model) → SKIP (already trained and tested)
+# ============================================================================
+if [[ "$GROUP" == "all" || "$GROUP" == "table4" ]]; then
+ echo ""
+ echo "################################################################"
+ echo "# TABLE 4: Single-Modal vs Multi-Modal #"
+ echo "# (Multi-modal Full Model result is already available) #"
+ echo "################################################################"
+ echo ""
+
+ # SAR-only → test SAR
+ train_and_test_single_modal \
+ "${CONFIG_DIR}/multimodal_floodnet_sar_only_swinbase_moe_config.py" \
+ "${WORK_ROOT}/table4/sar_only" \
+ "sar" \
+ "Table4 SAR-only"
+
+ # RGB-only → test RGB
+ train_and_test_single_modal \
+ "${ABLATION_DIR}/ablation_rgb_only.py" \
+ "${WORK_ROOT}/table4/rgb_only" \
+ "rgb" \
+ "Table4 RGB-only"
+
+ # GF-only → test GF
+ train_and_test_single_modal \
+ "${ABLATION_DIR}/ablation_gf_only.py" \
+ "${WORK_ROOT}/table4/gf_only" \
+ "GF" \
+ "Table4 GF-only"
+
+ # Multi-modal = Full Model — SKIP
+ echo "[SKIP] Multi-modal = Full Model (already trained and tested)"
+
+ echo ""
+ echo "[TABLE 4 COMPLETE] Results in ${WORK_ROOT}/table4/"
+ echo ""
+fi
+
+# ============================================================================
+# Summary
+# ============================================================================
+echo ""
+echo "============================================================"
+echo "EXPERIMENTS COMPLETED"
+echo "Results log: ${RESULTS_LOG}"
+echo "============================================================"
+echo ""
+echo "Result directories:"
+if [[ "$GROUP" == "all" || "$GROUP" == "table2" ]]; then
+ echo " Table 2 (Ablation): ${WORK_ROOT}/table2/"
+fi
+if [[ "$GROUP" == "all" || "$GROUP" == "table3" ]]; then
+ echo " Table 3 (MoE Hyper): ${WORK_ROOT}/table3/"
+fi
+if [[ "$GROUP" == "all" || "$GROUP" == "table4" ]]; then
+ echo " Table 4 (Single-Modal): ${WORK_ROOT}/table4/"
+fi
+echo ""
+echo "Per-modality test logs: /test_{sar,rgb,GF}/test_log.txt"
diff --git a/scripts/test_all_metrics.sh b/scripts/test_all_metrics.sh
new file mode 100755
index 0000000000..fb133466cd
--- /dev/null
+++ b/scripts/test_all_metrics.sh
@@ -0,0 +1,324 @@
+#!/bin/bash
+# ============================================================================
+# Re-test All Table 2/3/4 Experiments with Expanded Metrics
+# Metrics: mIoU, mDice, mFscore (Precision, Recall, F1), aAcc (OA)
+#
+# Usage:
+# bash scripts/test_all_metrics.sh table2
+# bash scripts/test_all_metrics.sh table3
+# bash scripts/test_all_metrics.sh table4
+# bash scripts/test_all_metrics.sh full # Full Model only
+# bash scripts/test_all_metrics.sh all # All tables + Full Model
+#
+# Environment Variables:
+# GPU_IDS GPU device IDs (default: 0)
+#
+# Example:
+# GPU_IDS=0 bash scripts/test_all_metrics.sh table2
+# ============================================================================
+
+set -e
+
+GPU_IDS=${GPU_IDS:-0}
+CONFIG_DIR="configs/floodnet"
+ABLATION_DIR="${CONFIG_DIR}/ablations"
+WORK_ROOT="work_dirs/paper_experiments"
+METRICS_LOG="${WORK_ROOT}/metrics_summary.txt"
+
+GROUP=${1:-"all"}
+
+mkdir -p "${WORK_ROOT}"
+echo "========== Metrics Test Run: $(date) ==========" >> "${METRICS_LOG}"
+
+# ============================================================================
+# Helper Functions
+# ============================================================================
+
+find_best_ckpt() {
+ local work_dir=$1
+ local best_ckpt=$(ls ${work_dir}/best_mIoU_*.pth 2>/dev/null | head -1)
+ if [ -z "$best_ckpt" ]; then
+ best_ckpt=$(ls ${work_dir}/epoch_*.pth 2>/dev/null | sort -V | tail -1)
+ fi
+ echo "$best_ckpt"
+}
+
+run_test_metrics() {
+ local config=$1
+ local checkpoint=$2
+ local work_dir=$3
+ local modal=$4
+ local desc=$5
+
+ local test_work_dir="${work_dir}/metrics_${modal}"
+ mkdir -p "${test_work_dir}"
+
+ echo "------------------------------------------------------------"
+ echo "[TEST] ${desc} | Modality: ${modal}"
+ echo "[CKPT] ${checkpoint}"
+ echo "[OUT] ${test_work_dir}"
+ echo "------------------------------------------------------------"
+
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \
+ ${config} \
+ ${checkpoint} \
+ --work-dir ${test_work_dir} \
+ --cfg-options \
+ test_dataloader.dataset.filter_modality="${modal}" \
+ "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \
+ 2>&1 | tee "${test_work_dir}/metrics_log.txt"
+
+ echo "[${desc}] modal=${modal} -> ${test_work_dir}/metrics_log.txt" >> "${METRICS_LOG}"
+ echo ""
+}
+
+run_test_all_modals_metrics() {
+ local config=$1
+ local checkpoint=$2
+ local work_dir=$3
+ local desc=$4
+
+ echo "============================================================"
+ echo "[METRICS TEST] ${desc}"
+ echo "============================================================"
+
+ run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "sar" "${desc}"
+ run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "rgb" "${desc}"
+ run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "GF" "${desc}"
+}
+
+# ============================================================================
+# Full Model
+# ============================================================================
+
+run_full_model() {
+ echo ""
+ echo "############################################################"
+ echo "# Full Model (Swin-B + MoE, E=8 K=3)"
+ echo "############################################################"
+
+ local config="${CONFIG_DIR}/multimodal_floodnet_sar_boost_swinbase_moe_config.py"
+ local work_dir="work_dirs/floodnet/SwinmoeB/655"
+ local ckpt="${work_dir}/best_mIoU_epoch_100.pth"
+
+ if [ ! -f "$ckpt" ]; then
+ echo "[WARN] Full Model checkpoint not found: ${ckpt}"
+ echo "[WARN] Trying to find best checkpoint in ${work_dir}..."
+ ckpt=$(find_best_ckpt "${work_dir}")
+ fi
+
+ if [ -z "$ckpt" ] || [ ! -f "$ckpt" ]; then
+ echo "[ERROR] No checkpoint found for Full Model. Skipping."
+ return 0
+ fi
+
+ run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "Full Model"
+}
+
+# ============================================================================
+# Table 2: Component Ablation
+# ============================================================================
+
+run_table2() {
+ echo ""
+ echo "############################################################"
+ echo "# Table 2: Component Ablation Study"
+ echo "############################################################"
+
+ local base_config="${CONFIG_DIR}/multimodal_floodnet_sar_boost_swinbase_moe_config.py"
+
+ declare -A T2_CONFIGS
+ T2_CONFIGS=(
+ ["no_moe"]="${ABLATION_DIR}/ablation_no_moe.py"
+ ["no_modal_specific_stem"]="${ABLATION_DIR}/ablation_no_modal_specific_stem.py"
+ ["no_modal_bias"]="${ABLATION_DIR}/ablation_no_modal_bias.py"
+ ["no_shared_experts"]="${ABLATION_DIR}/ablation_no_shared_experts.py"
+ ["shared_decoder"]="${ABLATION_DIR}/ablation_shared_decoder.py"
+ )
+
+ declare -A T2_DESCS
+ T2_DESCS=(
+ ["no_moe"]="Table2: w/o MoE"
+ ["no_modal_specific_stem"]="Table2: w/o ModalSpecificStem"
+ ["no_modal_bias"]="Table2: w/o Modal Bias"
+ ["no_shared_experts"]="Table2: w/o Shared Experts"
+ ["shared_decoder"]="Table2: w/o Separate Decoder"
+ )
+
+ for key in no_moe no_modal_specific_stem no_modal_bias no_shared_experts shared_decoder; do
+ local config="${T2_CONFIGS[$key]}"
+ local work_dir="${WORK_ROOT}/table2/${key}"
+ local desc="${T2_DESCS[$key]}"
+
+ local ckpt=$(find_best_ckpt "${work_dir}")
+ if [ -z "$ckpt" ]; then
+ echo "[ERROR] No checkpoint found for ${desc} in ${work_dir}. Skipping."
+ continue
+ fi
+
+ run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "${desc}"
+ done
+}
+
+# ============================================================================
+# Table 3: MoE Hyperparameter Study
+# ============================================================================
+
+run_table3() {
+ echo ""
+ echo "############################################################"
+ echo "# Table 3: MoE Hyperparameter Study"
+ echo "############################################################"
+
+ declare -A T3_CONFIGS
+ T3_CONFIGS=(
+ ["e6_k1"]="${ABLATION_DIR}/ablation_e6_k1.py"
+ ["e6_k2"]="${ABLATION_DIR}/ablation_e6_k2.py"
+ ["e6_k3"]="${ABLATION_DIR}/ablation_e6_k3.py"
+ ["e8_k1"]="${ABLATION_DIR}/ablation_e8_k1.py"
+ ["e8_k2"]="${ABLATION_DIR}/ablation_e8_k2.py"
+ )
+
+ declare -A T3_DESCS
+ T3_DESCS=(
+ ["e6_k1"]="Table3: E=6 K=1"
+ ["e6_k2"]="Table3: E=6 K=2"
+ ["e6_k3"]="Table3: E=6 K=3"
+ ["e8_k1"]="Table3: E=8 K=1"
+ ["e8_k2"]="Table3: E=8 K=2"
+ )
+
+ for key in e6_k1 e6_k2 e6_k3 e8_k1 e8_k2; do
+ local config="${T3_CONFIGS[$key]}"
+ local work_dir="${WORK_ROOT}/table3/${key}"
+ local desc="${T3_DESCS[$key]}"
+
+ local ckpt=$(find_best_ckpt "${work_dir}")
+ if [ -z "$ckpt" ]; then
+ echo "[ERROR] No checkpoint found for ${desc} in ${work_dir}. Skipping."
+ continue
+ fi
+
+ run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "${desc}"
+ done
+}
+
+# ============================================================================
+# Table 4: Single-Modal vs Multi-Modal
+# ============================================================================
+
+run_table4() {
+ echo ""
+ echo "############################################################"
+ echo "# Table 4: Single-Modal vs Multi-Modal"
+ echo "############################################################"
+
+ # SAR-only: test on SAR
+ local sar_config="${CONFIG_DIR}/multimodal_floodnet_sar_only_swinbase_moe_config.py"
+ local sar_dir="${WORK_ROOT}/table4/sar_only"
+ local sar_ckpt=$(find_best_ckpt "${sar_dir}")
+ if [ -n "$sar_ckpt" ]; then
+ local test_dir="${sar_dir}/metrics_sar"
+ mkdir -p "${test_dir}"
+ echo "------------------------------------------------------------"
+ echo "[TEST] Table4: SAR-Only | Modality: sar"
+ echo "[CKPT] ${sar_ckpt}"
+ echo "------------------------------------------------------------"
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \
+ ${sar_config} ${sar_ckpt} \
+ --work-dir ${test_dir} \
+ --cfg-options \
+ "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \
+ 2>&1 | tee "${test_dir}/metrics_log.txt"
+ echo "[Table4: SAR-Only] modal=sar -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}"
+ else
+ echo "[ERROR] No checkpoint found for SAR-Only in ${sar_dir}. Skipping."
+ fi
+
+ # RGB-only: test on RGB
+ local rgb_config="${ABLATION_DIR}/ablation_rgb_only.py"
+ local rgb_dir="${WORK_ROOT}/table4/rgb_only"
+ local rgb_ckpt=$(find_best_ckpt "${rgb_dir}")
+ if [ -n "$rgb_ckpt" ]; then
+ local test_dir="${rgb_dir}/metrics_rgb"
+ mkdir -p "${test_dir}"
+ echo "------------------------------------------------------------"
+ echo "[TEST] Table4: RGB-Only | Modality: rgb"
+ echo "[CKPT] ${rgb_ckpt}"
+ echo "------------------------------------------------------------"
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \
+ ${rgb_config} ${rgb_ckpt} \
+ --work-dir ${test_dir} \
+ --cfg-options \
+ "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \
+ 2>&1 | tee "${test_dir}/metrics_log.txt"
+ echo "[Table4: RGB-Only] modal=rgb -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}"
+ else
+ echo "[ERROR] No checkpoint found for RGB-Only in ${rgb_dir}. Skipping."
+ fi
+
+ # GF-only: test on GF
+ local gf_config="${ABLATION_DIR}/ablation_gf_only.py"
+ local gf_dir="${WORK_ROOT}/table4/gf_only"
+ local gf_ckpt=$(find_best_ckpt "${gf_dir}")
+ if [ -n "$gf_ckpt" ]; then
+ local test_dir="${gf_dir}/metrics_GF"
+ mkdir -p "${test_dir}"
+ echo "------------------------------------------------------------"
+ echo "[TEST] Table4: GF-Only | Modality: GF"
+ echo "[CKPT] ${gf_ckpt}"
+ echo "------------------------------------------------------------"
+ CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \
+ ${gf_config} ${gf_ckpt} \
+ --work-dir ${test_dir} \
+ --cfg-options \
+ "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \
+ 2>&1 | tee "${test_dir}/metrics_log.txt"
+ echo "[Table4: GF-Only] modal=GF -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}"
+ else
+ echo "[ERROR] No checkpoint found for GF-Only in ${gf_dir}. Skipping."
+ fi
+}
+
+# ============================================================================
+# Main
+# ============================================================================
+
+echo "============================================================"
+echo " Expanded Metrics Testing"
+echo " Metrics: mIoU, mDice, mFscore (Precision/Recall/F1), aAcc"
+echo " Group: ${GROUP}"
+echo " GPU: ${GPU_IDS}"
+echo "============================================================"
+
+case "${GROUP}" in
+ full)
+ run_full_model
+ ;;
+ table2)
+ run_table2
+ ;;
+ table3)
+ run_table3
+ ;;
+ table4)
+ run_table4
+ ;;
+ all)
+ run_full_model
+ run_table2
+ run_table3
+ run_table4
+ ;;
+ *)
+ echo "Unknown group: ${GROUP}"
+ echo "Usage: bash scripts/test_all_metrics.sh {full|table2|table3|table4|all}"
+ exit 1
+ ;;
+esac
+
+echo ""
+echo "============================================================"
+echo " All metrics tests completed!"
+echo " Results log: ${METRICS_LOG}"
+echo "============================================================"
diff --git a/tools/analysis_tools/benchmark_multimodal.py b/tools/analysis_tools/benchmark_multimodal.py
new file mode 100644
index 0000000000..e607e5fcb7
--- /dev/null
+++ b/tools/analysis_tools/benchmark_multimodal.py
@@ -0,0 +1,368 @@
+"""
+Compute Inference FPS and FLOPs for Multi-Modal Swin-MoE model.
+
+Usage:
+ python tools/analysis_tools/benchmark_multimodal.py \
+ configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \
+ work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \
+ --shape 256 256 \
+ --repeat-times 3 \
+ --num-iters 200
+"""
+
+import argparse
+import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmengine import Config
+from mmengine.model.utils import revert_sync_batchnorm
+from mmengine.registry import init_default_scope
+from mmengine.runner import load_checkpoint
+
+from mmseg.registry import MODELS
+from mmseg.structures import SegDataSample
+
+try:
+ from fvcore.nn import FlopCountAnalysis
+ HAS_FVCORE = True
+except ImportError:
+ HAS_FVCORE = False
+
+
+MODAL_CHANNELS = {
+ 'sar': 8,
+ 'rgb': 3,
+ 'GF': 5,
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Benchmark Multi-Modal Segmentor (FPS + FLOPs)')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument('checkpoint', help='checkpoint file path')
+ parser.add_argument('--shape', type=int, nargs=2, default=[256, 256],
+ help='input H W (default: 256 256)')
+ parser.add_argument('--repeat-times', type=int, default=3,
+ help='number of FPS measurement runs')
+ parser.add_argument('--num-iters', type=int, default=200,
+ help='iterations per FPS run')
+ parser.add_argument('--num-warmup', type=int, default=10,
+ help='warmup iterations before timing')
+ parser.add_argument('--modals', nargs='+', default=['sar', 'rgb', 'GF'],
+ help='modalities to benchmark')
+ return parser.parse_args()
+
+
+def build_model(cfg, checkpoint_path):
+ """Build model and load checkpoint."""
+ cfg.model.train_cfg = None
+ model = MODELS.build(cfg.model)
+
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+ model = revert_sync_batchnorm(model)
+ model.eval()
+ return model
+
+
+def make_dummy_input(modal, shape, device='cuda'):
+ """Create a dummy input mimicking the data preprocessor output."""
+ h, w = shape
+ channels = MODAL_CHANNELS[modal]
+ img = torch.randn(channels, h, w, device=device)
+
+ data_sample = SegDataSample()
+ data_sample.set_metainfo(dict(
+ img_shape=(h, w),
+ ori_shape=(h, w),
+ pad_shape=(h, w),
+ scale_factor=(1.0, 1.0),
+ flip=False,
+ flip_direction=None,
+ modal_type=modal,
+ actual_channels=channels,
+ dataset_name=modal,
+ reduce_zero_label=False,
+ ))
+
+ return img, data_sample
+
+
+def measure_fps(model, modal, shape, num_iters=200, num_warmup=10):
+ """Measure inference FPS for a single modality."""
+ device = next(model.parameters()).device
+ img, data_sample = make_dummy_input(modal, shape, device)
+
+ # Warmup
+ with torch.no_grad():
+ for _ in range(num_warmup):
+ model([img], [data_sample], mode='predict')
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ # Timed iterations
+ start = time.perf_counter()
+ with torch.no_grad():
+ for _ in range(num_iters):
+ model([img], [data_sample], mode='predict')
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ elapsed = time.perf_counter() - start
+ fps = num_iters / elapsed
+ return fps
+
+
+class BackboneDecoderWrapper(nn.Module):
+ """Wrapper that runs backbone + decode_head forward only.
+
+ Avoids predict/inference path so FLOPs tools (fvcore/thop)
+ don't encounter SegDataSample or slide_inference logic.
+ """
+
+ def __init__(self, model, modal):
+ super().__init__()
+ self.backbone = model.backbone
+ self.modal = modal
+ # Get the correct decode head
+ if hasattr(model, 'decode_heads') and modal in model.decode_heads:
+ self.decode_head = model.decode_heads[modal]
+ elif hasattr(model, '_shared_decode_head'):
+ self.decode_head = model._shared_decode_head
+ else:
+ self.decode_head = None
+
+ def forward(self, x):
+ """x: (1, C, H, W) tensor for a single modality."""
+ imgs_list = [x[0]] # list of (C, H, W)
+ modal_types = [self.modal]
+ features, _, _ = self.backbone(imgs_list, modal_types)
+ # features is a tuple of multi-scale tensors
+ if self.decode_head is not None:
+ out = self.decode_head(features)
+ return out
+
+
+def measure_flops_manual(model, modal, shape):
+ """Measure FLOPs by manually counting ops in backbone + decoder.
+
+ Uses a forward hook approach to count multiply-accumulate operations.
+ """
+ device = next(model.parameters()).device
+ h, w = shape
+ channels = MODAL_CHANNELS[modal]
+ img = torch.randn(1, channels, h, w, device=device)
+
+ wrapper = BackboneDecoderWrapper(model, modal)
+ wrapper.eval()
+
+ total_flops = 0
+
+ def count_conv2d(m, inp, out):
+ nonlocal total_flops
+ x = inp[0]
+ batch = x.shape[0]
+ out_h, out_w = out.shape[2], out.shape[3]
+ kernel_ops = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels // m.groups)
+ total_flops += batch * m.out_channels * out_h * out_w * kernel_ops
+
+ def count_linear(m, inp, out):
+ nonlocal total_flops
+ x = inp[0]
+ batch_size = x.numel() // m.in_features
+ total_flops += batch_size * m.in_features * m.out_features
+
+ def count_layernorm(m, inp, out):
+ nonlocal total_flops
+ total_flops += inp[0].numel() * 2 # mean + variance
+
+ def count_gelu(m, inp, out):
+ nonlocal total_flops
+ total_flops += inp[0].numel() * 4 # approximate
+
+ def count_bn(m, inp, out):
+ nonlocal total_flops
+ total_flops += inp[0].numel() * 2
+
+ hooks = []
+ for m in wrapper.modules():
+ if isinstance(m, nn.Conv2d):
+ hooks.append(m.register_forward_hook(count_conv2d))
+ elif isinstance(m, nn.Linear):
+ hooks.append(m.register_forward_hook(count_linear))
+ elif isinstance(m, nn.LayerNorm):
+ hooks.append(m.register_forward_hook(count_layernorm))
+ elif isinstance(m, nn.GELU):
+ hooks.append(m.register_forward_hook(count_gelu))
+ elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
+ hooks.append(m.register_forward_hook(count_bn))
+
+ with torch.no_grad():
+ try:
+ wrapper(img)
+ except Exception as e:
+ print(f' [manual WARN] forward failed: {e}')
+ for h in hooks:
+ h.remove()
+ return None
+
+ for h in hooks:
+ h.remove()
+
+ return total_flops
+
+
+def measure_flops_fvcore(model, modal, shape):
+ """Measure FLOPs using fvcore on backbone + decoder wrapper."""
+ device = next(model.parameters()).device
+ h, w = shape
+ channels = MODAL_CHANNELS[modal]
+ img = torch.randn(1, channels, h, w, device=device)
+
+ wrapper = BackboneDecoderWrapper(model, modal)
+ wrapper.eval()
+
+ try:
+ flop_analysis = FlopCountAnalysis(wrapper, (img,))
+ flop_analysis.unsupported_ops_warnings(False)
+ flop_analysis.uncalled_modules_warnings(False)
+ flops = flop_analysis.total()
+ return flops
+ except Exception as e:
+ print(f' [fvcore WARN] {e}')
+ return None
+
+
+def format_flops(flops):
+ """Format FLOPs to human readable string."""
+ if flops is None:
+ return 'N/A'
+ if flops >= 1e12:
+ return f'{flops / 1e12:.2f} TFLOPs'
+ elif flops >= 1e9:
+ return f'{flops / 1e9:.2f} GFLOPs'
+ elif flops >= 1e6:
+ return f'{flops / 1e6:.2f} MFLOPs'
+ else:
+ return f'{flops:.0f} FLOPs'
+
+
+def format_params(params):
+ """Format parameter count."""
+ if params >= 1e6:
+ return f'{params / 1e6:.2f} M'
+ elif params >= 1e3:
+ return f'{params / 1e3:.2f} K'
+ else:
+ return f'{params:.0f}'
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ init_default_scope(cfg.get('default_scope', 'mmseg'))
+
+ print('=' * 60)
+ print('Multi-Modal Model Benchmark')
+ print(f'Config: {args.config}')
+ print(f'Checkpoint: {args.checkpoint}')
+ print(f'Input size: {args.shape[0]} x {args.shape[1]}')
+ print(f'Modalities: {args.modals}')
+ print('=' * 60)
+ print()
+
+ # Build model
+ print('Building model...')
+ model = build_model(cfg, args.checkpoint)
+
+ # Count parameters
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f'Total params: {format_params(total_params)}')
+ print(f'Trainable params: {format_params(trainable_params)}')
+ print()
+
+ # ======================== FLOPs ========================
+ print('=' * 60)
+ print('FLOPs Measurement (backbone + decode_head)')
+ print('=' * 60)
+
+ flops_results = {}
+ for modal in args.modals:
+ ch = MODAL_CHANNELS[modal]
+ print(f'\n[{modal.upper()}] input: ({ch}, {args.shape[0]}, {args.shape[1]})')
+
+ flops_val = None
+
+ # Try fvcore first (wraps backbone+decoder only, no SegDataSample)
+ if HAS_FVCORE:
+ print(' Method: fvcore')
+ flops_val = measure_flops_fvcore(model, modal, args.shape)
+ if flops_val is not None:
+ print(f' FLOPs: {format_flops(flops_val)}')
+
+ # Fallback: manual hook-based counting
+ if flops_val is None:
+ print(' Method: manual (hook-based)')
+ flops_val = measure_flops_manual(model, modal, args.shape)
+ if flops_val is not None:
+ print(f' FLOPs: {format_flops(flops_val)}')
+ else:
+ print(' FLOPs: N/A')
+
+ flops_results[modal] = flops_val
+
+ # ======================== FPS ========================
+ print()
+ print('=' * 60)
+ print('FPS Measurement (full predict pipeline)')
+ print(f' Warmup: {args.num_warmup} iters')
+ print(f' Timed: {args.num_iters} iters x {args.repeat_times} runs')
+ print('=' * 60)
+
+ fps_results = {}
+ for modal in args.modals:
+ ch = MODAL_CHANNELS[modal]
+ print(f'\n[{modal.upper()}] input: ({ch}, {args.shape[0]}, {args.shape[1]})')
+
+ fps_list = []
+ for run_idx in range(args.repeat_times):
+ fps = measure_fps(model, modal, args.shape,
+ num_iters=args.num_iters,
+ num_warmup=args.num_warmup)
+ fps_list.append(fps)
+ print(f' Run {run_idx + 1}: {fps:.2f} img/s')
+
+ avg_fps = np.mean(fps_list)
+ std_fps = np.std(fps_list)
+ print(f' >> Average: {avg_fps:.2f} +/- {std_fps:.2f} img/s')
+ fps_results[modal] = (avg_fps, std_fps)
+
+ # ======================== Summary ========================
+ print()
+ print('=' * 60)
+ print('Summary')
+ print('=' * 60)
+ print(f'{"Modality":<10} {"Channels":<10} {"FLOPs":<18} {"FPS (avg)":>15}')
+ print('-' * 55)
+ for modal in args.modals:
+ ch = MODAL_CHANNELS[modal]
+ flops_str = format_flops(flops_results.get(modal))
+ fps_avg, fps_std = fps_results.get(modal, (0, 0))
+ print(f'{modal:<10} {ch:<10} {flops_str:<18} {fps_avg:>10.2f} img/s')
+ print('-' * 55)
+ print(f'Params: {format_params(total_params)}')
+ print(f'Input: {args.shape[0]} x {args.shape[1]}')
+ print()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py
index 78a73988d4..694e55f79c 100644
--- a/tools/analysis_tools/get_flops.py
+++ b/tools/analysis_tools/get_flops.py
@@ -62,8 +62,11 @@ def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
+ elif len(args.shape) == 3: # 新增
+ input_shape = tuple(args.shape)
else:
raise ValueError('invalid input shape')
+ print(f"DEBUG: Full input_shape = {input_shape}") # 添加这行
result = {}
model: BaseSegmentor = MODELS.build(cfg.model)
diff --git a/tools/analysis_tools/visualize_expert_routing.py b/tools/analysis_tools/visualize_expert_routing.py
new file mode 100644
index 0000000000..bc31e2dfa1
--- /dev/null
+++ b/tools/analysis_tools/visualize_expert_routing.py
@@ -0,0 +1,696 @@
+"""
+Figure 4: Expert Routing Analysis — Publication-quality visualization.
+
+Generates three sub-figures:
+ (a) Expert activation probability heatmap per modality per stage
+ (b) Learned Modal Bias matrix visualization
+ (c) Spatial expert assignment map for a single image
+
+Usage:
+ python tools/analysis_tools/visualize_expert_routing.py \
+ configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \
+ work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \
+ --data-root ../floodnet/data/mixed_dataset/ \
+ --output-dir work_dirs/figures/expert_routing \
+ --num-samples 50
+"""
+
+import argparse
+import os
+import os.path as osp
+from collections import defaultdict
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import numpy as np
+import torch
+import torch.nn.functional as F
+from matplotlib.colors import LinearSegmentedColormap
+from mmengine import Config
+from mmengine.model.utils import revert_sync_batchnorm
+from mmengine.registry import init_default_scope
+from mmengine.runner import Runner, load_checkpoint
+
+from mmseg.registry import MODELS
+
+# ======================== Publication Style ========================
+plt.rcParams.update({
+ 'font.family': 'serif',
+ 'font.serif': ['Times New Roman', 'DejaVu Serif'],
+ 'font.size': 10,
+ 'axes.labelsize': 11,
+ 'axes.titlesize': 12,
+ 'xtick.labelsize': 9,
+ 'ytick.labelsize': 9,
+ 'legend.fontsize': 9,
+ 'figure.dpi': 300,
+ 'savefig.dpi': 300,
+ 'savefig.bbox': 'tight',
+ 'savefig.pad_inches': 0.05,
+ 'axes.linewidth': 0.8,
+ 'axes.grid': False,
+})
+
+MODAL_DISPLAY = {
+ 'sar': 'UrbanSARflood',
+ 'rgb': 'FloodNet',
+ 'GF': 'GF-Floodnet',
+}
+
+MODAL_CHANNELS = {
+ 'sar': 8,
+ 'rgb': 3,
+ 'GF': 5,
+}
+
+MODAL_COLORS = {
+ 'sar': '#2196F3',
+ 'rgb': '#4CAF50',
+ 'GF': '#FF9800',
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Visualize Expert Routing Patterns')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument('checkpoint', help='checkpoint file path')
+ parser.add_argument('--data-root', type=str,
+ default='../floodnet/data/mixed_dataset/',
+ help='data root for test set')
+ parser.add_argument('--output-dir', type=str,
+ default='work_dirs/figures/expert_routing',
+ help='output directory for figures')
+ parser.add_argument('--num-samples', type=int, default=50,
+ help='number of test samples per modality for stats')
+ parser.add_argument('--spatial-image-idx', type=int, default=0,
+ help='index of image for spatial assignment map')
+ return parser.parse_args()
+
+
+# ======================== Model Building ========================
+
+def build_model(cfg, checkpoint_path):
+ cfg.model.train_cfg = None
+ model = MODELS.build(cfg.model)
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ if torch.cuda.is_available():
+ model = model.cuda()
+ model = revert_sync_batchnorm(model)
+ model.eval()
+ return model
+
+
+# ======================== Hook-based Gate Extraction ========================
+
+class GateHookManager:
+ """Register hooks on all MoE gating layers to capture routing weights."""
+
+ def __init__(self, model):
+ self.model = model
+ self.hooks = []
+ self.gate_records = [] # list of (stage, block, gates_tensor)
+ self._register_hooks()
+
+ def _register_hooks(self):
+ backbone = self._get_backbone(self.model)
+ for stage_idx, stage_dict in enumerate(backbone.stages):
+ blocks = stage_dict['blocks']
+ for block_idx, block in enumerate(blocks):
+ if block.use_moe:
+ moe_layer = block.mlp
+ hook = moe_layer.register_forward_hook(
+ self._make_hook(stage_idx, block_idx))
+ self.hooks.append(hook)
+
+ def _get_backbone(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+ if hasattr(model, 'backbone'):
+ return model.backbone
+ return model
+
+ def _make_hook(self, stage_idx, block_idx):
+ def hook_fn(module, input_args, output):
+ # Re-compute gates (eval mode, no noise)
+ x = input_args[0]
+ modal_types = input_args[1] if len(input_args) > 1 else None
+
+ with torch.no_grad():
+ x_pooled = x.mean(dim=1)
+ clean_logits = module.gating(x_pooled, modal_types)
+ top_logits, top_indices = clean_logits.topk(
+ min(module.top_k, module.num_experts), dim=-1)
+ top_k_gates = F.softmax(top_logits, dim=-1)
+ zeros = torch.zeros_like(clean_logits)
+ gates = zeros.scatter(-1, top_indices, top_k_gates)
+
+ self.gate_records.append({
+ 'stage': stage_idx,
+ 'block': block_idx,
+ 'gates': gates.detach().cpu(),
+ 'modal_types': list(modal_types) if modal_types else None,
+ })
+
+ return hook_fn
+
+ def clear(self):
+ self.gate_records = []
+
+ def remove_hooks(self):
+ for h in self.hooks:
+ h.remove()
+ self.hooks = []
+
+
+# ======================== Data Collection ========================
+
+def collect_routing_stats(model, cfg, hook_mgr, data_root, num_samples):
+ """Run inference on real test data and collect per-modality routing stats.
+
+ Loads actual test images from the dataset to capture true data-driven
+ routing patterns (feature-dependent + modal_bias), not just modal_bias.
+ Falls back to random noise if no images are found for a modality.
+ """
+ from mmseg.structures import SegDataSample
+ from mmengine.dataset import Compose
+ import glob
+
+ device = next(model.parameters()).device
+ stats = defaultdict(lambda: defaultdict(list))
+
+ # Build a minimal inference pipeline: load image, resize, normalize, pack.
+ # Skip LoadAnnotations (no labels needed for routing analysis).
+ # Use 256x256 (matching training crop) for efficient single-pass inference.
+ pipeline = Compose([
+ dict(type='LoadMultiModalImageFromFile', to_float32=True),
+ dict(type='Resize', scale=(256, 256), keep_ratio=False),
+ dict(type='MultiModalNormalize'),
+ dict(type='PackMultiModalSegInputs',
+ meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'modal_type', 'actual_channels', 'dataset_name',
+ 'img_norm_cfg', 'reduce_zero_label')),
+ ])
+
+ # Dataset directory
+ data_prefix = cfg.test_dataloader.dataset.get(
+ 'data_prefix', dict(img_path='test/images'))
+ img_dir = osp.join(data_root, data_prefix['img_path'])
+
+ for modal in ['sar', 'rgb', 'GF']:
+ hook_mgr.clear()
+ ch = MODAL_CHANNELS[modal]
+
+ # Find images for this modality by filename pattern
+ all_files = sorted(glob.glob(osp.join(img_dir, '*')))
+ modal_key = modal.lower()
+ modal_files = [f for f in all_files
+ if modal_key in osp.basename(f).lower()]
+
+ if modal_files:
+ n = min(num_samples, len(modal_files))
+ print(f' [{modal}] Using {n} real test images '
+ f'(found {len(modal_files)} total)')
+
+ count = 0
+ for img_path in modal_files[:num_samples]:
+ fname = osp.basename(img_path)
+
+ data = dict(
+ img_path=img_path,
+ modal_type=modal,
+ actual_channels=ch,
+ dataset_name=modal,
+ reduce_zero_label=False,
+ ori_filename=fname,
+ flip=False,
+ flip_direction=None,
+ )
+
+ try:
+ data = pipeline(data)
+ except Exception as e:
+ if count == 0:
+ print(f' [DEBUG] Pipeline error: {e}')
+ print(f' [DEBUG] img_path: {img_path}')
+ continue
+
+ # Pipeline outputs: {'inputs': Tensor, 'data_samples': SegDataSample}
+ img = data['inputs'].float().to(device)
+ data_sample = data['data_samples']
+
+ with torch.no_grad():
+ model([img], [data_sample], mode='predict')
+ count += 1
+
+ print(f' [{modal}] Successfully processed {count} images')
+
+ else:
+ print(f' [WARN] No images for {modal} in {img_dir}, '
+ f'using random noise')
+ for _ in range(num_samples):
+ img = torch.randn(ch, 256, 256, device=device)
+ ds = SegDataSample()
+ ds.set_metainfo(dict(
+ img_shape=(256, 256), ori_shape=(256, 256),
+ pad_shape=(256, 256), scale_factor=(1.0, 1.0),
+ flip=False, flip_direction=None,
+ modal_type=modal, actual_channels=ch,
+ dataset_name=modal, reduce_zero_label=False,
+ ))
+ with torch.no_grad():
+ model([img], [ds], mode='predict')
+
+ # Aggregate per (stage, block)
+ for record in hook_mgr.gate_records:
+ key = (record['stage'], record['block'])
+ stats[modal][key].append(record['gates'])
+
+ return stats
+
+
+def collect_spatial_gates(model, hook_mgr, modal='sar', shape=(256, 256)):
+ """Collect per-token spatial gating for one image."""
+ from mmseg.structures import SegDataSample
+
+ device = next(model.parameters()).device
+ ch = MODAL_CHANNELS[modal]
+ h, w = shape
+
+ # We need per-token gates, not per-sample gates.
+ # Modify hook to capture per-token routing.
+ backbone = hook_mgr._get_backbone(model)
+ spatial_records = []
+
+ hooks = []
+ for stage_idx, stage_dict in enumerate(backbone.stages):
+ for block_idx, block in enumerate(stage_dict['blocks']):
+ if block.use_moe:
+ moe_layer = block.mlp
+
+ def make_spatial_hook(s_idx, b_idx):
+ def hook_fn(module, input_args, output):
+ x = input_args[0] # [B, N, C]
+ modal_types = (input_args[1]
+ if len(input_args) > 1 else None)
+
+ with torch.no_grad():
+ # Per-sample gating (same as normal)
+ x_pooled = x.mean(dim=1)
+ logits = module.gating(x_pooled, modal_types)
+ top_logits, top_indices = logits.topk(
+ min(module.top_k, module.num_experts), dim=-1)
+ top_k_gates = F.softmax(top_logits, dim=-1)
+
+ # The dominant expert for this sample
+ dominant_expert = top_indices[0, 0].item()
+ gate_weights = torch.zeros(module.num_experts)
+ gate_weights.scatter_(
+ 0, top_indices[0],
+ top_k_gates[0].cpu())
+
+ spatial_records.append({
+ 'stage': s_idx,
+ 'block': b_idx,
+ 'dominant_expert': dominant_expert,
+ 'gate_weights': gate_weights,
+ })
+ return hook_fn
+
+ h_handle = moe_layer.register_forward_hook(
+ make_spatial_hook(stage_idx, block_idx))
+ hooks.append(h_handle)
+
+ img = torch.randn(ch, h, w, device=device)
+ data_sample = SegDataSample()
+ data_sample.set_metainfo(dict(
+ img_shape=(h, w), ori_shape=(h, w), pad_shape=(h, w),
+ scale_factor=(1.0, 1.0), flip=False, flip_direction=None,
+ modal_type=modal, actual_channels=ch,
+ dataset_name=modal, reduce_zero_label=False,
+ ))
+
+ with torch.no_grad():
+ model([img], [data_sample], mode='predict')
+
+ for h_handle in hooks:
+ h_handle.remove()
+
+ return spatial_records
+
+
+# ======================== Figure (a): Activation Heatmap ========================
+
+def plot_activation_heatmap(stats, num_experts, output_dir):
+ """Plot expert activation probability heatmap per modality per stage."""
+
+ # Aggregate: for each modal and each (stage, block), compute mean gate
+ modals = ['sar', 'rgb', 'GF']
+
+ # Collect all unique (stage, block) keys sorted
+ all_keys = set()
+ for modal in modals:
+ all_keys.update(stats[modal].keys())
+ all_keys = sorted(all_keys)
+
+ # Group by stage
+ stage_keys = defaultdict(list)
+ for s, b in all_keys:
+ stage_keys[s].append((s, b))
+
+ stages_with_moe = sorted(stage_keys.keys())
+ num_stages = len(stages_with_moe)
+
+ fig, axes = plt.subplots(
+ 1, num_stages,
+ figsize=(3.2 * num_stages + 0.8, 2.8),
+ gridspec_kw={'wspace': 0.35}
+ )
+ if num_stages == 1:
+ axes = [axes]
+
+ cmap = plt.cm.YlOrRd
+
+ for ax_idx, stage in enumerate(stages_with_moe):
+ keys = stage_keys[stage]
+ # Build matrix: [num_modals, num_experts], averaged over all blocks
+ heatmap = np.zeros((len(modals), num_experts))
+
+ for m_idx, modal in enumerate(modals):
+ all_gates = []
+ for key in keys:
+ if key in stats[modal]:
+ gate_list = stats[modal][key]
+ # Each is [B, num_experts], concat and mean
+ stacked = torch.cat(gate_list, dim=0) # [N, E]
+ all_gates.append(stacked)
+ if all_gates:
+ combined = torch.cat(all_gates, dim=0)
+ # Activation probability = fraction of times expert is selected
+ activation_prob = (combined > 0).float().mean(dim=0).numpy()
+ heatmap[m_idx] = activation_prob
+
+ im = axes[ax_idx].imshow(
+ heatmap, cmap=cmap, aspect='auto', vmin=0, vmax=1.0)
+
+ axes[ax_idx].set_xticks(range(num_experts))
+ axes[ax_idx].set_xticklabels(
+ [f'E{i}' for i in range(num_experts)], fontsize=8)
+ axes[ax_idx].set_yticks(range(len(modals)))
+ if ax_idx == 0:
+ axes[ax_idx].set_yticklabels(
+ [MODAL_DISPLAY[m] for m in modals], fontsize=9)
+ else:
+ axes[ax_idx].set_yticklabels([])
+ axes[ax_idx].set_xlabel('Expert Index', fontsize=9)
+ axes[ax_idx].set_title(f'Stage {stage}', fontsize=11, fontweight='bold')
+
+ # Annotate cells
+ for i in range(len(modals)):
+ for j in range(num_experts):
+ val = heatmap[i, j]
+ color = 'white' if val > 0.5 else 'black'
+ axes[ax_idx].text(
+ j, i, f'{val:.2f}',
+ ha='center', va='center', fontsize=7, color=color)
+
+ # Colorbar
+ cbar = fig.colorbar(
+ im, ax=axes, shrink=0.85, aspect=25, pad=0.03)
+ cbar.set_label('Activation Probability', fontsize=9)
+
+ fig.savefig(
+ osp.join(output_dir, 'fig4a_expert_activation_heatmap.pdf'),
+ format='pdf')
+ fig.savefig(
+ osp.join(output_dir, 'fig4a_expert_activation_heatmap.png'),
+ format='png')
+ plt.close(fig)
+ print(f'[Saved] fig4a_expert_activation_heatmap.pdf/png')
+
+
+# ======================== Figure (b): Modal Bias ========================
+
+def plot_modal_bias(model, output_dir):
+ """Visualize the learned modal_bias parameters from all MoE layers."""
+
+ backbone = (model.module.backbone
+ if hasattr(model, 'module') else model.backbone)
+
+ bias_per_stage = defaultdict(list)
+ modal_names = None
+
+ for stage_idx, stage_dict in enumerate(backbone.stages):
+ for block_idx, block in enumerate(stage_dict['blocks']):
+ if block.use_moe and hasattr(block.mlp, 'gating'):
+ gate = block.mlp.gating
+ if (hasattr(gate, 'modal_bias')
+ and gate.modal_bias is not None):
+ bias = gate.modal_bias.detach().cpu().numpy()
+ bias_per_stage[stage_idx].append(bias)
+ if modal_names is None:
+ modal_names = gate.modal_name_to_idx
+
+ if not bias_per_stage:
+ print('[WARN] No modal_bias parameters found.')
+ return
+
+ # Sort modal names by index
+ sorted_modals = sorted(modal_names.items(), key=lambda x: x[1])
+ modal_order = [m[0] for m in sorted_modals]
+
+ stages = sorted(bias_per_stage.keys())
+ num_stages = len(stages)
+ num_experts = bias_per_stage[stages[0]][0].shape[1]
+
+ fig, axes = plt.subplots(
+ 1, num_stages,
+ figsize=(3.2 * num_stages + 0.8, 2.8),
+ gridspec_kw={'wspace': 0.35}
+ )
+ if num_stages == 1:
+ axes = [axes]
+
+ # Diverging colormap centered at 0
+ cmap = plt.cm.RdBu_r
+ all_biases = np.concatenate(
+ [np.stack(v).mean(axis=0) for v in bias_per_stage.values()])
+ vmax = max(abs(all_biases.min()), abs(all_biases.max()))
+ if vmax < 1e-6:
+ vmax = 1.0
+
+ for ax_idx, stage in enumerate(stages):
+ # Average across all blocks in this stage
+ avg_bias = np.stack(bias_per_stage[stage]).mean(axis=0)
+ # avg_bias shape: [num_modals, num_experts]
+
+ im = axes[ax_idx].imshow(
+ avg_bias, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax)
+
+ axes[ax_idx].set_xticks(range(num_experts))
+ axes[ax_idx].set_xticklabels(
+ [f'E{i}' for i in range(num_experts)], fontsize=8)
+ axes[ax_idx].set_yticks(range(len(modal_order)))
+ if ax_idx == 0:
+ axes[ax_idx].set_yticklabels(
+ [MODAL_DISPLAY.get(m, m) for m in modal_order], fontsize=9)
+ else:
+ axes[ax_idx].set_yticklabels([])
+ axes[ax_idx].set_xlabel('Expert Index', fontsize=9)
+ axes[ax_idx].set_title(f'Stage {stage}', fontsize=11, fontweight='bold')
+
+ # Annotate
+ for i in range(avg_bias.shape[0]):
+ for j in range(avg_bias.shape[1]):
+ val = avg_bias[i, j]
+ color = 'white' if abs(val) > vmax * 0.6 else 'black'
+ axes[ax_idx].text(
+ j, i, f'{val:.2f}',
+ ha='center', va='center', fontsize=7, color=color)
+
+ cbar = fig.colorbar(im, ax=axes, shrink=0.85, aspect=25, pad=0.03)
+ cbar.set_label('Bias Value', fontsize=9)
+
+ fig.savefig(
+ osp.join(output_dir, 'fig4b_modal_bias_matrix.pdf'), format='pdf')
+ fig.savefig(
+ osp.join(output_dir, 'fig4b_modal_bias_matrix.png'), format='png')
+ plt.close(fig)
+ print(f'[Saved] fig4b_modal_bias_matrix.pdf/png')
+
+
+# ======================== Figure (c): Expert Assignment Bar ========================
+
+def plot_expert_assignment_comparison(stats, num_experts, output_dir):
+ """Grouped bar chart: expert selection frequency per modality.
+
+ Aggregated across all MoE layers, showing which experts are preferred
+ by each modality. More intuitive than heatmap for presentations.
+ """
+ modals = ['sar', 'rgb', 'GF']
+
+ # Aggregate across all stages/blocks
+ freq = {}
+ for modal in modals:
+ all_gates = []
+ for key, gate_list in stats[modal].items():
+ stacked = torch.cat(gate_list, dim=0)
+ all_gates.append(stacked)
+ if all_gates:
+ combined = torch.cat(all_gates, dim=0)
+ # Selection frequency
+ freq[modal] = (combined > 0).float().mean(dim=0).numpy()
+ else:
+ freq[modal] = np.zeros(num_experts)
+
+ x = np.arange(num_experts)
+ width = 0.25
+ fig, ax = plt.subplots(figsize=(5.5, 3.2))
+
+ for i, modal in enumerate(modals):
+ ax.bar(
+ x + (i - 1) * width, freq[modal], width,
+ label=MODAL_DISPLAY[modal],
+ color=MODAL_COLORS[modal], edgecolor='white', linewidth=0.5)
+
+ ax.set_xlabel('Expert Index')
+ ax.set_ylabel('Selection Frequency')
+ ax.set_xticks(x)
+ ax.set_xticklabels([f'E{i}' for i in range(num_experts)])
+ ax.legend(frameon=True, edgecolor='gray', fancybox=False)
+ ax.set_ylim(0, 1.0)
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+
+ fig.savefig(
+ osp.join(output_dir, 'fig4c_expert_selection_frequency.pdf'),
+ format='pdf')
+ fig.savefig(
+ osp.join(output_dir, 'fig4c_expert_selection_frequency.png'),
+ format='png')
+ plt.close(fig)
+ print(f'[Saved] fig4c_expert_selection_frequency.pdf/png')
+
+
+# ======================== Figure (d): Gate Weight Distribution ========================
+
+def plot_gate_weight_distribution(stats, num_experts, output_dir):
+ """Violin/box plot of gate weight distribution per modality per expert."""
+ modals = ['sar', 'rgb', 'GF']
+
+ fig, axes = plt.subplots(
+ 1, len(modals), figsize=(4.0 * len(modals), 3.0),
+ sharey=True, gridspec_kw={'wspace': 0.1})
+
+ for m_idx, modal in enumerate(modals):
+ all_gates = []
+ for key, gate_list in stats[modal].items():
+ stacked = torch.cat(gate_list, dim=0)
+ all_gates.append(stacked)
+
+ if not all_gates:
+ continue
+
+ combined = torch.cat(all_gates, dim=0).numpy() # [N, E]
+
+ # Only keep nonzero weights for box plot
+ data_per_expert = []
+ for e in range(num_experts):
+ weights = combined[:, e]
+ nonzero = weights[weights > 0]
+ data_per_expert.append(nonzero if len(nonzero) > 0
+ else np.array([0.0]))
+
+ bp = axes[m_idx].boxplot(
+ data_per_expert, patch_artist=True, widths=0.6,
+ showfliers=False, medianprops=dict(color='black', linewidth=1.2))
+
+ color = MODAL_COLORS[modal]
+ for patch in bp['boxes']:
+ patch.set_facecolor(color)
+ patch.set_alpha(0.7)
+
+ axes[m_idx].set_xlabel('Expert Index')
+ axes[m_idx].set_xticklabels(
+ [f'E{i}' for i in range(num_experts)], fontsize=8)
+ axes[m_idx].set_title(MODAL_DISPLAY[modal], fontsize=11,
+ fontweight='bold', color=color)
+ axes[m_idx].spines['top'].set_visible(False)
+ axes[m_idx].spines['right'].set_visible(False)
+
+ axes[0].set_ylabel('Gate Weight (non-zero)')
+
+ fig.savefig(
+ osp.join(output_dir, 'fig4d_gate_weight_distribution.pdf'),
+ format='pdf')
+ fig.savefig(
+ osp.join(output_dir, 'fig4d_gate_weight_distribution.png'),
+ format='png')
+ plt.close(fig)
+ print(f'[Saved] fig4d_gate_weight_distribution.pdf/png')
+
+
+# ======================== Main ========================
+
+def main():
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ cfg = Config.fromfile(args.config)
+ init_default_scope(cfg.get('default_scope', 'mmseg'))
+
+ print('=' * 60)
+ print('Expert Routing Visualization')
+ print(f'Config: {args.config}')
+ print(f'Checkpoint: {args.checkpoint}')
+ print(f'Output: {args.output_dir}')
+ print(f'Samples: {args.num_samples} per modality')
+ print('=' * 60)
+
+ # Use 'whole' mode to avoid slide_inference complexity for routing analysis
+ cfg.model.test_cfg = dict(mode='whole')
+
+ # Build model
+ print('\nBuilding model...')
+ model = build_model(cfg, args.checkpoint)
+
+ # Read num_experts from config
+ num_experts = cfg.model.backbone.get('num_experts', 8)
+
+ # ---- Figure (b): Modal Bias (no inference needed) ----
+ print('\n--- Figure (b): Modal Bias Matrix ---')
+ plot_modal_bias(model, args.output_dir)
+
+ # ---- Collect routing stats via hooks ----
+ print('\n--- Collecting routing statistics ---')
+ hook_mgr = GateHookManager(model)
+ stats = collect_routing_stats(
+ model, cfg, hook_mgr, args.data_root, args.num_samples)
+ hook_mgr.remove_hooks()
+
+ # ---- Figure (a): Activation Heatmap ----
+ print('\n--- Figure (a): Expert Activation Heatmap ---')
+ plot_activation_heatmap(stats, num_experts, args.output_dir)
+
+ # ---- Figure (c): Expert Selection Frequency Bar ----
+ print('\n--- Figure (c): Expert Selection Frequency ---')
+ plot_expert_assignment_comparison(stats, num_experts, args.output_dir)
+
+ # ---- Figure (d): Gate Weight Distribution ----
+ print('\n--- Figure (d): Gate Weight Distribution ---')
+ plot_gate_weight_distribution(stats, num_experts, args.output_dir)
+
+ print('\n' + '=' * 60)
+ print(f'All figures saved to: {args.output_dir}/')
+ print(' fig4a_expert_activation_heatmap.pdf/png')
+ print(' fig4b_modal_bias_matrix.pdf/png')
+ print(' fig4c_expert_selection_frequency.pdf/png')
+ print(' fig4d_gate_weight_distribution.pdf/png')
+ print('=' * 60)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/compute_sen1floods11_stats.py b/tools/compute_sen1floods11_stats.py
new file mode 100644
index 0000000000..d007490216
--- /dev/null
+++ b/tools/compute_sen1floods11_stats.py
@@ -0,0 +1,192 @@
+"""Compute channel-wise mean / std for Sen1Floods11 S1Hand or S2Hand.
+
+The tool walks ``//*.tif`` (default: S1Hand or S2Hand)
+and ignores pixels that are flagged as nodata in the matching
+``LabelHand`` file (label value == -1). Pass the resulting mean / std
+arrays into ``MultiModalNormalize.NORM_CONFIGS`` under the ``s1`` / ``s2``
+entries in ``mmseg/datasets/transforms/multimodal_pipelines.py``.
+
+Usage::
+
+ python tools/compute_sen1floods11_stats.py \
+ --data-root data/Sen1Floods11 --modality s1
+
+ python tools/compute_sen1floods11_stats.py \
+ --data-root data/Sen1Floods11 --modality s2 --split splits/train.txt
+"""
+import argparse
+import os
+import os.path as osp
+from typing import List
+
+import numpy as np
+
+try:
+ import tifffile
+except ImportError as e:
+ raise SystemExit(
+ 'tifffile is required. Install with `pip install tifffile`.') from e
+
+
+MODALITIES = {
+ 's1': {
+ 'subdir': 'S1Hand',
+ 'suffix': '_S1Hand.tif',
+ 'channels': 2,
+ },
+ 's2': {
+ 'subdir': 'S2Hand',
+ 'suffix': '_S2Hand.tif',
+ 'channels': 13,
+ },
+}
+
+LABEL_SUBDIR = 'LabelHand'
+LABEL_SUFFIX = '_LabelHand.tif'
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument('--data-root', required=True,
+ help='Sen1Floods11 root containing S1Hand/S2Hand/LabelHand')
+ p.add_argument('--modality', choices=list(MODALITIES), required=True)
+ p.add_argument('--split', default=None,
+ help='Optional txt file listing sample base-names '
+ '(one per line). Paths are resolved against '
+ 'data-root.')
+ p.add_argument('--max-samples', type=int, default=None,
+ help='Optional cap on the number of images to scan.')
+ p.add_argument('--mask-nodata', action='store_true', default=True,
+ help='Exclude pixels where LabelHand == -1 '
+ '(default: True).')
+ p.add_argument('--no-mask-nodata', dest='mask_nodata',
+ action='store_false')
+ return p.parse_args()
+
+
+def load_image_list(data_root: str, modality_cfg: dict,
+ split: str) -> List[str]:
+ img_dir = osp.join(data_root, modality_cfg['subdir'])
+ suffix = modality_cfg['suffix']
+
+ if split:
+ split_path = split if osp.isabs(split) else osp.join(data_root, split)
+ if not osp.isfile(split_path):
+ raise FileNotFoundError(f'split file not found: {split_path}')
+ bases = []
+ with open(split_path, 'r') as f:
+ for line in f:
+ base = line.strip()
+ if not base:
+ continue
+ for s in (suffix, LABEL_SUFFIX):
+ if base.endswith(s):
+ base = base[:-len(s)]
+ break
+ bases.append(base)
+ img_names = [b + suffix for b in bases]
+ else:
+ img_names = sorted(f for f in os.listdir(img_dir)
+ if f.endswith(suffix))
+
+ return [osp.join(img_dir, n) for n in img_names]
+
+
+def label_path_for(img_path: str, modality_cfg: dict, data_root: str) -> str:
+ base = osp.basename(img_path)[:-len(modality_cfg['suffix'])]
+ return osp.join(data_root, LABEL_SUBDIR, base + LABEL_SUFFIX)
+
+
+def _hwc(img: np.ndarray) -> np.ndarray:
+ """Ensure (H, W, C)."""
+ if img.ndim == 2:
+ return img[:, :, None]
+ if img.ndim == 3 and img.shape[0] < img.shape[-1]:
+ return np.transpose(img, (1, 2, 0))
+ return img
+
+
+def compute_stats(img_paths, data_root, modality_cfg, mask_nodata):
+ num_channels = modality_cfg['channels']
+
+ # Running sums for per-channel mean / std using Welford-style
+ # accumulators (numerically stable and streaming).
+ total_count = np.zeros(num_channels, dtype=np.float64)
+ total_sum = np.zeros(num_channels, dtype=np.float64)
+ total_sqsum = np.zeros(num_channels, dtype=np.float64)
+
+ for i, img_path in enumerate(img_paths):
+ img = tifffile.imread(img_path)
+ img = _hwc(img).astype(np.float64)
+ if img.shape[-1] != num_channels:
+ print(f'[warn] skip {img_path}: got {img.shape[-1]} ch '
+ f'(expected {num_channels})')
+ continue
+
+ valid_mask = None
+ if mask_nodata:
+ lbl_path = label_path_for(img_path, modality_cfg, data_root)
+ if osp.isfile(lbl_path):
+ lbl = tifffile.imread(lbl_path)
+ lbl = np.squeeze(lbl)
+ valid_mask = (lbl != -1)
+ # also drop NaN / Inf pixels that crop up in SAR dB
+ finite_mask = np.all(np.isfinite(img), axis=-1)
+ if valid_mask is None:
+ valid_mask = finite_mask
+ else:
+ valid_mask &= finite_mask
+
+ if not valid_mask.any():
+ continue
+
+ valid = img[valid_mask] # (N, C)
+ total_count += valid.shape[0]
+ total_sum += valid.sum(axis=0)
+ total_sqsum += (valid ** 2).sum(axis=0)
+
+ if (i + 1) % 50 == 0:
+ print(f' processed {i + 1}/{len(img_paths)} images')
+
+ # Avoid divide-by-zero
+ total_count = np.where(total_count == 0, 1, total_count)
+ mean = total_sum / total_count
+ var = total_sqsum / total_count - mean ** 2
+ var = np.clip(var, a_min=0.0, a_max=None)
+ std = np.sqrt(var)
+ return mean, std, total_count
+
+
+def main():
+ args = parse_args()
+ modality_cfg = MODALITIES[args.modality]
+
+ img_paths = load_image_list(args.data_root, modality_cfg, args.split)
+ if args.max_samples:
+ img_paths = img_paths[:args.max_samples]
+
+ print(f'Found {len(img_paths)} images for modality={args.modality}')
+ if not img_paths:
+ raise SystemExit('No images matched - check --data-root / --split')
+
+ mean, std, count = compute_stats(
+ img_paths, args.data_root, modality_cfg, args.mask_nodata)
+
+ print()
+ print('=' * 60)
+ print(f'Sen1Floods11 {args.modality} statistics '
+ f'(valid pixels: {int(count[0])})')
+ print('=' * 60)
+ print(f'mean = {mean.tolist()}')
+ print(f'std = {std.tolist()}')
+ print()
+ print('Paste into NORM_CONFIGS in '
+ 'mmseg/datasets/transforms/multimodal_pipelines.py:')
+ print(f" '{args.modality}': {{")
+ print(f" 'mean': {mean.tolist()},")
+ print(f" 'std': {std.tolist()},")
+ print(' },')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/predict_large_tif.py b/tools/predict_large_tif.py
new file mode 100644
index 0000000000..2ecc2f5125
--- /dev/null
+++ b/tools/predict_large_tif.py
@@ -0,0 +1,408 @@
+"""
+Large TIF Inference for Multi-Modal Segmentor.
+
+Reads a large GeoTIFF (e.g. 20803x36986), tiles it into patches,
+runs inference with the multi-modal model, and stitches results
+back into a full-size GeoTIFF. Flood pixels are colored red (255,0,0),
+non-flood pixels are black (0,0,0).
+
+Usage:
+ python tools/predict_large_tif.py \
+ configs/floodnet/finetune_single_modal.py \
+ work_dirs/generalization/LY-train-station/best_mIoU_epoch_30.pth \
+ --input data/luoyuan/result.tif \
+ --output data/luoyuan/prediction.tif \
+ --tile-size 512 \
+ --overlap 64 \
+ --modal rgb \
+ --bands 0 1 2 \
+ --batch-size 4
+"""
+
+import argparse
+import math
+import os
+import time
+
+import numpy as np
+import torch
+from mmengine import Config
+from mmengine.model.utils import revert_sync_batchnorm
+from mmengine.registry import init_default_scope
+from mmengine.runner import load_checkpoint
+
+from mmseg.registry import MODELS
+from mmseg.structures import SegDataSample
+
+try:
+ from osgeo import gdal
+ HAS_GDAL = True
+except ImportError:
+ HAS_GDAL = False
+
+try:
+ import rasterio
+ HAS_RASTERIO = True
+except ImportError:
+ HAS_RASTERIO = False
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Large TIF inference with tile stitching')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument('checkpoint', help='checkpoint file path')
+ parser.add_argument('--input', required=True, help='input TIF file')
+ parser.add_argument('--output', default=None,
+ help='output TIF file (default: input_pred.tif)')
+ parser.add_argument('--tile-size', type=int, default=512,
+ help='tile size for inference (default: 512)')
+ parser.add_argument('--overlap', type=int, default=64,
+ help='overlap between tiles (default: 64)')
+ parser.add_argument('--batch-size', type=int, default=4,
+ help='batch size for inference (default: 4)')
+ parser.add_argument('--modal', default='rgb',
+ help='modality type for model (default: rgb)')
+ parser.add_argument('--bands', type=int, nargs='+', default=[0, 1, 2],
+ help='band indices to read (0-indexed, default: 0 1 2)')
+ parser.add_argument('--flood-class', type=int, default=1,
+ help='class index for flood (default: 1)')
+ parser.add_argument('--device', default='cuda:0',
+ help='device for inference (default: cuda:0)')
+ return parser.parse_args()
+
+
+def build_model(cfg, checkpoint_path, device):
+ """Build model and load checkpoint."""
+ cfg.model.train_cfg = None
+ model = MODELS.build(cfg.model)
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ model.to(device)
+ model = revert_sync_batchnorm(model)
+ model.eval()
+ return model
+
+
+def read_tif_info(tif_path):
+ """Read TIF metadata."""
+ if HAS_RASTERIO:
+ with rasterio.open(tif_path) as src:
+ return {
+ 'height': src.height,
+ 'width': src.width,
+ 'bands': src.count,
+ 'dtype': src.dtypes[0],
+ 'crs': src.crs,
+ 'transform': src.transform,
+ 'profile': src.profile.copy(),
+ }
+ elif HAS_GDAL:
+ ds = gdal.Open(tif_path, gdal.GA_ReadOnly)
+ return {
+ 'height': ds.RasterYSize,
+ 'width': ds.RasterXSize,
+ 'bands': ds.RasterCount,
+ 'geo_transform': ds.GetGeoTransform(),
+ 'projection': ds.GetProjection(),
+ }
+ else:
+ raise ImportError('Neither rasterio nor GDAL available. '
+ 'Install with: pip install rasterio')
+
+
+def read_tile_rasterio(tif_path, x, y, w, h, band_indices):
+ """Read a tile from TIF using rasterio."""
+ with rasterio.open(tif_path) as src:
+ window = rasterio.windows.Window(x, y, w, h)
+ # rasterio bands are 1-indexed
+ bands = [b + 1 for b in band_indices]
+ data = src.read(bands, window=window) # (C, H, W)
+ return data.astype(np.float32)
+
+
+def read_tile_gdal(tif_path, x, y, w, h, band_indices):
+ """Read a tile from TIF using GDAL."""
+ ds = gdal.Open(tif_path, gdal.GA_ReadOnly)
+ data = []
+ for bi in band_indices:
+ band = ds.GetRasterBand(bi + 1) # GDAL is 1-indexed
+ arr = band.ReadAsArray(x, y, w, h)
+ data.append(arr.astype(np.float32))
+ return np.stack(data, axis=0) # (C, H, W)
+
+
+def generate_tiles(img_h, img_w, tile_size, overlap):
+ """Generate tile coordinates with overlap."""
+ stride = tile_size - overlap
+ tiles = []
+
+ for y in range(0, img_h, stride):
+ for x in range(0, img_w, stride):
+ # Clamp to image boundary
+ x1 = min(x, img_w - tile_size)
+ y1 = min(y, img_h - tile_size)
+ x1 = max(x1, 0)
+ y1 = max(y1, 0)
+ w = min(tile_size, img_w - x1)
+ h = min(tile_size, img_h - y1)
+ tiles.append((x1, y1, w, h))
+
+ # Deduplicate (edge tiles may repeat)
+ seen = set()
+ unique_tiles = []
+ for t in tiles:
+ if t not in seen:
+ seen.add(t)
+ unique_tiles.append(t)
+
+ return unique_tiles
+
+
+def normalize_tile(tile_data, modal='rgb'):
+ """Normalize tile using the same mean/std as training pipeline.
+
+ Must match MultiModalNormalize in multimodal_pipelines.py exactly.
+ tile_data: (C, H, W) float32, raw pixel values (e.g. 0-255 for RGB).
+ """
+ NORM_CONFIGS = {
+ 'rgb': {
+ 'mean': [123.675, 116.28, 103.53],
+ 'std': [58.395, 57.12, 57.375],
+ },
+ 'sar': {
+ 'mean': [0.23651549, 0.31761484, 0.18514981, 0.26901252,
+ -14.57879175, -8.6098158, -14.2907338, -8.33534564],
+ 'std': [0.16280619, 0.20849304, 0.14008107, 0.19767644,
+ 4.07141682, 3.94773216, 4.21006244, 4.05494136],
+ },
+ 'GF': {
+ 'mean': [432.02181, 315.92948, 246.468659,
+ 310.61462, 360.267789],
+ 'std': [97.73313111900238, 85.78646917160748,
+ 95.78015824658593,
+ 124.84677067613467, 251.73965882246978],
+ }
+ }
+
+ channels = tile_data.shape[0]
+ if modal in NORM_CONFIGS:
+ mean = np.array(NORM_CONFIGS[modal]['mean'][:channels],
+ dtype=np.float32)
+ std = np.array(NORM_CONFIGS[modal]['std'][:channels],
+ dtype=np.float32)
+ else:
+ mean = np.array([128.0] * channels, dtype=np.float32)
+ std = np.array([50.0] * channels, dtype=np.float32)
+
+ # tile_data shape: (C, H, W), normalize per channel
+ for c in range(channels):
+ tile_data[c] = (tile_data[c] - mean[c]) / std[c]
+
+ return tile_data
+
+
+def predict_batch(model, batch_imgs, modal, device):
+ """Run inference on a batch of tile images.
+
+ Args:
+ model: segmentor model
+ batch_imgs: list of (C, H, W) numpy arrays
+ modal: modality string
+ device: torch device
+
+ Returns:
+ list of (H, W) numpy prediction masks
+ """
+ imgs = []
+ data_samples = []
+
+ for img_np in batch_imgs:
+ img_tensor = torch.from_numpy(img_np).float().to(device)
+ imgs.append(img_tensor)
+
+ ds = SegDataSample()
+ h, w = img_np.shape[1], img_np.shape[2]
+ ds.set_metainfo(dict(
+ img_shape=(h, w),
+ ori_shape=(h, w),
+ pad_shape=(h, w),
+ scale_factor=(1.0, 1.0),
+ flip=False,
+ flip_direction=None,
+ modal_type=modal,
+ actual_channels=img_np.shape[0],
+ dataset_name=modal,
+ reduce_zero_label=False,
+ ))
+ data_samples.append(ds)
+
+ with torch.no_grad():
+ results = model(imgs, data_samples, mode='predict')
+
+ preds = []
+ for r in results:
+ pred = r.pred_sem_seg.data.cpu().numpy()[0] # (H, W)
+ preds.append(pred)
+
+ return preds
+
+
+def main():
+ args = parse_args()
+
+ if not HAS_RASTERIO and not HAS_GDAL:
+ raise ImportError('Install rasterio or GDAL: pip install rasterio')
+
+ read_tile = read_tile_rasterio if HAS_RASTERIO else read_tile_gdal
+
+ # Output path
+ if args.output is None:
+ base, ext = os.path.splitext(args.input)
+ args.output = f'{base}_pred{ext}'
+
+ # Load config and model
+ cfg = Config.fromfile(args.config)
+ init_default_scope(cfg.get('default_scope', 'mmseg'))
+
+ # Override test_cfg to use 'whole' mode (we handle tiling ourselves)
+ cfg.model.test_cfg = dict(mode='whole')
+
+ print('=' * 60)
+ print('Large TIF Inference')
+ print(f'Input: {args.input}')
+ print(f'Output: {args.output}')
+ print(f'Tile size: {args.tile_size}')
+ print(f'Overlap: {args.overlap}')
+ print(f'Modality: {args.modal}')
+ print(f'Bands: {args.bands}')
+ print(f'Batch size: {args.batch_size}')
+ print('=' * 60)
+
+ # Read TIF info
+ info = read_tif_info(args.input)
+ img_h, img_w = info['height'], info['width']
+ print(f'\nImage size: {img_w} x {img_h} (W x H)')
+ print(f'Bands: {info["bands"]}')
+
+ # Build model
+ print('\nBuilding model...')
+ model = build_model(cfg, args.checkpoint, args.device)
+ print('Model loaded.')
+
+ # Generate tiles
+ tiles = generate_tiles(img_h, img_w, args.tile_size, args.overlap)
+ num_tiles = len(tiles)
+ num_batches = math.ceil(num_tiles / args.batch_size)
+ print(f'\nTotal tiles: {num_tiles}')
+ print(f'Batches: {num_batches}')
+
+ # Verify normalization
+ print(f'\nNormalization check (modal={args.modal}):')
+ test_tile = read_tile(args.input, 0, 0,
+ min(args.tile_size, img_w),
+ min(args.tile_size, img_h),
+ args.bands)
+ print(f' Raw pixel range: [{test_tile.min():.2f}, {test_tile.max():.2f}]')
+ test_normed = normalize_tile(test_tile.copy(), modal=args.modal)
+ print(f' After normalize: [{test_normed.min():.2f}, {test_normed.max():.2f}]')
+ del test_tile, test_normed
+
+ # Allocate output: prediction mask + count for overlap voting
+ pred_sum = np.zeros((img_h, img_w), dtype=np.float32)
+ count_map = np.zeros((img_h, img_w), dtype=np.float32)
+
+ # Inference
+ print('\nStarting inference...')
+ t_start = time.time()
+
+ for batch_idx in range(num_batches):
+ start_i = batch_idx * args.batch_size
+ end_i = min(start_i + args.batch_size, num_tiles)
+ batch_tiles = tiles[start_i:end_i]
+
+ # Read tiles
+ batch_imgs = []
+ for (x, y, w, h) in batch_tiles:
+ tile_data = read_tile(args.input, x, y, w, h, args.bands)
+ # Handle edge tiles smaller than tile_size: pad
+ if h < args.tile_size or w < args.tile_size:
+ padded = np.zeros(
+ (len(args.bands), args.tile_size, args.tile_size),
+ dtype=np.float32)
+ padded[:, :h, :w] = tile_data
+ tile_data = padded
+ tile_data = normalize_tile(tile_data, modal=args.modal)
+ batch_imgs.append(tile_data)
+
+ # Predict
+ preds = predict_batch(model, batch_imgs, args.modal, args.device)
+
+ # Stitch predictions (vote-based: accumulate class labels)
+ for (x, y, w, h), pred in zip(batch_tiles, preds):
+ pred_sum[y:y+h, x:x+w] += pred[:h, :w].astype(np.float32)
+ count_map[y:y+h, x:x+w] += 1.0
+
+ # Progress
+ done = end_i
+ elapsed = time.time() - t_start
+ eta = elapsed / done * (num_tiles - done) if done > 0 else 0
+ print(f'\r [{done}/{num_tiles}] '
+ f'{done/num_tiles*100:.1f}% | '
+ f'Elapsed: {elapsed:.0f}s | ETA: {eta:.0f}s',
+ end='', flush=True)
+
+ print()
+
+ # Average overlapping predictions
+ count_map = np.maximum(count_map, 1.0)
+ pred_avg = pred_sum / count_map
+ pred_mask = (pred_avg >= 0.5).astype(np.uint8) # threshold
+
+ # Count statistics
+ flood_pixels = (pred_mask == 1).sum()
+ total_pixels = pred_mask.size
+ print(f'\nFlood pixels: {flood_pixels:,} '
+ f'({flood_pixels/total_pixels*100:.2f}%)')
+ print(f'Non-flood pixels: {total_pixels - flood_pixels:,}')
+
+ # Write output TIF (3-band RGB: flood=red, non-flood=black)
+ print(f'\nWriting output: {args.output}')
+ os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
+
+ if HAS_RASTERIO:
+ profile = info['profile'].copy()
+ profile.update(
+ count=3,
+ dtype='uint8',
+ compress='lzw',
+ )
+ with rasterio.open(args.output, 'w', **profile) as dst:
+ # Red channel: 255 for flood
+ red = (pred_mask * 255).astype(np.uint8)
+ # Green and Blue: 0
+ green = np.zeros_like(red)
+ blue = np.zeros_like(red)
+ dst.write(red, 1)
+ dst.write(green, 2)
+ dst.write(blue, 3)
+ elif HAS_GDAL:
+ driver = gdal.GetDriverByName('GTiff')
+ out_ds = driver.Create(
+ args.output, img_w, img_h, 3, gdal.GDT_Byte,
+ options=['COMPRESS=LZW'])
+ out_ds.SetGeoTransform(info['geo_transform'])
+ out_ds.SetProjection(info['projection'])
+ red = (pred_mask * 255).astype(np.uint8)
+ out_ds.GetRasterBand(1).WriteArray(red)
+ out_ds.GetRasterBand(2).WriteArray(np.zeros_like(red))
+ out_ds.GetRasterBand(3).WriteArray(np.zeros_like(red))
+ out_ds.FlushCache()
+ out_ds = None
+
+ elapsed_total = time.time() - t_start
+ print(f'\nDone! Total time: {elapsed_total:.1f}s')
+ print(f'Output: {args.output}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/remap_pred_colors.py b/tools/remap_pred_colors.py
new file mode 100644
index 0000000000..6f5e6a30c8
--- /dev/null
+++ b/tools/remap_pred_colors.py
@@ -0,0 +1,97 @@
+"""
+Remap prediction image colors.
+
+Converts the two-color palette output from SegVisualizationHook:
+ - Black [ 0, 0, 0] (Background / Nodata) -> #7c7c7c [124, 124, 124]
+ - Red [255, 0, 0] (Flood) -> #000bc5 [ 0, 11, 197]
+
+Nodata pixels are treated the same as Background (both render as black
+after SegVisualizationHook._mask_nodata, then both remap to #7c7c7c).
+
+Usage:
+ # In-place: remap all PNGs under a directory
+ python tools/remap_pred_colors.py --src vis_pred/vis_data/vis_image/
+
+ # Save to a new directory
+ python tools/remap_pred_colors.py \
+ --src vis_pred/vis_data/vis_image/ \
+ --dst vis_pred/vis_data/vis_image_remapped/
+
+ # Process a single file
+ python tools/remap_pred_colors.py --src path/to/image.png
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+import numpy as np
+from PIL import Image
+
+COLOR_MAP = {
+ (0, 0, 0): (124, 124, 124), # Background / Nodata -> #7c7c7c
+ (255, 0, 0): (0, 11, 197), # Flood -> #000bc5
+}
+
+
+def remap_image(img_array: np.ndarray) -> np.ndarray:
+ out = img_array.copy()
+ for src_rgb, dst_rgb in COLOR_MAP.items():
+ mask = (
+ (img_array[:, :, 0] == src_rgb[0]) &
+ (img_array[:, :, 1] == src_rgb[1]) &
+ (img_array[:, :, 2] == src_rgb[2])
+ )
+ out[mask] = dst_rgb
+ return out
+
+
+def process_file(src_path: Path, dst_path: Path) -> None:
+ img = Image.open(src_path).convert('RGB')
+ arr = np.array(img, dtype=np.uint8)
+ arr_out = remap_image(arr)
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ Image.fromarray(arr_out).save(dst_path)
+
+
+def collect_png(root: Path):
+ return sorted(root.rglob('*.png'))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Remap prediction image colors '
+ '(black→#7c7c7c, red→#000bc5)')
+ parser.add_argument('--src', required=True,
+ help='Source directory (or single PNG).')
+ parser.add_argument('--dst', default=None,
+ help='Destination directory. Default: in-place.')
+ args = parser.parse_args()
+
+ src = Path(args.src)
+ if not src.exists():
+ print(f'ERROR: --src "{src}" does not exist.', file=sys.stderr)
+ sys.exit(1)
+
+ if src.is_file():
+ pairs = [(src, Path(args.dst) if args.dst else src)]
+ else:
+ files = collect_png(src)
+ if not files:
+ print(f'No PNG files found under "{src}".', file=sys.stderr)
+ sys.exit(0)
+ if args.dst is None:
+ pairs = [(f, f) for f in files]
+ else:
+ dst_root = Path(args.dst)
+ pairs = [(f, dst_root / f.relative_to(src)) for f in files]
+
+ for i, (src_f, dst_f) in enumerate(pairs, 1):
+ process_file(src_f, dst_f)
+ print(f'[{i}/{len(pairs)}] {src_f} -> {dst_f}')
+
+ print(f'\nDone. {len(pairs)} image(s) remapped.')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/setup_sen1floods11.py b/tools/setup_sen1floods11.py
new file mode 100644
index 0000000000..129e816f0d
--- /dev/null
+++ b/tools/setup_sen1floods11.py
@@ -0,0 +1,337 @@
+"""
+One-shot Sen1Floods11 setup for fine-tuning.
+
+This script bridges the gap between a freshly-downloaded Sen1Floods11
+folder and a training run of
+``configs/floodnet/finetune_sen1floods11_{s1,s2}.py``.
+
+Expected on-disk layout::
+
+ data/Sen1Floods11/
+ S1Hand/_S1Hand.tif # 2-band SAR (VV, VH in dB)
+ S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI
+ LabelHand/_LabelHand.tif # 1-band label, -1=nodata, 0/1 classes
+
+What the script does (all four steps are independent and can be
+re-run):
+
+ 1. Scan ``/LabelHand`` for every available base-name.
+ 2. Write deterministic train/val/test splits to
+ ``/splits/{train,val,test}.txt`` based on an MD5 hash
+ of the base-name (so the same tile always lands in the same
+ split on re-run).
+ 3. Compute per-channel mean / std for S1Hand (2 ch) and S2Hand
+ (13 ch) *using only training-split tiles*, with NaN / Inf pixels
+ and label == -1 pixels masked out. This is important because
+ the shipped NORM_CONFIGS values in
+ ``mmseg/datasets/transforms/multimodal_pipelines.py`` were
+ computed on a different split and may not match your copy of
+ the data.
+ 4. Print the NORM_CONFIGS snippet so you can paste it into
+ ``multimodal_pipelines.py``.
+
+Usage::
+
+ # full setup (splits + stats for both modalities)
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11
+
+ # only regenerate splits (e.g. after adding more tiles)
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\
+ --skip-stats
+
+ # only recompute stats (splits already exist)
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\
+ --skip-splits
+
+ # only one modality
+ python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\
+ --modalities s1
+
+The split ratio defaults to 70 / 15 / 15 train / val / test. Pass
+``--train-ratio`` / ``--val-ratio`` to change it.
+"""
+import argparse
+import hashlib
+import os
+import os.path as osp
+from typing import List, Tuple
+
+import numpy as np
+
+try:
+ import tifffile
+except ImportError as e:
+ raise SystemExit(
+ 'tifffile is required. Install with `pip install tifffile`.') from e
+
+
+# ---------------------------------------------------------------------------
+# Layout constants - keep in sync with Sen1Floods11Dataset.MODAL_CONFIG and
+# multimodal_pipelines.MultiModalNormalize.NORM_CONFIGS.
+# ---------------------------------------------------------------------------
+LABEL_SUBDIR = 'LabelHand'
+LABEL_SUFFIX = '_LabelHand.tif'
+
+MODALITIES = {
+ 's1': {
+ 'subdir': 'S1Hand',
+ 'suffix': '_S1Hand.tif',
+ 'channels': 2,
+ },
+ 's2': {
+ 'subdir': 'S2Hand',
+ 'suffix': '_S2Hand.tif',
+ 'channels': 13,
+ },
+}
+
+
+def parse_args():
+ p = argparse.ArgumentParser(
+ description='Sen1Floods11 setup: splits + normalization stats.')
+ p.add_argument('--data-root', required=True,
+ help='Sen1Floods11 root with S1Hand/S2Hand/LabelHand.')
+ p.add_argument('--train-ratio', type=float, default=0.70,
+ help='Train fraction (default: 0.70).')
+ p.add_argument('--val-ratio', type=float, default=0.15,
+ help='Val fraction (default: 0.15). '
+ 'Test fraction = 1 - train - val.')
+ p.add_argument('--modalities', nargs='+', default=['s1', 's2'],
+ choices=list(MODALITIES),
+ help='Which modalities to compute stats for.')
+ p.add_argument('--skip-splits', action='store_true',
+ help='Reuse existing splits/{train,val,test}.txt.')
+ p.add_argument('--skip-stats', action='store_true',
+ help='Only generate splits; do not compute stats.')
+ p.add_argument('--seed', type=int, default=42,
+ help='Salt prepended to base-names when hashing. '
+ 'Change to get a different split assignment.')
+ return p.parse_args()
+
+
+# ---------------------------------------------------------------------------
+# Step 1: scan
+# ---------------------------------------------------------------------------
+def scan_basenames(data_root: str) -> List[str]:
+ label_dir = osp.join(data_root, LABEL_SUBDIR)
+ if not osp.isdir(label_dir):
+ raise SystemExit(f'LabelHand dir not found: {label_dir}')
+
+ bases = []
+ for fn in sorted(os.listdir(label_dir)):
+ if fn.endswith(LABEL_SUFFIX):
+ bases.append(fn[:-len(LABEL_SUFFIX)])
+
+ if not bases:
+ raise SystemExit(
+ f'No *{LABEL_SUFFIX} files found in {label_dir}. '
+ 'Check --data-root.')
+ return bases
+
+
+# ---------------------------------------------------------------------------
+# Step 2: splits
+# ---------------------------------------------------------------------------
+def assign_split(base: str, seed: int,
+ train_ratio: float, val_ratio: float) -> str:
+ """Hash-based deterministic split assignment.
+
+ Using a hash (instead of shuffling + slicing) means the assignment
+ is stable if the set of base-names grows - a tile that was in
+ ``train`` before stays in ``train``.
+ """
+ h = hashlib.md5(f'{seed}:{base}'.encode('utf-8')).hexdigest()
+ # first 8 hex chars -> 32-bit uint in [0, 2^32), normalized to [0, 1)
+ v = int(h[:8], 16) / 0x100000000
+ if v < train_ratio:
+ return 'train'
+ if v < train_ratio + val_ratio:
+ return 'val'
+ return 'test'
+
+
+def write_splits(data_root: str, bases: List[str],
+ train_ratio: float, val_ratio: float, seed: int
+ ) -> Tuple[List[str], List[str], List[str]]:
+ if train_ratio <= 0 or val_ratio < 0 or train_ratio + val_ratio >= 1.0:
+ raise SystemExit(
+ 'Invalid split ratios: need train_ratio > 0, val_ratio >= 0, '
+ 'train + val < 1.')
+
+ splits = {'train': [], 'val': [], 'test': []}
+ for base in bases:
+ splits[assign_split(base, seed, train_ratio, val_ratio)].append(base)
+
+ splits_dir = osp.join(data_root, 'splits')
+ os.makedirs(splits_dir, exist_ok=True)
+ for name in ('train', 'val', 'test'):
+ items = sorted(splits[name])
+ out_path = osp.join(splits_dir, f'{name}.txt')
+ with open(out_path, 'w') as f:
+ if items:
+ f.write('\n'.join(items) + '\n')
+ print(f'[splits] {name:5s}: {len(items):4d} -> {out_path}')
+
+ total = sum(len(v) for v in splits.values())
+ print(f'[splits] total: {total}')
+ return splits['train'], splits['val'], splits['test']
+
+
+def load_existing_splits(data_root: str
+ ) -> Tuple[List[str], List[str], List[str]]:
+ def read(name: str) -> List[str]:
+ p = osp.join(data_root, 'splits', f'{name}.txt')
+ if not osp.isfile(p):
+ raise SystemExit(
+ f'--skip-splits was set but {p} is missing. '
+ 'Run without --skip-splits first.')
+ with open(p) as f:
+ return [ln.strip() for ln in f if ln.strip()]
+ return read('train'), read('val'), read('test')
+
+
+# ---------------------------------------------------------------------------
+# Step 3: stats
+# ---------------------------------------------------------------------------
+def _hwc(img: np.ndarray) -> np.ndarray:
+ """Normalize raster shape to (H, W, C)."""
+ if img.ndim == 2:
+ return img[:, :, None]
+ # tifffile typically returns (C, H, W) for multi-band TIFFs -
+ # transpose when the leading axis is the smallest.
+ if img.ndim == 3 and img.shape[0] < img.shape[-1]:
+ return np.transpose(img, (1, 2, 0))
+ return img
+
+
+def compute_stats_for_bases(data_root: str, modality: str,
+ bases: List[str]
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
+ modal_cfg = MODALITIES[modality]
+ img_dir = osp.join(data_root, modal_cfg['subdir'])
+ label_dir = osp.join(data_root, LABEL_SUBDIR)
+ num_channels = modal_cfg['channels']
+
+ if not osp.isdir(img_dir):
+ raise SystemExit(
+ f'[{modality}] image dir not found: {img_dir}')
+
+ # Welford-style streaming accumulators (float64 for precision).
+ total_count = 0
+ total_sum = np.zeros(num_channels, dtype=np.float64)
+ total_sqsum = np.zeros(num_channels, dtype=np.float64)
+
+ n_missing_img = 0
+ n_missing_label = 0
+ n_channel_mismatch = 0
+
+ for i, base in enumerate(bases):
+ img_path = osp.join(img_dir, base + modal_cfg['suffix'])
+ lbl_path = osp.join(label_dir, base + LABEL_SUFFIX)
+
+ if not osp.isfile(img_path):
+ n_missing_img += 1
+ continue
+
+ img = _hwc(tifffile.imread(img_path)).astype(np.float64)
+ if img.shape[-1] != num_channels:
+ print(f' [warn] skip {base}: got {img.shape[-1]} channels '
+ f'(expected {num_channels})')
+ n_channel_mismatch += 1
+ continue
+
+ # Valid = finite everywhere AND not label nodata.
+ valid_mask = np.all(np.isfinite(img), axis=-1)
+ if osp.isfile(lbl_path):
+ lbl = np.squeeze(tifffile.imread(lbl_path))
+ if lbl.shape == img.shape[:2]:
+ valid_mask &= (lbl != -1)
+ else:
+ n_missing_label += 1
+
+ if not valid_mask.any():
+ continue
+
+ valid = img[valid_mask] # (N, C)
+ total_count += int(valid.shape[0])
+ total_sum += valid.sum(axis=0)
+ total_sqsum += (valid ** 2).sum(axis=0)
+
+ if (i + 1) % 50 == 0:
+ print(f' [{modality}] processed {i + 1}/{len(bases)} tiles, '
+ f'valid pixels so far: {total_count}')
+
+ if total_count == 0:
+ raise SystemExit(
+ f'[{modality}] no valid pixels across {len(bases)} tiles '
+ f'(missing imgs: {n_missing_img}, '
+ f'channel mismatches: {n_channel_mismatch}).')
+
+ mean = total_sum / total_count
+ var = np.clip(total_sqsum / total_count - mean ** 2, 0.0, None)
+ std = np.sqrt(var)
+
+ print(f' [{modality}] done. valid pixels={total_count}, '
+ f'missing imgs={n_missing_img}, '
+ f'missing labels={n_missing_label}, '
+ f'channel mismatches={n_channel_mismatch}')
+ return mean, std, total_count
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+def main():
+ args = parse_args()
+ data_root = args.data_root.rstrip('/')
+
+ if not osp.isdir(data_root):
+ raise SystemExit(f'--data-root not found: {data_root}')
+
+ # Step 1: scan -----------------------------------------------------
+ bases = scan_basenames(data_root)
+ print(f'Found {len(bases)} tiles under {data_root}/{LABEL_SUBDIR}/')
+
+ # Step 2: splits ---------------------------------------------------
+ if args.skip_splits:
+ train_bases, val_bases, test_bases = load_existing_splits(data_root)
+ print(f'[splits] reusing existing splits: '
+ f'{len(train_bases)}/{len(val_bases)}/{len(test_bases)}')
+ else:
+ train_bases, val_bases, test_bases = write_splits(
+ data_root, bases,
+ train_ratio=args.train_ratio,
+ val_ratio=args.val_ratio,
+ seed=args.seed,
+ )
+
+ # Step 3 & 4: stats ------------------------------------------------
+ if args.skip_stats:
+ print('\n[stats] skipped (--skip-stats)')
+ return
+
+ all_stats = {}
+ for modality in args.modalities:
+ print()
+ print(f'[stats] computing {modality} mean/std over '
+ f'{len(train_bases)} train tiles (NaN/Inf/label==-1 masked)')
+ mean, std, count = compute_stats_for_bases(
+ data_root, modality, train_bases)
+ all_stats[modality] = (mean, std, count)
+
+ print()
+ print('=' * 72)
+ print('Paste the following into NORM_CONFIGS in')
+ print(' mmseg/datasets/transforms/multimodal_pipelines.py')
+ print('(replacing the existing \'s1\' / \'s2\' entries)')
+ print('=' * 72)
+ for modality in args.modalities:
+ mean, std, count = all_stats[modality]
+ print(f' \'{modality}\': {{ # {count} valid train pixels')
+ print(f' \'mean\': {mean.tolist()},')
+ print(f' \'std\': {std.tolist()},')
+ print(' },')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/test.py b/tools/test.py
index 0d7f39b3a8..36c35d90f0 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -6,6 +6,11 @@
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
+from mmseg.utils import register_all_modules
+
+# register all modules in mmseg into the registries
+register_all_modules()
+
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args():
diff --git a/tools/test_full_metrics.py b/tools/test_full_metrics.py
new file mode 100644
index 0000000000..09a9299d75
--- /dev/null
+++ b/tools/test_full_metrics.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Test script with full metrics: Precision, Recall, OA, F1 Score, mIoU.
+
+This script extends the default test.py by automatically configuring the
+evaluator to compute all five metrics:
+ - mIoU (Mean Intersection over Union)
+ - mPrecision (Mean Precision)
+ - mRecall (Mean Recall)
+ - mFscore (F1 Score)
+ - aAcc (Overall Accuracy, OA)
+
+Usage:
+ python tools/test_full_metrics.py [options]
+
+Example:
+ python tools/test_full_metrics.py \\
+ ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \\
+ work_dirs/SAR/Deeplabv3+/best_mIoU_epoch_100.pth \\
+ --work-dir ./Result/SAR/Deeplabv3+ \\
+ --show-dir ./Result/SAR/Deeplabv3+/vis \\
+ --cfg-options visualizer.alpha=1.0
+"""
+import argparse
+import os
+import os.path as osp
+
+from mmengine.config import Config, DictAction
+from mmengine.runner import Runner
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='MMSeg test with full metrics '
+ '(Precision, Recall, OA, F1 Score, mIoU)')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('checkpoint', help='checkpoint file')
+ parser.add_argument(
+ '--work-dir',
+ help='if specified, the evaluation metric results will be dumped '
+ 'into the directory as json')
+ parser.add_argument(
+ '--out',
+ type=str,
+ help='The directory to save output prediction for offline evaluation')
+ parser.add_argument(
+ '--show', action='store_true', help='show prediction results')
+ parser.add_argument(
+ '--show-dir',
+ help='directory where painted images will be saved. '
+ 'If specified, it will be automatically saved '
+ 'to the work_dir/timestamp/show_dir')
+ parser.add_argument(
+ '--wait-time', type=float, default=2, help='the interval of show (s)')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument(
+ '--tta', action='store_true', help='Test time augmentation')
+ parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def trigger_visualization_hook(cfg, args):
+ default_hooks = cfg.default_hooks
+ if 'visualization' in default_hooks:
+ visualization_hook = default_hooks['visualization']
+ visualization_hook['draw'] = True
+ if args.show:
+ visualization_hook['show'] = True
+ visualization_hook['wait_time'] = args.wait_time
+ if args.show_dir:
+ visualizer = cfg.visualizer
+ visualizer['save_dir'] = args.show_dir
+ else:
+ raise RuntimeError(
+ 'VisualizationHook must be included in default_hooks. '
+ 'refer to usage '
+ '"visualization=dict(type=\'VisualizationHook\')"')
+
+ return cfg
+
+
+def main():
+ args = parse_args()
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ cfg.launcher = args.launcher
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+
+ cfg.load_from = args.checkpoint
+
+ if args.show or args.show_dir:
+ cfg = trigger_visualization_hook(cfg, args)
+
+ if args.tta:
+ cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
+ cfg.tta_model.module = cfg.model
+ cfg.model = cfg.tta_model
+
+ # add output_dir in metric
+ if args.out is not None:
+ cfg.test_evaluator['output_dir'] = args.out
+ cfg.test_evaluator['keep_results'] = True
+
+ # Override test_evaluator to compute full metrics:
+ # mIoU, mFscore (which includes Precision, Recall, F1), and aAcc (OA)
+ cfg.test_evaluator['type'] = 'IoUMetric'
+ cfg.test_evaluator['iou_metrics'] = ['mIoU', 'mFscore']
+
+ # build the runner from config
+ runner = Runner.from_cfg(cfg)
+
+ # start testing
+ runner.test()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/train.py b/tools/train.py
index 10fdaa1874..08ce040c08 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -9,6 +9,7 @@
from mmengine.runner import Runner
from mmseg.registry import RUNNERS
+from mmseg.utils import register_all_modules
def parse_args():
@@ -54,6 +55,9 @@ def parse_args():
def main():
args = parse_args()
+ # register all modules
+ register_all_modules(init_default_scope=True)
+
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher