diff --git a/8to2.py b/8to2.py new file mode 100644 index 0000000000..4540a9524f --- /dev/null +++ b/8to2.py @@ -0,0 +1,170 @@ +import os +import numpy as np +import rasterio +from shutil import copy2 +from tqdm import tqdm + + +def keep_last_two_bands(folder_path, backup=True, backup_suffix='_8band_backup'): + """ + 将所有TIF文件只保留最后两个波段,覆盖原文件 + + 参数: + folder_path: 要处理的根目录路径 + backup: 是否备份原文件 + backup_suffix: 备份文件后缀 + """ + total_files = 0 + files_processed = 0 + files_skipped = 0 + tif_extensions = ['.tif', '.tiff', '.TIF', '.TIFF'] + + # 收集所有tif文件 + print(f"扫描路径: {folder_path}") + tif_files = [] + for root, dirs, files in os.walk(folder_path): + for file in files: + if any(file.endswith(ext) for ext in tif_extensions): + tif_files.append(os.path.join(root, file)) + + total_files = len(tif_files) + print(f"找到 {total_files} 个TIF文件\n") + + # 第一步:检查文件波段数 + print("步骤1: 检查文件波段数...") + print("=" * 80) + + band_info = {} + files_to_process = [] + + try: + iterator = tqdm(tif_files, desc="扫描文件") + except: + iterator = tif_files + + for file_path in iterator: + try: + with rasterio.open(file_path) as src: + num_bands = src.count + rel_path = os.path.relpath(file_path, folder_path) + + if num_bands not in band_info: + band_info[num_bands] = [] + band_info[num_bands].append(rel_path) + + if num_bands >= 2: # 只处理波段数>=2的文件 + files_to_process.append((file_path, num_bands)) + else: + files_skipped += 1 + + except Exception as e: + print(f"\n读取错误: {os.path.basename(file_path)}") + print(f" 错误: {e}") + + # 显示波段数统计 + print("\n波段数统计:") + for num_bands in sorted(band_info.keys()): + print(f" {num_bands}波段: {len(band_info[num_bands])} 个文件") + + print(f"\n将处理 {len(files_to_process)} 个文件(波段数>=2)") + print(f"跳过 {files_skipped} 个文件(波段数<2)") + + if len(files_to_process) == 0: + print("没有需要处理的文件!") + return + + # 第二步:确认并处理 + print("\n" + "=" * 80) + print("步骤2: 处理文件") + print("操作: 只保留最后两个波段,覆盖原文件") + + if backup: + print(f"将创建备份文件(后缀: {backup_suffix})") + else: + print("警告: 不会创建备份!") + + user_input = input("\n是否继续?(y/n): ") + if user_input.lower() != 'y': + print("操作已取消") + return + + print("\n开始处理...") + + try: + iterator = tqdm(files_to_process, desc="处理文件") + except: + iterator = files_to_process + + for file_path, original_bands in iterator: + try: + # 备份原文件 + if backup: + backup_path = file_path + backup_suffix + copy2(file_path, backup_path) + + # 读取文件 + with rasterio.open(file_path) as src: + # 保存元数据并修改波段数 + meta = src.meta.copy() + meta['count'] = 2 # 修改为2波段 + + # 读取最后两个波段 + band_n_1 = src.read(original_bands - 1) # 倒数第二个波段 + band_n = src.read(original_bands) # 最后一个波段 + + # 创建临时文件 + temp_path = file_path + '.tmp' + + # 写入新文件(只有2个波段) + with rasterio.open(temp_path, 'w', **meta) as dst: + dst.write(band_n_1, 1) # 写入第1个波段 + dst.write(band_n, 2) # 写入第2个波段 + + # 替换原文件 + os.replace(temp_path, file_path) + files_processed += 1 + + except Exception as e: + print(f"\n处理错误: {os.path.basename(file_path)}") + print(f" 原波段数: {original_bands}") + print(f" 错误: {e}") + # 清理临时文件 + temp_path = file_path + '.tmp' + if os.path.exists(temp_path): + os.remove(temp_path) + + # 输出结果 + print("\n" + "=" * 80) + print("处理完成!") + print(f"总文件数: {total_files}") + print(f"成功处理的文件: {files_processed}") + print(f"跳过的文件: {files_skipped}") + + if backup: + print(f"\n原始文件已备份(后缀: {backup_suffix})") + print("如果确认无误,可以使用以下命令删除备份文件:") + print(f" find {folder_path} -name '*{backup_suffix}' -delete") + + # 验证处理结果 + print("\n步骤3: 验证处理结果...") + print("检查前3个文件的波段数...") + + for i, (file_path, _) in enumerate(files_to_process[:3]): + try: + with rasterio.open(file_path) as src: + rel_path = os.path.relpath(file_path, folder_path) + print(f" ✓ {rel_path}: {src.count} 波段") + except Exception as e: + print(f" ✗ 验证失败: {e}") + + if i >= 2: # 只检查前3个 + break + + +if __name__ == "__main__": + folder_path = "/mnt/d/Project/Code/Floodnet/data/mixed_dataset/SAR/" + + # 执行处理 + # backup=True 会备份原文件(推荐) + # backup=False 直接覆盖(不推荐) + keep_last_two_bands(folder_path, backup=False, backup_suffix='_8band_backup') \ No newline at end of file diff --git a/command b/command new file mode 100644 index 0000000000..f774dc5546 --- /dev/null +++ b/command @@ -0,0 +1,148 @@ +git clone https://VIncentmuyi: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@github.com/VIncentmuyi/Floodnet.git + +/解决python不搜索本文件夹包的问题 +ModuleNotFoundError: No module named 'mmseg' +export PYTHONPATH=.:$PYTHONPATH + +ps -ef | grep python +pkill -9 python +git fetch origin +git reset --hard origin/main +git clean -fd # 删除未跟踪的文件 + +python tools/train.py ./configs/deeplabv3plus/Deeplabv3+UAVflood.py --work-dir work_dirs/SAR/Deeplabv3+ +python tools/test_full_metrics.py ./configs/deeplabv3plus/Deeplabv3+UAVflood.py work_dirs/SAR/Deeplabv3+/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/Deeplabv3+ --show-dir ./Result/SAR/Deeplabv3+/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \ + work_dirs/UAVflood/Deeplabv3+/best_val_mIoU_iter_20000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \ + --shape 8 256 256 + +python tools/train.py ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py --work-dir work_dirs/SAR/segformer +python tools/test_full_metrics.py ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py work_dirs/SAR/segformer/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/segformer/ --show-dir ./Result/SAR/segformer/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py\ + work_dirs/UAVflood/segformer/best_val_mIoU_iter_40000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py \ + --shape 5 256 256 + +python tools/train.py ./configs/unet/Unet-Uavflood.py --work-dir work_dirs/SARflood/unet +python tools/test_full_metrics.py ./configs/unet/Unet-Uavflood.py work_dirs/SAR/unet/best_mIoU_epoch_90.pth --work-dir ./Result/SAR/unet/ --show-dir ./Result/SAR/unet/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/unet/Unet-Uavflood.py \ + work_dirs/UAVflood/unet/best_val_mIoU_iter_40000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/unet/Unet-Uavflood.py\ + --shape 8 256 256 + +python tools/train.py ./configs/mae/mae-base-Uavflood.py --work-dir work_dirs/UAVflood/mae +python tools/test_full_metrics.py ./configs/mae/mae-base-Uavflood.py work_dirs/UAVflood/mae/best_val_mIoU_iter_20000.pth --work-dir ./Result/UAV/mae/ --show-dir ./Result/UAV/mae/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/mae/mae-base-Uavflood.py\ + work_dirs/UAVflood/mae/best_val_mIoU_iter_28000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/mae/mae-base-Uavflood.py\ + --shape 5 256 256 + +python tools/train.py ./configs/vit/vit-Uavflood.py --work-dir work_dirs/SAR/vit +python tools/test_full_metrics.py ./configs/vit/vit-Uavflood.py work_dirs/SAR/vit/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/vit/ --show-dir ./Result/SAR/vit/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/vit/vit-Uavflood.py\ + work_dirs/UAVflood/vit/best_val_mIoU_iter_36000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/vit/vit-Uavflood.py\ + --shape 3 256 256 + +python tools/train.py ./configs/beit/beit-Uavflood.py --work-dir work_dirs/SARflood/beit +python tools/test_full_metrics.py ./configs/beit/beit-Uavflood.py work_dirs/UAVflood/beit/best_val_mIoU_iter_16000.pth --work-dir ./Result/UAV/beit/ --show-dir ./Result/UAV/beit/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/beit/beit-Uavflood.py\ + work_dirs/UAVflood/beit/best_val_mIoU_iter_40000.pth\ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/beit/beit-Uavflood.py\ + --shape 3 256 256 + +python tools/train.py ./configs/convnext/convnext-base-uavflood.py --work-dir work_dirs/SAR/convnext +python tools/test_full_metrics.py ./configs/convnext/convnext-base-uavflood.py work_dirs/SAR/convnext/best_mIoU_epoch_100.pth --work-dir ./Result/SAR/convnext/ --show-dir ./Result/SAR/convnext/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/convnext/convnext-base-uavflood.py\ + work_dirs/GFflood/convnext/best_val_mIoU_iter_20000.pth \ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/convnext/convnext-base-uavflood.py\ + --shape 5 256 256 + +python tools/train.py ./configs/swin/Swin-uavflood-256x256.py --work-dir work_dirs/SAR/Swin --cfg-options seed=42 +python tools/test_full_metrics.py ./configs/swin/Swin-uavflood-256x256.py work_dirs/GFflood/Swin/best_val_mIoU_iter_16000.pth --work-dir ./Result/GF/swin/ --show-dir ./Result/GF/swin/vis --cfg-options visualizer.alpha=1.0 +python tools/analysis_tools/benchmark.py \ + ./configs/swin/Swin-uavflood-256x256.py\ + work_dirs/GFflood/Swin/best_val_mIoU_iter_16000.pth\ + --repeat-times 3 +python tools/analysis_tools/get_flops.py \ + ./configs/swin/Swin-uavflood-256x256.py\ + --shape 5 256 256 + +python tools/train.py ./configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py --work-dir work_dirs/floodnet/SwinmoeB/655 --cfg-options seed=42 +python tools/test_full_metrics.py ./configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py work_dirs/floodnet/SwinmoeB/best_mIoU_epoch_100.pth --work-dir ./Result/Floodnet/SAR/ --show-dir ./Result/Floodnet/SAR/vis --cfg-options visualizer.alpha=1.0 + +python tools/test_full_metrics.py \ + configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \ + work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \ + --cfg-options test_dataloader.dataset.filter_modality=sar \ + --work-dir ./Result/Floodnet/SAR/ \ + --show-dir ./Result/Floodnet/SAR/vis --cfg-options visualizer.alpha=1.0 + + +python tools/train.py \ + configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \ + --work-dir work_dirs/floodnet/SwinmoeB_sar_only \ + --cfg-options seed=42 + +python tools/test_full_metrics.py \ + configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \ + work_dirs/floodnet/SwinmoeB_sar_only/best_mIoU_epoch_100.pth \ + + + python tools/train.py configs/floodnet/continue_train_150ep.py --work-dir work_dirs/floodnet/SwinmoeB/655 --resume --cfg-options load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth" + + python tools/analysis_tools/visualize_expert_routing.py \ + configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \ + work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \ + --output-dir work_dirs/figures/expert_routing \ + --num-samples 50 + +python tools/train.py configs/floodnet/finetune_single_modal.py \ + --work-dir work_dirs/generalization/LY-train-station/ \ + --cfg-options \ + train_dataloader.dataset.data_root="data/LY-train-station/" \ + val_dataloader.dataset.data_root="data/LY-train-station/" \ + test_dataloader.dataset.data_root="data/LY-train-station/" + +python tools/test.py \ + configs/floodnet/finetune_single_modal.py \ + work_dirs/generalization/LY-train-station/best_mIoU_epoch_50.pth \ + --work-dir work_dirs/generalization/LY-train-station/test_results \ + --cfg-options \ + test_dataloader.dataset.data_root="data/LY-train-station/" \ + "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \ + --show-dir work_dirs/generalization/LY-train-station/test_results/vis \ + --out work_dirs/generalization/LY-train-station/test_results/predictions + +python tools/predict_large_tif.py \ + configs/floodnet/finetune_single_modal.py \ + work_dirs/generalization/LY-train-station/best_mIoU_epoch_50.pth \ + --input data/luoyuan/result.tif \ + --output data/luoyuan/prediction.tif \ + --tile-size 512 \ + --overlap 64 \ + --modal rgb \ + --bands 0 1 2 \ + --batch-size 16 \ No newline at end of file diff --git a/configs/_base_/datasets/GFflood.py b/configs/_base_/datasets/GFflood.py new file mode 100644 index 0000000000..7478691b16 --- /dev/null +++ b/configs/_base_/datasets/GFflood.py @@ -0,0 +1,83 @@ +# GF 5-channel dataset settings +dataset_type = 'UAVfloodDataset' +data_root = '../Floodnet/data/mixed_dataset/GF/' +crop_size = (256, 256) + +# GF 5-channel normalization parameters +train_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict( + type='RandomResize', + scale=(2048, 512), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + # PhotoMetricDistortion removed - not suitable for multispectral imagery + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackSegInputs') +] + +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], + [dict(type='LoadAnnotations')], + [dict(type='PackSegInputs')] + ]) +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误 + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline)) + +test_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline)) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/_base_/datasets/SARflood.py b/configs/_base_/datasets/SARflood.py new file mode 100644 index 0000000000..c4e4ac9742 --- /dev/null +++ b/configs/_base_/datasets/SARflood.py @@ -0,0 +1,83 @@ +# SAR 8-channel dataset settings +dataset_type = 'UAVfloodDataset' +data_root = '../Floodnet/data/mixed_dataset/SAR/' +crop_size = (256, 256) + +# SAR 8-channel normalization parameters +train_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict( + type='RandomResize', + scale=(2048, 512), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + # PhotoMetricDistortion removed - not suitable for SAR imagery + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackSegInputs') +] + +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], + [dict(type='LoadAnnotations')], + [dict(type='PackSegInputs')] + ]) +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误 + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline)) + +test_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline)) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/configs/_base_/datasets/UAVflood.py b/configs/_base_/datasets/UAVflood.py new file mode 100644 index 0000000000..4f8b700fb6 --- /dev/null +++ b/configs/_base_/datasets/UAVflood.py @@ -0,0 +1,77 @@ +# dataset settings +dataset_type = 'UAVfloodDataset' +data_root = '../Floodnet/data/mixed_dataset/SAR/' +crop_size = (256, 256) +train_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict( + type='RandomResize', + scale=(2048, 512), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + # PhotoMetricDistortion removed - not compatible with multi-band (8-channel) images + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackSegInputs') +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadMultiBandTiffFromFile'), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, # 丢弃最后一个不完整的batch,避免BatchNorm错误 + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline)) +test_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline)) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + diff --git a/configs/_base_/models/deeplabv3_unet_s5-d16.py b/configs/_base_/models/deeplabv3_unet_s5-d16.py index 92df52c35d..f373110aa6 100644 --- a/configs/_base_/models/deeplabv3_unet_s5-d16.py +++ b/configs/_base_/models/deeplabv3_unet_s5-d16.py @@ -1,19 +1,24 @@ -# model settings +# UNet for GF 5-channel imagery norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, # Not RGB imagery pad_val=0, seg_pad_val=255) + model = dict( type='EncoderDecoder', data_preprocessor=data_preprocessor, pretrained=None, backbone=dict( type='UNet', - in_channels=3, + in_channels=8, # 5-channel input base_channels=64, num_stages=5, strides=(1, 1, 1, 1, 1), @@ -53,6 +58,5 @@ align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), - # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/configs/_base_/models/deeplabv3plus_r50-d8.py b/configs/_base_/models/deeplabv3plus_r50-d8.py index 74dbed5593..19af44fe17 100644 --- a/configs/_base_/models/deeplabv3plus_r50-d8.py +++ b/configs/_base_/models/deeplabv3plus_r50-d8.py @@ -2,17 +2,22 @@ norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean= [117.926186, 117.568402, 97.217239], + #std= [53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( type='EncoderDecoder', data_preprocessor=data_preprocessor, - pretrained='open-mmlab://resnet50_v1c', + #pretrained='open-mmlab://resnet50_v1c', backbone=dict( type='ResNetV1c', + in_channels=8, depth=50, num_stages=4, out_indices=(0, 1, 2, 3), @@ -31,7 +36,7 @@ c1_in_channels=256, c1_channels=48, dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( @@ -44,7 +49,7 @@ num_convs=1, concat_input=False, dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( diff --git a/configs/_base_/models/segformer_mit-b0.py b/configs/_base_/models/segformer_mit-b0.py index 46841adc07..0c9dbafd83 100644 --- a/configs/_base_/models/segformer_mit-b0.py +++ b/configs/_base_/models/segformer_mit-b0.py @@ -1,10 +1,14 @@ -# model settings +# model settings 这实际上是b5的参数量 norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( @@ -13,13 +17,13 @@ pretrained=None, backbone=dict( type='MixVisionTransformer', - in_channels=3, - embed_dims=32, + in_channels=8, + embed_dims=64, # B0: 32 -> B5: 64 num_stages=4, - num_layers=[2, 2, 2, 2], - num_heads=[1, 2, 5, 8], - patch_sizes=[7, 3, 3, 3], - sr_ratios=[8, 4, 2, 1], + num_layers=[3, 6, 40, 3], # B0: [2, 2, 2, 2] -> B5: [3, 6, 40, 3] + num_heads=[1, 2, 5, 8], # 保持不变 + patch_sizes=[7, 3, 3, 3], # 保持不变 + sr_ratios=[8, 4, 2, 1], # 保持不变 out_indices=(0, 1, 2, 3), mlp_ratio=4, qkv_bias=True, @@ -28,11 +32,11 @@ drop_path_rate=0.1), decode_head=dict( type='SegformerHead', - in_channels=[32, 64, 160, 256], + in_channels=[64, 128, 320, 512], # B0: [32, 64, 160, 256] -> B5: [64, 128, 320, 512] in_index=[0, 1, 2, 3], - channels=256, + channels=768, # B0: 256 -> B5: 768 dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( diff --git a/configs/_base_/models/upernet_beit.py b/configs/_base_/models/upernet_beit.py index 691e288dbf..d1ae45e11e 100644 --- a/configs/_base_/models/upernet_beit.py +++ b/configs/_base_/models/upernet_beit.py @@ -1,9 +1,13 @@ norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean=[432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std=[97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( @@ -12,9 +16,9 @@ pretrained=None, backbone=dict( type='BEiT', - img_size=(640, 640), + img_size=(256, 256), patch_size=16, - in_channels=3, + in_channels=8, embed_dims=768, num_layers=12, num_heads=12, @@ -35,7 +39,7 @@ pool_scales=(1, 2, 3, 6), channels=768, dropout_ratio=0.1, - num_classes=150, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( @@ -48,7 +52,7 @@ num_convs=1, concat_input=False, dropout_ratio=0.1, - num_classes=150, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( diff --git a/configs/_base_/models/upernet_convnext.py b/configs/_base_/models/upernet_convnext.py index 958994c91e..39d59b4963 100644 --- a/configs/_base_/models/upernet_convnext.py +++ b/configs/_base_/models/upernet_convnext.py @@ -3,9 +3,13 @@ checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth' # noqa data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( @@ -15,13 +19,15 @@ backbone=dict( type='mmpretrain.ConvNeXt', arch='base', + in_channels=8, out_indices=[0, 1, 2, 3], drop_path_rate=0.4, layer_scale_init_value=1.0, gap_before_final_norm=False, - init_cfg=dict( - type='Pretrained', checkpoint=checkpoint_file, - prefix='backbone.')), + #init_cfg=dict( + #type='Pretrained', checkpoint=checkpoint_file, + #prefix='backbone.') + ), decode_head=dict( type='UPerHead', in_channels=[128, 256, 512, 1024], diff --git a/configs/_base_/models/upernet_mae.py b/configs/_base_/models/upernet_mae.py index b833b67645..a7d6e6b242 100644 --- a/configs/_base_/models/upernet_mae.py +++ b/configs/_base_/models/upernet_mae.py @@ -1,9 +1,13 @@ norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + mean=[117.926186, 117.568402, 97.217239], + std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + # mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + # std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( @@ -12,7 +16,7 @@ pretrained=None, backbone=dict( type='MAE', - img_size=(640, 640), + img_size=(256, 256), patch_size=16, in_channels=3, embed_dims=768, @@ -47,7 +51,7 @@ num_convs=1, concat_input=False, dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( diff --git a/configs/_base_/models/upernet_vit-b16_ln_mln.py b/configs/_base_/models/upernet_vit-b16_ln_mln.py index 776525ad98..0d43c1ffae 100644 --- a/configs/_base_/models/upernet_vit-b16_ln_mln.py +++ b/configs/_base_/models/upernet_vit-b16_ln_mln.py @@ -2,20 +2,24 @@ norm_cfg = dict(type='SyncBN', requires_grad=True) data_preprocessor = dict( type='SegDataPreProcessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + #mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + #std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, pad_val=0, seg_pad_val=255) model = dict( type='EncoderDecoder', data_preprocessor=data_preprocessor, - pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', + #pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', backbone=dict( type='VisionTransformer', - img_size=(512, 512), + img_size=(256, 256), patch_size=16, - in_channels=3, + in_channels=8, embed_dims=768, num_layers=12, num_heads=12, @@ -42,7 +46,7 @@ pool_scales=(1, 2, 3, 6), channels=512, dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( @@ -55,7 +59,7 @@ num_convs=1, concat_input=False, dropout_ratio=0.1, - num_classes=19, + num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( diff --git a/configs/_base_/schedules/schedule_20k.py b/configs/_base_/schedules/schedule_20k.py index e809e3e880..45d8942e0b 100644 --- a/configs/_base_/schedules/schedule_20k.py +++ b/configs/_base_/schedules/schedule_20k.py @@ -19,6 +19,6 @@ timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000, max_keep_ckpts=1, save_best='val/mIoU'), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='SegVisualizationHook')) diff --git a/configs/_base_/schedules/schedule_40k.py b/configs/_base_/schedules/schedule_40k.py index 4b823339a2..a2b7269844 100644 --- a/configs/_base_/schedules/schedule_40k.py +++ b/configs/_base_/schedules/schedule_40k.py @@ -17,8 +17,8 @@ test_cfg = dict(type='TestLoop') default_hooks = dict( timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=False), param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000, max_keep_ckpts=1, save_best='val/mIoU'), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='SegVisualizationHook')) diff --git a/configs/beit/beit-UAVflood.py b/configs/beit/beit-UAVflood.py new file mode 100644 index 0000000000..16cb337835 --- /dev/null +++ b/configs/beit/beit-UAVflood.py @@ -0,0 +1,43 @@ +_base_ = [ + '../_base_/models/upernet_beit.py', '../_base_/datasets/SARflood.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py' +] +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict( + img_size=(256, 256)), + decode_head=dict( + num_classes=2), + auxiliary_head=dict( + num_classes=2), + test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170))) + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05), + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9)) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + power=1.0, + begin=1500, + end=20000, + eta_min=0.0, + by_epoch=False, + ) +] + + +train_dataloader = dict(batch_size=8, num_workers=8) +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader + + diff --git a/configs/convnext/convnext-base-uavflood.py b/configs/convnext/convnext-base-uavflood.py new file mode 100644 index 0000000000..24c4924644 --- /dev/null +++ b/configs/convnext/convnext-base-uavflood.py @@ -0,0 +1,91 @@ +_base_ = [ + '../_base_/models/upernet_convnext.py', + '../_base_/datasets/SARflood.py', + '../_base_/default_runtime.py' +] + +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) + +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict( + in_channels=[128, 256, 512, 1024], + num_classes=2 # UAVflood二分类:background, flood + ), + auxiliary_head=dict( + in_channels=512, + num_classes=2 + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +# 使用AMP优化器包装器以提高训练效率 +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05 + ), + paramwise_cfg={ + 'decay_rate': 0.9, + 'decay_type': 'stage_wise', + 'num_layers': 12 + }, + constructor='LearningRateDecayOptimizerConstructor', + loss_scale='dynamic' +) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5 # warmup前5个epoch + ), + dict( + type='PolyLR', + power=1.0, + begin=5, + end=100, + eta_min=0.0, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) + +# 数据加载器配置,适配UAVflood数据集 +train_dataloader = dict(batch_size=8, num_workers=8) +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader + +# 可选:设置随机种子以提高可复现性 +randomness = dict( + seed=42, + deterministic=False, +) \ No newline at end of file diff --git a/configs/deeplabv3plus/Deeplabv3+UAVflood.py b/configs/deeplabv3plus/Deeplabv3+UAVflood.py new file mode 100644 index 0000000000..d22cc9fc3b --- /dev/null +++ b/configs/deeplabv3plus/Deeplabv3+UAVflood.py @@ -0,0 +1,66 @@ +_base_ = [ + '../_base_/models/deeplabv3plus_r50-d8.py', + '../_base_/datasets/UAVflood.py', + '../_base_/default_runtime.py' +] +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=dict(num_classes=2)) + +# 优化器配置 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005), + clip_grad=None) + +# 修改为epoch训练,训练100个epoch +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5), # warmup前5个epoch + dict( + type='PolyLR', + eta_min=1e-4, + power=0.9, + begin=5, + end=100, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +randomness = dict( + seed=42, + deterministic=False, # 如需完全可复现,设为True +) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_e6_k1.py b/configs/floodnet/ablations/ablation_e6_k1.py new file mode 100644 index 0000000000..dad0e3d267 --- /dev/null +++ b/configs/floodnet/ablations/ablation_e6_k1.py @@ -0,0 +1,180 @@ +""" +MoE Hyperparameter Study: 6 experts, top_k=1 +Table 3 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_e6_k1.py \ + --work-dir work_dirs/moe_hyper/e6_k1/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 6 +top_k = 1 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, use_modal_bias=True, + moe_balance_weight=1.0, moe_diversity_weight=0.1, + multi_tasks_reweight=None, decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_e6_k2.py b/configs/floodnet/ablations/ablation_e6_k2.py new file mode 100644 index 0000000000..ae0dcb467d --- /dev/null +++ b/configs/floodnet/ablations/ablation_e6_k2.py @@ -0,0 +1,180 @@ +""" +MoE Hyperparameter Study: 6 experts, top_k=2 +Table 3 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_e6_k2.py \ + --work-dir work_dirs/moe_hyper/e6_k2/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 6 +top_k = 2 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, use_modal_bias=True, + moe_balance_weight=1.0, moe_diversity_weight=0.1, + multi_tasks_reweight=None, decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_e6_k3.py b/configs/floodnet/ablations/ablation_e6_k3.py new file mode 100644 index 0000000000..8527debdc2 --- /dev/null +++ b/configs/floodnet/ablations/ablation_e6_k3.py @@ -0,0 +1,180 @@ +""" +MoE Hyperparameter Study: 6 experts, top_k=3 +Table 3 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_e6_k3.py \ + --work-dir work_dirs/moe_hyper/e6_k3/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 6 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, use_modal_bias=True, + moe_balance_weight=1.0, moe_diversity_weight=0.1, + multi_tasks_reweight=None, decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_e8_k1.py b/configs/floodnet/ablations/ablation_e8_k1.py new file mode 100644 index 0000000000..f084f5f53b --- /dev/null +++ b/configs/floodnet/ablations/ablation_e8_k1.py @@ -0,0 +1,180 @@ +""" +MoE Hyperparameter Study: 8 experts, top_k=1 +Table 3 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_e8_k1.py \ + --work-dir work_dirs/moe_hyper/e8_k1/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 1 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, use_modal_bias=True, + moe_balance_weight=1.0, moe_diversity_weight=0.1, + multi_tasks_reweight=None, decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_e8_k2.py b/configs/floodnet/ablations/ablation_e8_k2.py new file mode 100644 index 0000000000..1b90ffae68 --- /dev/null +++ b/configs/floodnet/ablations/ablation_e8_k2.py @@ -0,0 +1,180 @@ +""" +MoE Hyperparameter Study: 8 experts, top_k=2 +Table 3 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_e8_k2.py \ + --work-dir work_dirs/moe_hyper/e8_k2/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 2 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, use_modal_bias=True, + moe_balance_weight=1.0, moe_diversity_weight=0.1, + multi_tasks_reweight=None, decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_experts_16.py b/configs/floodnet/ablations/ablation_experts_16.py new file mode 100644 index 0000000000..30c3f39ae5 --- /dev/null +++ b/configs/floodnet/ablations/ablation_experts_16.py @@ -0,0 +1,186 @@ +""" +Ablation Study: 16 Experts (vs default 8) +Increases MoE capacity to study effect of expert count. + +Table 3 Row: num_experts=16, top_k=4 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_experts_16.py \ + --work-dir work_dirs/ablations/experts_16/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 16 # <<< 16 instead of 8 +top_k = 4 # <<< 4 instead of 3 (keep ~25% ratio) +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_experts_4.py b/configs/floodnet/ablations/ablation_experts_4.py new file mode 100644 index 0000000000..f9ce12078c --- /dev/null +++ b/configs/floodnet/ablations/ablation_experts_4.py @@ -0,0 +1,186 @@ +""" +Ablation Study: 4 Experts (vs default 8) +Reduces MoE capacity to study effect of expert count. + +Table 3 Row: num_experts=4, top_k=2 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_experts_4.py \ + --work-dir work_dirs/ablations/experts_4/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 4 # <<< 4 instead of 8 +top_k = 2 # <<< 2 instead of 3 (keep ~50% ratio) +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_gf_only.py b/configs/floodnet/ablations/ablation_gf_only.py new file mode 100644 index 0000000000..38c4e5fc4e --- /dev/null +++ b/configs/floodnet/ablations/ablation_gf_only.py @@ -0,0 +1,187 @@ +""" +Ablation Study: GaoFen-Only Training +Single-modal baseline using only GaoFen satellite data. + +Table 4 Row: GF-only + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_gf_only.py \ + --work-dir work_dirs/ablations/gf_only/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# GF-only: filter_modality='GF' +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='GF', + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='GF', + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='GF', + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_no_diversity_loss.py b/configs/floodnet/ablations/ablation_no_diversity_loss.py new file mode 100644 index 0000000000..e28bca1a9a --- /dev/null +++ b/configs/floodnet/ablations/ablation_no_diversity_loss.py @@ -0,0 +1,187 @@ +""" +Ablation Study: No Expert Diversity Loss +MoE enabled with modal bias but expert diversity loss disabled. + +Table 2 Row (e): w/o Expert Diversity Loss + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_no_diversity_loss.py \ + --work-dir work_dirs/ablations/no_diversity_loss/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.0, # <<< DISABLED + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, + use_expert_diversity_loss=False, # <<< DISABLED + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_no_modal_bias.py b/configs/floodnet/ablations/ablation_no_modal_bias.py new file mode 100644 index 0000000000..dffd5e8032 --- /dev/null +++ b/configs/floodnet/ablations/ablation_no_modal_bias.py @@ -0,0 +1,185 @@ +""" +Ablation Study: No Modal Bias +MoE is enabled but modal-specific routing bias is disabled. + +Table 2 Row (c): w/o Modal Bias + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_no_modal_bias.py \ + --work-dir work_dirs/ablations/no_modal_bias/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=False, # <<< DISABLED + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_no_modal_specific_stem.py b/configs/floodnet/ablations/ablation_no_modal_specific_stem.py new file mode 100644 index 0000000000..6f16ab855b --- /dev/null +++ b/configs/floodnet/ablations/ablation_no_modal_specific_stem.py @@ -0,0 +1,187 @@ +""" +Ablation Study: No ModalSpecificStem (Unified Patch Embedding) +Uses a single shared Conv2d with zero-padding instead of per-modal patch embedding. + +Table 2 Row: w/o ModalSpecificStem + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_no_modal_specific_stem.py \ + --work-dir work_dirs/ablations/no_modal_specific_stem/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_modal_specific_stem=False, # <<< DISABLED: use UnifiedPatchEmbed + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_no_moe.py b/configs/floodnet/ablations/ablation_no_moe.py new file mode 100644 index 0000000000..c18e7bf0a3 --- /dev/null +++ b/configs/floodnet/ablations/ablation_no_moe.py @@ -0,0 +1,257 @@ +""" +Ablation Study: No MoE (Standard Swin-Base + UPerNet baseline) +All MoE components disabled - uses standard FFN instead of MoE FFN. + +Table 2 Row (b): w/o MoE + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_no_moe.py \ + --work-dir work_dirs/ablations/no_moe/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +# ==================== Dataset Definition ==================== +DATASETS_CONFIG = dict( + names=['sar', 'rgb', 'GF'], +) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', + 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', + 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', + 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +# ==================== MoE DISABLED ==================== +depths = [2, 2, 18, 2] + +# ==================== Model Config ==================== +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', + pad_val=0, + seg_pad_val=255, + size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=False, # <<< DISABLED + use_modal_bias=False, # No MoE means no modal bias + moe_balance_weight=0.0, + moe_diversity_weight=0.0, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + + pretrain_img_size=224, + patch_size=4, + embed_dims=128, + depths=depths, + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=[0, 1, 2, 3], + + # ---- MoE DISABLED ---- + use_moe=False, + num_experts=8, + num_shared_experts_config={0: 0, 1: 0, 2: 0, 3: 0}, + top_k=3, + noisy_gating=False, + MoE_Block_inds=[[], [], [], []], + use_expert_diversity_loss=False, + + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', + in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0 + ) + ), + + auxiliary_head=dict( + type='FCNHead', + in_channels=512, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4 + ) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +# ==================== Dataset Config ==================== +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', + batch_size=16, + ), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', + seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='val/images', + seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='test/images', + seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +# ==================== Optimizer Config ==================== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'patch_embed': dict(lr_mult=2.0), + 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), + } + ) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_no_shared_experts.py b/configs/floodnet/ablations/ablation_no_shared_experts.py new file mode 100644 index 0000000000..60c0fa7090 --- /dev/null +++ b/configs/floodnet/ablations/ablation_no_shared_experts.py @@ -0,0 +1,186 @@ +""" +Ablation Study: No Shared Experts +MoE enabled but shared experts removed from all stages. + +Table 2 Row (d): w/o Shared Experts + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_no_shared_experts.py \ + --work-dir work_dirs/ablations/no_shared_experts/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 0, 3: 0} # <<< ALL ZEROS +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_rgb_only.py b/configs/floodnet/ablations/ablation_rgb_only.py new file mode 100644 index 0000000000..070b6114a0 --- /dev/null +++ b/configs/floodnet/ablations/ablation_rgb_only.py @@ -0,0 +1,187 @@ +""" +Ablation Study: RGB-Only Training +Single-modal baseline using only RGB optical data. + +Table 4 Row: RGB-only + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_rgb_only.py \ + --work-dir work_dirs/ablations/rgb_only/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# RGB-only: filter_modality='rgb' +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='rgb', + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='rgb', + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='rgb', + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_sar_only.py b/configs/floodnet/ablations/ablation_sar_only.py new file mode 100644 index 0000000000..84f2ffd2bd --- /dev/null +++ b/configs/floodnet/ablations/ablation_sar_only.py @@ -0,0 +1,187 @@ +""" +Ablation Study: SAR-Only Training +Single-modal baseline using only Synthetic Aperture Radar (SAR) data. + +Table 4 Row: SAR-only + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_sar_only.py \ + --work-dir work_dirs/ablations/sar_only/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# SAR-only: filter_modality='sar' +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='sar', + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='sar', + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + filter_modality='sar', + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_shared_decoder.py b/configs/floodnet/ablations/ablation_shared_decoder.py new file mode 100644 index 0000000000..4b5646146b --- /dev/null +++ b/configs/floodnet/ablations/ablation_shared_decoder.py @@ -0,0 +1,186 @@ +""" +Ablation Study: Shared Decoder +Full model but with a single shared decode head for all modalities. + +Table 2 Row (g): Shared Decoder (vs Separate) + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_shared_decoder.py \ + --work-dir work_dirs/ablations/shared_decoder/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='shared', # <<< SHARED instead of separate + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_topk_1.py b/configs/floodnet/ablations/ablation_topk_1.py new file mode 100644 index 0000000000..4386f53fda --- /dev/null +++ b/configs/floodnet/ablations/ablation_topk_1.py @@ -0,0 +1,186 @@ +""" +Ablation Study: Top-k=1 (vs default k=3) +Only one expert activated per token. + +Table 3 Row: num_experts=8, top_k=1 + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_topk_1.py \ + --work-dir work_dirs/ablations/topk_1/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 1 # <<< 1 instead of 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', batch_size=16), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/ablations/ablation_uniform_sampling.py b/configs/floodnet/ablations/ablation_uniform_sampling.py new file mode 100644 index 0000000000..e7d8fe27ab --- /dev/null +++ b/configs/floodnet/ablations/ablation_uniform_sampling.py @@ -0,0 +1,184 @@ +""" +Ablation Study: Uniform Sampling (No SAR Boost) +Full model but with uniform random sampling instead of SAR-boosted ratio. + +Table 2 Row (f): w/o SAR Boost Sampling + +Usage: + python tools/train.py configs/floodnet/ablations/ablation_uniform_sampling.py \ + --work-dir work_dirs/ablations/uniform_sampling/ --seed 42 +""" + +_base_ = [ + '../../_base_/default_runtime.py', +] + +DATASETS_CONFIG = dict(names=['sar', 'rgb', 'GF']) + +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +depths = [2, 2, 18, 2] +MoE_Block_inds = [[], [1], [1, 3, 5, 7, 9, 11, 13, 15, 17], [0, 1]] +num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} +num_experts = 8 +top_k = 3 +noisy_gating = True + +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', pad_val=0, seg_pad_val=255, size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + pretrain_img_size=224, patch_size=4, + embed_dims=128, depths=depths, num_heads=[4, 8, 16, 32], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + patch_norm=True, out_indices=[0, 1, 2, 3], + use_moe=True, num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, use_expert_diversity_loss=True, + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, + dropout_ratio=0.1, num_classes=num_classes, norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + + auxiliary_head=dict( + type='FCNHead', in_channels=512, in_index=2, channels=256, + num_convs=1, concat_input=False, dropout_ratio=0.1, + num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False, + loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# <<< KEY CHANGE: DefaultSampler instead of FixedRatioModalSampler >>> +train_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='train/images', seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='val/images', seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, num_workers=8, persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type=dataset_type, data_root=data_root, + data_prefix=dict(img_path='test/images', seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'patch_embed': dict(lr_mult=2.0), 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + }) +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), + dict(type='PolyLR', eta_min=0.0, power=1.0, begin=5, end=100, by_epoch=True), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, + max_keep_ckpts=1, save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/continue_train_150ep.py b/configs/floodnet/continue_train_150ep.py new file mode 100644 index 0000000000..d269d17019 --- /dev/null +++ b/configs/floodnet/continue_train_150ep.py @@ -0,0 +1,28 @@ +""" +Continue training Full Model for 50 more epochs (101-150). +""" + +_base_ = ['./multimodal_floodnet_sar_boost_swinbase_moe_config.py'] + +# Extend to 150 epochs +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=150, + val_interval=10) + +# Extend PolyLR end to 150 +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=150, + by_epoch=True), +] diff --git a/configs/floodnet/finetune_sen1floods11_s1.py b/configs/floodnet/finetune_sen1floods11_s1.py new file mode 100644 index 0000000000..bec6510938 --- /dev/null +++ b/configs/floodnet/finetune_sen1floods11_s1.py @@ -0,0 +1,187 @@ +""" +Sen1Floods11 fine-tune: S1Hand (2-band VV/VH SAR). + +Expected on-disk layout (see tools/setup_sen1floods11.py):: + + data/Sen1Floods11/ + S1Hand/_S1Hand.tif # 2-band SAR (VV, VH in dB) + S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI (unused here) + LabelHand/_LabelHand.tif # 1-band label: -1=nodata, 0=bg, 1=flood + splits/train.txt # one per line, generated by + splits/val.txt # tools/setup_sen1floods11.py + splits/test.txt + +All tiles are 512x512. Labels are signed TIFFs; the ``-1`` nodata +value is remapped to ``ignore_index=255`` by +:class:`LoadSen1Floods11Annotation`. + +Setup (run once, before the first training run): + + # 1. Generate train/val/test splits + compute mean/std for s1 & s2. + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 + + # 2. Paste the printed NORM_CONFIGS block into the 's1' entry in + # mmseg/datasets/transforms/multimodal_pipelines.py + # (skip this if you're happy with the shipped defaults). + +Training: + + python tools/train.py configs/floodnet/finetune_sen1floods11_s1.py \\ + --work-dir work_dirs/generalization/sen1floods11_s1/ \\ + --cfg-options \\ + load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth" + +The pretrained Swin-Base + MoE checkpoint has modal-specific patch +embeds / decode heads for ``sar / rgb / GF``. Here we redefine the +single trainable modal as ``s1`` (2-channel) and the single decode-head +key as ``s1``; the mismatched pretrained keys are ignored on load +(``strict=False``), so the new patch-embed and decode head train from +scratch while the frozen Swin stages reuse their pretrained weights. +""" + +_base_ = ['./finetune_stem_decoder.py'] + +# ==================== Modal / dataset identifiers ==================== +MODAL_NAME = 's1' +MODAL_CHANNELS = 2 + +ALL_KNOWN_MODALS = { + # _delete_ forces a full replacement so the base config's + # sar/rgb/GF entries don't leak through. + '_delete_': True, + MODAL_NAME: { + 'channels': MODAL_CHANNELS, + 'pattern': 's1hand', + 'description': 'Sen1Floods11 S1Hand (VV/VH dB)', + }, +} +TRAINING_MODALS = [MODAL_NAME] +DATASET_NAMES = [MODAL_NAME] + +# 512x512 input -> crop to 256x256 for training (gives ~4x position +# diversity per tile and keeps each training step cheap). +crop_size = (256, 256) + +# ==================== Model override ==================== +model = dict( + dataset_names=DATASET_NAMES, + backbone=dict( + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + # frozen_stages / freeze_patch_embed inherit from base + ), +) + +# ==================== Dataset / pipelines ==================== +dataset_type = 'Sen1Floods11Dataset' +data_root = 'data/Sen1Floods11/' + +# All tiles are already 512x512, so no RandomResize is needed. +# Pipeline order: +# 1. Load image (sets 'img') +# 2. Load annotation (sets 'gt_seg_map', 'seg_fields') +# - must come BEFORE RandomCrop so cat_max_ratio can check label +# distribution and avoid degenerate all-background crops. +# 3. RandomCrop 256x256 +# 4. RandomFlip +# 5. MultiModalNormalize (NaN/Inf-safe, modality-aware mean/std) +# 6. MultiModalPad to crop_size (no-op if RandomCrop matched exactly, +# but handles any edge case where the tile is SegDataSample +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadSen1Floods11Annotation'), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# Test / val: feed the full 512x512 tile; sliding-window inference +# (configured in the base config as mode='slide', crop_size=256, +# stride=170) handles the cropping internally. +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadSen1Floods11Annotation'), + dict(type='MultiModalNormalize'), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# ==================== Dataloaders ==================== +# _delete_=True on sampler forces full replacement so the base +# FixedRatioModalSampler config is discarded - Sen1Floods11 is single +# modal, so we just use the standard shuffling sampler. +# +# ann_file='splits/train.txt' is resolved relative to data_root by +# BaseSegDataset._join_prefix, so the actual path read is +# data/Sen1Floods11/splits/train.txt . Generate these files with +# tools/setup_sen1floods11.py before training. +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + _delete_=True, + type='DefaultSampler', + shuffle=True, + ), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/train.txt', + data_prefix=dict( + img_path='S1Hand', + seg_map_path='LabelHand', + ), + pipeline=train_pipeline, + ), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/val.txt', + data_prefix=dict( + img_path='S1Hand', + seg_map_path='LabelHand', + ), + pipeline=test_pipeline, + ), +) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/test.txt', + data_prefix=dict( + img_path='S1Hand', + seg_map_path='LabelHand', + ), + pipeline=test_pipeline, + ), +) diff --git a/configs/floodnet/finetune_sen1floods11_s2.py b/configs/floodnet/finetune_sen1floods11_s2.py new file mode 100644 index 0000000000..a2e4fe1a41 --- /dev/null +++ b/configs/floodnet/finetune_sen1floods11_s2.py @@ -0,0 +1,168 @@ +""" +Sen1Floods11 fine-tune: S2Hand (13-band Sentinel-2 MSI). + +Sibling of ``finetune_sen1floods11_s1.py`` - same freeze / LR / schedule +strategy, but swaps the trainable modal to ``s2`` (13 ch) and points +the dataset at the ``S2Hand`` subdirectory. LabelHand files are shared +between S1 and S2, so the exact same splits/train.txt / splits/val.txt +/ splits/test.txt files are used - this means any S1 vs S2 comparison +is done on the same tiles. + +Expected on-disk layout (see tools/setup_sen1floods11.py):: + + data/Sen1Floods11/ + S1Hand/_S1Hand.tif # 2-band SAR (unused here) + S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI + LabelHand/_LabelHand.tif # -1=nodata, 0=bg, 1=flood + splits/{train,val,test}.txt + +All tiles are 512x512. + +Setup (run once, before the first training run): + + # 1. Generate splits + compute mean/std (covers both s1 and s2). + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 + + # 2. Paste the printed NORM_CONFIGS block into the 's2' entry in + # mmseg/datasets/transforms/multimodal_pipelines.py + # (skip this if you're happy with the shipped defaults). + +Training: + + python tools/train.py configs/floodnet/finetune_sen1floods11_s2.py \\ + --work-dir work_dirs/generalization/sen1floods11_s2/ \\ + --cfg-options \\ + load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth" +""" + +_base_ = ['./finetune_stem_decoder.py'] + +# ==================== Modal / dataset identifiers ==================== +MODAL_NAME = 's2' +MODAL_CHANNELS = 13 + +ALL_KNOWN_MODALS = { + # _delete_ forces a full replacement so the base config's + # sar/rgb/GF entries don't leak through. + '_delete_': True, + MODAL_NAME: { + 'channels': MODAL_CHANNELS, + 'pattern': 's2hand', + 'description': 'Sen1Floods11 S2Hand (13-band Sentinel-2 MSI)', + }, +} +TRAINING_MODALS = [MODAL_NAME] +DATASET_NAMES = [MODAL_NAME] + +# 512x512 input -> crop to 256x256 for training. +crop_size = (256, 256) + +# ==================== Model override ==================== +# See finetune_sen1floods11_s1.py for the rationale - the single +# trainable modal here is `s2` with a 13-channel stem conv. The +# s2 mean/std entry in MultiModalNormalize.NORM_CONFIGS is used +# by default; re-run tools/setup_sen1floods11.py to refresh it +# for your local data. +model = dict( + dataset_names=DATASET_NAMES, + backbone=dict( + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + ), +) + +# ==================== Dataset / pipelines ==================== +dataset_type = 'Sen1Floods11Dataset' +data_root = 'data/Sen1Floods11/' + +# Pipeline order matches finetune_sen1floods11_s1.py. See that file +# for the rationale behind each step. +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadSen1Floods11Annotation'), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadSen1Floods11Annotation'), + dict(type='MultiModalNormalize'), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +# ==================== Dataloaders ==================== +# Shares splits/train.txt / splits/val.txt / splits/test.txt with +# finetune_sen1floods11_s1.py so the S1 / S2 results are directly +# comparable (same tiles in each split, just different sensors). +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + _delete_=True, + type='DefaultSampler', + shuffle=True, + ), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/train.txt', + data_prefix=dict( + img_path='S2Hand', + seg_map_path='LabelHand', + ), + pipeline=train_pipeline, + ), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/val.txt', + data_prefix=dict( + img_path='S2Hand', + seg_map_path='LabelHand', + ), + pipeline=test_pipeline, + ), +) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + _delete_=True, + type=dataset_type, + data_root=data_root, + modality=MODAL_NAME, + ann_file='splits/test.txt', + data_prefix=dict( + img_path='S2Hand', + seg_map_path='LabelHand', + ), + pipeline=test_pipeline, + ), +) diff --git a/configs/floodnet/finetune_single_modal.py b/configs/floodnet/finetune_single_modal.py new file mode 100644 index 0000000000..b756f317a8 --- /dev/null +++ b/configs/floodnet/finetune_single_modal.py @@ -0,0 +1,48 @@ +""" +Generalization Fine-tuning Config: Single-Modal (RGB only) + +Freeze backbone, retrain stem + decoder on a new RGB-only flood event. +Based on finetune_stem_decoder.py but adapted for single-modality data: + - filter_modality='rgb' to only load RGB images + - DefaultSampler instead of FixedRatioModalSampler + - Uses the rgb decode_head from the pretrained separate-decoder model + +Usage: + python tools/train.py configs/floodnet/finetune_single_modal.py \ + --work-dir work_dirs/generalization/LY-train-station/ \ + --cfg-options \ + train_dataloader.dataset.data_root="data/LY-train-station/" \ + val_dataloader.dataset.data_root="data/LY-train-station/" \ + test_dataloader.dataset.data_root="data/LY-train-station/" +""" + +_base_ = ['./finetune_stem_decoder.py'] + +# ==================== Override sampler to DefaultSampler ==================== +# _delete_=True forces full replacement instead of recursive merge, +# otherwise FixedRatioModalSampler's modal_ratios etc. leak through +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + _delete_=True, + type='DefaultSampler', + shuffle=True, + ), + dataset=dict( + filter_modality='rgb', + ), +) + +val_dataloader = dict( + dataset=dict( + filter_modality='rgb', + ), +) + +test_dataloader = dict( + dataset=dict( + filter_modality='rgb', + ), +) diff --git a/configs/floodnet/finetune_stem_decoder.py b/configs/floodnet/finetune_stem_decoder.py new file mode 100644 index 0000000000..ba149f71f5 --- /dev/null +++ b/configs/floodnet/finetune_stem_decoder.py @@ -0,0 +1,87 @@ +""" +Generalization Fine-tuning Config: Freeze Backbone, Retrain Stem + Decoder + +Loads pretrained Full Model checkpoint, freezes all 4 Swin Transformer stages +(attention, MoE, patch merging) via backbone.frozen_stages=3, and only retrains: + 1. ModalSpecificPatchEmbed (backbone.patch_embed) - adapts to new sensor inputs + 2. UPerHead decode_heads - adapts segmentation output to new domain + 3. FCNHead auxiliary_heads - auxiliary segmentation head + +Frozen stages use requires_grad=False + eval() mode for proper BN/Dropout behavior. + +Usage: + python tools/train.py configs/floodnet/finetune_stem_decoder.py \ + --work-dir work_dirs/generalization/event_name/ \ + --cfg-options \ + load_from="work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth" \ + train_dataloader.dataset.data_root="../floodnet/data/new_event/" \ + val_dataloader.dataset.data_root="../floodnet/data/new_event/" \ + test_dataloader.dataset.data_root="../floodnet/data/new_event/" +""" + +_base_ = ['./multimodal_floodnet_sar_boost_swinbase_moe_config.py'] + +# ==================== Load pretrained checkpoint ==================== +# Override this via --cfg-options load_from="path/to/checkpoint.pth" +load_from = 'work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth' + +# ==================== Freeze all 4 backbone stages ==================== +# frozen_stages=3 freezes stages 0-3 (all Swin blocks, MoE, patch merging, norms) +# freeze_patch_embed=False keeps ModalSpecificPatchEmbed trainable +model = dict( + backbone=dict( + frozen_stages=3, # Freeze all 4 stages (0,1,2,3) + freeze_patch_embed=False # Keep stem trainable + ), +) + +# ==================== Optimizer for fine-tuning ==================== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.0001, # Higher base LR for fine-tuning fewer params + betas=(0.9, 0.999), + weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + # Stem: moderate LR boost + 'backbone.patch_embed': dict(lr_mult=2.0), + # Decode heads: high LR for fast adaptation + 'decode_heads': dict(lr_mult=10.0), + 'auxiliary_heads': dict(lr_mult=10.0), + } + ) +) + +# ==================== Shorter schedule for fine-tuning ==================== +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10), # Short warmup + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=2, + end=30, + by_epoch=True), +] + +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, # Much shorter than full training (100 epochs) + val_interval=10) + +# ==================== Checkpoint hook ==================== +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, + max_keep_ckpts=1, + save_best='mIoU'), +) diff --git a/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py new file mode 100644 index 0000000000..936e4eaed6 --- /dev/null +++ b/configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py @@ -0,0 +1,298 @@ +""" +SwinMoE: Multi-Dataset FloodNet Training - SAR-Boosted Configuration +MMSeg 1.x Version + +Usage: + python tools/train.py configs/floodnet/multimodal_floodnet_sar_boost_swin_moe_config.py \ + --work-dir work_dirs/swin_moe_upernet/SAR_Boost/ --seed 42 +""" + +_base_ = [ + '../_base_/default_runtime.py', +] + +# ==================== Dataset Definition ==================== +DATASETS_CONFIG = dict( + names=['sar', 'rgb', 'GF'], +) + +# ==================== Modal Definition ==================== +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', + 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', + 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', + 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +# ==================== MoE Config ==================== +depths = [2, 2, 18, 2] +MoE_Block_inds = [ + [], + [1], + [1, 3, 5, 7, 9, 11, 13, 15, 17], + [0, 1] +] + +num_shared_experts_config = { + 0: 0, + 1: 0, + 2: 2, + 3: 1 +} + +num_experts = 8 +top_k = 3 +noisy_gating = True + +# ==================== Model Config ==================== +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', + pad_val=0, + seg_pad_val=255, + size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + + pretrain_img_size=224, + patch_size=4, + embed_dims=96, + depths=depths, + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=[0, 1, 2, 3], + + use_moe=True, + num_experts=8, + num_shared_experts_config=num_shared_experts_config, + top_k=3, + noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, + use_expert_diversity_loss=True, + + pretrained=None, + ), + + decode_head=dict( + type='UPerHead', + in_channels=[96, 192, 384, 768], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0 + ) + ), + + auxiliary_head=dict( + type='FCNHead', + in_channels=384, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4 + ) + ), + + train_cfg=dict(), + test_cfg=dict(mode='whole') +) + +# ==================== Dataset Config ==================== +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/SAR' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='Resize', scale=(256, 256), keep_ratio=False), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(256, 256), keep_ratio=False), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='sar', + batch_size=16, + ), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', + seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='val/images', + seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='test/images', + seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +# ==================== Optimizer Config ==================== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'patch_embed': dict(lr_mult=2.0), + 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), + 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), + 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + } + ) +) + +# ==================== Learning Rate Config ==================== +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=200, + by_epoch=True), +] + +# ==================== Training Loop ==================== +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# ==================== Hooks ==================== +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, + max_keep_ckpts=1, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +# ==================== Runtime Overrides ==================== +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py new file mode 100644 index 0000000000..0109a0a30d --- /dev/null +++ b/configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py @@ -0,0 +1,305 @@ +""" +Swin-Base + MoE: Multi-Dataset FloodNet Training - SAR-Boosted Configuration +MMSeg 1.x Version + +Based on Swin-uavflood-256x256 (Swin-Base) upgraded with MoE. +Backbone: Swin-Base (embed_dims=128) + MoE (8 experts, top_k=3) +Estimated params: ~456M (vs Swin-T+MoE ~278M, vs Swin-B ~109M) + +Usage: + python tools/train.py configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \ + --work-dir work_dirs/swinbase_moe_upernet/SAR_Boost/ --seed 42 +""" + +_base_ = [ + '../_base_/default_runtime.py', +] + +# ==================== Dataset Definition ==================== +DATASETS_CONFIG = dict( + names=['sar', 'rgb', 'GF'], +) + +# ==================== Modal Definition ==================== +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', + 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', + 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', + 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +# ==================== MoE Config ==================== +depths = [2, 2, 18, 2] +MoE_Block_inds = [ + [], + [1], + [1, 3, 5, 7, 9, 11, 13, 15, 17], + [0, 1] +] + +num_shared_experts_config = { + 0: 0, + 1: 0, + 2: 2, + 3: 1 +} + +num_experts = 8 +top_k = 3 +noisy_gating = True + +# ==================== Model Config ==================== +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', + pad_val=0, + seg_pad_val=255, + size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + + pretrain_img_size=224, + patch_size=4, + # ---- Swin-Base dimensions ---- + embed_dims=128, + depths=depths, + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, # Swin-B uses 0.3 (vs Swin-T 0.2) + patch_norm=True, + out_indices=[0, 1, 2, 3], + + # ---- MoE ---- + use_moe=True, + num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, + noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, + use_expert_diversity_loss=True, + + pretrained=None, + ), + + # ---- Swin-Base output channels: [128, 256, 512, 1024] ---- + decode_head=dict( + type='UPerHead', + in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0 + ) + ), + + auxiliary_head=dict( + type='FCNHead', + in_channels=512, # Swin-B stage2 output + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4 + ) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +# ==================== Dataset Config ==================== +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict( + type='FixedRatioModalSampler', + modal_ratios={'sar': 6, 'rgb': 5, 'GF': 5}, + modal_order=['sar', 'rgb', 'GF'], + reference_modal='GF', + batch_size=16, + ), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train/images', + seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='val/images', + seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='test/images', + seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +# ==================== Optimizer Config ==================== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'patch_embed': dict(lr_mult=2.0), + 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), + 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), + 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + } + ) +) + +# ==================== Learning Rate Config ==================== +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=100, + by_epoch=True), +] + +# ==================== Training Loop ==================== +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# ==================== Hooks ==================== +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, + max_keep_ckpts=1, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +# ==================== Runtime Overrides ==================== +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py b/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py new file mode 100644 index 0000000000..d1db87be7c --- /dev/null +++ b/configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py @@ -0,0 +1,300 @@ +""" +Swin-Base + MoE: SAR-Only FloodNet Training Configuration +MMSeg 1.x Version + +Only trains on SAR modality (100 epochs). + +Usage: + python tools/train.py configs/floodnet/multimodal_floodnet_sar_only_swinbase_moe_config.py \ + --work-dir work_dirs/floodnet/SwinmoeB_sar_only/ --cfg-options seed=42 +""" + +_base_ = [ + '../_base_/default_runtime.py', +] + +# ==================== Dataset Definition ==================== +DATASETS_CONFIG = dict( + names=['sar', 'rgb', 'GF'], +) + +# ==================== Modal Definition ==================== +ALL_KNOWN_MODALS = { + 'sar': {'channels': 8, 'pattern': 'sar', + 'description': 'Synthetic Aperture Radar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb', + 'description': 'RGB Optical'}, + 'GF': {'channels': 5, 'pattern': 'GF', + 'description': 'GaoFen Satellite'}, +} + +TRAINING_MODALS = ['sar', 'rgb', 'GF'] +num_classes = 2 + +# ==================== MoE Config ==================== +depths = [2, 2, 18, 2] +MoE_Block_inds = [ + [], + [1], + [1, 3, 5, 7, 9, 11, 13, 15, 17], + [0, 1] +] + +num_shared_experts_config = { + 0: 0, + 1: 0, + 2: 2, + 3: 1 +} + +num_experts = 8 +top_k = 3 +noisy_gating = True + +# ==================== Model Config ==================== +norm_cfg = dict(type='BN', requires_grad=True) +crop_size = (256, 256) + +data_preprocessor = dict( + type='MultiModalDataPreProcessor', + pad_val=0, + seg_pad_val=255, + size=crop_size, +) + +model = dict( + type='MultiModalEncoderDecoderV2', + data_preprocessor=data_preprocessor, + use_moe=True, + use_modal_bias=True, + moe_balance_weight=1.0, + moe_diversity_weight=0.1, + multi_tasks_reweight=None, + decoder_mode='separate', + dataset_names=DATASETS_CONFIG['names'], + + backbone=dict( + type='MultiModalSwinMoE', + modal_configs=ALL_KNOWN_MODALS, + training_modals=TRAINING_MODALS, + + pretrain_img_size=224, + patch_size=4, + # ---- Swin-Base dimensions ---- + embed_dims=128, + depths=depths, + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=[0, 1, 2, 3], + + # ---- MoE ---- + use_moe=True, + num_experts=num_experts, + num_shared_experts_config=num_shared_experts_config, + top_k=top_k, + noisy_gating=noisy_gating, + MoE_Block_inds=MoE_Block_inds, + use_expert_diversity_loss=True, + + pretrained=None, + ), + + # ---- Swin-Base output channels: [128, 256, 512, 1024] ---- + decode_head=dict( + type='UPerHead', + in_channels=[128, 256, 512, 1024], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0 + ) + ), + + auxiliary_head=dict( + type='FCNHead', + in_channels=512, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=num_classes, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4 + ) + ), + + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +# ==================== Dataset Config ==================== +dataset_type = 'MultiModalDeepflood' +data_root = '../floodnet/data/mixed_dataset/' + +train_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='GenerateBoundary', thickness=3), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0), keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='MultiModalNormalize'), + dict(type='MultiModalPad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +test_pipeline = [ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='MultiModalNormalize'), + dict(type='LoadAnnotations', reduce_zero_label=False), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + filter_modality='sar', + data_prefix=dict( + img_path='train/images', + seg_map_path='train/labels'), + pipeline=train_pipeline), +) + +val_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + filter_modality='sar', + data_prefix=dict( + img_path='val/images', + seg_map_path='val/labels'), + pipeline=test_pipeline), +) + +test_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + filter_modality='sar', + data_prefix=dict( + img_path='test/images', + seg_map_path='test/labels'), + pipeline=test_pipeline), +) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator + +# ==================== Optimizer Config ==================== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'patch_embed': dict(lr_mult=2.0), + 'modal_patch_embeds': dict(lr_mult=2.0), + 'relative_position_bias_table': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.), + 'gating': dict(lr_mult=2.0), + 'experts': dict(lr_mult=1.5), + 'modal_bias': dict(lr_mult=3.0), + 'shared_experts': dict(lr_mult=2.0), + } + ) +) + +# ==================== Learning Rate Config ==================== +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=100, + by_epoch=True), +] + +# ==================== Training Loop ==================== +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# ==================== Hooks ==================== +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=500, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, + max_keep_ckpts=1, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook'), +) + +# ==================== Runtime Overrides ==================== +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +log_processor = dict(by_epoch=True) diff --git a/configs/mae/mae-base-Uavflood.py b/configs/mae/mae-base-Uavflood.py new file mode 100644 index 0000000000..828fd929f1 --- /dev/null +++ b/configs/mae/mae-base-Uavflood.py @@ -0,0 +1,53 @@ +_base_ = [ + '../_base_/models/upernet_mae.py', '../_base_/datasets/UAVflood.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py' +] +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict( + type='MAE', + img_size=(256, 256), + patch_size=16, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + init_values=1.0, + drop_path_rate=0.1, + out_indices=[3, 5, 7, 11]), + neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]), + decode_head=dict( + in_channels=[768, 768, 768, 768], num_classes=2, channels=768), + auxiliary_head=dict(in_channels=768, num_classes=2), + test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170))) + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05), + paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65), + constructor='LayerDecayOptimizerConstructor') + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=1500, + end=20000, + by_epoch=False, + ) +] + +# mixed precision +fp16 = dict(loss_scale='dynamic') + +# By default, models are trained on 8 GPUs with 2 images per GPU +train_dataloader = dict(batch_size=8) +val_dataloader = dict(batch_size=8) +test_dataloader = val_dataloader diff --git a/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py b/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py new file mode 100644 index 0000000000..0d42f170be --- /dev/null +++ b/configs/segformer/segformer_mit-b0_8xb1-160k_UAVflood-256x256.py @@ -0,0 +1,73 @@ +_base_ = [ + '../_base_/models/segformer_mit-b0.py', + '../_base_/datasets/UAVflood.py', + '../_base_/default_runtime.py' +] +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) + +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict( + num_classes=2 # 修改为你需要的类别数 + ) +) + +randomness = dict( + seed=42, + deterministic=False, # 如需完全可复现,设为True +) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'pos_block': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.) + })) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=100, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) + +train_dataloader = dict(batch_size=8, num_workers=8) +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader diff --git a/configs/swin/Swin-uavflood-256x256.py b/configs/swin/Swin-uavflood-256x256.py new file mode 100644 index 0000000000..2826cfc2e0 --- /dev/null +++ b/configs/swin/Swin-uavflood-256x256.py @@ -0,0 +1,154 @@ +_base_ = [ + '../_base_/datasets/GFflood.py', + '../_base_/default_runtime.py', + '../_base_/models/upernet_swin.py' +] + +# ===== 基础配置 ===== +norm_cfg = dict(type='SyncBN', requires_grad=True) +backbone_norm_cfg = dict(type='LN', requires_grad=True) +crop_size = (256, 256) + +# ===== 数据预处理器 ===== +data_preprocessor = dict( + type='SegDataPreProcessor', + size=crop_size, + #mean=[117.926186, 117.568402, 97.217239], + #std=[53.542876104049824, 50.084170325219176, 50.49331035114637], + mean= [432.02181, 315.92948, 246.468659, 310.61462, 360.267789], + std= [97.73313111900238, 85.78646917160748, 95.78015824658593, 124.84677067613467, 251.73965882246978], + #mean=[0.23651549, 0.31761484, 0.18514981, 0.26901252, -14.57879175, -8.6098158, -14.2907338, -8.33534564], + #std=[0.16280619, 0.20849304, 0.14008107, 0.19767644, 4.07141682, 3.94773216, 4.21006244, 4.05494136], + bgr_to_rgb=False, + pad_val=0, + seg_pad_val=255 +) + +# ===== 模型配置 - Swin-Base ===== +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained=None, + backbone=dict( + type='SwinTransformer', + pretrain_img_size=224, + embed_dims=128, # Base: 128 + in_channels=5, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=[2, 2, 18, 2], # Base: [2, 2, 18, 2] + num_heads=[4, 8, 16, 32], # Base: [4, 8, 16, 32] + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=backbone_norm_cfg + ), + decode_head=dict( + type='UPerHead', + in_channels=[128, 256, 512, 1024], # Base输出通道 + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + ), + auxiliary_head=dict( + type='FCNHead', + in_channels=512, # Base stage2输出 + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) + ), + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(170, 170)) +) + +# ===== 优化器配置 ===== +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01 + ), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) +) + +# ===== 学习率调度器 ===== +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6, + by_epoch=True, + begin=0, + end=5 # warmup前5个epoch + ), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=100, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) + +# ===== 数据加载器配置 ===== +train_dataloader = dict(batch_size=8, num_workers=8) # Base模型较大,减小batch_size +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader + +# ===== 随机种子 ===== +randomness = dict(seed=42, deterministic=False) diff --git a/configs/unet/Unet-Uavflood.py b/configs/unet/Unet-Uavflood.py new file mode 100644 index 0000000000..2b7d902c71 --- /dev/null +++ b/configs/unet/Unet-Uavflood.py @@ -0,0 +1,66 @@ +_base_ = [ + '../_base_/models/deeplabv3_unet_s5-d16.py', + '../_base_/datasets/SARflood.py', + '../_base_/default_runtime.py' +] + +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=dict(num_classes=2), + test_cfg=dict(mode='slide', crop_size=(256, 256), stride=(170, 170))) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch + dict( + type='PolyLR', + eta_min=1e-4, + power=0.9, + begin=5, + end=100, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +randomness = dict( + seed=42, + deterministic=False, # 如需完全可复现,设为True +) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) + +train_dataloader = dict(batch_size=8, num_workers=8) +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader \ No newline at end of file diff --git a/configs/vit/vit-Uavflood.py b/configs/vit/vit-Uavflood.py new file mode 100644 index 0000000000..24da97f435 --- /dev/null +++ b/configs/vit/vit-Uavflood.py @@ -0,0 +1,77 @@ +_base_ = [ + '../_base_/models/upernet_vit-b16_ln_mln.py', + '../_base_/datasets/SARflood.py', + '../_base_/default_runtime.py' +] + +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict( + img_size=(256, 256), + drop_path_rate=0.1, + final_norm=True), + decode_head=dict(num_classes=2), + auxiliary_head=dict(num_classes=2)) + +# AdamW optimizer, no weight decay for position embedding & layer norm +# in backbone +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=True, begin=0, end=5), # warmup前5个epoch + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=5, + end=100, + by_epoch=True, + ) +] + +# 使用EpochBasedTrainLoop训练100个epoch +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=100, + val_interval=10) # 每10个epoch验证一次 + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# 修改hooks为基于epoch +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=1, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + by_epoch=True, + interval=10, # 每10个epoch保存一次 + max_keep_ckpts=3, + save_best='mIoU'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) + +randomness = dict( + seed=42, + deterministic=False, # 如需完全可复现,设为True +) + +# 设置日志按epoch显示 +log_processor = dict(by_epoch=True) + +train_dataloader = dict(batch_size=8, num_workers=8) +val_dataloader = dict(batch_size=8, num_workers=8) +test_dataloader = val_dataloader \ No newline at end of file diff --git a/cropSAR.py b/cropSAR.py new file mode 100644 index 0000000000..a8f4c788ea --- /dev/null +++ b/cropSAR.py @@ -0,0 +1,124 @@ +import os +import numpy as np +from PIL import Image +import rasterio +from rasterio.transform import from_bounds +import random + + + +# 设置路径 +gt_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/GT' # 标签文件夹 +sar_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/SAR' # 影像文件夹 +output_folder = '/mnt/d//Data/DLdata/urban_sar_floods/urban_sar_floods/03_FU/output' # 输出文件夹 + +# 创建输出目录结构 +for subset in ['train', 'val', 'test']: + os.makedirs(os.path.join(output_folder, subset, 'images'), exist_ok=True) + os.makedirs(os.path.join(output_folder, subset, 'labels'), exist_ok=True) + + +# 裁剪函数 - 适配多波段 +def crop_image(img_array, crop_size=256): + crops = [] + if len(img_array.shape) == 3: # 多波段 (bands, height, width) + bands, h, w = img_array.shape + for i in range(0, h, crop_size): + for j in range(0, w, crop_size): + crop = img_array[:, i:i + crop_size, j:j + crop_size] + if crop.shape[1] == crop_size and crop.shape[2] == crop_size: + crops.append(crop) + else: # 单波段 (height, width) + h, w = img_array.shape + for i in range(0, h, crop_size): + for j in range(0, w, crop_size): + crop = img_array[i:i + crop_size, j:j + crop_size] + if crop.shape[0] == crop_size and crop.shape[1] == crop_size: + crops.append(crop) + return crops + + +# 获取所有文件 +gt_files = sorted([f for f in os.listdir(gt_folder) if f.endswith('.tif')]) +sar_files = sorted([f for f in os.listdir(sar_folder) if f.endswith('.tif')]) + +# 生成所有裁剪块 +all_crops = [] +for gt_file, sar_file in zip(gt_files, sar_files): + gt_path = os.path.join(gt_folder, gt_file) + sar_path = os.path.join(sar_folder, sar_file) + + # 读取标签(单波段) + with Image.open(gt_path) as gt_img: + gt_array = np.array(gt_img) + + # 读取SAR影像(多波段) + with rasterio.open(sar_path) as sar_src: + sar_array = sar_src.read() # 读取所有波段 (bands, height, width) + sar_profile = sar_src.profile + + # 裁剪 + gt_crops = crop_image(gt_array) + sar_crops = crop_image(sar_array) + + # 保存裁剪信息 + base_name = os.path.splitext(gt_file)[0] + for idx, (gt_crop, sar_crop) in enumerate(zip(gt_crops, sar_crops)): + all_crops.append({ + 'gt': gt_crop, + 'sar': sar_crop, + 'name': f'{base_name}_{idx}', + 'profile': sar_profile + }) + +print(f'共生成 {len(all_crops)} 个裁剪块') + +# 打乱数据 +random.shuffle(all_crops) + +# 按6:2:2划分 +total = len(all_crops) +train_end = int(total * 0.6) +val_end = int(total * 0.8) + +train_crops = all_crops[:train_end] +val_crops = all_crops[train_end:val_end] +test_crops = all_crops[val_end:] + + +# 保存函数 +def save_crops(crops, subset): + for crop_data in crops: + # 保存多波段SAR影像为TIF + sar_path = os.path.join(output_folder, subset, 'images', f"sar_{crop_data['name']}.tif") + + # 更新profile + profile = crop_data['profile'].copy() + profile.update({ + 'height': 256, + 'width': 256, + 'transform': from_bounds(0, 0, 256, 256, 256, 256) + }) + + with rasterio.open(sar_path, 'w', **profile) as dst: + dst.write(crop_data['sar']) + + # 处理标签:将值为2的像元改为1 + gt_array = crop_data['gt'].copy() + gt_array[gt_array == 2] = 1 + + # 保存标签为PNG + gt_img = Image.fromarray(gt_array.astype(np.uint8)) + gt_path = os.path.join(output_folder, subset, 'labels', f"sar_{crop_data['name']}.png") + gt_img.save(gt_path) + + print(f'{subset}: {len(crops)} 张图像') + + +# 保存所有数据集 +save_crops(train_crops, 'train') +save_crops(val_crops, 'val') +save_crops(test_crops, 'test') + +print(f'\n总共处理了 {total} 张裁剪图像') +print(f'train: {len(train_crops)}, val: {len(val_crops)}, test: {len(test_crops)}') \ No newline at end of file diff --git a/dealNaN.py b/dealNaN.py new file mode 100644 index 0000000000..0d83c44dc7 --- /dev/null +++ b/dealNaN.py @@ -0,0 +1,143 @@ +import os +import numpy as np +import rasterio +from shutil import copy2 +from tqdm import tqdm + + +def replace_nan_with_zero(folder_path, backup=False, backup_suffix='_backup'): + """ + 遍历文件夹,将所有TIF文件中的NaN值替换为0 + + 参数: + folder_path: 要处理的根目录路径 + backup: 是否备份原文件 + backup_suffix: 备份文件后缀 + """ + total_files = 0 + files_with_nan = 0 + files_processed = 0 + tif_extensions = ['.tif', '.tiff', '.TIF', '.TIFF'] + + # 收集所有tif文件 + tif_files = [] + for root, dirs, files in os.walk(folder_path): + for file in files: + if any(file.endswith(ext) for ext in tif_extensions): + tif_files.append(os.path.join(root, file)) + + total_files = len(tif_files) + print(f"找到 {total_files} 个TIF文件") + print("=" * 80) + + # 第一步:检查哪些文件包含NaN + print("\n步骤1: 检查NaN值...") + files_to_process = [] + + try: + iterator = tqdm(tif_files, desc="扫描文件") + except: + iterator = tif_files + + for file_path in iterator: + try: + with rasterio.open(file_path) as src: + has_nan = False + for band_idx in range(1, src.count + 1): + band_data = src.read(band_idx) + if np.isnan(band_data).any(): + has_nan = True + break + + if has_nan: + files_with_nan += 1 + files_to_process.append(file_path) + + except Exception as e: + print(f"\n读取错误: {file_path}") + print(f" 错误: {e}") + + print(f"\n检查完成: {files_with_nan} 个文件包含NaN值") + + if files_with_nan == 0: + print("没有需要处理的文件!") + return + + # 第二步:处理包含NaN的文件 + print(f"\n步骤2: 替换NaN值为0...") + print(f"需要处理 {len(files_to_process)} 个文件") + + if backup: + print(f"将创建备份文件(后缀: {backup_suffix})") + + user_input = input("\n是否继续?(y/n): ") + if user_input.lower() != 'y': + print("操作已取消") + return + + print("\n开始处理...") + + try: + iterator = tqdm(files_to_process, desc="处理文件") + except: + iterator = files_to_process + + for file_path in iterator: + try: + # 备份原文件 + if backup: + backup_path = file_path + backup_suffix + copy2(file_path, backup_path) + + # 读取文件 + with rasterio.open(file_path) as src: + # 保存元数据 + meta = src.meta.copy() + num_bands = src.count + + # 创建临时文件路径 + temp_path = file_path + '.tmp' + + # 写入处理后的数据 + with rasterio.open(temp_path, 'w', **meta) as dst: + for band_idx in range(1, num_bands + 1): + # 读取波段数据 + band_data = src.read(band_idx) + + # 将NaN替换为0 + band_data = np.nan_to_num(band_data, nan=0.0) + + # 写入波段 + dst.write(band_data, band_idx) + + # 替换原文件 + os.replace(temp_path, file_path) + files_processed += 1 + + except Exception as e: + print(f"\n处理错误: {file_path}") + print(f" 错误: {e}") + # 如果临时文件存在,删除它 + temp_path = file_path + '.tmp' + if os.path.exists(temp_path): + os.remove(temp_path) + + # 输出结果 + print("\n" + "=" * 80) + print("处理完成!") + print(f"总文件数: {total_files}") + print(f"包含NaN的文件: {files_with_nan}") + print(f"成功处理的文件: {files_processed}") + + if backup: + print(f"\n原始文件已备份(后缀: {backup_suffix})") + print("如果确认无误,可以删除备份文件") + + +if __name__ == "__main__": + folder_path = "/mnt/d/Project/Code/Floodnet/data/mixed_dataset/val/" + + # 执行替换 + # backup=True 会备份原文件 + # backup=False 直接覆盖原文件(不推荐) + replace_nan_with_zero(folder_path, backup=False, backup_suffix='_backup') \ No newline at end of file diff --git a/delete8X256.py b/delete8X256.py new file mode 100644 index 0000000000..25041c4fcd --- /dev/null +++ b/delete8X256.py @@ -0,0 +1,45 @@ +import os +from PIL import Image + +# 设置基础路径 +base_folder = '../Floodnet/data/mixed_dataset/SAR/' # 修改为你的SAR文件夹路径 +subsets = ['train', 'val', 'test'] + +total_deleted = 0 + +for subset in subsets: + image_folder = os.path.join(base_folder, subset, 'images') + label_folder = os.path.join(base_folder, subset, 'labels') + + if not os.path.exists(image_folder): + print(f'{image_folder} 不存在,跳过') + continue + + deleted_count = 0 + + for filename in os.listdir(image_folder): + if filename.endswith('.tif'): + img_path = os.path.join(image_folder, filename) + + # 读取图像尺寸 + with Image.open(img_path) as img: + width, height = img.size + + # 如果不是256x256,删除图像和对应标签 + if width != 256 or height != 256: + # 删除图像 + os.remove(img_path) + + # 删除对应标签 + label_filename = filename.replace('.tif', '.png') + label_path = os.path.join(label_folder, label_filename) + if os.path.exists(label_path): + os.remove(label_path) + + deleted_count += 1 + print(f'[{subset}] 已删除: {filename} (尺寸: {width}x{height})') + + print(f'{subset} 删除了 {deleted_count} 对文件\n') + total_deleted += deleted_count + +print(f'总共删除了 {total_deleted} 对文件') \ No newline at end of file diff --git a/docs/experiment_plan.md b/docs/experiment_plan.md new file mode 100644 index 0000000000..0364909821 --- /dev/null +++ b/docs/experiment_plan.md @@ -0,0 +1,188 @@ +# FloodNet Paper Experiment Plan +## Multi-Modal Flood Segmentation via Mixture-of-Experts with Modal-Aware Routing + +--- + +## 1. Model Overview (Final Model) + +**Config**: `multimodal_floodnet_sar_boost_swinbase_moe_config.py` + +| Component | Configuration | +|-----------|--------------| +| Backbone | Swin-Base (embed_dims=128, depths=[2,2,18,2]) | +| Patch Embed | ModalSpecificStem (独立卷积 per modality) | +| MoE | 8 experts, top_k=3, noisy gating | +| Shared Experts | Stage 2: 2, Stage 3: 1 | +| Modal Bias | Learnable per-modal routing bias | +| Diversity Loss | weight=0.1; Balance Loss weight=1.0 | +| Decoder | Separate UPerHead per modality | +| Sampling | SAR Boost (6:5:5) | +| Modalities | SAR (8ch), RGB (3ch), GaoFen (5ch) | + +--- + +## 2. Experiment Design + +### Table 2: Component Ablation Study (组件消融) + +**目的**: 隔离每个组件的贡献,验证设计合理性。 + +**评估**: 每个变体训练后,分别在 SAR / RGB / GF 三个模态上独立测试 mIoU。 + +| Row | 配置 | 变更 | SAR | RGB | GF | Avg | +|-----|------|------|-----|-----|-----|-----| +| (a) | **Full Model** | — | - | - | - | - | +| (b) | w/o MoE | 标准 FFN, 无专家路由 | - | - | - | - | +| (c) | w/o ModalSpecificStem | 统一卷积+零填充 替代 模态独立卷积 | - | - | - | - | +| (d) | w/o Modal Bias | MoE gating 无模态偏置 | - | - | - | - | +| (e) | w/o Shared Experts | 所有专家均为路由专家,无共享 | - | - | - | - | +| (f) | w/o Separate Decoder | 一个 UPerHead 共享所有模态 | - | - | - | - | + +**Config Files**: +- (a) `multimodal_floodnet_sar_boost_swinbase_moe_config.py` +- (b) `ablations/ablation_no_moe.py` +- (c) `ablations/ablation_no_modal_specific_stem.py` +- (d) `ablations/ablation_no_modal_bias.py` +- (e) `ablations/ablation_no_shared_experts.py` +- (f) `ablations/ablation_shared_decoder.py` + +**运行**: +```bash +bash scripts/run_all_experiments.sh table2 +``` + +**讨论要点**: + +**(b) w/o MoE**: 预期最大下降。标准 FFN 无法适配不同模态的特征分布差异。所有 token 共享相同权重,无论来自 SAR 后向散射还是 RGB 反射率。 + +**(c) w/o ModalSpecificStem**: 统一 patch embedding 将所有模态零填充到 max_channels(8) 再卷积,丢失了模态特异性的早期特征提取。SAR 的 8 通道数据被原样保留,但 RGB(3ch) 和 GF(5ch) 被大量零填充,引入噪声。 + +**(d) w/o Modal Bias**: 去除模态路由偏置后,gating 网络对所有模态一视同仁。模态偏置允许模型学习到某些专家专门处理 SAR 的散斑噪声和后向散射强度 vs RGB 的颜色纹理。 + +**(e) w/o Shared Experts**: 共享专家捕获模态不变特征(如水体空间结构、边界形态)。移除后,路由专家需冗余学习通用特征,降低专门化能力。 + +**(f) w/o Separate Decoder**: 单一解码器须在不同模态特征分布间妥协。独立解码器允许每个模态有定制化的上采样和分类边界。 + +--- + +### Table 3: MoE Hyperparameter Study (MoE超参数研究) + +**目的**: 研究专家数量和 top-k 路由对性能的影响。 + +**评估**: 每个变体训练后,分别在 SAR / RGB / GF 三个模态上独立测试 mIoU。 + +| num_experts | top_k | SAR | RGB | GF | Avg | +|-------------|-------|-----|-----|-----|-----| +| 6 | 1 | - | - | - | - | +| 6 | 2 | - | - | - | - | +| 6 | 3 | - | - | - | - | +| 8 | 1 | - | - | - | - | +| 8 | 2 | - | - | - | - | +| **8** | **3** | **-** | **-** | **-** | **-** | + +**Config Files**: +- `ablations/ablation_e6_k1.py` +- `ablations/ablation_e6_k2.py` +- `ablations/ablation_e6_k3.py` +- `ablations/ablation_e8_k1.py` +- `ablations/ablation_e8_k2.py` +- Full Model (E=8, K=3) + +**运行**: +```bash +bash scripts/run_all_experiments.sh table3 +``` + +**讨论要点**: +- **top_k=1**: 过于稀疏,每个 token 仅使用 1 个专家,限制了多专家集成效果 +- **top_k=2**: 中等稀疏度,基本满足 3 模态路由需求 +- **top_k=3 (ours)**: 允许 token 同时受益于多个专门化专家 +- **6 experts**: 对 3 种模态来说容量偏小,平均每模态仅 2 个专家 +- **8 experts (ours)**: 最佳平衡点,每模态约 2-3 个专家并允许跨模态共享 + +--- + +### Table 4: Single-Modal vs Multi-Modal Training (单模态 vs 多模态) + +**目的**: 证明多模态联合训练通过跨模态知识迁移提升了每个模态的性能。 + +**评估**: 单模态训练只测试本模态;多模态训练测试所有三个模态。 + +| 训练数据 | 测试模态 | mIoU | +|----------|----------|------| +| SAR-only | SAR | - | +| RGB-only | RGB | - | +| GF-only | GF | - | +| Multi-modal (Ours) | SAR | - | +| Multi-modal (Ours) | RGB | - | +| Multi-modal (Ours) | GF | - | + +**Config Files**: +- SAR-only → `multimodal_floodnet_sar_only_swinbase_moe_config.py` +- RGB-only → `ablations/ablation_rgb_only.py` +- GF-only → `ablations/ablation_gf_only.py` +- Multi-modal → Full Model + +**运行**: +```bash +bash scripts/run_all_experiments.sh table4 +``` + +**讨论要点**: +- 多模态训练预期在**每个**模态上都优于单模态训练 +- SAR 受益最大:有限的 SAR 数据通过 RGB/GF 知识迁移得到增强 +- MoE 架构防止负迁移:模态专用专家避免了不同传感器类型间的干扰 +- 共享专家捕获通用洪水模式(水体几何、边界结构),在所有模态间迁移 + +--- + +## 3. Running Experiments + +```bash +# 分表运行(推荐) +bash scripts/run_all_experiments.sh table2 # 组件消融 (6 实验 × 3 模态测试) +bash scripts/run_all_experiments.sh table3 # MoE 超参数 (6 实验 × 3 模态测试) +bash scripts/run_all_experiments.sh table4 # 单/多模态 (4 实验) + +# 全部运行 +bash scripts/run_all_experiments.sh all + +# 指定 GPU +GPU_IDS=0,1 bash scripts/run_all_experiments.sh table2 +``` + +**结果目录结构**: +``` +work_dirs/paper_experiments/ +├── results_summary.txt # 汇总日志 +├── table2/ +│ ├── full_model/ +│ │ ├── best_mIoU_*.pth +│ │ ├── test_sar/test_log.txt +│ │ ├── test_rgb/test_log.txt +│ │ └── test_GF/test_log.txt +│ ├── no_moe/ +│ ├── no_modal_specific_stem/ +│ ├── no_modal_bias/ +│ ├── no_shared_experts/ +│ └── shared_decoder/ +├── table3/ +│ ├── e6_k1/ +│ ├── e6_k2/ +│ ├── e6_k3/ +│ ├── e8_k1/ +│ └── e8_k2/ +└── table4/ + ├── sar_only/ + │ └── test_sar/test_log.txt + ├── rgb_only/ + │ └── test_rgb/test_log.txt + ├── gf_only/ + │ └── test_GF/test_log.txt + └── (multi_modal reuses table2/full_model) +``` + +**实验总量**: 15 个训练 + 48 次测试 +- Table 2: 6 训练 × 3 模态测试 = 6 + 18 +- Table 3: 5 训练 × 3 模态测试 + 1 复用 = 5 + 18 +- Table 4: 3 训练 × 1 模态测试 + 1 复用 × 3 测试 = 3 + 6 diff --git a/docs/method_section_details.md b/docs/method_section_details.md new file mode 100644 index 0000000000..b97fb905d6 --- /dev/null +++ b/docs/method_section_details.md @@ -0,0 +1,370 @@ +# 论文 Method 部分 — 技术内容详细提纲与草稿素材 + +> 以下内容基于代码实现逐一提取,所有数字、公式、维度均与代码一致。 +> 供写作时参考,可直接改写为论文正文。 + +--- + +## 3.1 Overall Framework(整体框架) + +### 需要写的内容 + +一段话概述整体流水线,配合 Fig.1(架构图)。 + +### 技术事实 + +整体流水线为: + +``` +多模态输入(SAR/RGB/GF,通道数各异) + → Modal-Specific Stem(模态独立 Patch Embedding) + → 4-Stage Swin Transformer(选定 block 中 FFN 替换为 Sparse MoE) + → Modality-Separate UPerNet Decoder(每种模态独立解码头) + → 二分类分割输出(flood / non-flood) +``` + +模型类:`MultiModalEncoderDecoderV2`。 + +**输入处理**:`MultiModalDataPreProcessor` 不对多模态输入做通道堆叠(因为各模态通道数不同),而是保持为 `List[Tensor]` 传入骨干网络。每个 tensor 仅做空间维度 padding 到 crop_size (256×256)。 + +**训练采样**:采用 `FixedRatioModalSampler` 按固定比例(SAR:RGB:GF = 6:5:5)组成每个 batch,以 GF 为参考模态确定 epoch 长度。这确保 SAR 数据(样本量最少)在每个 batch 中占 37.5%(6/16),高于其在数据集中的自然比例,缓解模态不均衡问题。 + +--- + +## 3.2 Modal-Specific Stem(模态特定嵌入层) + +### 需要写的内容 + +解释为什么不能直接用标准 Patch Embedding,以及你的解决方案。 + +### 技术事实 + +**问题**:三种模态通道数不同(SAR=8, RGB=3, GF=5),标准 Swin Transformer 的 Patch Embedding 是一个固定输入通道的 Conv2d,无法处理可变通道输入。 + +**朴素方案(UnifiedPatchEmbed,用于消融对比)**:将所有模态零填充到最大通道数(8),然后共享一个 `Conv2d(8, embed_dims, kernel_size=4, stride=4)`。缺点是 RGB(3ch→8ch)和 GF(5ch→8ch)引入大量零值,稀释有效信息。 + +**本文方案(ModalSpecificPatchEmbed)**:为每种模态配置独立的投影卷积: + +``` +对于模态 m(通道数为 C_m): + Conv2d(C_m, embed_dims, kernel_size=patch_size, stride=patch_size) + LayerNorm(embed_dims) +``` + +具体实例化: +- SAR: `Conv2d(8, 128, kernel_size=4, stride=4)` + `LayerNorm(128)` +- RGB: `Conv2d(3, 128, kernel_size=4, stride=4)` + `LayerNorm(128)` +- GF: `Conv2d(5, 128, kernel_size=4, stride=4)` + `LayerNorm(128)` + +输入图像经 patch embedding 后,空间分辨率缩小为 H/4 × W/4,特征维度统一为 embed_dims=128。之后所有模态在同一特征空间中处理,共享 Transformer 的注意力层。 + +### 公式 + +$$ +\mathbf{z}_0^{(m)} = \text{LN}\left(\text{Conv}_{m}(\mathbf{x}^{(m)})\right), \quad \mathbf{z}_0^{(m)} \in \mathbb{R}^{\frac{HW}{P^2} \times D} +$$ + +其中 $\text{Conv}_{m}: \mathbb{R}^{C_m \times H \times W} \to \mathbb{R}^{D \times \frac{H}{P} \times \frac{W}{P}}$ 为模态 $m$ 专属的卷积投影,$P=4$ 为 patch 大小,$D=128$ 为嵌入维度。 + +--- + +## 3.3 Swin Transformer Backbone with Sparse MoE(骨干网络) + +### 需要写的内容 + +先简要介绍 Swin Transformer 基本结构(审稿人可能不熟悉),然后重点描述你在哪些位置插入了 MoE,以及为什么是稀疏放置。 + +### 技术事实 + +**Swin-Base 配置**: + +| 参数 | 值 | +|------|-----| +| embed_dims | 128 | +| depths | [2, 2, 18, 2] | +| num_heads | [4, 8, 16, 32] | +| window_size | 7 | +| mlp_ratio | 4.0 | +| drop_path_rate | 0.3 | +| 各 stage 输出通道 | [128, 256, 512, 1024] | + +**每个 Swin Block 的结构**(标准部分,简要提及即可): + +``` +x = x + DropPath(W-MSA(LN(x))) // 窗口多头自注意力 +x = x + DropPath(FFN(LN(x))) // 前馈网络(在选定 block 中替换为 MoE) +``` + +偶数 block 用 W-MSA,奇数 block 用 SW-MSA(shifted window)。 + +**MoE 替换位置(Sparse Placement)**: + +并非所有 block 的 FFN 都替换为 MoE,而是选择性稀疏放置: + +| Stage | 总 block 数 | MoE block 索引 | MoE 数量 | 说明 | +|-------|------------|---------------|---------|------| +| 0 | 2 | [] | 0 | 浅层特征通用性强,无需 MoE | +| 1 | 2 | [1] | 1 | 仅最后一个 block | +| 2 | 18 | [1,3,5,7,9,11,13,15,17] | 9 | 每隔一个 block 放置(交替) | +| 3 | 2 | [0, 1] | 2 | 全部 block | + +**共享专家配置**: + +| Stage | 共享专家数 | +|-------|-----------| +| 0 | 0 | +| 1 | 0 | +| 2 | 2 | +| 3 | 1 | + +Stage 2 是主力计算阶段(18 个 block),交替放置 MoE 允许普通 FFN block 做特征整合,MoE block 做模态专门化。 + +**Stage 间下采样**:使用 PatchMerging,将空间分辨率减半、通道数加倍。 + +--- + +## 3.4 Sparse Mixture-of-Experts Module(稀疏 MoE 模块) + +### 需要写的内容 + +这是方法的核心。需要分三个子部分详细描述:(1) 余弦相似度门控 (2) 模态偏置 (3) 稀疏路由与专家计算。 + +### 3.4.1 Cosine Similarity Gating(余弦相似度门控) + +**技术事实**: + +门控网络 `CosineTopKGate` 的完整计算流程: + +1. **特征池化**:将输入 $\mathbf{X} \in \mathbb{R}^{B \times N \times C}$ 沿 token 维度均值池化得到 $\bar{\mathbf{x}} \in \mathbb{R}^{B \times C}$ + +2. **余弦相似度计算**: + $$\mathbf{l} = \frac{\mathbf{W}_p \bar{\mathbf{x}}}{\|\mathbf{W}_p \bar{\mathbf{x}}\|_2} \cdot \frac{\mathbf{S}}{\|\mathbf{S}\|_2}$$ + 其中 $\mathbf{W}_p \in \mathbb{R}^{C \times d}$ 为投影矩阵(`cosine_projector`),$d = \min(C/2, 256)$;$\mathbf{S} \in \mathbb{R}^{d \times E}$ 为相似度矩阵(`sim_matrix`),$E$ 为专家数。 + +3. **温度缩放**: + $$\mathbf{l} = \mathbf{l} \cdot \exp\left(\text{clamp}(\tau, \max=\ln 100)\right)$$ + 其中 $\tau$ 为可学习温度参数,初始化为 $\ln(1/0.5) \approx 0.693$。 + +4. **模态偏置注入**(见 3.4.2) + +5. **噪声注入**(仅训练阶段): + $$\tilde{\mathbf{l}} = \mathbf{l} + \epsilon \cdot \text{Softplus}(\bar{\mathbf{x}} \mathbf{W}_{\text{noise}}), \quad \epsilon \sim \mathcal{N}(0, 1)$$ + 其中 $\mathbf{W}_{\text{noise}} \in \mathbb{R}^{C \times E}$。噪声鼓励探索,防止路由僵化。 + +6. **Top-K 选择与归一化**: + $$\text{TopK}(\tilde{\mathbf{l}}, k) \to (\mathbf{l}_{\text{top}}, \mathbf{I}_{\text{top}})$$ + $$\mathbf{g}_{\text{top}} = \text{Softmax}(\mathbf{l}_{\text{top}})$$ + 将 $\mathbf{g}_{\text{top}}$ scatter 回完整的 $E$ 维向量,未被选中的专家权重为 0。 + +本文配置:$E=8$, $k=3$。 + +### 公式(汇总版,适合论文) + +$$ +\mathbf{g} = \text{TopK-Softmax}\Big(\underbrace{\text{CosSim}(\mathbf{W}_p \bar{\mathbf{x}},\ \mathbf{S})}_{\text{content-based routing}} \cdot e^{\tau} + \underbrace{\mathbf{b}_{m}}_{\text{modal bias}},\ k\Big) +$$ + +### 3.4.2 Learnable Modal Bias(可学习模态偏置) + +**技术事实**: + +参数:$\mathbf{B}_{\text{modal}} \in \mathbb{R}^{M \times E}$,其中 $M=3$(三种模态),$E=8$(专家数)。初始化为零矩阵。 + +应用方式:对于模态 $m$ 的输入样本,在门控 logits 上叠加偏置: +$$\mathbf{l}_i = \mathbf{l}_i + \mathbf{B}_{\text{modal}}[m, :]$$ + +每个 MoE 层有独立的 modal_bias 参数(不共享)。 + +**作用**:即使两个不同模态的 token 在特征空间中相近(余弦相似度门控给出相似路由),modal bias 也能将它们导向不同的专家。这为路由网络提供了显式的模态先验。 + +**学习率**:modal_bias 使用 3× 的学习率倍率(`lr_mult=3.0`),加速模态偏好的学习。 + +### 3.4.3 Expert Computation & Shared Experts(专家计算与共享专家) + +**路由专家(Routed Experts)**: + +每个 MoE 层包含 $E=8$ 个结构相同但参数独立的 FFN: +$$\text{FFN}_i(\mathbf{x}) = \mathbf{W}_2^{(i)} \cdot \text{GELU}(\mathbf{W}_1^{(i)} \mathbf{x}) + \mathbf{bias}_2^{(i)}$$ +其中 $\mathbf{W}_1^{(i)} \in \mathbb{R}^{C \times 4C}$, $\mathbf{W}_2^{(i)} \in \mathbb{R}^{4C \times C}$(mlp_ratio=4)。 + +**稀疏派发(Sparse Dispatch)**: + +通过 `SparseDispatcher` 实现,仅将每个样本发送到被选中的 top-k 专家: +1. 根据 gate 矩阵 $\mathbf{G} \in \mathbb{R}^{B \times E}$ 的非零位置确定路由 +2. 按专家分组派发输入 +3. 各专家独立计算 +4. 按 gate 权重加权合并:$\mathbf{y}_{\text{routed}} = \sum_{i \in \text{TopK}} g_i \cdot \text{FFN}_i(\mathbf{x})$ + +**共享专家(Shared Experts)**: + +共享专家不经过门控路由,对所有输入无条件执行: +$$\mathbf{y}_{\text{shared}} = \text{FFN}_{\text{shared}}(\mathbf{x})$$ + +当 Stage 2 配置 2 个共享专家时,其 FFN hidden_dim 为 $2 \times 4C$(等效于将两个专家的 FFN 拼接为一个更宽的 FFN)。 + +**最终输出**: +$$\mathbf{y} = \mathbf{y}_{\text{routed}} + \mathbf{y}_{\text{shared}}$$ + +--- + +## 3.5 Modality-Separate Decoder(模态独立解码头) + +### 需要写的内容 + +解释为什么不用共享解码头,以及独立解码头的实现方式。 + +### 技术事实 + +**结构**:`decoder_mode='separate'` 时,为每种模态(sar/rgb/GF)各创建一个独立的 UPerHead + FCNHead(辅助头)。 + +**主解码头(UPerHead)配置**: + +| 参数 | 值 | +|------|-----| +| in_channels | [128, 256, 512, 1024](对应 4 个 stage 输出) | +| pool_scales | (1, 2, 3, 6)(PPM 模块) | +| channels | 512 | +| dropout_ratio | 0.1 | +| num_classes | 2(flood / non-flood) | +| loss | CrossEntropyLoss (weight=1.0) | + +**辅助解码头(FCNHead)配置**: + +| 参数 | 值 | +|------|-----| +| in_channels | 512(取 Stage 2 输出) | +| channels | 256 | +| num_convs | 1 | +| loss | CrossEntropyLoss (weight=0.4) | + +**路由逻辑**:训练时,根据每个样本 metainfo 中的 `dataset_name` 字段,将同一 batch 内的样本按模态分组,分别送入对应的解码头计算 loss。推理时,根据输入样本的模态类型选择对应的解码头。 + +**与共享解码头对比**(消融实验 Table 2(f)):共享模式下所有模态使用同一组 UPerHead 参数。 + +--- + +## 3.6 Training Objectives(训练目标) + +### 需要写的内容 + +详细描述总损失函数的组成。 + +### 技术事实 + +总损失由三部分组成: + +$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{seg}} + \lambda_{\text{bal}} \mathcal{L}_{\text{balance}} + \lambda_{\text{div}} \mathcal{L}_{\text{diversity}}$$ + +**1. 分割损失 $\mathcal{L}_{\text{seg}}$**: + +各模态独立计算 CrossEntropy loss,主头权重 1.0 + 辅助头权重 0.4: +$$\mathcal{L}_{\text{seg}} = \sum_{m \in \{sar, rgb, GF\}} \left(\mathcal{L}_{\text{CE}}^{m} + 0.4 \cdot \mathcal{L}_{\text{CE,aux}}^{m}\right)$$ + +**2. 负载均衡损失 $\mathcal{L}_{\text{balance}}$**($\lambda_{\text{bal}}=1.0$): + +鼓励专家负载均匀,防止少数专家被过度使用: +$$\mathcal{L}_{\text{balance}} = \text{CV}^2(\text{importance}) + \text{CV}^2(\text{load})$$ + +其中: +- $\text{importance}_i = \sum_{b=1}^{B} g_{b,i}$(专家 $i$ 的总 gate 权重) +- $\text{load}_i = \sum_{b=1}^{B} \mathbb{1}[g_{b,i} > 0]$(专家 $i$ 被选中的次数) +- $\text{CV}^2(\mathbf{x}) = \frac{\text{Var}(\mathbf{x})}{(\text{Mean}(\mathbf{x}))^2 + \epsilon}$(变异系数的平方) + +该损失在所有 MoE 层上求平均。 + +**3. 专家多样性损失 $\mathcal{L}_{\text{diversity}}$**($\lambda_{\text{div}}=0.1$): + +防止不同专家学到相似的表征(专家坍缩): + +$$\mathcal{L}_{\text{diversity}} = \text{ReLU}\left(\frac{1}{E(E-1)/2} \sum_{i None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index f8ad750d76..d9c562805e 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -26,6 +26,10 @@ from .refuge import REFUGEDataset from .stare import STAREDataset from .synapse import SynapseDataset +from .UAVflood import UAVfloodDataset +from .multimodal_deepflood import MultiModalDeepflood +from .sen1floods11 import Sen1Floods11Dataset +from .fixed_ratio_modal_sampler import FixedRatioModalSampler # yapf: disable from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, BioMedical3DRandomCrop, BioMedical3DRandomFlip, @@ -61,5 +65,6 @@ 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile', 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset', - 'NYUDataset', 'HSIDrive20Dataset' + 'NYUDataset', 'HSIDrive20Dataset', 'UAVfloodDataset', + 'MultiModalDeepflood', 'Sen1Floods11Dataset', 'FixedRatioModalSampler' ] diff --git a/mmseg/datasets/fixed_ratio_modal_sampler.py b/mmseg/datasets/fixed_ratio_modal_sampler.py new file mode 100644 index 0000000000..9f1ab3c61f --- /dev/null +++ b/mmseg/datasets/fixed_ratio_modal_sampler.py @@ -0,0 +1,193 @@ +""" +Fixed Ratio Modal Sampler for Multi-Dataset Training - MMSeg 1.x Version + +Registered as DATA_SAMPLERS for mmengine compatibility. +""" +from typing import Dict, Iterator, List, Optional, Sequence, Sized, Union + +import torch +from torch.utils.data import Sampler + +from mmseg.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class FixedRatioModalSampler(Sampler): + """Fixed ratio modal sampler. + + Ensures each batch has a fixed modal composition. + + Args: + dataset: Dataset object, must have data_list with modal_type. + modal_ratios: Modal sampling ratios. + modal_order: Modal ordering. + reference_modal: Reference modal for epoch length calculation. + seed: Random seed. + batch_size: Batch size (replaces samples_per_gpu). + """ + + def __init__( + self, + dataset: Sized, + modal_ratios: Optional[Union[Sequence[int], Dict[str, int]]] = None, + modal_order: Optional[Sequence[str]] = None, + reference_modal: Optional[str] = None, + seed: Optional[int] = None, + batch_size: int = 1, + **kwargs, + ): + if modal_ratios is None: + raise ValueError( + "modal_ratios must be provided for FixedRatioModalSampler.") + + self.dataset = dataset + self.batch_size = batch_size + self.seed = 0 if seed is None else seed + self.epoch = 0 + self.reference_modal = reference_modal + + if isinstance(modal_ratios, dict): + if modal_order is None: + modal_order = list(modal_ratios.keys()) + self.modal_order = list(modal_order) + self.modal_ratios = [ + modal_ratios[modal] for modal in self.modal_order] + else: + if modal_order is None: + raise ValueError( + "modal_order must be provided when " + "modal_ratios is a list.") + self.modal_order = list(modal_order) + self.modal_ratios = list(modal_ratios) + + self._validate_ratios() + self.modal_indices = self._group_by_modal() + self.num_samples = self._calculate_num_samples() + + self._print_statistics() + + def _validate_ratios(self): + if len(self.modal_order) != len(self.modal_ratios): + raise ValueError( + "modal_order and modal_ratios must have the same length.") + if any(ratio <= 0 for ratio in self.modal_ratios): + raise ValueError("All modal ratios must be positive.") + if sum(self.modal_ratios) > self.batch_size: + raise ValueError( + "Sum of modal ratios must be <= batch_size " + f"(got {sum(self.modal_ratios)} > {self.batch_size}).") + if self.batch_size % sum(self.modal_ratios) != 0: + raise ValueError( + "batch_size must be divisible by sum(modal_ratios) " + f"(got {self.batch_size} % {sum(self.modal_ratios)} != 0).") + + def _group_by_modal(self) -> Dict[str, List[int]]: + modal_indices = {modal: [] for modal in self.modal_order} + # In 1.x, dataset.data_list contains the data info dicts + for idx in range(len(self.dataset)): + data_info = self.dataset.get_data_info(idx) + modal_type = data_info.get('modal_type', 'unknown') + if modal_type in modal_indices: + modal_indices[modal_type].append(idx) + return modal_indices + + def _calculate_num_samples(self) -> int: + total_ratio = sum(self.modal_ratios) + batch_repeats = self.batch_size // total_ratio + + if self.reference_modal is None: + total_original = sum( + len(v) for v in self.modal_indices.values()) + full_batches = total_original // self.batch_size + return full_batches * self.batch_size + + if self.reference_modal not in self.modal_order: + raise ValueError( + f"reference_modal '{self.reference_modal}' " + f"is not in modal_order.") + + reference_count = len(self.modal_indices[self.reference_modal]) + if reference_count == 0: + raise ValueError( + f"No samples found for reference_modal " + f"'{self.reference_modal}'.") + + reference_ratio = self.modal_ratios[ + self.modal_order.index(self.reference_modal)] + total_groups = reference_count // reference_ratio + full_batches = total_groups // batch_repeats + return full_batches * self.batch_size + + def _print_statistics(self): + print("\n" + "=" * 60) + print("Fixed Ratio Modal Sampler Statistics") + print("=" * 60) + print(f"Batch size: {self.batch_size}") + print("Modal Ratios:") + for modal, ratio in zip(self.modal_order, self.modal_ratios): + print(f" {modal}: {ratio}") + print("Modal Distribution (Original):") + total_original = sum(len(v) for v in self.modal_indices.values()) + for modal in self.modal_order: + count = len(self.modal_indices[modal]) + percentage = ( + (count / total_original * 100) + if total_original > 0 else 0.0) + print(f" {modal}: {count:5d} samples ({percentage:5.2f}%)") + if self.reference_modal is not None: + reference_count = len( + self.modal_indices[self.reference_modal]) + reference_ratio = self.modal_ratios[ + self.modal_order.index(self.reference_modal)] + print( + f"Reference modal: {self.reference_modal} " + f"(count={reference_count}, ratio={reference_ratio})") + print(f"Total samples per epoch: {self.num_samples}") + print(f"Iterations per epoch: {len(self)}") + print("=" * 60 + "\n") + + def __iter__(self) -> Iterator[int]: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + per_modal_pools = {} + per_modal_pos = {} + for modal, indices in self.modal_indices.items(): + if len(indices) == 0: + raise ValueError( + f"No samples found for modal '{modal}'.") + perm = torch.randperm(len(indices), generator=g) + per_modal_pools[modal] = [indices[i] for i in perm.tolist()] + per_modal_pos[modal] = 0 + + indices = [] + batch_repeats = self.batch_size // sum(self.modal_ratios) + num_batches = len(self) + + for _ in range(num_batches): + batch_indices = [] + for _ in range(batch_repeats): + for modal, ratio in zip( + self.modal_order, self.modal_ratios): + pool = per_modal_pools[modal] + pos = per_modal_pos[modal] + for _ in range(ratio): + if pos >= len(pool): + perm = torch.randperm( + len(pool), generator=g) + pool = [pool[i] for i in perm.tolist()] + per_modal_pools[modal] = pool + pos = 0 + batch_indices.append(pool[pos]) + pos += 1 + per_modal_pos[modal] = pos + + indices.extend(batch_indices) + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples // self.batch_size + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch diff --git a/mmseg/datasets/multimodal_deepflood.py b/mmseg/datasets/multimodal_deepflood.py new file mode 100644 index 0000000000..520be4e03a --- /dev/null +++ b/mmseg/datasets/multimodal_deepflood.py @@ -0,0 +1,185 @@ +""" +Multi-Modal Deepflood Dataset - MMSeg 1.x Version + +Key changes from 0.x: +- Base class: CustomDataset -> BaseSegDataset +- Registry: from builder import DATASETS -> from mmseg.registry import DATASETS +- load_annotations() -> load_data_list() +- img_dir/ann_dir -> data_prefix dict +- img_infos -> data_list +- evaluate() -> IoUMetric (external evaluator) +""" +import os.path as osp +from collections import OrderedDict +from typing import List + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.logging import print_log + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class MultiModalDeepflood(BaseSegDataset): + """Multi-Modal Deepflood dataset for flood segmentation. + + Supports SAR, RGB, and GaoFen (GF) modalities with automatic + modality identification from filenames. + """ + + METAINFO = dict( + classes=('Background', 'Flood'), + palette=[[0, 0, 0], [255, 0, 0]] + ) + + # Modal configuration + MODAL_CONFIGS = { + 'sar': {'channels': 8, 'pattern': 'sar'}, + 'rgb': {'channels': 3, 'pattern': 'rgb'}, + 'GF': {'channels': 5, 'pattern': 'GF'}, + } + + MODAL_TO_DATASET = { + 'sar': {'dataset_source': 0, 'dataset_name': 'sar'}, + 'rgb': {'dataset_source': 1, 'dataset_name': 'rgb'}, + 'GF': {'dataset_source': 2, 'dataset_name': 'GF'}, + } + + def __init__(self, filter_modality=None, **kwargs): + # Set default suffixes for flood data + kwargs.setdefault('img_suffix', '.tif') + kwargs.setdefault('seg_map_suffix', '.png') + kwargs.setdefault('reduce_zero_label', False) + # filter_modality: str or None, e.g. 'sar', 'rgb', 'GF' + # When set, only samples of this modality are kept. + self.filter_modality = filter_modality + super().__init__(**kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory. + + Returns: + list[dict]: All data info of dataset, each dict contains: + - img_path (str) + - seg_map_path (str) + - modal_type (str) + - actual_channels (int) + - dataset_source (int) + - dataset_name (str) + - label_map (dict or None) + - reduce_zero_label (bool) + - seg_fields (list) + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + + if not osp.isdir(self.ann_file) and self.ann_file: + # Load from annotation file + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, + img_name + self.img_suffix)) + + # Identify modality + modal_info = self._identify_modality(img_name) + data_info.update(modal_info) + + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + # Scan directory + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + + # Identify modality + modal_info = self._identify_modality(img) + data_info.update(modal_info) + + if ann_dir is not None: + seg_map = img[:-len(self.img_suffix)] + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + + data_list = sorted(data_list, key=lambda x: x['img_path']) + + # Filter by modality if specified + if self.filter_modality is not None: + data_list = [ + d for d in data_list + if d['modal_type'] == self.filter_modality + ] + print_log( + f'Filtered to modality "{self.filter_modality}": ' + f'{len(data_list)} images', + logger='current') + + print_log(f'Loaded {len(data_list)} images', logger='current') + self._print_modal_statistics(data_list) + + return data_list + + def _identify_modality(self, img_name): + """Identify modality from filename.""" + img_name_lower = img_name.lower() + + for modal_name, config in self.MODAL_CONFIGS.items(): + if config['pattern'].lower() in img_name_lower: + dataset_info = self.MODAL_TO_DATASET.get(modal_name, { + 'dataset_source': 1, + 'dataset_name': 'rgb' + }) + + return { + 'modal_type': modal_name, + 'actual_channels': config['channels'], + 'dataset_source': dataset_info['dataset_source'], + 'dataset_name': dataset_info['dataset_name'], + } + + # Default: RGB + return { + 'modal_type': 'rgb', + 'actual_channels': 3, + 'dataset_source': 1, + 'dataset_name': 'rgb', + } + + def _print_modal_statistics(self, data_list): + """Print dataset modal statistics.""" + modal_counts = {} + for info in data_list: + modal = info['modal_type'] + modal_counts[modal] = modal_counts.get(modal, 0) + 1 + + print_log("\n=== Dataset Modal Statistics ===", logger='current') + for modal, count in sorted(modal_counts.items()): + channels = self.MODAL_CONFIGS.get( + modal, {}).get('channels', 'unknown') + print_log( + f" {modal}: {count} images ({channels} channels)", + logger='current') + print_log("================================\n", logger='current') diff --git a/mmseg/datasets/sen1floods11.py b/mmseg/datasets/sen1floods11.py new file mode 100644 index 0000000000..5f31a4cccc --- /dev/null +++ b/mmseg/datasets/sen1floods11.py @@ -0,0 +1,169 @@ +""" +Sen1Floods11 single-modal flood segmentation dataset - MMSeg 1.x. + +Expected on-disk layout (the default from the official Sen1Floods11 release):: + + data/Sen1Floods11/ + S1Hand/ + __S1Hand.tif # 2-band SAR (VV, VH in dB) + S2Hand/ + __S2Hand.tif # 13-band Sentinel-2 MSI + LabelHand/ + __LabelHand.tif # 1-band label (-1=nodata, 0=bg, 1=flood) + +All images are 512x512. Labels are signed TIFFs; the ``-1`` nodata value +is mapped to ``ignore_index=255`` by the companion +``LoadSen1Floods11Annotation`` transform. + +This dataset plugs into the existing multi-modal Swin+MoE pipeline by +exposing ``modal_type`` / ``actual_channels`` / ``dataset_name`` on each +sample, so the only things a config needs to do to fine-tune on +Sen1Floods11 is to: + + 1. set ``dataset_type='Sen1Floods11Dataset'`` and pick + ``modality='s1'`` or ``'s2'``; + 2. register a new modal in ``model.backbone.modal_configs`` / + ``training_modals`` with the matching channel count; + 3. use ``model.dataset_names=['s1']`` (or ``['s2']``) so the + per-dataset decode head key matches. +""" +import os.path as osp +from typing import List + +import mmengine +import mmengine.fileio as fileio +from mmengine.logging import print_log + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class Sen1Floods11Dataset(BaseSegDataset): + """Sen1Floods11 dataset for single-modal fine-tuning. + + Args: + modality (str): ``'s1'`` (2-band SAR) or ``'s2'`` (13-band MSI). + ann_file (str, optional): Path to a split file listing one + sample base-name per line (e.g. ``Bolivia_23014``). The + path may be relative to ``data_root``. If empty or absent, + the whole ``data_prefix['img_path']`` directory is scanned. + **kwargs: forwarded to :class:`BaseSegDataset`. ``img_suffix``, + ``seg_map_suffix``, ``data_prefix`` all default to values + that match the Sen1Floods11 layout described above. + """ + + METAINFO = dict( + classes=('Background', 'Flood'), + palette=[[0, 0, 0], [255, 0, 0]], + ) + + # Per-modality layout / shape info. + MODAL_CONFIG = { + 's1': { + 'channels': 2, + 'img_subdir': 'S1Hand', + 'img_suffix': '_S1Hand.tif', + 'dataset_source': 0, + 'dataset_name': 's1', + }, + 's2': { + 'channels': 13, + 'img_subdir': 'S2Hand', + 'img_suffix': '_S2Hand.tif', + 'dataset_source': 1, + 'dataset_name': 's2', + }, + } + + SEG_SUBDIR = 'LabelHand' + SEG_SUFFIX = '_LabelHand.tif' + + def __init__(self, modality: str = 's1', **kwargs): + if modality not in self.MODAL_CONFIG: + raise ValueError( + f'Unknown modality "{modality}". Expected one of ' + f'{list(self.MODAL_CONFIG.keys())}') + self.modality = modality + + modal_cfg = self.MODAL_CONFIG[modality] + + kwargs.setdefault('img_suffix', modal_cfg['img_suffix']) + kwargs.setdefault('seg_map_suffix', self.SEG_SUFFIX) + kwargs.setdefault('reduce_zero_label', False) + + data_prefix = kwargs.pop('data_prefix', None) + if not data_prefix: + data_prefix = dict( + img_path=modal_cfg['img_subdir'], + seg_map_path=self.SEG_SUBDIR, + ) + + super().__init__(data_prefix=data_prefix, **kwargs) + + # ------------------------------------------------------------------ + # core list building + # ------------------------------------------------------------------ + def load_data_list(self) -> List[dict]: + modal_cfg = self.MODAL_CONFIG[self.modality] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + assert img_dir is not None, \ + 'Sen1Floods11Dataset requires data_prefix["img_path"]' + + data_list = [] + + ann_file = self.ann_file + use_ann_file = bool(ann_file) and osp.isfile(ann_file) + + if use_ann_file: + lines = mmengine.list_from_file( + ann_file, backend_args=self.backend_args) + for line in lines: + base = line.strip() + if not base: + continue + # Accept entries that include any of the known suffixes. + for suf in (modal_cfg['img_suffix'], self.SEG_SUFFIX): + if base.endswith(suf): + base = base[:-len(suf)] + break + img_name = base + modal_cfg['img_suffix'] + data_list.append( + self._build_info(img_name, img_dir, ann_dir)) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=modal_cfg['img_suffix'], + recursive=True, + backend_args=self.backend_args): + data_list.append( + self._build_info(img, img_dir, ann_dir)) + + data_list = sorted(data_list, key=lambda x: x['img_path']) + + print_log( + f'[Sen1Floods11] modality="{self.modality}" ' + f'loaded {len(data_list)} images from "{img_dir}"', + logger='current') + + return data_list + + def _build_info(self, img_name: str, img_dir: str, + ann_dir: str) -> dict: + modal_cfg = self.MODAL_CONFIG[self.modality] + info = dict( + img_path=osp.join(img_dir, img_name), + modal_type=self.modality, + actual_channels=modal_cfg['channels'], + dataset_source=modal_cfg['dataset_source'], + dataset_name=modal_cfg['dataset_name'], + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label, + seg_fields=[], + ) + if ann_dir is not None: + base = img_name[:-len(modal_cfg['img_suffix'])] + info['seg_map_path'] = osp.join(ann_dir, base + self.SEG_SUFFIX) + return info diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 125f070818..33d8f5c242 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -1,8 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. from .formatting import PackSegInputs +from .multimodal_pipelines import (LoadMultiModalImageFromFile, + MultiModalNormalize, + GenerateBoundary, + MultiModalPad, + PackMultiModalSegInputs) +from .sen1floods11_pipelines import LoadSen1Floods11Annotation from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadDepthAnnotation, LoadImageFromNDArray, + LoadMultiBandTiffFromFile, LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile) # yapf: disable from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, @@ -25,6 +32,9 @@ 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', - 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix', - 'RandomFlip', 'Resize' + 'LoadMultiBandTiffFromFile', 'LoadMultipleRSImageFromFile', + 'LoadDepthAnnotation', 'RandomDepthMix', 'RandomFlip', 'Resize', + 'LoadMultiModalImageFromFile', 'MultiModalNormalize', + 'GenerateBoundary', 'MultiModalPad', 'PackMultiModalSegInputs', + 'LoadSen1Floods11Annotation' ] diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py index c28937e55e..a846f6b5f9 100644 --- a/mmseg/datasets/transforms/loading.py +++ b/mmseg/datasets/transforms/loading.py @@ -18,6 +18,11 @@ except ImportError: gdal = None +try: + import tifffile +except ImportError: + tifffile = None + @TRANSFORMS.register_module() class LoadAnnotations(MMCV_LoadAnnotations): @@ -769,3 +774,72 @@ def transform(self, results: dict) -> Optional[dict]: results['img_shape'] = img.shape[:2] results['ori_shape'] = img.shape[:2] return results + + +@TRANSFORMS.register_module() +class LoadMultiBandTiffFromFile(BaseTransform): + """Load a multi-band TIFF image from file using tifffile. + + This loader supports TIFF images with more than 4 channels, which are + not supported by OpenCV. It uses the tifffile library to read the image. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image keeps its original + data type. Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + self.to_float32 = to_float32 + + if tifffile is None: + raise RuntimeError('tifffile is not installed. Please install it ' + 'using: pip install tifffile') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + try: + # Read multi-band TIFF using tifffile + img = tifffile.imread(filename) + + # tifffile returns (height, width, channels) for multi-channel images + # or (channels, height, width) depending on the file structure + # We need to ensure the format is (height, width, channels) + if img.ndim == 3 and img.shape[0] < img.shape[2]: + # If first dimension is smaller, it's likely (channels, height, width) + # Convert to (height, width, channels) + img = np.transpose(img, (1, 2, 0)) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + except Exception as e: + raise RuntimeError(f'Failed to load image from {filename}: {str(e)}') + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str diff --git a/mmseg/datasets/transforms/multimodal_pipelines.py b/mmseg/datasets/transforms/multimodal_pipelines.py new file mode 100644 index 0000000000..947bd47250 --- /dev/null +++ b/mmseg/datasets/transforms/multimodal_pipelines.py @@ -0,0 +1,435 @@ +""" +Multi-Modal Data Loading Pipeline - MMSeg 1.x Version + +Key changes from 0.x: +- Registry: PIPELINES -> TRANSFORMS +- __call__ -> transform (BaseTransform) +- Normalize moved to SegDataPreProcessor +- Collect + FormatBundle -> PackMultiModalSegInputs (SegDataSample) +""" +import copy + +import numpy as np +from mmcv.transforms import BaseTransform, to_tensor +from mmengine.structures import PixelData + +from mmseg.registry import TRANSFORMS +from mmseg.structures import SegDataSample + +try: + import tifffile +except ImportError: + tifffile = None + +try: + import mmcv +except ImportError: + mmcv = None + +import os.path as osp + + +@TRANSFORMS.register_module() +class LoadMultiModalImageFromFile(BaseTransform): + """Load multi-modal image - no zero padding version. + + Required Keys: + - img_path (str) + - modal_type (str) + - actual_channels (int) + + Added Keys: + - img (np.ndarray) + - img_shape (tuple) + - ori_shape (tuple) + """ + + def __init__(self, + to_float32=False, + color_type='unchanged', + imdecode_backend='cv2'): + self.to_float32 = to_float32 + self.color_type = color_type + self.imdecode_backend = imdecode_backend + + def transform(self, results: dict) -> dict: + filename = results['img_path'] + + # Support TIFF format + if filename.endswith('.tif') or filename.endswith('.tiff'): + if tifffile is not None: + img = tifffile.imread(filename) + if len(img.shape) == 3 and img.shape[0] < img.shape[2]: + img = np.transpose(img, (1, 2, 0)) + elif mmcv is not None: + img_bytes = mmcv.FileClient.infer_client( + None, filename).get(filename) + img = mmcv.imfrombytes( + img_bytes, + flag=self.color_type, + backend=self.imdecode_backend) + else: + raise ImportError( + 'tifffile or mmcv required for TIFF loading') + else: + if mmcv is not None: + img_bytes = mmcv.FileClient.infer_client( + None, filename).get(filename) + img = mmcv.imfrombytes( + img_bytes, + flag=self.color_type, + backend=self.imdecode_backend) + else: + from PIL import Image + img = np.array(Image.open(filename)) + + if self.to_float32: + img = img.astype(np.float32) + + actual_channels = results.get('actual_channels', 3) + + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + + current_channels = img.shape[2] + + if current_channels != actual_channels: + actual_channels = current_channels + results['actual_channels'] = actual_channels + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + results['ori_filename'] = osp.basename(filename) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32}, ' + repr_str += f"color_type='{self.color_type}')" + return repr_str + + +@TRANSFORMS.register_module() +class MultiModalNormalize(BaseTransform): + """Multi-modal normalization - supports dynamic channel count. + + Required Keys: + - img (np.ndarray) + - modal_type (str) + - actual_channels (int) + + Modified Keys: + - img (np.ndarray) + + Added Keys: + - img_norm_cfg (dict) + """ + + NORM_CONFIGS = { + 'rgb': { + 'mean': [123.675, 116.28, 103.53], + 'std': [58.395, 57.12, 57.375], + }, + 'sar': { + 'mean': [0.23651549, 0.31761484, 0.18514981, 0.26901252, + -14.57879175, -8.6098158, -14.2907338, -8.33534564], + 'std': [0.16280619, 0.20849304, 0.14008107, 0.19767644, + 4.07141682, 3.94773216, 4.21006244, 4.05494136], + }, + 'multispectral': { + 'mean': [1353., 1329., 1627., 1935., 2268., 2723., 3154., + 3541., 3652., 3416., 1112., 2619., 2060.], + 'std': [1108., 942., 976., 1164., 1196., 1351., 1500., + 1605., 1611., 1288., 770., 1325., 1186.], + }, + 'GF': { + 'mean': [432.02181, 315.92948, 246.468659, + 310.61462, 360.267789], + 'std': [97.73313111900238, 85.78646917160748, + 95.78015824658593, + 124.84677067613467, 251.73965882246978], + }, + # ---- Sen1Floods11: S1Hand (2-band VV/VH SAR, dB) ---- + # Defaults computed from the Sen1Floods11 train split with + # the nodata pixels (-1) masked out. Re-run + # tools/compute_sen1floods11_stats.py on your own split to + # refresh these numbers if needed. + 's1': { + 'mean': [-10.483032437170175, -17.362463068117055], + 'std': [4.178513068178825, 4.863193681650141], + }, + # ---- Sen1Floods11: S2Hand (13-band Sentinel-2 MSI, TOA*10000) ---- + 's2': { + 'mean': [1483.1443242989628, 1234.2590152666212, 1204.8650733526135, 1034.8886830055903, 1305.9293935728826, 2257.5830489084565, 2723.9229150471333, 2515.8173451011917, 2957.957558849772, 447.04848745605636, 57.01156794081842, 1893.1678433033005, 1040.2051810757566], + 'std': [314.8831761693824, 341.1134613706103, 367.08026898219816, 524.9257541131836, 446.1063090649358, 697.2651115379473, 897.1816646156293, 868.071173572936, 1031.6451147331716, 287.1789646791449, 130.50110493994782, 830.2058546614454, 610.0878654966048], + }, + } + + def __init__(self, to_rgb=True): + self.to_rgb = to_rgb + + def transform(self, results: dict) -> dict: + img = results['img'] + modal_type = results.get('modal_type', 'rgb') + actual_channels = results['actual_channels'] + + if modal_type in self.NORM_CONFIGS: + config = self.NORM_CONFIGS[modal_type] + mean = np.array( + config['mean'][:actual_channels], dtype=np.float32) + std = np.array( + config['std'][:actual_channels], dtype=np.float32) + else: + mean = np.array([128.0] * actual_channels, dtype=np.float32) + std = np.array([50.0] * actual_channels, dtype=np.float32) + + mean_b = mean.reshape(1, 1, -1) + std_b = std.reshape(1, 1, -1) + + # Sen1Floods11 S1Hand (and many other SAR/MSI products) encode + # nodata pixels as NaN / ±Inf inside the TIFF. Left alone, the + # NaN propagates through `(img - mean) / std` and poisons every + # downstream feature map, so both the CE loss and the MoE + # balance loss become NaN from the very first training step. + # Replace non-finite pixels with the per-channel mean so they + # normalize to 0 (a neutral value the network will learn to + # ignore alongside the 255-ignored label pixels). + if not np.all(np.isfinite(img)): + img = np.where( + np.isfinite(img), + img, + np.broadcast_to(mean_b, img.shape), + ).astype(np.float32) + + img = (img - mean_b) / std_b + + # Final safety net: some SAR products use a finite sentinel + # (e.g. -9999) instead of NaN for nodata, which would otherwise + # survive normalization as a huge outlier. Clipping to ±10σ is + # well outside any legitimate value for the modalities + # configured above, so real data is untouched. + img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0) + np.clip(img, -10.0, 10.0, out=img) + + results['img'] = img + results['img_norm_cfg'] = dict( + mean=mean.tolist(), + std=std.tolist(), + to_rgb=self.to_rgb, + ) + + return results + + +@TRANSFORMS.register_module() +class GenerateBoundary(BaseTransform): + """Generate boundary map from segmentation label. + + Required Keys: + - gt_seg_map (np.ndarray) + + Added Keys: + - gt_boundary_map (np.ndarray) + """ + + def __init__(self, thickness=3, ignore_index=255): + self.thickness = thickness + self.ignore_index = ignore_index + + def transform(self, results: dict) -> dict: + if 'gt_seg_map' not in results: + return results + + seg = results['gt_seg_map'] + + if len(seg.shape) == 3: + seg = seg.squeeze(-1) + + boundary = self._generate_boundary(seg) + + results['gt_boundary_map'] = boundary + if 'seg_fields' in results: + results['seg_fields'].append('gt_boundary_map') + + return results + + def _generate_boundary(self, seg_mask): + import cv2 + + boundary = np.zeros_like(seg_mask, dtype=np.uint8) + + unique_labels = np.unique(seg_mask) + unique_labels = unique_labels[unique_labels != self.ignore_index] + + kernel = np.ones((self.thickness, self.thickness), np.uint8) + + for label in unique_labels: + class_mask = (seg_mask == label).astype(np.uint8) + dilated = cv2.dilate(class_mask, kernel, iterations=1) + eroded = cv2.erode(class_mask, kernel, iterations=1) + class_boundary = dilated - eroded + boundary = np.maximum(boundary, class_boundary) + + boundary[seg_mask == self.ignore_index] = self.ignore_index + + return boundary + + +@TRANSFORMS.register_module() +class MultiModalPad(BaseTransform): + """Pad images with arbitrary channel counts using numpy. + + mmcv.transforms.Pad uses cv2.copyMakeBorder which only supports up to + 4 channels. This transform uses numpy padding instead, so it works + with SAR (8ch), GF (5ch), etc. + + Required Keys: + - img (np.ndarray) + + Modified Keys: + - img (np.ndarray) + - img_shape (tuple) + + Added Keys: + - pad_shape (tuple) + - padding_size (tuple) + + Args: + size (tuple): Target (H, W). + pad_val (float): Padding value for images. Default: 0. + seg_pad_val (float): Padding value for seg maps. Default: 255. + """ + + def __init__(self, size, pad_val=0, seg_pad_val=255): + self.size = size + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def transform(self, results: dict) -> dict: + img = results['img'] + h, w = img.shape[:2] + target_h, target_w = self.size + + pad_h = max(target_h - h, 0) + pad_w = max(target_w - w, 0) + + if pad_h > 0 or pad_w > 0: + if len(img.shape) == 3: + pad_width = ((0, pad_h), (0, pad_w), (0, 0)) + else: + pad_width = ((0, pad_h), (0, pad_w)) + + img = np.pad(img, pad_width, mode='constant', + constant_values=self.pad_val) + results['img'] = img + + # Pad seg maps + for key in results.get('seg_fields', []): + if key in results: + seg = results[key] + if len(seg.shape) == 3: + seg_pad = ((0, pad_h), (0, pad_w), (0, 0)) + else: + seg_pad = ((0, pad_h), (0, pad_w)) + results[key] = np.pad( + seg, seg_pad, mode='constant', + constant_values=self.seg_pad_val) + + if 'gt_seg_map' in results and 'gt_seg_map' not in results.get( + 'seg_fields', []): + seg = results['gt_seg_map'] + if len(seg.shape) == 3: + seg_pad = ((0, pad_h), (0, pad_w), (0, 0)) + else: + seg_pad = ((0, pad_h), (0, pad_w)) + results['gt_seg_map'] = np.pad( + seg, seg_pad, mode='constant', + constant_values=self.seg_pad_val) + + results['img_shape'] = img.shape[:2] + results['pad_shape'] = img.shape[:2] + results['padding_size'] = (0, pad_w, 0, pad_h) + + return results + + +@TRANSFORMS.register_module() +class PackMultiModalSegInputs(BaseTransform): + """Pack multi-modal data into SegDataSample format. + + This replaces CollectMultiModalData + MultiModalFormatBundle from 0.x. + + Required Keys: + - img (np.ndarray) + + Optional Keys: + - gt_seg_map (np.ndarray) + - gt_boundary_map (np.ndarray) + - dataset_name (str) + - modal_type (str) + - actual_channels (int) + + Added Keys: + - inputs (torch.Tensor) + - data_samples (SegDataSample) + """ + + def __init__(self, + meta_keys=('img_path', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', + 'flip', 'flip_direction', + 'modal_type', 'actual_channels', + 'dataset_name', 'img_norm_cfg', + 'reduce_zero_label')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + packed_results = dict() + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor( + np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results['inputs'] = img + + data_sample = SegDataSample() + + if 'gt_seg_map' in results: + gt_seg_map = results['gt_seg_map'] + if len(gt_seg_map.shape) == 2: + data = to_tensor(gt_seg_map[None, ...].astype(np.int64)) + else: + data = to_tensor(gt_seg_map.astype(np.int64)) + gt_sem_seg_data = dict(data=data) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + if 'gt_boundary_map' in results: + gt_boundary_data = dict( + data=to_tensor( + results['gt_boundary_map'][None, ...].astype(np.int64))) + data_sample.set_data( + dict(gt_boundary_map=PixelData(**gt_boundary_data))) + + # Set meta info + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/mmseg/datasets/transforms/sen1floods11_pipelines.py b/mmseg/datasets/transforms/sen1floods11_pipelines.py new file mode 100644 index 0000000000..a8efe54a8e --- /dev/null +++ b/mmseg/datasets/transforms/sen1floods11_pipelines.py @@ -0,0 +1,86 @@ +""" +Sen1Floods11 loading transforms. + +The ``LabelHand`` TIFFs store signed values in ``{-1, 0, 1}``: + + -1: nodata -> mapped to ``ignore_index`` (255 by default) + 0: non-flood / background + 1: flood + +Default :class:`LoadAnnotations` uses ``mmcv.imfrombytes`` which is not +reliable for signed-int TIFF, so we decode with :mod:`tifffile` and +explicitly remap nodata. The returned ``gt_seg_map`` is ``np.uint8`` +which is what the rest of the mmseg 1.x pipeline expects. +""" +import numpy as np +from mmcv.transforms import BaseTransform + +from mmseg.registry import TRANSFORMS + +try: + import tifffile +except ImportError: + tifffile = None + + +@TRANSFORMS.register_module() +class LoadSen1Floods11Annotation(BaseTransform): + """Load a Sen1Floods11 LabelHand TIFF segmentation map. + + Required Keys: + - seg_map_path + + Added Keys: + - gt_seg_map (np.uint8, shape (H, W)) + - seg_fields (list) + + Args: + ignore_index (int): Value that replaces ``nodata_value`` in the + output mask. Default: 255. + nodata_value (int): Raw label value that indicates "no data". + Default: -1. + """ + + def __init__(self, ignore_index: int = 255, nodata_value: int = -1): + if tifffile is None: + raise ImportError( + 'tifffile is required to load Sen1Floods11 labels. ' + 'Install with `pip install tifffile`.') + self.ignore_index = int(ignore_index) + self.nodata_value = int(nodata_value) + + def transform(self, results: dict) -> dict: + seg_path = results['seg_map_path'] + raw = tifffile.imread(seg_path) + + # Some TIFFs come as (1, H, W) / (H, W, 1) + raw = np.squeeze(raw) + if raw.ndim != 2: + raise RuntimeError( + f'Expected a 2D label for {seg_path}, got shape {raw.shape}') + + raw = raw.astype(np.int32) + + gt = np.full(raw.shape, self.ignore_index, dtype=np.uint8) + gt[raw == 0] = 0 + gt[raw == 1] = 1 + # Explicit nodata remap - catches anything that isn't {0, 1}. + gt[raw == self.nodata_value] = self.ignore_index + + # Optional label remapping (preserves parity with LoadAnnotations) + if results.get('label_map', None): + gt_copy = gt.copy() + for old_id, new_id in results['label_map'].items(): + gt[gt_copy == old_id] = new_id + + results['gt_seg_map'] = gt + results.setdefault('seg_fields', []) + if 'gt_seg_map' not in results['seg_fields']: + results['seg_fields'].append('gt_seg_map') + + return results + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'ignore_index={self.ignore_index}, ' + f'nodata_value={self.nodata_value})') diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py index 21cddde89d..ac9282a38c 100644 --- a/mmseg/engine/hooks/visualization_hook.py +++ b/mmseg/engine/hooks/visualization_hook.py @@ -4,9 +4,12 @@ from typing import Optional, Sequence import mmcv +import numpy as np +import torch from mmengine.fileio import get from mmengine.hooks import Hook from mmengine.runner import Runner +from mmengine.utils import mkdir_or_exist from mmengine.visualization import Visualizer from mmseg.registry import HOOKS @@ -15,14 +18,8 @@ @HOOKS.register_module() class SegVisualizationHook(Hook): - """Segmentation Visualization Hook. Used to visualize validation and - testing process prediction results. - - In the testing phase: - - 1. If ``show`` is True, it means that only the prediction results are - visualized without storing data, so ``vis_backends`` needs to - be excluded. + """Segmentation Visualization Hook. This hook visualize the prediction + results during validation and testing. Args: draw (bool): whether to draw prediction results. If it is False, @@ -30,10 +27,8 @@ class SegVisualizationHook(Hook): interval (int): The interval of visualization. Defaults to 50. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. - backend_args (dict, Optional): Arguments to instantiate a file backend. - See https://mmengine.readthedocs.io/en/latest/api/fileio.htm - for details. Defaults to None. - Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + backend_args (dict, Optional): Arguments to instantiate a file client. + Defaults to None. """ def __init__(self, @@ -52,16 +47,47 @@ def __init__(self, 'the prediction results are visualized ' 'without storing data, so vis_backends ' 'needs to be excluded.') - - self.wait_time = wait_time - self.backend_args = backend_args.copy() if backend_args else None + self.wait_time = wait_time + else: + self.wait_time = 0. self.draw = draw - if not self.draw: - warnings.warn('The draw is False, it means that the ' - 'hook for visualization will not take ' - 'effect. The results will NOT be ' - 'visualized or stored.') - self._test_index = 0 + self.backend_args = backend_args + + IGNORE_INDEX = 255 + + def _create_dummy_image(self, height: int, width: int) -> np.ndarray: + """Create a dummy black image for visualization. + + Args: + height: Image height + width: Image width + + Returns: + Black RGB image of shape (H, W, 3) + """ + return np.zeros((height, width, 3), dtype=np.uint8) + + @staticmethod + def _mask_nodata(output: SegDataSample, + ignore_index: int = 255) -> None: + """Set prediction to ignore_index where GT is nodata. + + This prevents nodata pixels from being visualized as a real + class (e.g. Flood). Called AFTER evaluator.process() so + metrics are unaffected. + """ + if not (hasattr(output, 'gt_sem_seg') + and hasattr(output, 'pred_sem_seg')): + return + gt = output.gt_sem_seg.data # (1, H, W) or (H, W) + pred = output.pred_sem_seg.data # same shape + nodata_mask = (gt == ignore_index) + if nodata_mask.any(): + output.pred_sem_seg.data = torch.where( + nodata_mask, + torch.tensor(ignore_index, dtype=pred.dtype, + device=pred.device), + pred) def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[SegDataSample]) -> None: @@ -71,59 +97,84 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, runner (:obj:`Runner`): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. data_batch (dict): Data from dataloader. - outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples - that contain annotations and predictions. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. """ if self.draw is False: return - # There is no guarantee that the same batch of images - # is visualized for each evaluation. - total_curr_iter = runner.iter + batch_idx - - # Visualize only the first data - img_path = outputs[0].img_path - img_bytes = get(img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - window_name = f'val_{osp.basename(img_path)}' - - if total_curr_iter % self.interval == 0: - self._visualizer.add_datasample( - window_name, - img, - data_sample=outputs[0], - show=self.show, - wait_time=self.wait_time, - step=total_curr_iter) + if self.every_n_inner_iters(batch_idx, self.interval): + for output in outputs: + self._mask_nodata(output, self.IGNORE_INDEX) + + img_path = output.img_path + img_name = osp.basename(img_path) + + # Get image size from prediction + if hasattr(output, 'pred_sem_seg'): + pred_shape = output.pred_sem_seg.data.shape + h, w = pred_shape[-2], pred_shape[-1] + elif hasattr(output, 'gt_sem_seg'): + gt_shape = output.gt_sem_seg.data.shape + h, w = gt_shape[-2], gt_shape[-1] + else: + # Fallback to default size + h, w = 512, 512 + + # Create dummy image instead of loading original + img = self._create_dummy_image(h, w) + + # Only draw prediction, not ground truth + self._visualizer.add_datasample( + img_name, + img, + data_sample=output, + draw_gt=False, # Don't draw GT + draw_pred=True, # Only draw prediction + show=self.show, + wait_time=self.wait_time, + step=runner.iter, + with_labels=False) def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[SegDataSample]) -> None: - """Run after every testing iterations. + """Run after every testing iteration. Args: runner (:obj:`Runner`): The runner of the testing process. - batch_idx (int): The index of the current batch in the val loop. + batch_idx (int): The index of the current batch in the test loop. data_batch (dict): Data from dataloader. - outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples - that contain annotations and predictions. + outputs (Sequence[:obj:`SegDataSample`]]): Outputs from model. """ if self.draw is False: return - for data_sample in outputs: - self._test_index += 1 + for output in outputs: + self._mask_nodata(output, self.IGNORE_INDEX) + + img_path = output.img_path + img_name = osp.basename(img_path) - img_path = data_sample.img_path - window_name = f'test_{osp.basename(img_path)}' + # Get image size from prediction + if hasattr(output, 'pred_sem_seg'): + pred_shape = output.pred_sem_seg.data.shape + h, w = pred_shape[-2], pred_shape[-1] + elif hasattr(output, 'gt_sem_seg'): + gt_shape = output.gt_sem_seg.data.shape + h, w = gt_shape[-2], gt_shape[-1] + else: + # Fallback to default size + h, w = 512, 512 - img_path = data_sample.img_path - img_bytes = get(img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + # Create dummy image instead of loading original + img = self._create_dummy_image(h, w) self._visualizer.add_datasample( - window_name, + img_name, img, - data_sample=data_sample, + data_sample=output, + draw_gt=False, # Don't draw GT + draw_pred=True, # Only draw prediction show=self.show, wait_time=self.wait_time, - step=self._test_index) + step=runner.iter, + with_labels=False) # 移除了 out_file 参数 \ No newline at end of file diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py index a98951283c..e993e600c3 100644 --- a/mmseg/models/__init__.py +++ b/mmseg/models/__init__.py @@ -4,6 +4,7 @@ from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, build_head, build_loss, build_segmentor) from .data_preprocessor import SegDataPreProcessor +from .multimodal_data_preprocessor import MultiModalDataPreProcessor from .decode_heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 @@ -12,5 +13,6 @@ __all__ = [ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', - 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' + 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor', + 'MultiModalDataPreProcessor' ] diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 784d3dfdb7..3838e1a584 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -24,6 +24,7 @@ from .unet import UNet from .vit import VisionTransformer from .vpd import VPD +from .multimodal_swin_moe_backbone import MultiModalSwinMoE __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', @@ -31,5 +32,5 @@ 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN', - 'DDRNet', 'VPD' + 'DDRNet', 'VPD', 'MultiModalSwinMoE' ] diff --git a/mmseg/models/backbones/multimodal_swin_moe_backbone.py b/mmseg/models/backbones/multimodal_swin_moe_backbone.py new file mode 100644 index 0000000000..51b738c2bb --- /dev/null +++ b/mmseg/models/backbones/multimodal_swin_moe_backbone.py @@ -0,0 +1,1092 @@ +""" +Multi-Modal Swin Transformer with MoE - MMSeg 1.x Version + +Migrated from mmseg 0.x to 1.x: +- Registry: BACKBONES -> MODELS +- BaseModule: mmcv.runner -> mmengine.model +- load_checkpoint: mmcv.runner -> mmengine.runner +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from collections import OrderedDict +import warnings + +from mmengine.model import BaseModule +from mmengine.runner import load_checkpoint +from mmcv.cnn import build_norm_layer +from mmengine.utils import to_2tuple +from timm.models.layers import DropPath + +from mmseg.registry import MODELS + + +# ==================== Modal-Specific Patch Embedding ==================== +class ModalSpecificPatchEmbed(nn.Module): + """模态专用Patch Embedding - 无零填充版本""" + + def __init__(self, + modal_configs, + training_modals, + embed_dims=96, + patch_size=4, + norm_cfg=dict(type='LN')): + super().__init__() + self.modal_configs = modal_configs + self.training_modals = set(training_modals) if training_modals else set() + self.embed_dims = embed_dims + self.patch_size = patch_size + + self.modal_patch_embeds = nn.ModuleDict() + for modal_name in self.training_modals: + if modal_name in modal_configs: + in_ch = modal_configs[modal_name]['channels'] + self.modal_patch_embeds[modal_name] = nn.Sequential( + nn.Conv2d(in_ch, embed_dims, + kernel_size=patch_size, + stride=patch_size), + nn.LayerNorm(embed_dims, eps=1e-6) + ) + + self.register_buffer('forward_count', torch.zeros(1, dtype=torch.long)) + self._print_init_info() + + def _print_init_info(self): + print(f"\n{'=' * 80}") + print(f"Modal-Specific Patch Embedding (NO ZERO PADDING)") + print(f"{'=' * 80}") + print(f"Embed dims: {self.embed_dims}") + print(f"Patch size: {self.patch_size}") + print(f"Training modals: {sorted(self.training_modals)}") + + total_params = 0 + for modal_name in sorted(self.modal_patch_embeds.keys()): + in_ch = self.modal_configs[modal_name]['channels'] + params = in_ch * self.embed_dims * self.patch_size * self.patch_size + total_params += params + print(f" {modal_name:>10s}: {in_ch:>2d}ch -> {self.embed_dims:>3d}d " + f"({params:>6,d} params)") + + print(f"\nTotal parameters: {total_params:,}") + print(f"Strategy: Direct processing (NO zero-padding!)") + print(f"{'=' * 80}\n") + + def forward(self, imgs, modal_types): + """ + Args: + imgs: List[Tensor] - each [C_i, H, W] + modal_types: List[str] + Returns: + x: [B, H_p*W_p, embed_dims] + hw_shape: (H_p, W_p) + """ + self.forward_count += 1 + + if self.forward_count == 1 and self.training: + self._print_first_forward(imgs, modal_types) + + outputs = [] + hw_shape = None + + for img, modal in zip(imgs, modal_types): + x_i = img.unsqueeze(0) + + if modal in self.modal_patch_embeds: + conv = self.modal_patch_embeds[modal][0] + norm = self.modal_patch_embeds[modal][1] + + out = conv(x_i) + _, C, H, W = out.shape + if hw_shape is None: + hw_shape = (H, W) + + out = out.flatten(2).transpose(1, 2) + out = norm(out) + else: + if self.training: + raise ValueError( + f"Modal '{modal}' not in training modals: " + f"{self.training_modals}" + ) + else: + in_ch = x_i.shape[1] + temp_conv = nn.Conv2d( + in_ch, self.embed_dims, + kernel_size=self.patch_size, + stride=self.patch_size).to(x_i.device) + temp_norm = nn.LayerNorm(self.embed_dims).to(x_i.device) + + nn.init.trunc_normal_(temp_conv.weight, std=0.02) + nn.init.constant_(temp_conv.bias, 0) + + out = temp_conv(x_i) + _, C, H, W = out.shape + if hw_shape is None: + hw_shape = (H, W) + out = out.flatten(2).transpose(1, 2) + out = temp_norm(out) + + outputs.append(out) + + x = torch.cat(outputs, dim=0) + return x, hw_shape + + def _print_first_forward(self, imgs, modal_types): + from collections import Counter + print(f"\n{'=' * 80}") + print(f"Modal-Specific Patch Embedding - First Forward (NO PADDING)") + print(f"{'=' * 80}") + print(f"Batch size: {len(imgs)}") + + modal_counts = Counter(modal_types) + print(f"\nBatch composition:") + for modal, count in sorted(modal_counts.items()): + idx = modal_types.index(modal) + ch = imgs[idx].shape[0] + print(f" {modal}: {count} samples x {ch}ch (direct, no padding!)") + print(f"{'=' * 80}\n") + + +class UnifiedPatchEmbed(nn.Module): + """统一Patch Embedding - 所有模态共享一个卷积核(消融baseline)""" + + def __init__(self, + modal_configs, + training_modals, + embed_dims=96, + patch_size=4, + norm_cfg=dict(type='LN')): + super().__init__() + self.modal_configs = modal_configs + self.training_modals = set(training_modals) if training_modals else set() + self.embed_dims = embed_dims + self.patch_size = patch_size + + self.max_channels = max( + modal_configs[m]['channels'] for m in training_modals + ) if training_modals and modal_configs else 3 + + self.modal_channels = {} + for modal_name in (training_modals or []): + if modal_name in modal_configs: + self.modal_channels[modal_name] = modal_configs[modal_name]['channels'] + + self.unified_patch_embed = nn.Sequential( + nn.Conv2d(self.max_channels, embed_dims, + kernel_size=patch_size, stride=patch_size), + nn.LayerNorm(embed_dims, eps=1e-6) + ) + + self.register_buffer('forward_count', torch.zeros(1, dtype=torch.long)) + + def forward(self, imgs, modal_types): + self.forward_count += 1 + outputs = [] + hw_shape = None + + for img, modal in zip(imgs, modal_types): + x_i = img.unsqueeze(0) + actual_ch = x_i.shape[1] + + if actual_ch < self.max_channels: + pad_size = self.max_channels - actual_ch + padding = torch.zeros( + 1, pad_size, x_i.shape[2], x_i.shape[3], + device=x_i.device, dtype=x_i.dtype + ) + x_i = torch.cat([x_i, padding], dim=1) + + conv = self.unified_patch_embed[0] + norm = self.unified_patch_embed[1] + + out = conv(x_i) + _, C, H, W = out.shape + if hw_shape is None: + hw_shape = (H, W) + + out = out.flatten(2).transpose(1, 2) + out = norm(out) + outputs.append(out) + + x = torch.cat(outputs, dim=0) + return x, hw_shape + + +# ==================== MoE Components ==================== +class CosineTopKGate(nn.Module): + """Cosine similarity Gating with modal bias""" + + def __init__(self, model_dim, num_experts, + modal_configs=None, training_modals=None, init_t=0.5, + use_modal_bias=True): + super().__init__() + proj_dim = min(model_dim // 2, 256) + + self.temperature = nn.Parameter( + torch.log(torch.full([1], 1.0 / init_t)), + requires_grad=True + ) + self.cosine_projector = nn.Linear(model_dim, proj_dim) + self.sim_matrix = nn.Parameter( + torch.randn(size=(proj_dim, num_experts)), + requires_grad=True + ) + self.clamp_max = torch.log(torch.tensor(1. / 0.01)).item() + + self.modal_configs = modal_configs + self.training_modals = set(training_modals) if training_modals else set() + self.use_modal_bias = use_modal_bias + + if modal_configs is not None and use_modal_bias: + self.modal_bias = nn.Parameter( + torch.zeros(len(modal_configs), num_experts), + requires_grad=True + ) + self.modal_name_to_idx = { + name: i for i, name in enumerate(modal_configs.keys()) + } + + nn.init.normal_(self.sim_matrix, 0, 0.01) + + def forward(self, x, modal_types=None): + if len(x.shape) == 4: + B, H, W, C = x.shape + x = x.reshape(B, -1, C).mean(dim=1) + elif len(x.shape) == 3: + x = x.mean(dim=1) + + # eps guards against NaN from 0/0 when a row of the projection + # (or a column of sim_matrix) has zero L2 norm. This matters during + # fine-tuning on a new modal whose patch embed is freshly + # initialized and can emit all-zero pooled features. + logits = torch.matmul( + F.normalize(self.cosine_projector(x), dim=1, eps=1e-6), + F.normalize(self.sim_matrix, dim=0, eps=1e-6) + ) + + logit_scale = torch.clamp(self.temperature, max=self.clamp_max).exp() + logits = logits * logit_scale + + if (self.use_modal_bias and modal_types is not None + and self.modal_configs is not None): + modal_bias = torch.zeros_like(logits) + for i, modal in enumerate(modal_types): + if (modal in self.modal_name_to_idx + and modal in self.training_modals): + modal_idx = self.modal_name_to_idx[modal] + modal_bias[i] = self.modal_bias[modal_idx] + logits = logits + modal_bias + + return logits + + +class SparseDispatcher: + """Sparse dispatcher for MoE""" + + def __init__(self, num_experts, gates): + # Sanitize gates: NaN/Inf can leak in from F.normalize on zero-norm + # vectors inside CosineTopKGate when a freshly-initialized modal + # patch embed produces degenerate features early in fine-tuning. + # Replace them with 0 so they are simply not routed anywhere. + if not torch.isfinite(gates).all(): + gates = torch.nan_to_num(gates, nan=0.0, posinf=0.0, neginf=0.0) + + self._gates = gates + self._num_experts = num_experts + + # IMPORTANT: use the SAME positive-mask for both `_batch_index` + # and `_part_sizes`. Previously line 287 used `torch.nonzero(gates)` + # (which includes NaN because `NaN != 0`) while line 290 used + # `gates > 0` (which excludes NaN), making `split_sizes` disagree + # with `_batch_index.numel()` and crashing `torch.split`. + positive_mask = gates > 0 + sorted_experts, index_sorted_experts = torch.nonzero(positive_mask).sort(0) + _, self._expert_index = sorted_experts.split(1, dim=1) + self._batch_index = sorted_experts[index_sorted_experts[:, 1], 0] + self._part_sizes = list(positive_mask.sum(0).cpu().numpy()) + + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + def combine(self, expert_out, multiply_by_gates=True): + stitched = torch.cat(expert_out, 0) + + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates) + + zeros = torch.zeros( + self._gates.size(0), expert_out[-1].size(1), + requires_grad=True, device=stitched.device + ) + combined = zeros.index_add(0, self._batch_index, stitched.float()) + return combined + + +class SwinFFN(nn.Module): + """Swin-style FFN (single expert)""" + + def __init__(self, in_features, hidden_features, + act_layer=nn.GELU, drop=0.): + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, in_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwinMoELayer(nn.Module): + """MoE Layer for Swin Transformer""" + + def __init__(self, + in_features, + hidden_features, + num_experts=8, + num_shared_experts=0, + top_k=2, + noisy_gating=True, + modal_configs=None, + training_modals=None, + use_modal_bias=True, + act_layer=nn.GELU, + drop=0., + return_expert_outputs=False): + super().__init__() + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.top_k = top_k + self.noisy_gating = noisy_gating + self.return_expert_outputs = return_expert_outputs + + self.experts = nn.ModuleList([ + SwinFFN(in_features, hidden_features, act_layer, drop) + for _ in range(num_experts) + ]) + + self.gating = CosineTopKGate( + in_features, num_experts, + modal_configs, training_modals, + use_modal_bias=use_modal_bias + ) + + if noisy_gating: + self.w_noise = nn.Parameter( + torch.zeros(in_features, num_experts), + requires_grad=True + ) + + if num_shared_experts > 0: + shared_hidden = hidden_features * num_shared_experts + self.shared_experts = SwinFFN( + in_features, shared_hidden, act_layer, drop + ) + else: + self.shared_experts = None + + self.softplus = nn.Softplus() + self.softmax = nn.Softmax(-1) + + def cv_squared(self, x): + eps = 1e-10 + if x.shape[0] == 1: + return torch.tensor([0.0], device=x.device) + return x.float().var() / (x.float().mean() ** 2 + eps) + + def noisy_top_k_gating(self, x, modal_types=None, noise_epsilon=1e-2): + x_pooled = x.mean(dim=1) + clean_logits = self.gating(x_pooled, modal_types) + + if self.noisy_gating and self.training: + raw_noise_stddev = x_pooled @ self.w_noise + noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon + noisy_logits = clean_logits + ( + torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + + top_logits, top_indices = logits.topk( + min(self.top_k, self.num_experts), dim=-1 + ) + top_k_gates = self.softmax(top_logits) + + zeros = torch.zeros_like(logits, requires_grad=True) + gates = zeros.scatter(-1, top_indices, top_k_gates) + + load = (gates > 0).sum(0) + return gates, load + + def forward(self, x, modal_types=None, loss_coef=1e-2): + identity = x + B, N, C = x.shape + + gates, load = self.noisy_top_k_gating(x, modal_types) + importance = gates.sum(0) + + balance_loss = (self.cv_squared(importance) + + self.cv_squared(load.float())) + balance_loss = balance_loss * loss_coef + + dispatcher = SparseDispatcher(self.num_experts, gates) + + x_for_dispatch = x.reshape(B, -1) + expert_inputs = dispatcher.dispatch(x_for_dispatch) + + expert_outputs = [] + expert_features_list = [] + + for i, expert_input in enumerate(expert_inputs): + if expert_input.size(0) > 0: + n_samples = expert_input.size(0) + expert_input_reshaped = expert_input.reshape(n_samples, N, C) + + expert_out = self.experts[i](expert_input_reshaped) + + if self.return_expert_outputs and self.training: + expert_feat_pooled = expert_out.mean(dim=1) + expert_feat_4d = expert_feat_pooled.unsqueeze(-1).unsqueeze(-1) + expert_features_list.append(expert_feat_4d) + + expert_out = expert_out.reshape(n_samples, -1) + expert_outputs.append(expert_out) + else: + expert_outputs.append( + torch.empty(0, N * C, device=x.device)) + + y_routed = dispatcher.combine(expert_outputs) + y_routed = y_routed.reshape(B, N, C) + + if self.shared_experts is not None: + y_shared = self.shared_experts(identity) + else: + y_shared = 0 + + y = y_routed + y_shared + + if (self.return_expert_outputs and self.training + and len(expert_features_list) > 0): + if self.shared_experts is not None: + y_shared_pooled = y_shared.mean(dim=1) + y_shared_4d = y_shared_pooled.unsqueeze(-1).unsqueeze(-1) + expert_features_list.append(y_shared_4d) + + return y, balance_loss, expert_features_list + else: + return y, balance_loss, None + + +# ==================== Swin Transformer Components ==================== +class WindowMSA(nn.Module): + """Window-based Multi-head Self Attention""" + + def __init__(self, dim, window_size, num_heads, + qkv_bias=True, qk_scale=None, + attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads) + ) + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = (coords_flatten[:, :, None] + - coords_flatten[:, None, :]) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer( + "relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + B_, N, C = x.shape + qkv = self.qkv(x).reshape( + B_, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + if B_ % nW == 0: + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + attn = attn + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, + W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, + window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SwinBlockWithMoE(nn.Module): + """Swin Transformer Block with optional MoE""" + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_moe=False, + num_experts=8, + num_shared_experts=0, + top_k=2, + noisy_gating=True, + modal_configs=None, + training_modals=None, + use_modal_bias=True, + return_expert_outputs=False): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_moe = use_moe + + assert 0 <= self.shift_size < self.window_size, \ + "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowMSA( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop + ) + + self.drop_path = (DropPath(drop_path) + if drop_path > 0. else nn.Identity()) + self.norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + + if use_moe: + self.mlp = SwinMoELayer( + in_features=dim, + hidden_features=mlp_hidden_dim, + num_experts=num_experts, + num_shared_experts=num_shared_experts, + top_k=top_k, + noisy_gating=noisy_gating, + modal_configs=modal_configs, + training_modals=training_modals, + use_modal_bias=use_modal_bias, + act_layer=act_layer, + drop=drop, + return_expert_outputs=return_expert_outputs + ) + else: + self.mlp = SwinFFN( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop + ) + + self.register_buffer("attn_mask", None) + + def forward(self, x, hw_shape, modal_types=None): + H, W = hw_shape + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + if self.attn_mask is None: + self.attn_mask = self._calculate_mask(Hp, Wp).to(x.device) + attn_mask = self.attn_mask + else: + shifted_x = x + attn_mask = None + + x_windows = window_partition(shifted_x, self.window_size) + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C) + + attn_windows = self.attn(x_windows, mask=attn_mask) + + attn_windows = attn_windows.view( + -1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) + + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(x) + + moe_loss = None + expert_features = None + + if self.use_moe: + mlp_out = self.mlp(self.norm2(x), modal_types) + if len(mlp_out) == 3: + mlp_x, moe_loss, expert_features = mlp_out + else: + mlp_x, moe_loss = mlp_out + x = x + self.drop_path(mlp_x) + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, moe_loss, expert_features + + def _calculate_mask(self, H, W): + img_mask = torch.zeros((1, H, W, 1)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)) + attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +class PatchMerging(nn.Module): + """Patch Merging Layer""" + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, hw_shape): + H, W = hw_shape + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, \ + f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], -1) + x = x.view(B, -1, 4 * C) + + x = self.norm(x) + x = self.reduction(x) + + return x, (H // 2, W // 2) + + +# ==================== Main Backbone ==================== +@MODELS.register_module() +class MultiModalSwinMoE(BaseModule): + """Multi-Modal Swin Transformer with MoE - MMSeg 1.x""" + + def __init__(self, + modal_configs=None, + training_modals=None, + pretrain_img_size=224, + patch_size=4, + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=[0, 1, 2, 3], + use_moe=True, + use_modal_bias=True, + num_experts=8, + num_shared_experts_config=None, + top_k=2, + noisy_gating=True, + MoE_Block_inds=None, + use_expert_diversity_loss=False, + use_modal_specific_stem=True, + frozen_stages=-1, + freeze_patch_embed=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + self.frozen_stages = frozen_stages + self.freeze_patch_embed = freeze_patch_embed + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dims = embed_dims + self.patch_norm = patch_norm + self.out_indices = out_indices + self.depths = depths + self.num_heads = num_heads + self.use_moe = use_moe + self.use_expert_diversity_loss = use_expert_diversity_loss + + self.modal_configs = modal_configs + if modal_configs is not None: + if training_modals is None: + self.training_modals = list(modal_configs.keys()) + else: + self.training_modals = training_modals + else: + self.training_modals = [] + + if MoE_Block_inds is None: + MoE_Block_inds = [] + for depth in depths: + start_idx = depth // 2 + MoE_Block_inds.append(list(range(start_idx, depth))) + self.MoE_Block_inds = MoE_Block_inds + + if num_shared_experts_config is None: + num_shared_experts_config = {0: 0, 1: 0, 2: 2, 3: 1} + self.num_shared_experts_config = num_shared_experts_config + + self.use_modal_specific_stem = use_modal_specific_stem + if use_modal_specific_stem: + self.patch_embed = ModalSpecificPatchEmbed( + modal_configs=modal_configs if modal_configs else {}, + training_modals=self.training_modals, + embed_dims=embed_dims, + patch_size=patch_size, + norm_cfg=dict(type='LN') if patch_norm else None + ) + else: + self.patch_embed = UnifiedPatchEmbed( + modal_configs=modal_configs if modal_configs else {}, + training_modals=self.training_modals, + embed_dims=embed_dims, + patch_size=patch_size, + norm_cfg=dict(type='LN') if patch_norm else None + ) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace( + 0, drop_path_rate, sum(depths))] + + self.stages = nn.ModuleList() + num_features = [int(embed_dims * 2 ** i) + for i in range(self.num_layers)] + + for i_stage in range(self.num_layers): + stage_blocks = [] + dim = num_features[i_stage] + num_shared = num_shared_experts_config.get(i_stage, 0) + + for i_block in range(depths[i_stage]): + use_moe_block = (use_moe + and (i_block in MoE_Block_inds[i_stage])) + + block = SwinBlockWithMoE( + dim=dim, + num_heads=num_heads[i_stage], + window_size=window_size, + shift_size=(0 if (i_block % 2 == 0) + else window_size // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_stage]) + i_block], + norm_layer=norm_layer, + use_moe=use_moe_block, + num_experts=num_experts, + num_shared_experts=(num_shared + if use_moe_block else 0), + top_k=top_k, + noisy_gating=noisy_gating, + modal_configs=(modal_configs + if use_moe_block else None), + training_modals=(self.training_modals + if use_moe_block else None), + use_modal_bias=use_modal_bias, + return_expert_outputs=use_expert_diversity_loss + ) + stage_blocks.append(block) + + if i_stage < self.num_layers - 1: + downsample = PatchMerging(dim=dim, norm_layer=norm_layer) + else: + downsample = None + + self.stages.append(nn.ModuleDict({ + 'blocks': nn.ModuleList(stage_blocks), + 'downsample': downsample + })) + + for i in out_indices: + layer = norm_layer(num_features[i]) + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + self.num_features = num_features + self.pretrained = pretrained + self._init_weights() + self._freeze_stages() + self._print_architecture_info() + + def _freeze_stages(self): + """Freeze stages to stop gradient and set eval mode. + + frozen_stages (int): Stages to freeze. -1 means no freezing. + 0 freezes stage 0, 1 freezes stages 0-1, etc. + freeze_patch_embed (bool): If True, also freeze patch_embed (stem). + Default False (stem remains trainable for domain adaptation). + """ + if self.freeze_patch_embed: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + if i >= self.num_layers: + break + # Freeze stage blocks + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + # Freeze corresponding output norm layer + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 0: + frozen_params = sum( + 1 for p in self.parameters() if not p.requires_grad) + total_params = sum(1 for _ in self.parameters()) + print(f'\n[Freeze] frozen_stages={self.frozen_stages}, ' + f'freeze_patch_embed={self.freeze_patch_embed}') + print(f'[Freeze] {frozen_params}/{total_params} params frozen\n') + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def init_weights(self): + if self.pretrained is not None: + print(f'Loading pretrained Swin weights from {self.pretrained}') + load_checkpoint( + self, self.pretrained, + map_location='cpu', + strict=False + ) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + keys_to_remove = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in keys_to_remove: + state_dict.pop(k) + return state_dict + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + state_dict.pop(k) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs + ) + + def _print_architecture_info(self): + print(f"\n{'=' * 80}") + print(f"Multi-Modal Swin Transformer with MoE") + print(f"{'=' * 80}") + print(f"Training modals: {self.training_modals}") + print(f"Embed dims: {self.embed_dims}") + print(f"Depths: {self.depths}") + print(f"Use MoE: {self.use_moe}") + + print(f"\nShared Experts:") + for i in range(self.num_layers): + num_moe = len(self.MoE_Block_inds[i]) + num_shared = self.num_shared_experts_config.get(i, 0) + print(f" Stage {i}: {self.depths[i]} blocks, " + f"{num_moe} MoE blocks, {num_shared} shared experts") + + print(f"\nOutput indices: {self.out_indices}") + print(f"Output features: " + f"{[self.num_features[i] for i in self.out_indices]}") + print(f"{'=' * 80}\n") + + def train(self, mode=True): + """Override train to keep frozen stages in eval mode.""" + super().train(mode) + self._freeze_stages() + + def forward(self, imgs, modal_types=None, **kwargs): + """ + Args: + imgs: List[Tensor] - each [C_i, H, W] + modal_types: List[str] + Returns: + tuple of features, moe_balance_loss, expert_features + """ + B = len(imgs) + + if modal_types is None: + modal_types = ['rgb'] * B + + x, hw_shape = self.patch_embed(imgs, modal_types) + x = self.pos_drop(x) + + outs = [] + moe_balance_losses = [] + all_expert_features = [] + + for i_stage, stage_dict in enumerate(self.stages): + blocks = stage_dict['blocks'] + downsample = stage_dict['downsample'] + + for block in blocks: + x, moe_loss, expert_features = block( + x, hw_shape, modal_types) + + if moe_loss is not None: + moe_balance_losses.append(moe_loss) + + if (expert_features is not None + and self.use_expert_diversity_loss + and i_stage == self.num_layers - 1): + all_expert_features.extend(expert_features) + + if i_stage in self.out_indices: + norm_layer = getattr(self, f'norm{i_stage}') + x_out = norm_layer(x) + + H, W = hw_shape + x_out = x_out.view(B, H, W, -1).permute( + 0, 3, 1, 2).contiguous() + outs.append(x_out) + + if downsample is not None: + x, hw_shape = downsample(x, hw_shape) + + avg_balance_loss = None + if len(moe_balance_losses) > 0: + avg_balance_loss = (sum(moe_balance_losses) + / len(moe_balance_losses)) + + return tuple(outs), avg_balance_loss, all_expert_features diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py index 0467cb3ad8..564c95039c 100644 --- a/mmseg/models/losses/__init__.py +++ b/mmseg/models/losses/__init__.py @@ -10,6 +10,7 @@ from .ohem_cross_entropy_loss import OhemCrossEntropy from .silog_loss import SiLogLoss from .tversky_loss import TverskyLoss +from .moe_losses import CombinedMoELoss, ExpertDiversityLoss, MoEBalanceLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ @@ -17,5 +18,6 @@ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', - 'HuasdorffDisstanceLoss', 'SiLogLoss' + 'HuasdorffDisstanceLoss', 'SiLogLoss', + 'MoEBalanceLoss', 'ExpertDiversityLoss', 'CombinedMoELoss' ] diff --git a/mmseg/models/losses/moe_losses.py b/mmseg/models/losses/moe_losses.py new file mode 100644 index 0000000000..6901e3aec8 --- /dev/null +++ b/mmseg/models/losses/moe_losses.py @@ -0,0 +1,110 @@ +""" +MoE-related loss functions - MMSeg 1.x Version +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MoEBalanceLoss(nn.Module): + """MoE load balancing loss wrapper""" + + def __init__(self, loss_weight=1.0): + super().__init__() + self.loss_weight = loss_weight + + def forward(self, balance_loss): + if balance_loss is None: + return torch.tensor(0.0) + return self.loss_weight * balance_loss + + +class ExpertDiversityLoss(nn.Module): + """Expert diversity loss - encourages different expert representations""" + + def __init__(self, + loss_weight=0.01, + similarity_type='cosine', + normalize=True): + super().__init__() + self.loss_weight = loss_weight + self.similarity_type = similarity_type + self.normalize = normalize + + def forward(self, expert_features_list): + if expert_features_list is None or len(expert_features_list) < 2: + return torch.tensor( + 0.0, + device=(expert_features_list[0].device + if expert_features_list else None)) + + expert_vectors = [] + for expert_feat in expert_features_list: + pooled = F.adaptive_avg_pool2d( + expert_feat, 1).squeeze(-1).squeeze(-1) + + if self.normalize: + pooled = F.normalize(pooled, dim=1) + + n_samples = pooled.shape[0] + expert_vector = pooled.sum(dim=0) / max(n_samples, 1) + expert_vectors.append(expert_vector) + + num_experts = len(expert_vectors) + total_sim = 0.0 + count = 0 + + for i in range(num_experts): + for j in range(i + 1, num_experts): + if self.similarity_type == 'cosine': + sim = F.cosine_similarity( + expert_vectors[i].unsqueeze(0), + expert_vectors[j].unsqueeze(0), + dim=1 + ) + elif self.similarity_type == 'l2': + dist = torch.norm( + expert_vectors[i] - expert_vectors[j], p=2) + sim = torch.clamp(1.0 - dist / 2.0, min=0.0, max=1.0) + else: + raise ValueError( + f"Unknown similarity type: {self.similarity_type}") + + total_sim += sim + count += 1 + + if count > 0: + avg_sim = total_sim / count + else: + avg_sim = torch.tensor(0.0, device=expert_vectors[0].device) + + diversity_loss = F.relu(avg_sim) + return self.loss_weight * diversity_loss + + +class CombinedMoELoss(nn.Module): + """Combined MoE loss: balance + diversity""" + + def __init__(self, + balance_weight=1.0, + diversity_weight=0.01, + diversity_similarity='cosine'): + super().__init__() + self.balance_loss = MoEBalanceLoss(loss_weight=balance_weight) + self.diversity_loss = ExpertDiversityLoss( + loss_weight=diversity_weight, + similarity_type=diversity_similarity + ) + + def forward(self, balance_loss, expert_features_list): + loss_balance = self.balance_loss(balance_loss) + loss_diversity = self.diversity_loss(expert_features_list) + + total_loss = loss_balance + loss_diversity + + loss_dict = { + 'loss_moe_balance': loss_balance, + 'loss_moe_diversity': loss_diversity + } + + return total_loss, loss_dict diff --git a/mmseg/models/multimodal_data_preprocessor.py b/mmseg/models/multimodal_data_preprocessor.py new file mode 100644 index 0000000000..798efda31c --- /dev/null +++ b/mmseg/models/multimodal_data_preprocessor.py @@ -0,0 +1,83 @@ +""" +Multi-Modal Data PreProcessor - handles images with different channel counts. + +Unlike SegDataPreProcessor which stacks all images into a [B,C,H,W] tensor, +this preprocessor keeps them as a list of [C_i,H,W] tensors since different +modalities have different channel counts (e.g., SAR:8ch, RGB:3ch, GF:5ch). +""" +from numbers import Number +from typing import Any, Dict, List, Optional, Sequence + +import torch +import torch.nn.functional as F +from mmengine.model import BaseDataPreprocessor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class MultiModalDataPreProcessor(BaseDataPreprocessor): + """Data pre-processor for multi-modal segmentation. + + Unlike SegDataPreProcessor, this does NOT stack inputs into a single + tensor because different modalities may have different channel counts. + Instead, it returns inputs as a list of tensors. + + Args: + size (tuple, optional): Fixed padding size (H, W). + pad_val (float): Padding value for images. Default: 0. + seg_pad_val (float): Padding value for segmentation maps. Default: 255. + """ + + def __init__( + self, + size: Optional[tuple] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + ): + super().__init__() + self.size = size + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: + data = self.cast_data(data) + inputs = data['inputs'] + data_samples = data.get('data_samples', None) + + inputs = [_input.float() for _input in inputs] + + # Pad spatial dimensions only (no stacking across channels) + padded_inputs = [] + for i, tensor in enumerate(inputs): + if self.size is not None: + width = max(self.size[-1] - tensor.shape[-1], 0) + height = max(self.size[-2] - tensor.shape[-2], 0) + padding_size = (0, width, 0, height) + else: + padding_size = (0, 0, 0, 0) + + pad_img = F.pad(tensor, padding_size, value=self.pad_val) + padded_inputs.append(pad_img) + + if data_samples is not None: + ds = data_samples[i] + if 'gt_sem_seg' in ds: + gt = ds.gt_sem_seg.data + del ds.gt_sem_seg.data + ds.gt_sem_seg.data = F.pad( + gt, padding_size, value=self.seg_pad_val) + if hasattr(ds, 'gt_boundary_map') and 'gt_boundary_map' in ds: + gt_b = ds.gt_boundary_map.data + del ds.gt_boundary_map.data + ds.gt_boundary_map.data = F.pad( + gt_b, padding_size, value=self.seg_pad_val) + + ds.set_metainfo({ + 'img_shape': tensor.shape[-2:], + 'pad_shape': pad_img.shape[-2:], + 'padding_size': padding_size + }) + + # Return list (NOT stacked tensor) — the model handles the list + return dict(inputs=padded_inputs, data_samples=data_samples) diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py index 59b012f417..bb253edf27 100644 --- a/mmseg/models/segmentors/__init__.py +++ b/mmseg/models/segmentors/__init__.py @@ -5,8 +5,9 @@ from .encoder_decoder import EncoderDecoder from .multimodal_encoder_decoder import MultimodalEncoderDecoder from .seg_tta import SegTTAModel +from .multimodal_encoder_decoder_v2 import MultiModalEncoderDecoderV2 __all__ = [ 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', - 'MultimodalEncoderDecoder', 'DepthEstimator' + 'MultimodalEncoderDecoder', 'DepthEstimator', 'MultiModalEncoderDecoderV2' ] diff --git a/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py b/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py new file mode 100644 index 0000000000..d6a67176ef --- /dev/null +++ b/mmseg/models/segmentors/multimodal_encoder_decoder_v2.py @@ -0,0 +1,588 @@ +""" +Multi-Modal EncoderDecoder - MMSeg 1.x Version + +Key changes from 0.x: +- Registry: SEGMENTORS -> MODELS +- Base class: 0.x EncoderDecoder -> 1.x EncoderDecoder +- forward_train(img, img_metas, gt) -> loss(inputs, data_samples) +- simple_test/forward_test -> predict(inputs, data_samples) +- img_metas dict -> SegDataSample.metainfo +- gt_semantic_seg -> data_samples[i].gt_sem_seg.data +- train_step removed (handled by mmengine Runner) +- _parse_losses removed (handled by mmengine) +- DataContainer removed +""" +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .encoder_decoder import EncoderDecoder +from ..losses.moe_losses import CombinedMoELoss + + +@MODELS.register_module() +class MultiModalEncoderDecoderV2(EncoderDecoder): + """Multi-modal encoder-decoder segmentor with MoE support. + + Supports shared/separate decode heads per modality. + + Args: + use_moe: Enable MoE in backbone. + use_modal_bias: Enable modal bias in MoE gating. + moe_balance_weight: Weight for MoE balance loss. + moe_diversity_weight: Weight for MoE diversity loss. + multi_tasks_reweight: Reweighting strategy + ('equal', 'uncertainty', None). + dataset_names: List of dataset/modality names. + decoder_mode: 'shared' or 'separate'. + """ + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None, + # Custom args + use_moe: bool = False, + use_modal_bias: bool = True, + moe_balance_weight: float = 1.0, + moe_diversity_weight: float = 0.01, + moe_diversity_similarity: str = 'cosine', + multi_tasks_reweight: Optional[str] = None, + mtl_sigma_init: float = 1.0, + dataset_names: list = None, + decoder_mode: str = 'separate'): + self.dataset_names = dataset_names or ['sar', 'rgb', 'GF'] + self.decoder_mode = decoder_mode + self._use_moe = use_moe + self._use_modal_bias = use_modal_bias + self._multi_tasks_reweight = multi_tasks_reweight + self._mtl_sigma_init = mtl_sigma_init + + if decoder_mode not in ['shared', 'separate']: + raise ValueError( + f"decoder_mode must be 'shared' or 'separate', " + f"got '{decoder_mode}'") + + # Call grandparent __init__ to skip EncoderDecoder's + # _init_decode_head / _init_auxiliary_head + # We need custom initialization for multi-head support + from .base import BaseSegmentor + BaseSegmentor.__init__( + self, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + if pretrained is not None: + assert backbone.get('pretrained') is None + backbone.pretrained = pretrained + + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + # MoE loss + if use_moe: + self.moe_loss_fn = CombinedMoELoss( + balance_weight=moe_balance_weight, + diversity_weight=moe_diversity_weight, + diversity_similarity=moe_diversity_similarity + ) + + # Multi-task reweighting + self.multi_tasks_reweight = multi_tasks_reweight + self.mtl_sigma_eps = 1e-6 + if multi_tasks_reweight == 'uncertainty': + self.mtl_sigma = nn.Parameter( + torch.full((len(self.dataset_names),), + float(mtl_sigma_init))) + + def _init_decode_head(self, decode_head): + if decode_head is None: + return + + if self.decoder_mode == 'shared': + if isinstance(decode_head, dict) and 'type' not in decode_head: + decode_head = decode_head[self.dataset_names[0]] + self._shared_decode_head = MODELS.build(decode_head) + self.align_corners = self._shared_decode_head.align_corners + self.num_classes = self._shared_decode_head.num_classes + self.out_channels = self._shared_decode_head.out_channels + else: + self.decode_heads = nn.ModuleDict() + if isinstance(decode_head, dict) and 'type' not in decode_head: + for name in self.dataset_names: + self.decode_heads[name] = MODELS.build( + decode_head[name]) + else: + for name in self.dataset_names: + self.decode_heads[name] = MODELS.build(decode_head) + + first_head = self.decode_heads[self.dataset_names[0]] + self.align_corners = first_head.align_corners + self.num_classes = first_head.num_classes + self.out_channels = first_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head): + if auxiliary_head is None: + return + + if self.decoder_mode == 'shared': + if isinstance(auxiliary_head, dict) and 'type' not in auxiliary_head: + auxiliary_head = auxiliary_head[self.dataset_names[0]] + self._shared_auxiliary_head = MODELS.build(auxiliary_head) + else: + self.auxiliary_heads = nn.ModuleDict() + if isinstance(auxiliary_head, dict) and 'type' not in auxiliary_head: + for name in self.dataset_names: + self.auxiliary_heads[name] = MODELS.build( + auxiliary_head[name]) + else: + for name in self.dataset_names: + self.auxiliary_heads[name] = MODELS.build( + auxiliary_head) + + @property + def decode_head(self): + if self.decoder_mode == 'shared': + return getattr(self, '_shared_decode_head', None) + else: + if hasattr(self, 'decode_heads') and len(self.decode_heads) > 0: + return self.decode_heads[self.dataset_names[0]] + return None + + @property + def auxiliary_head(self): + if self.decoder_mode == 'shared': + return getattr(self, '_shared_auxiliary_head', None) + else: + if (hasattr(self, 'auxiliary_heads') + and len(self.auxiliary_heads) > 0): + return self.auxiliary_heads[self.dataset_names[0]] + return None + + @property + def with_decode_head(self): + if self.decoder_mode == 'shared': + return hasattr(self, '_shared_decode_head') + else: + return (hasattr(self, 'decode_heads') + and len(self.decode_heads) > 0) + + @property + def with_auxiliary_head(self): + if self.decoder_mode == 'shared': + return hasattr(self, '_shared_auxiliary_head') + else: + return (hasattr(self, 'auxiliary_heads') + and len(self.auxiliary_heads) > 0) + + def loss(self, inputs, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs: Input images (Tensor or List[Tensor]). + data_samples: list[SegDataSample] with metainfo and gt_sem_seg. + + Returns: + dict[str, Tensor]: Loss components. + """ + # Extract modal types from data_samples + modal_types = [ + ds.metainfo.get('modal_type', 'rgb') for ds in data_samples + ] + + # Extract dataset names + ds_names = [ + ds.metainfo.get('dataset_name', + ds.metainfo.get('modal_type', 'rgb')) + for ds in data_samples + ] + + # For multi-modal backbone: pass List[Tensor] + modal_types + if isinstance(inputs, torch.Tensor): + # Standard tensor - pass directly + # But backbone expects List[Tensor] for multi-modal + imgs_list = list(inputs) + else: + imgs_list = inputs + + x, moe_balance_loss, expert_features = self.extract_feat( + imgs_list, modal_types=modal_types) + + losses = dict() + + # MoE losses + if self._use_moe and moe_balance_loss is not None: + if expert_features is None: + losses['loss_moe'] = moe_balance_loss + else: + total_moe_loss, moe_loss_dict = self.moe_loss_fn( + moe_balance_loss, expert_features) + losses.update(moe_loss_dict) + + modal_losses = {} + + if self.decoder_mode == 'shared': + self._loss_shared_mode( + x, data_samples, ds_names, losses, modal_losses) + else: + self._loss_separate_mode( + x, data_samples, ds_names, losses, modal_losses) + + # Apply multi-task reweighting + if self.multi_tasks_reweight == 'equal' and modal_losses: + losses.update(self._apply_equal_reweight(modal_losses)) + elif self.multi_tasks_reweight == 'uncertainty' and modal_losses: + losses.update( + self._apply_uncertainty_reweight(modal_losses)) + + return losses + + def _loss_shared_mode(self, x, data_samples, ds_names, + losses, modal_losses): + """Compute loss in shared decode head mode.""" + if self.multi_tasks_reweight is not None: + for dataset_name in sorted(set(ds_names)): + mask = [n == dataset_name for n in ds_names] + if not any(mask): + continue + + # Filter data_samples for this modality + modal_ds = [ds for ds, m in zip(data_samples, mask) if m] + # Filter features + modal_x = tuple( + feat[torch.tensor(mask)] for feat in x) + + loss_decode = self._shared_decode_head.loss( + modal_x, modal_ds, self.train_cfg) + losses.update(add_prefix( + loss_decode, f'decode_{dataset_name}')) + + modal_loss = self._sum_loss_dict(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._shared_auxiliary_head.loss( + modal_x, modal_ds, self.train_cfg) + losses.update(add_prefix( + loss_aux, f'aux_{dataset_name}')) + modal_loss = modal_loss + self._sum_loss_dict(loss_aux) + + modal_losses[dataset_name] = modal_loss + else: + loss_decode = self._shared_decode_head.loss( + x, data_samples, self.train_cfg) + losses.update(add_prefix(loss_decode, 'decode')) + + if self.with_auxiliary_head: + loss_aux = self._shared_auxiliary_head.loss( + x, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + def _loss_separate_mode(self, x, data_samples, ds_names, + losses, modal_losses): + """Compute loss in separate decode head mode.""" + for dataset_name in sorted(set(ds_names)): + if dataset_name not in self.decode_heads: + continue + + mask = [n == dataset_name for n in ds_names] + if not any(mask): + continue + + modal_ds = [ds for ds, m in zip(data_samples, mask) if m] + modal_x = tuple( + feat[torch.tensor(mask)] for feat in x) + + decode_head = self.decode_heads[dataset_name] + loss_decode = decode_head.loss( + modal_x, modal_ds, self.train_cfg) + losses.update(add_prefix( + loss_decode, f'decode_{dataset_name}')) + + modal_loss = self._sum_loss_dict(loss_decode) + + if self.with_auxiliary_head: + aux_head = self.auxiliary_heads[dataset_name] + loss_aux = aux_head.loss( + modal_x, modal_ds, self.train_cfg) + losses.update(add_prefix( + loss_aux, f'aux_{dataset_name}')) + modal_loss = modal_loss + self._sum_loss_dict(loss_aux) + + modal_losses[dataset_name] = modal_loss + + def extract_feat(self, inputs, modal_types=None): + """Extract features, supporting multi-modal input. + + Args: + inputs: List[Tensor] or Tensor + modal_types: List[str] or None + + Returns: + x, moe_balance_loss, expert_features + """ + # Check if backbone supports modal_types + if modal_types is not None and hasattr(self.backbone, 'forward'): + import inspect + sig = inspect.signature(self.backbone.forward) + if 'modal_types' in sig.parameters: + backbone_output = self.backbone( + inputs, modal_types=modal_types) + else: + backbone_output = self.backbone(inputs) + else: + backbone_output = self.backbone(inputs) + + moe_balance_loss = None + expert_features = None + + if isinstance(backbone_output, tuple): + if len(backbone_output) == 3: + x, moe_balance_loss, expert_features = backbone_output + elif len(backbone_output) == 2: + if (isinstance(backbone_output[1], torch.Tensor) + and backbone_output[1].dim() == 0): + x, moe_balance_loss = backbone_output + else: + x = backbone_output + else: + x = backbone_output + else: + x = backbone_output + if not isinstance(x, tuple): + x = (x,) + + if self.with_neck: + x = self.neck(x) + + return x, moe_balance_loss, expert_features + + def encode_decode(self, inputs, batch_img_metas: List[dict]): + """Encode images and decode into segmentation map.""" + modal_types = [ + meta.get('modal_type', 'rgb') for meta in batch_img_metas + ] + + if isinstance(inputs, torch.Tensor): + imgs_list = list(inputs) + else: + imgs_list = inputs + + x, _, _ = self.extract_feat(imgs_list, modal_types=modal_types) + + # Select appropriate decode head + if self.decoder_mode == 'shared': + decode_head = self._shared_decode_head + else: + dataset_name = self._get_dataset_name_from_metas( + batch_img_metas) + if dataset_name in self.decode_heads: + decode_head = self.decode_heads[dataset_name] + else: + decode_head = self.decode_head + + seg_logits = decode_head.predict( + x, batch_img_metas, self.test_cfg) + + return seg_logits + + def slide_inference(self, inputs, batch_img_metas: List[dict]): + """Inference by sliding-window with overlap. + + Overrides base class to handle list inputs from multimodal pipeline. + When inputs have different channel sizes (e.g. mixed modalities), + each sample is processed individually to avoid stack errors. + """ + # Handle list inputs - check if channels are uniform + if isinstance(inputs, (list, tuple)): + channel_sizes = [t.shape[0] for t in inputs] + if len(set(channel_sizes)) > 1: + # Mixed channels: process each sample individually + return self._slide_inference_per_sample( + inputs, batch_img_metas) + else: + inputs = torch.stack(inputs, dim=0) + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + return seg_logits + + def _slide_inference_per_sample(self, inputs, batch_img_metas): + """Slide inference processing each sample individually. + + Used when batch contains mixed modalities with different channels. + """ + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + out_channels = self.out_channels + + seg_logits_list = [] + for i, (img, img_meta) in enumerate(zip(inputs, batch_img_metas)): + # img shape: (C, H, W) -> (1, C, H, W) + img = img.unsqueeze(0) + _, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((1, out_channels, h_img, w_img)) + count_mat = img.new_zeros((1, 1, h_img, w_img)) + meta_list = [img_meta] + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + meta_list[0]['img_shape'] = crop_img.shape[2:] + crop_seg_logit = self.encode_decode( + crop_img, meta_list) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), + int(y1), int(preds.shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits_list.append(preds / count_mat) + + return torch.cat(seg_logits_list, dim=0) + + def predict(self, inputs, data_samples: OptSampleList = None): + """Predict results from inputs and data samples.""" + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, inputs, data_samples: OptSampleList = None): + """Network forward process (tensor mode).""" + modal_types = None + if data_samples is not None: + modal_types = [ + ds.metainfo.get('modal_type', 'rgb') + for ds in data_samples + ] + + if isinstance(inputs, torch.Tensor): + imgs_list = list(inputs) + else: + imgs_list = inputs + + x, _, _ = self.extract_feat(imgs_list, modal_types=modal_types) + return self.decode_head.forward(x) + + def _get_dataset_name_from_metas(self, batch_img_metas): + if batch_img_metas and len(batch_img_metas) > 0: + first_meta = batch_img_metas[0] + if 'dataset_name' in first_meta: + return first_meta['dataset_name'] + elif 'modal_type' in first_meta: + return first_meta['modal_type'] + return self.dataset_names[0] + + @staticmethod + def _sum_loss_dict(loss_dict): + total_loss = None + for loss_name, loss_value in loss_dict.items(): + if 'loss' not in loss_name: + continue + if isinstance(loss_value, torch.Tensor): + current = loss_value + elif isinstance(loss_value, list): + current = sum(loss for loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + total_loss = (current if total_loss is None + else total_loss + current) + if total_loss is None: + total_loss = torch.tensor(0.0) + return total_loss + + def _apply_equal_reweight(self, modal_losses): + if not modal_losses: + return {} + n_modals = len(modal_losses) + total = sum(modal_losses.values()) / n_modals + reweight_losses = { + 'equal_reweighted_total_loss': total, + } + for name in modal_losses: + reweight_losses[f'equal_weight_{name}'] = torch.tensor( + 1.0 / n_modals, device=total.device) + return reweight_losses + + def _apply_uncertainty_reweight(self, modal_losses): + loss_sum = None + reweight_losses = {} + + for idx, dataset_name in enumerate(self.dataset_names): + if dataset_name not in modal_losses: + continue + loss = modal_losses[dataset_name] + sigma_sq = self.mtl_sigma[idx] ** 2 + self.mtl_sigma_eps + weighted = 0.5 / sigma_sq * loss + torch.log1p(sigma_sq) + loss_sum = (weighted if loss_sum is None + else loss_sum + weighted) + reweight_losses[f'mtl_sigma_{dataset_name}'] = ( + sigma_sq.detach()) + + if loss_sum is None: + return {} + + reweight_losses['reweighted_total_losses'] = loss_sum + return reweight_losses diff --git a/scripts/collect_results_to_excel.py b/scripts/collect_results_to_excel.py new file mode 100644 index 0000000000..aae23d666f --- /dev/null +++ b/scripts/collect_results_to_excel.py @@ -0,0 +1,536 @@ +""" +Collect experiment results from test logs and export to Excel. + +Parses metrics_log.txt / test_log.txt from all Table 2/3/4 experiments +and the Full Model, then generates one Excel file per table with +per-modality metrics (aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, mRecall) +plus per-class breakdowns. + +Usage: + python scripts/collect_results_to_excel.py + python scripts/collect_results_to_excel.py --work-root work_dirs/paper_experiments + python scripts/collect_results_to_excel.py --tables table2 table3 table4 + +Output: + work_dirs/paper_experiments/Table2_results.xlsx + work_dirs/paper_experiments/Table3_results.xlsx + work_dirs/paper_experiments/Table4_results.xlsx +""" + +import argparse +import os +import re +import glob +from collections import OrderedDict + +try: + import openpyxl + from openpyxl.styles import Font, Alignment, PatternFill, Border, Side + from openpyxl.utils import get_column_letter + HAS_OPENPYXL = True +except ImportError: + HAS_OPENPYXL = False + +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + + +# ============================================================================ +# Log Parsing +# ============================================================================ + +def parse_summary_metrics(log_text): + """Parse summary metrics from mmengine log output. + + Looks for lines like: + 04/02 02:10:00 - mmengine - INFO - Iter(test) ... aAcc: 82.45 mIoU: 65.32 ... + or the final dict output. + """ + metrics = {} + + # Pattern 1: mmengine log line with metrics + # e.g., "aAcc: 82.45 mIoU: 65.32 mAcc: 72.18 mDice: 75.00 mFscore: 73.50 mPrecision: 71.20 mRecall: 75.80" + metric_pattern = r'(aAcc|mIoU|mAcc|mDice|mFscore|mPrecision|mRecall)\s*:\s*([\d.]+)' + matches = re.findall(metric_pattern, log_text) + for key, val in matches: + # Take the last occurrence (in case of multiple test runs) + metrics[key] = float(val) + + return metrics + + +def parse_per_class_table(log_text): + """Parse per-class PrettyTable from log. + + Looks for: + per class results: + +-------+-------+-------+... + | Class | IoU | Acc |... + +-------+-------+-------+... + | cls1 | 45.23 | 67.89 |... + | cls2 | 52.11 | 71.34 |... + +-------+-------+-------+... + """ + per_class = {} + + # Find the table after "per class results:" + marker = 'per class results:' + idx = log_text.rfind(marker) # Use last occurrence + if idx == -1: + return per_class + + table_text = log_text[idx:] + lines = table_text.split('\n') + + header = None + for line in lines: + line = line.strip() + if not line or line.startswith('+'): + continue + if line.startswith('|'): + cells = [c.strip() for c in line.split('|') if c.strip()] + if header is None: + header = cells + else: + if len(cells) == len(header): + class_name = cells[0] + per_class[class_name] = {} + for i, h in enumerate(header[1:], 1): + try: + per_class[class_name][h] = float(cells[i]) + except ValueError: + per_class[class_name][h] = cells[i] + + return per_class + + +def parse_log_file(log_path): + """Parse a single log file and return metrics dict.""" + if not os.path.exists(log_path): + return None, None + + with open(log_path, 'r') as f: + text = f.read() + + summary = parse_summary_metrics(text) + per_class = parse_per_class_table(text) + + return summary, per_class + + +def find_log_file(work_dir, modal, prefer_metrics=True): + """Find the log file for a given experiment and modality.""" + candidates = [] + if prefer_metrics: + candidates.append(os.path.join(work_dir, f'metrics_{modal}', 'metrics_log.txt')) + candidates.append(os.path.join(work_dir, f'test_{modal}', 'test_log.txt')) + # Also try metrics log as fallback + candidates.append(os.path.join(work_dir, f'metrics_{modal}', 'metrics_log.txt')) + + for path in candidates: + if os.path.exists(path): + return path + + # Last resort: search for any log in the directory + for subdir_pattern in [f'metrics_{modal}', f'test_{modal}']: + subdir = os.path.join(work_dir, subdir_pattern) + if os.path.isdir(subdir): + # Try to find mmengine log files + log_files = glob.glob(os.path.join(subdir, '*.log')) + if log_files: + return sorted(log_files)[-1] + + return None + + +# ============================================================================ +# Experiment Definitions +# ============================================================================ + +TABLE2_EXPERIMENTS = OrderedDict([ + ('Full Model', { + 'work_dir': 'work_dirs/floodnet/SwinmoeB/655', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'Swin-B + MoE (E=8, K=3)', + }), + ('w/o MoE', { + 'work_dir': 'work_dirs/paper_experiments/table2/no_moe', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'Remove MoE, use standard FFN', + }), + ('w/o ModalSpecificStem', { + 'work_dir': 'work_dirs/paper_experiments/table2/no_modal_specific_stem', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'Shared patch embedding', + }), + ('w/o Modal Bias', { + 'work_dir': 'work_dirs/paper_experiments/table2/no_modal_bias', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'No modal bias in gating', + }), + ('w/o Shared Experts', { + 'work_dir': 'work_dirs/paper_experiments/table2/no_shared_experts', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'No shared experts', + }), + ('w/o Separate Decoder', { + 'work_dir': 'work_dirs/paper_experiments/table2/shared_decoder', + 'modals': ['sar', 'rgb', 'GF'], + 'desc': 'Shared decoder head', + }), +]) + +TABLE3_EXPERIMENTS = OrderedDict([ + ('E=6, K=1', { + 'work_dir': 'work_dirs/paper_experiments/table3/e6_k1', + 'modals': ['sar', 'rgb', 'GF'], + }), + ('E=6, K=2', { + 'work_dir': 'work_dirs/paper_experiments/table3/e6_k2', + 'modals': ['sar', 'rgb', 'GF'], + }), + ('E=6, K=3', { + 'work_dir': 'work_dirs/paper_experiments/table3/e6_k3', + 'modals': ['sar', 'rgb', 'GF'], + }), + ('E=8, K=1', { + 'work_dir': 'work_dirs/paper_experiments/table3/e8_k1', + 'modals': ['sar', 'rgb', 'GF'], + }), + ('E=8, K=2', { + 'work_dir': 'work_dirs/paper_experiments/table3/e8_k2', + 'modals': ['sar', 'rgb', 'GF'], + }), + ('E=8, K=3 (Full)', { + 'work_dir': 'work_dirs/floodnet/SwinmoeB/655', + 'modals': ['sar', 'rgb', 'GF'], + }), +]) + +TABLE4_EXPERIMENTS = OrderedDict([ + ('SAR-only', { + 'work_dir': 'work_dirs/paper_experiments/table4/sar_only', + 'modals': ['sar'], + }), + ('RGB-only', { + 'work_dir': 'work_dirs/paper_experiments/table4/rgb_only', + 'modals': ['rgb'], + }), + ('GF-only', { + 'work_dir': 'work_dirs/paper_experiments/table4/gf_only', + 'modals': ['GF'], + }), + ('Multi-Modal (Full)', { + 'work_dir': 'work_dirs/floodnet/SwinmoeB/655', + 'modals': ['sar', 'rgb', 'GF'], + }), +]) + +SUMMARY_METRICS = ['aAcc', 'mIoU', 'mAcc', 'mDice', 'mFscore', 'mPrecision', 'mRecall'] +MODAL_DISPLAY = {'sar': 'SAR', 'rgb': 'RGB', 'GF': 'GaoFen'} + + +# ============================================================================ +# Excel Generation +# ============================================================================ + +def collect_table_data(experiments, work_root=''): + """Collect all metrics for a table's experiments.""" + results = OrderedDict() + + for exp_name, exp_info in experiments.items(): + work_dir = exp_info['work_dir'] + if work_root and not os.path.isabs(work_dir): + # Don't prepend work_root if work_dir is already an absolute-like path + # that doesn't start with work_root + if not work_dir.startswith('work_dirs/paper_experiments'): + pass # Keep original path (e.g., Full Model) + # work_dir stays as is + + results[exp_name] = {} + for modal in exp_info['modals']: + log_path = find_log_file(work_dir, modal) + if log_path: + summary, per_class = parse_log_file(log_path) + results[exp_name][modal] = { + 'summary': summary or {}, + 'per_class': per_class or {}, + 'log_path': log_path, + } + found_metrics = list((summary or {}).keys()) + print(f" [OK] {exp_name} / {MODAL_DISPLAY.get(modal, modal)}: " + f"{log_path} ({len(found_metrics)} metrics)") + else: + results[exp_name][modal] = { + 'summary': {}, + 'per_class': {}, + 'log_path': None, + } + print(f" [--] {exp_name} / {MODAL_DISPLAY.get(modal, modal)}: " + f"no log found in {work_dir}") + + return results + + +def write_excel(results, experiments, output_path, table_name): + """Write results to Excel with openpyxl for better formatting.""" + wb = openpyxl.Workbook() + + # ---- Style definitions ---- + header_font = Font(bold=True, size=11) + header_fill = PatternFill(start_color='4472C4', end_color='4472C4', fill_type='solid') + header_font_white = Font(bold=True, size=11, color='FFFFFF') + modal_fill = { + 'sar': PatternFill(start_color='E2EFDA', end_color='E2EFDA', fill_type='solid'), + 'rgb': PatternFill(start_color='D6E4F0', end_color='D6E4F0', fill_type='solid'), + 'GF': PatternFill(start_color='FCE4D6', end_color='FCE4D6', fill_type='solid'), + } + exp_fill = PatternFill(start_color='F2F2F2', end_color='F2F2F2', fill_type='solid') + best_font = Font(bold=True, color='C00000') + thin_border = Border( + left=Side(style='thin'), right=Side(style='thin'), + top=Side(style='thin'), bottom=Side(style='thin')) + center_align = Alignment(horizontal='center', vertical='center') + + # ==================== Sheet 1: Summary Table ==================== + ws = wb.active + ws.title = 'Summary' + + # Determine all modals used + all_modals = [] + for exp_info in experiments.values(): + for m in exp_info['modals']: + if m not in all_modals: + all_modals.append(m) + + # Header row 1: Experiment | SAR (spanning) | RGB (spanning) | GF (spanning) + row = 1 + ws.cell(row=row, column=1, value='Experiment').font = header_font_white + ws.cell(row=row, column=1).fill = header_fill + ws.cell(row=row, column=1).border = thin_border + + col = 2 + for modal in all_modals: + modal_name = MODAL_DISPLAY.get(modal, modal) + start_col = col + for metric in SUMMARY_METRICS: + ws.cell(row=row + 1, column=col, value=metric).font = header_font + ws.cell(row=row + 1, column=col).fill = modal_fill.get(modal, exp_fill) + ws.cell(row=row + 1, column=col).border = thin_border + ws.cell(row=row + 1, column=col).alignment = center_align + col += 1 + end_col = col - 1 + # Merge header for modal name + ws.merge_cells(start_row=row, start_column=start_col, + end_row=row, end_column=end_col) + cell = ws.cell(row=row, column=start_col, value=modal_name) + cell.font = header_font_white + cell.fill = header_fill + cell.alignment = center_align + cell.border = thin_border + + ws.merge_cells(start_row=row, start_column=1, end_row=row + 1, end_column=1) + + # Data rows + data_row = row + 2 + # Track best values per (modal, metric) for highlighting + metric_values = {(m, met): [] for m in all_modals for met in SUMMARY_METRICS} + + for exp_name in results: + for modal in all_modals: + if modal in results[exp_name]: + summary = results[exp_name][modal]['summary'] + for met in SUMMARY_METRICS: + val = summary.get(met) + if val is not None: + metric_values[(modal, met)].append((exp_name, val)) + + for exp_name, exp_data in results.items(): + ws.cell(row=data_row, column=1, value=exp_name).font = Font(bold=True) + ws.cell(row=data_row, column=1).border = thin_border + + col = 2 + for modal in all_modals: + if modal in exp_data and exp_data[modal]['summary']: + summary = exp_data[modal]['summary'] + for metric in SUMMARY_METRICS: + val = summary.get(metric) + cell = ws.cell(row=data_row, column=col) + if val is not None: + cell.value = val + cell.number_format = '0.00' + # Bold best value + best_vals = metric_values.get((modal, metric), []) + if best_vals: + best_val = max(v for _, v in best_vals) + if val == best_val and len(best_vals) > 1: + cell.font = best_font + else: + cell.value = '-' + cell.border = thin_border + cell.alignment = center_align + col += 1 + else: + for _ in SUMMARY_METRICS: + cell = ws.cell(row=data_row, column=col, value='-') + cell.border = thin_border + cell.alignment = center_align + col += 1 + + data_row += 1 + + # Auto-width + for col_idx in range(1, col + 1): + ws.column_dimensions[get_column_letter(col_idx)].width = 13 + ws.column_dimensions['A'].width = 25 + + # ==================== Sheet 2: Per-Class Details ==================== + ws2 = wb.create_sheet('Per-Class Details') + row = 1 + + for exp_name, exp_data in results.items(): + for modal in all_modals: + if modal not in exp_data or not exp_data[modal]['per_class']: + continue + + per_class = exp_data[modal]['per_class'] + modal_name = MODAL_DISPLAY.get(modal, modal) + + # Section header + ws2.cell(row=row, column=1, + value=f'{exp_name} - {modal_name}').font = Font(bold=True, size=12) + ws2.cell(row=row, column=1).fill = modal_fill.get(modal, exp_fill) + row += 1 + + # Get metric columns from first class + first_class = list(per_class.keys())[0] + metric_cols = list(per_class[first_class].keys()) + + # Header + ws2.cell(row=row, column=1, value='Class').font = header_font + ws2.cell(row=row, column=1).border = thin_border + for j, met in enumerate(metric_cols, 2): + ws2.cell(row=row, column=j, value=met).font = header_font + ws2.cell(row=row, column=j).border = thin_border + ws2.cell(row=row, column=j).alignment = center_align + row += 1 + + # Data + for cls_name, cls_metrics in per_class.items(): + ws2.cell(row=row, column=1, value=cls_name).border = thin_border + for j, met in enumerate(metric_cols, 2): + cell = ws2.cell(row=row, column=j) + cell.value = cls_metrics.get(met, '-') + cell.border = thin_border + cell.alignment = center_align + if isinstance(cell.value, float): + cell.number_format = '0.00' + row += 1 + + row += 1 # Blank row between sections + + # Auto-width sheet 2 + ws2.column_dimensions['A'].width = 20 + for col_idx in range(2, 15): + ws2.column_dimensions[get_column_letter(col_idx)].width = 13 + + # Save + wb.save(output_path) + print(f'\n => Saved: {output_path}') + + +def write_csv_fallback(results, experiments, output_path, table_name): + """Fallback: write CSV if openpyxl not available.""" + import csv + + all_modals = [] + for exp_info in experiments.values(): + for m in exp_info['modals']: + if m not in all_modals: + all_modals.append(m) + + csv_path = output_path.replace('.xlsx', '.csv') + + with open(csv_path, 'w', newline='') as f: + writer = csv.writer(f) + + # Header + header = ['Experiment'] + for modal in all_modals: + modal_name = MODAL_DISPLAY.get(modal, modal) + for metric in SUMMARY_METRICS: + header.append(f'{modal_name}_{metric}') + writer.writerow(header) + + # Data + for exp_name, exp_data in results.items(): + row = [exp_name] + for modal in all_modals: + if modal in exp_data and exp_data[modal]['summary']: + summary = exp_data[modal]['summary'] + for metric in SUMMARY_METRICS: + val = summary.get(metric) + row.append(f'{val:.2f}' if val is not None else '-') + else: + row.extend(['-'] * len(SUMMARY_METRICS)) + writer.writerow(row) + + print(f'\n => Saved (CSV fallback): {csv_path}') + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser( + description='Collect experiment results into Excel tables') + parser.add_argument('--work-root', default='.', + help='Root directory (default: current dir)') + parser.add_argument('--output-dir', default=None, + help='Output directory (default: work_dirs/paper_experiments/)') + parser.add_argument('--tables', nargs='+', + default=['table2', 'table3', 'table4'], + help='Which tables to collect (default: all)') + args = parser.parse_args() + + os.chdir(args.work_root) + + output_dir = args.output_dir or 'work_dirs/paper_experiments' + os.makedirs(output_dir, exist_ok=True) + + write_fn = write_excel if HAS_OPENPYXL else write_csv_fallback + if not HAS_OPENPYXL: + print('[WARN] openpyxl not installed. Will output CSV instead of Excel.') + print(' Install with: pip install openpyxl') + + table_configs = { + 'table2': ('Table2_ComponentAblation', TABLE2_EXPERIMENTS), + 'table3': ('Table3_MoE_Hyperparameters', TABLE3_EXPERIMENTS), + 'table4': ('Table4_SingleModal_vs_MultiModal', TABLE4_EXPERIMENTS), + } + + for table_key in args.tables: + if table_key not in table_configs: + print(f'[WARN] Unknown table: {table_key}, skipping') + continue + + table_name, experiments = table_configs[table_key] + print(f'\n{"="*60}') + print(f'Collecting {table_name}') + print(f'{"="*60}') + + results = collect_table_data(experiments) + + ext = '.xlsx' if HAS_OPENPYXL else '.csv' + output_path = os.path.join(output_dir, f'{table_name}{ext}') + write_fn(results, experiments, output_path, table_name) + + print(f'\nDone! Results saved to {output_dir}/') + + +if __name__ == '__main__': + main() diff --git a/scripts/run_all_experiments.sh b/scripts/run_all_experiments.sh new file mode 100755 index 0000000000..70aba534f5 --- /dev/null +++ b/scripts/run_all_experiments.sh @@ -0,0 +1,321 @@ +#!/bin/bash +# ============================================================================ +# FloodNet Paper Experiments - Complete Run Script +# Multi-Modal Flood Segmentation with Swin-Base + MoE +# ============================================================================ +# +# NOTE: Full Model (Swin-B + MoE, E=8 K=3) is already trained and tested. +# This script only runs ablation/hyperparameter/single-modal experiments. +# +# Usage: +# bash scripts/run_all_experiments.sh table2 # Component Ablation +# bash scripts/run_all_experiments.sh table3 # MoE Hyperparameter Study +# bash scripts/run_all_experiments.sh table4 # Single-Modal vs Multi-Modal +# bash scripts/run_all_experiments.sh all # Run all tables +# +# Environment Variables: +# GPU_IDS GPU device IDs (default: 0) +# +# Example: +# GPU_IDS=0 bash scripts/run_all_experiments.sh table2 +# +# Each experiment uses --seed 42 for reproducibility. +# After training, each experiment runs per-modality testing automatically. +# ============================================================================ + +set -e + +SEED=42 +GPU_IDS=${GPU_IDS:-0} +CONFIG_DIR="configs/floodnet" +ABLATION_DIR="${CONFIG_DIR}/ablations" +WORK_ROOT="work_dirs/paper_experiments" +RESULTS_LOG="${WORK_ROOT}/results_summary.txt" + +GROUP=${1:-"all"} + +mkdir -p "${WORK_ROOT}" + +# ============================================================================ +# Helper Functions +# ============================================================================ + +run_train() { + local config=$1 + local work_dir=$2 + local desc=$3 + + echo "============================================================" + echo "[TRAIN] ${desc}" + echo "[CFG] ${config}" + echo "[DIR] ${work_dir}" + echo "============================================================" + + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/train.py \ + ${config} \ + --work-dir ${work_dir} \ + --cfg-options randomness.seed=${SEED} + + echo "[TRAIN DONE] ${desc}" + echo "" +} + +find_best_ckpt() { + local work_dir=$1 + local best_ckpt=$(ls ${work_dir}/best_mIoU_*.pth 2>/dev/null | head -1) + if [ -z "$best_ckpt" ]; then + best_ckpt=$(ls ${work_dir}/epoch_*.pth 2>/dev/null | sort -V | tail -1) + fi + echo "$best_ckpt" +} + +run_test_modal() { + local config=$1 + local checkpoint=$2 + local work_dir=$3 + local modal=$4 + local desc=$5 + + local test_work_dir="${work_dir}/test_${modal}" + mkdir -p "${test_work_dir}" + + echo "------------------------------------------------------------" + echo "[TEST] ${desc} | Modality: ${modal}" + echo "[CKPT] ${checkpoint}" + echo "------------------------------------------------------------" + + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \ + ${config} \ + ${checkpoint} \ + --work-dir ${test_work_dir} \ + --cfg-options \ + test_dataloader.dataset.filter_modality="${modal}" \ + 2>&1 | tee "${test_work_dir}/test_log.txt" + + echo "[${desc}] modal=${modal} -> see ${test_work_dir}/test_log.txt" >> "${RESULTS_LOG}" + echo "" +} + +run_test_all_modals() { + local config=$1 + local checkpoint=$2 + local work_dir=$3 + local desc=$4 + + echo "============================================================" + echo "[TEST ALL MODALS] ${desc}" + echo "============================================================" + + run_test_modal "${config}" "${checkpoint}" "${work_dir}" "sar" "${desc}" + run_test_modal "${config}" "${checkpoint}" "${work_dir}" "rgb" "${desc}" + run_test_modal "${config}" "${checkpoint}" "${work_dir}" "GF" "${desc}" +} + +train_and_test_all_modals() { + local config=$1 + local work_dir=$2 + local desc=$3 + + run_train "${config}" "${work_dir}" "${desc}" + + local ckpt=$(find_best_ckpt "${work_dir}") + if [ -z "$ckpt" ]; then + echo "[ERROR] No checkpoint found in ${work_dir} after training" + return 1 + fi + + run_test_all_modals "${config}" "${ckpt}" "${work_dir}" "${desc}" +} + +train_and_test_single_modal() { + local config=$1 + local work_dir=$2 + local modal=$3 + local desc=$4 + + run_train "${config}" "${work_dir}" "${desc}" + + local ckpt=$(find_best_ckpt "${work_dir}") + if [ -z "$ckpt" ]; then + echo "[ERROR] No checkpoint found in ${work_dir} after training" + return 1 + fi + + run_test_modal "${config}" "${ckpt}" "${work_dir}" "${modal}" "${desc}" +} + +# ============================================================================ +# TABLE 2: Component Ablation Study +# ============================================================================ +# Full Model (a) is already trained and tested — NOT included here. +# Only the ablation variants (b)-(f) are trained and tested. +# +# (b) w/o MoE — train + test SAR/RGB/GF +# (c) w/o ModalSpecificStem— train + test SAR/RGB/GF +# (d) w/o Modal Bias — train + test SAR/RGB/GF +# (e) w/o Shared Experts — train + test SAR/RGB/GF +# (f) w/o Separate Decoder — train + test SAR/RGB/GF +# ============================================================================ +if [[ "$GROUP" == "all" || "$GROUP" == "table2" ]]; then + echo "" + echo "################################################################" + echo "# TABLE 2: Component Ablation Study #" + echo "# (Full Model result is already available) #" + echo "################################################################" + echo "" + + # (b) w/o MoE + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_no_moe.py" \ + "${WORK_ROOT}/table2/no_moe" \ + "Table2(b) w/o MoE" + + # (c) w/o ModalSpecificStem + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_no_modal_specific_stem.py" \ + "${WORK_ROOT}/table2/no_modal_specific_stem" \ + "Table2(c) w/o ModalSpecificStem" + + # (d) w/o Modal Bias + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_no_modal_bias.py" \ + "${WORK_ROOT}/table2/no_modal_bias" \ + "Table2(d) w/o Modal Bias" + + # (e) w/o Shared Experts + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_no_shared_experts.py" \ + "${WORK_ROOT}/table2/no_shared_experts" \ + "Table2(e) w/o Shared Experts" + + # (f) w/o Separate Decoder (use shared decoder) + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_shared_decoder.py" \ + "${WORK_ROOT}/table2/shared_decoder" \ + "Table2(f) w/o Separate Decoder" + + echo "" + echo "[TABLE 2 COMPLETE] Results in ${WORK_ROOT}/table2/" + echo "" +fi + +# ============================================================================ +# TABLE 3: MoE Hyperparameter Study +# ============================================================================ +# Grid: num_experts={6, 8} x top_k={1, 2, 3} +# (8, 3) = Full Model — already done, NOT included here. +# +# Each variant tested on SAR / RGB / GF separately. +# ============================================================================ +if [[ "$GROUP" == "all" || "$GROUP" == "table3" ]]; then + echo "" + echo "################################################################" + echo "# TABLE 3: MoE Hyperparameter Study #" + echo "# (E=8 K=3 Full Model result is already available) #" + echo "################################################################" + echo "" + + # E=6, K=1 + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_e6_k1.py" \ + "${WORK_ROOT}/table3/e6_k1" \ + "Table3 E=6 K=1" + + # E=6, K=2 + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_e6_k2.py" \ + "${WORK_ROOT}/table3/e6_k2" \ + "Table3 E=6 K=2" + + # E=6, K=3 + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_e6_k3.py" \ + "${WORK_ROOT}/table3/e6_k3" \ + "Table3 E=6 K=3" + + # E=8, K=1 + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_e8_k1.py" \ + "${WORK_ROOT}/table3/e8_k1" \ + "Table3 E=8 K=1" + + # E=8, K=2 + train_and_test_all_modals \ + "${ABLATION_DIR}/ablation_e8_k2.py" \ + "${WORK_ROOT}/table3/e8_k2" \ + "Table3 E=8 K=2" + + # E=8, K=3 = Full Model — SKIP (already trained and tested) + echo "[SKIP] E=8 K=3 = Full Model (already trained and tested)" + + echo "" + echo "[TABLE 3 COMPLETE] Results in ${WORK_ROOT}/table3/" + echo "" +fi + +# ============================================================================ +# TABLE 4: Single-Modal vs Multi-Modal Training +# ============================================================================ +# SAR-only → train + test SAR +# RGB-only → train + test RGB +# GF-only → train + test GF +# Multi-modal (Full Model) → SKIP (already trained and tested) +# ============================================================================ +if [[ "$GROUP" == "all" || "$GROUP" == "table4" ]]; then + echo "" + echo "################################################################" + echo "# TABLE 4: Single-Modal vs Multi-Modal #" + echo "# (Multi-modal Full Model result is already available) #" + echo "################################################################" + echo "" + + # SAR-only → test SAR + train_and_test_single_modal \ + "${CONFIG_DIR}/multimodal_floodnet_sar_only_swinbase_moe_config.py" \ + "${WORK_ROOT}/table4/sar_only" \ + "sar" \ + "Table4 SAR-only" + + # RGB-only → test RGB + train_and_test_single_modal \ + "${ABLATION_DIR}/ablation_rgb_only.py" \ + "${WORK_ROOT}/table4/rgb_only" \ + "rgb" \ + "Table4 RGB-only" + + # GF-only → test GF + train_and_test_single_modal \ + "${ABLATION_DIR}/ablation_gf_only.py" \ + "${WORK_ROOT}/table4/gf_only" \ + "GF" \ + "Table4 GF-only" + + # Multi-modal = Full Model — SKIP + echo "[SKIP] Multi-modal = Full Model (already trained and tested)" + + echo "" + echo "[TABLE 4 COMPLETE] Results in ${WORK_ROOT}/table4/" + echo "" +fi + +# ============================================================================ +# Summary +# ============================================================================ +echo "" +echo "============================================================" +echo "EXPERIMENTS COMPLETED" +echo "Results log: ${RESULTS_LOG}" +echo "============================================================" +echo "" +echo "Result directories:" +if [[ "$GROUP" == "all" || "$GROUP" == "table2" ]]; then + echo " Table 2 (Ablation): ${WORK_ROOT}/table2/" +fi +if [[ "$GROUP" == "all" || "$GROUP" == "table3" ]]; then + echo " Table 3 (MoE Hyper): ${WORK_ROOT}/table3/" +fi +if [[ "$GROUP" == "all" || "$GROUP" == "table4" ]]; then + echo " Table 4 (Single-Modal): ${WORK_ROOT}/table4/" +fi +echo "" +echo "Per-modality test logs: /test_{sar,rgb,GF}/test_log.txt" diff --git a/scripts/test_all_metrics.sh b/scripts/test_all_metrics.sh new file mode 100755 index 0000000000..fb133466cd --- /dev/null +++ b/scripts/test_all_metrics.sh @@ -0,0 +1,324 @@ +#!/bin/bash +# ============================================================================ +# Re-test All Table 2/3/4 Experiments with Expanded Metrics +# Metrics: mIoU, mDice, mFscore (Precision, Recall, F1), aAcc (OA) +# +# Usage: +# bash scripts/test_all_metrics.sh table2 +# bash scripts/test_all_metrics.sh table3 +# bash scripts/test_all_metrics.sh table4 +# bash scripts/test_all_metrics.sh full # Full Model only +# bash scripts/test_all_metrics.sh all # All tables + Full Model +# +# Environment Variables: +# GPU_IDS GPU device IDs (default: 0) +# +# Example: +# GPU_IDS=0 bash scripts/test_all_metrics.sh table2 +# ============================================================================ + +set -e + +GPU_IDS=${GPU_IDS:-0} +CONFIG_DIR="configs/floodnet" +ABLATION_DIR="${CONFIG_DIR}/ablations" +WORK_ROOT="work_dirs/paper_experiments" +METRICS_LOG="${WORK_ROOT}/metrics_summary.txt" + +GROUP=${1:-"all"} + +mkdir -p "${WORK_ROOT}" +echo "========== Metrics Test Run: $(date) ==========" >> "${METRICS_LOG}" + +# ============================================================================ +# Helper Functions +# ============================================================================ + +find_best_ckpt() { + local work_dir=$1 + local best_ckpt=$(ls ${work_dir}/best_mIoU_*.pth 2>/dev/null | head -1) + if [ -z "$best_ckpt" ]; then + best_ckpt=$(ls ${work_dir}/epoch_*.pth 2>/dev/null | sort -V | tail -1) + fi + echo "$best_ckpt" +} + +run_test_metrics() { + local config=$1 + local checkpoint=$2 + local work_dir=$3 + local modal=$4 + local desc=$5 + + local test_work_dir="${work_dir}/metrics_${modal}" + mkdir -p "${test_work_dir}" + + echo "------------------------------------------------------------" + echo "[TEST] ${desc} | Modality: ${modal}" + echo "[CKPT] ${checkpoint}" + echo "[OUT] ${test_work_dir}" + echo "------------------------------------------------------------" + + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \ + ${config} \ + ${checkpoint} \ + --work-dir ${test_work_dir} \ + --cfg-options \ + test_dataloader.dataset.filter_modality="${modal}" \ + "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \ + 2>&1 | tee "${test_work_dir}/metrics_log.txt" + + echo "[${desc}] modal=${modal} -> ${test_work_dir}/metrics_log.txt" >> "${METRICS_LOG}" + echo "" +} + +run_test_all_modals_metrics() { + local config=$1 + local checkpoint=$2 + local work_dir=$3 + local desc=$4 + + echo "============================================================" + echo "[METRICS TEST] ${desc}" + echo "============================================================" + + run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "sar" "${desc}" + run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "rgb" "${desc}" + run_test_metrics "${config}" "${checkpoint}" "${work_dir}" "GF" "${desc}" +} + +# ============================================================================ +# Full Model +# ============================================================================ + +run_full_model() { + echo "" + echo "############################################################" + echo "# Full Model (Swin-B + MoE, E=8 K=3)" + echo "############################################################" + + local config="${CONFIG_DIR}/multimodal_floodnet_sar_boost_swinbase_moe_config.py" + local work_dir="work_dirs/floodnet/SwinmoeB/655" + local ckpt="${work_dir}/best_mIoU_epoch_100.pth" + + if [ ! -f "$ckpt" ]; then + echo "[WARN] Full Model checkpoint not found: ${ckpt}" + echo "[WARN] Trying to find best checkpoint in ${work_dir}..." + ckpt=$(find_best_ckpt "${work_dir}") + fi + + if [ -z "$ckpt" ] || [ ! -f "$ckpt" ]; then + echo "[ERROR] No checkpoint found for Full Model. Skipping." + return 0 + fi + + run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "Full Model" +} + +# ============================================================================ +# Table 2: Component Ablation +# ============================================================================ + +run_table2() { + echo "" + echo "############################################################" + echo "# Table 2: Component Ablation Study" + echo "############################################################" + + local base_config="${CONFIG_DIR}/multimodal_floodnet_sar_boost_swinbase_moe_config.py" + + declare -A T2_CONFIGS + T2_CONFIGS=( + ["no_moe"]="${ABLATION_DIR}/ablation_no_moe.py" + ["no_modal_specific_stem"]="${ABLATION_DIR}/ablation_no_modal_specific_stem.py" + ["no_modal_bias"]="${ABLATION_DIR}/ablation_no_modal_bias.py" + ["no_shared_experts"]="${ABLATION_DIR}/ablation_no_shared_experts.py" + ["shared_decoder"]="${ABLATION_DIR}/ablation_shared_decoder.py" + ) + + declare -A T2_DESCS + T2_DESCS=( + ["no_moe"]="Table2: w/o MoE" + ["no_modal_specific_stem"]="Table2: w/o ModalSpecificStem" + ["no_modal_bias"]="Table2: w/o Modal Bias" + ["no_shared_experts"]="Table2: w/o Shared Experts" + ["shared_decoder"]="Table2: w/o Separate Decoder" + ) + + for key in no_moe no_modal_specific_stem no_modal_bias no_shared_experts shared_decoder; do + local config="${T2_CONFIGS[$key]}" + local work_dir="${WORK_ROOT}/table2/${key}" + local desc="${T2_DESCS[$key]}" + + local ckpt=$(find_best_ckpt "${work_dir}") + if [ -z "$ckpt" ]; then + echo "[ERROR] No checkpoint found for ${desc} in ${work_dir}. Skipping." + continue + fi + + run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "${desc}" + done +} + +# ============================================================================ +# Table 3: MoE Hyperparameter Study +# ============================================================================ + +run_table3() { + echo "" + echo "############################################################" + echo "# Table 3: MoE Hyperparameter Study" + echo "############################################################" + + declare -A T3_CONFIGS + T3_CONFIGS=( + ["e6_k1"]="${ABLATION_DIR}/ablation_e6_k1.py" + ["e6_k2"]="${ABLATION_DIR}/ablation_e6_k2.py" + ["e6_k3"]="${ABLATION_DIR}/ablation_e6_k3.py" + ["e8_k1"]="${ABLATION_DIR}/ablation_e8_k1.py" + ["e8_k2"]="${ABLATION_DIR}/ablation_e8_k2.py" + ) + + declare -A T3_DESCS + T3_DESCS=( + ["e6_k1"]="Table3: E=6 K=1" + ["e6_k2"]="Table3: E=6 K=2" + ["e6_k3"]="Table3: E=6 K=3" + ["e8_k1"]="Table3: E=8 K=1" + ["e8_k2"]="Table3: E=8 K=2" + ) + + for key in e6_k1 e6_k2 e6_k3 e8_k1 e8_k2; do + local config="${T3_CONFIGS[$key]}" + local work_dir="${WORK_ROOT}/table3/${key}" + local desc="${T3_DESCS[$key]}" + + local ckpt=$(find_best_ckpt "${work_dir}") + if [ -z "$ckpt" ]; then + echo "[ERROR] No checkpoint found for ${desc} in ${work_dir}. Skipping." + continue + fi + + run_test_all_modals_metrics "${config}" "${ckpt}" "${work_dir}" "${desc}" + done +} + +# ============================================================================ +# Table 4: Single-Modal vs Multi-Modal +# ============================================================================ + +run_table4() { + echo "" + echo "############################################################" + echo "# Table 4: Single-Modal vs Multi-Modal" + echo "############################################################" + + # SAR-only: test on SAR + local sar_config="${CONFIG_DIR}/multimodal_floodnet_sar_only_swinbase_moe_config.py" + local sar_dir="${WORK_ROOT}/table4/sar_only" + local sar_ckpt=$(find_best_ckpt "${sar_dir}") + if [ -n "$sar_ckpt" ]; then + local test_dir="${sar_dir}/metrics_sar" + mkdir -p "${test_dir}" + echo "------------------------------------------------------------" + echo "[TEST] Table4: SAR-Only | Modality: sar" + echo "[CKPT] ${sar_ckpt}" + echo "------------------------------------------------------------" + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \ + ${sar_config} ${sar_ckpt} \ + --work-dir ${test_dir} \ + --cfg-options \ + "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \ + 2>&1 | tee "${test_dir}/metrics_log.txt" + echo "[Table4: SAR-Only] modal=sar -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}" + else + echo "[ERROR] No checkpoint found for SAR-Only in ${sar_dir}. Skipping." + fi + + # RGB-only: test on RGB + local rgb_config="${ABLATION_DIR}/ablation_rgb_only.py" + local rgb_dir="${WORK_ROOT}/table4/rgb_only" + local rgb_ckpt=$(find_best_ckpt "${rgb_dir}") + if [ -n "$rgb_ckpt" ]; then + local test_dir="${rgb_dir}/metrics_rgb" + mkdir -p "${test_dir}" + echo "------------------------------------------------------------" + echo "[TEST] Table4: RGB-Only | Modality: rgb" + echo "[CKPT] ${rgb_ckpt}" + echo "------------------------------------------------------------" + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \ + ${rgb_config} ${rgb_ckpt} \ + --work-dir ${test_dir} \ + --cfg-options \ + "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \ + 2>&1 | tee "${test_dir}/metrics_log.txt" + echo "[Table4: RGB-Only] modal=rgb -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}" + else + echo "[ERROR] No checkpoint found for RGB-Only in ${rgb_dir}. Skipping." + fi + + # GF-only: test on GF + local gf_config="${ABLATION_DIR}/ablation_gf_only.py" + local gf_dir="${WORK_ROOT}/table4/gf_only" + local gf_ckpt=$(find_best_ckpt "${gf_dir}") + if [ -n "$gf_ckpt" ]; then + local test_dir="${gf_dir}/metrics_GF" + mkdir -p "${test_dir}" + echo "------------------------------------------------------------" + echo "[TEST] Table4: GF-Only | Modality: GF" + echo "[CKPT] ${gf_ckpt}" + echo "------------------------------------------------------------" + CUDA_VISIBLE_DEVICES=${GPU_IDS} python tools/test.py \ + ${gf_config} ${gf_ckpt} \ + --work-dir ${test_dir} \ + --cfg-options \ + "test_evaluator.iou_metrics=['mIoU','mDice','mFscore']" \ + 2>&1 | tee "${test_dir}/metrics_log.txt" + echo "[Table4: GF-Only] modal=GF -> ${test_dir}/metrics_log.txt" >> "${METRICS_LOG}" + else + echo "[ERROR] No checkpoint found for GF-Only in ${gf_dir}. Skipping." + fi +} + +# ============================================================================ +# Main +# ============================================================================ + +echo "============================================================" +echo " Expanded Metrics Testing" +echo " Metrics: mIoU, mDice, mFscore (Precision/Recall/F1), aAcc" +echo " Group: ${GROUP}" +echo " GPU: ${GPU_IDS}" +echo "============================================================" + +case "${GROUP}" in + full) + run_full_model + ;; + table2) + run_table2 + ;; + table3) + run_table3 + ;; + table4) + run_table4 + ;; + all) + run_full_model + run_table2 + run_table3 + run_table4 + ;; + *) + echo "Unknown group: ${GROUP}" + echo "Usage: bash scripts/test_all_metrics.sh {full|table2|table3|table4|all}" + exit 1 + ;; +esac + +echo "" +echo "============================================================" +echo " All metrics tests completed!" +echo " Results log: ${METRICS_LOG}" +echo "============================================================" diff --git a/tools/analysis_tools/benchmark_multimodal.py b/tools/analysis_tools/benchmark_multimodal.py new file mode 100644 index 0000000000..e607e5fcb7 --- /dev/null +++ b/tools/analysis_tools/benchmark_multimodal.py @@ -0,0 +1,368 @@ +""" +Compute Inference FPS and FLOPs for Multi-Modal Swin-MoE model. + +Usage: + python tools/analysis_tools/benchmark_multimodal.py \ + configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \ + work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \ + --shape 256 256 \ + --repeat-times 3 \ + --num-iters 200 +""" + +import argparse +import time + +import numpy as np +import torch +import torch.nn as nn +from mmengine import Config +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + +try: + from fvcore.nn import FlopCountAnalysis + HAS_FVCORE = True +except ImportError: + HAS_FVCORE = False + + +MODAL_CHANNELS = { + 'sar': 8, + 'rgb': 3, + 'GF': 5, +} + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Benchmark Multi-Modal Segmentor (FPS + FLOPs)') + parser.add_argument('config', help='config file path') + parser.add_argument('checkpoint', help='checkpoint file path') + parser.add_argument('--shape', type=int, nargs=2, default=[256, 256], + help='input H W (default: 256 256)') + parser.add_argument('--repeat-times', type=int, default=3, + help='number of FPS measurement runs') + parser.add_argument('--num-iters', type=int, default=200, + help='iterations per FPS run') + parser.add_argument('--num-warmup', type=int, default=10, + help='warmup iterations before timing') + parser.add_argument('--modals', nargs='+', default=['sar', 'rgb', 'GF'], + help='modalities to benchmark') + return parser.parse_args() + + +def build_model(cfg, checkpoint_path): + """Build model and load checkpoint.""" + cfg.model.train_cfg = None + model = MODELS.build(cfg.model) + + load_checkpoint(model, checkpoint_path, map_location='cpu') + + if torch.cuda.is_available(): + model = model.cuda() + model = revert_sync_batchnorm(model) + model.eval() + return model + + +def make_dummy_input(modal, shape, device='cuda'): + """Create a dummy input mimicking the data preprocessor output.""" + h, w = shape + channels = MODAL_CHANNELS[modal] + img = torch.randn(channels, h, w, device=device) + + data_sample = SegDataSample() + data_sample.set_metainfo(dict( + img_shape=(h, w), + ori_shape=(h, w), + pad_shape=(h, w), + scale_factor=(1.0, 1.0), + flip=False, + flip_direction=None, + modal_type=modal, + actual_channels=channels, + dataset_name=modal, + reduce_zero_label=False, + )) + + return img, data_sample + + +def measure_fps(model, modal, shape, num_iters=200, num_warmup=10): + """Measure inference FPS for a single modality.""" + device = next(model.parameters()).device + img, data_sample = make_dummy_input(modal, shape, device) + + # Warmup + with torch.no_grad(): + for _ in range(num_warmup): + model([img], [data_sample], mode='predict') + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # Timed iterations + start = time.perf_counter() + with torch.no_grad(): + for _ in range(num_iters): + model([img], [data_sample], mode='predict') + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + elapsed = time.perf_counter() - start + fps = num_iters / elapsed + return fps + + +class BackboneDecoderWrapper(nn.Module): + """Wrapper that runs backbone + decode_head forward only. + + Avoids predict/inference path so FLOPs tools (fvcore/thop) + don't encounter SegDataSample or slide_inference logic. + """ + + def __init__(self, model, modal): + super().__init__() + self.backbone = model.backbone + self.modal = modal + # Get the correct decode head + if hasattr(model, 'decode_heads') and modal in model.decode_heads: + self.decode_head = model.decode_heads[modal] + elif hasattr(model, '_shared_decode_head'): + self.decode_head = model._shared_decode_head + else: + self.decode_head = None + + def forward(self, x): + """x: (1, C, H, W) tensor for a single modality.""" + imgs_list = [x[0]] # list of (C, H, W) + modal_types = [self.modal] + features, _, _ = self.backbone(imgs_list, modal_types) + # features is a tuple of multi-scale tensors + if self.decode_head is not None: + out = self.decode_head(features) + return out + + +def measure_flops_manual(model, modal, shape): + """Measure FLOPs by manually counting ops in backbone + decoder. + + Uses a forward hook approach to count multiply-accumulate operations. + """ + device = next(model.parameters()).device + h, w = shape + channels = MODAL_CHANNELS[modal] + img = torch.randn(1, channels, h, w, device=device) + + wrapper = BackboneDecoderWrapper(model, modal) + wrapper.eval() + + total_flops = 0 + + def count_conv2d(m, inp, out): + nonlocal total_flops + x = inp[0] + batch = x.shape[0] + out_h, out_w = out.shape[2], out.shape[3] + kernel_ops = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels // m.groups) + total_flops += batch * m.out_channels * out_h * out_w * kernel_ops + + def count_linear(m, inp, out): + nonlocal total_flops + x = inp[0] + batch_size = x.numel() // m.in_features + total_flops += batch_size * m.in_features * m.out_features + + def count_layernorm(m, inp, out): + nonlocal total_flops + total_flops += inp[0].numel() * 2 # mean + variance + + def count_gelu(m, inp, out): + nonlocal total_flops + total_flops += inp[0].numel() * 4 # approximate + + def count_bn(m, inp, out): + nonlocal total_flops + total_flops += inp[0].numel() * 2 + + hooks = [] + for m in wrapper.modules(): + if isinstance(m, nn.Conv2d): + hooks.append(m.register_forward_hook(count_conv2d)) + elif isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(count_linear)) + elif isinstance(m, nn.LayerNorm): + hooks.append(m.register_forward_hook(count_layernorm)) + elif isinstance(m, nn.GELU): + hooks.append(m.register_forward_hook(count_gelu)) + elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + hooks.append(m.register_forward_hook(count_bn)) + + with torch.no_grad(): + try: + wrapper(img) + except Exception as e: + print(f' [manual WARN] forward failed: {e}') + for h in hooks: + h.remove() + return None + + for h in hooks: + h.remove() + + return total_flops + + +def measure_flops_fvcore(model, modal, shape): + """Measure FLOPs using fvcore on backbone + decoder wrapper.""" + device = next(model.parameters()).device + h, w = shape + channels = MODAL_CHANNELS[modal] + img = torch.randn(1, channels, h, w, device=device) + + wrapper = BackboneDecoderWrapper(model, modal) + wrapper.eval() + + try: + flop_analysis = FlopCountAnalysis(wrapper, (img,)) + flop_analysis.unsupported_ops_warnings(False) + flop_analysis.uncalled_modules_warnings(False) + flops = flop_analysis.total() + return flops + except Exception as e: + print(f' [fvcore WARN] {e}') + return None + + +def format_flops(flops): + """Format FLOPs to human readable string.""" + if flops is None: + return 'N/A' + if flops >= 1e12: + return f'{flops / 1e12:.2f} TFLOPs' + elif flops >= 1e9: + return f'{flops / 1e9:.2f} GFLOPs' + elif flops >= 1e6: + return f'{flops / 1e6:.2f} MFLOPs' + else: + return f'{flops:.0f} FLOPs' + + +def format_params(params): + """Format parameter count.""" + if params >= 1e6: + return f'{params / 1e6:.2f} M' + elif params >= 1e3: + return f'{params / 1e3:.2f} K' + else: + return f'{params:.0f}' + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + init_default_scope(cfg.get('default_scope', 'mmseg')) + + print('=' * 60) + print('Multi-Modal Model Benchmark') + print(f'Config: {args.config}') + print(f'Checkpoint: {args.checkpoint}') + print(f'Input size: {args.shape[0]} x {args.shape[1]}') + print(f'Modalities: {args.modals}') + print('=' * 60) + print() + + # Build model + print('Building model...') + model = build_model(cfg, args.checkpoint) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f'Total params: {format_params(total_params)}') + print(f'Trainable params: {format_params(trainable_params)}') + print() + + # ======================== FLOPs ======================== + print('=' * 60) + print('FLOPs Measurement (backbone + decode_head)') + print('=' * 60) + + flops_results = {} + for modal in args.modals: + ch = MODAL_CHANNELS[modal] + print(f'\n[{modal.upper()}] input: ({ch}, {args.shape[0]}, {args.shape[1]})') + + flops_val = None + + # Try fvcore first (wraps backbone+decoder only, no SegDataSample) + if HAS_FVCORE: + print(' Method: fvcore') + flops_val = measure_flops_fvcore(model, modal, args.shape) + if flops_val is not None: + print(f' FLOPs: {format_flops(flops_val)}') + + # Fallback: manual hook-based counting + if flops_val is None: + print(' Method: manual (hook-based)') + flops_val = measure_flops_manual(model, modal, args.shape) + if flops_val is not None: + print(f' FLOPs: {format_flops(flops_val)}') + else: + print(' FLOPs: N/A') + + flops_results[modal] = flops_val + + # ======================== FPS ======================== + print() + print('=' * 60) + print('FPS Measurement (full predict pipeline)') + print(f' Warmup: {args.num_warmup} iters') + print(f' Timed: {args.num_iters} iters x {args.repeat_times} runs') + print('=' * 60) + + fps_results = {} + for modal in args.modals: + ch = MODAL_CHANNELS[modal] + print(f'\n[{modal.upper()}] input: ({ch}, {args.shape[0]}, {args.shape[1]})') + + fps_list = [] + for run_idx in range(args.repeat_times): + fps = measure_fps(model, modal, args.shape, + num_iters=args.num_iters, + num_warmup=args.num_warmup) + fps_list.append(fps) + print(f' Run {run_idx + 1}: {fps:.2f} img/s') + + avg_fps = np.mean(fps_list) + std_fps = np.std(fps_list) + print(f' >> Average: {avg_fps:.2f} +/- {std_fps:.2f} img/s') + fps_results[modal] = (avg_fps, std_fps) + + # ======================== Summary ======================== + print() + print('=' * 60) + print('Summary') + print('=' * 60) + print(f'{"Modality":<10} {"Channels":<10} {"FLOPs":<18} {"FPS (avg)":>15}') + print('-' * 55) + for modal in args.modals: + ch = MODAL_CHANNELS[modal] + flops_str = format_flops(flops_results.get(modal)) + fps_avg, fps_std = fps_results.get(modal, (0, 0)) + print(f'{modal:<10} {ch:<10} {flops_str:<18} {fps_avg:>10.2f} img/s') + print('-' * 55) + print(f'Params: {format_params(total_params)}') + print(f'Input: {args.shape[0]} x {args.shape[1]}') + print() + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index 78a73988d4..694e55f79c 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -62,8 +62,11 @@ def inference(args: argparse.Namespace, logger: MMLogger) -> dict: input_shape = (3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = (3, ) + tuple(args.shape) + elif len(args.shape) == 3: # 新增 + input_shape = tuple(args.shape) else: raise ValueError('invalid input shape') + print(f"DEBUG: Full input_shape = {input_shape}") # 添加这行 result = {} model: BaseSegmentor = MODELS.build(cfg.model) diff --git a/tools/analysis_tools/visualize_expert_routing.py b/tools/analysis_tools/visualize_expert_routing.py new file mode 100644 index 0000000000..bc31e2dfa1 --- /dev/null +++ b/tools/analysis_tools/visualize_expert_routing.py @@ -0,0 +1,696 @@ +""" +Figure 4: Expert Routing Analysis — Publication-quality visualization. + +Generates three sub-figures: + (a) Expert activation probability heatmap per modality per stage + (b) Learned Modal Bias matrix visualization + (c) Spatial expert assignment map for a single image + +Usage: + python tools/analysis_tools/visualize_expert_routing.py \ + configs/floodnet/multimodal_floodnet_sar_boost_swinbase_moe_config.py \ + work_dirs/floodnet/SwinmoeB/655/best_mIoU_epoch_100.pth \ + --data-root ../floodnet/data/mixed_dataset/ \ + --output-dir work_dirs/figures/expert_routing \ + --num-samples 50 +""" + +import argparse +import os +import os.path as osp +from collections import defaultdict + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import numpy as np +import torch +import torch.nn.functional as F +from matplotlib.colors import LinearSegmentedColormap +from mmengine import Config +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import Runner, load_checkpoint + +from mmseg.registry import MODELS + +# ======================== Publication Style ======================== +plt.rcParams.update({ + 'font.family': 'serif', + 'font.serif': ['Times New Roman', 'DejaVu Serif'], + 'font.size': 10, + 'axes.labelsize': 11, + 'axes.titlesize': 12, + 'xtick.labelsize': 9, + 'ytick.labelsize': 9, + 'legend.fontsize': 9, + 'figure.dpi': 300, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'savefig.pad_inches': 0.05, + 'axes.linewidth': 0.8, + 'axes.grid': False, +}) + +MODAL_DISPLAY = { + 'sar': 'UrbanSARflood', + 'rgb': 'FloodNet', + 'GF': 'GF-Floodnet', +} + +MODAL_CHANNELS = { + 'sar': 8, + 'rgb': 3, + 'GF': 5, +} + +MODAL_COLORS = { + 'sar': '#2196F3', + 'rgb': '#4CAF50', + 'GF': '#FF9800', +} + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Visualize Expert Routing Patterns') + parser.add_argument('config', help='config file path') + parser.add_argument('checkpoint', help='checkpoint file path') + parser.add_argument('--data-root', type=str, + default='../floodnet/data/mixed_dataset/', + help='data root for test set') + parser.add_argument('--output-dir', type=str, + default='work_dirs/figures/expert_routing', + help='output directory for figures') + parser.add_argument('--num-samples', type=int, default=50, + help='number of test samples per modality for stats') + parser.add_argument('--spatial-image-idx', type=int, default=0, + help='index of image for spatial assignment map') + return parser.parse_args() + + +# ======================== Model Building ======================== + +def build_model(cfg, checkpoint_path): + cfg.model.train_cfg = None + model = MODELS.build(cfg.model) + load_checkpoint(model, checkpoint_path, map_location='cpu') + if torch.cuda.is_available(): + model = model.cuda() + model = revert_sync_batchnorm(model) + model.eval() + return model + + +# ======================== Hook-based Gate Extraction ======================== + +class GateHookManager: + """Register hooks on all MoE gating layers to capture routing weights.""" + + def __init__(self, model): + self.model = model + self.hooks = [] + self.gate_records = [] # list of (stage, block, gates_tensor) + self._register_hooks() + + def _register_hooks(self): + backbone = self._get_backbone(self.model) + for stage_idx, stage_dict in enumerate(backbone.stages): + blocks = stage_dict['blocks'] + for block_idx, block in enumerate(blocks): + if block.use_moe: + moe_layer = block.mlp + hook = moe_layer.register_forward_hook( + self._make_hook(stage_idx, block_idx)) + self.hooks.append(hook) + + def _get_backbone(self, model): + if hasattr(model, 'module'): + model = model.module + if hasattr(model, 'backbone'): + return model.backbone + return model + + def _make_hook(self, stage_idx, block_idx): + def hook_fn(module, input_args, output): + # Re-compute gates (eval mode, no noise) + x = input_args[0] + modal_types = input_args[1] if len(input_args) > 1 else None + + with torch.no_grad(): + x_pooled = x.mean(dim=1) + clean_logits = module.gating(x_pooled, modal_types) + top_logits, top_indices = clean_logits.topk( + min(module.top_k, module.num_experts), dim=-1) + top_k_gates = F.softmax(top_logits, dim=-1) + zeros = torch.zeros_like(clean_logits) + gates = zeros.scatter(-1, top_indices, top_k_gates) + + self.gate_records.append({ + 'stage': stage_idx, + 'block': block_idx, + 'gates': gates.detach().cpu(), + 'modal_types': list(modal_types) if modal_types else None, + }) + + return hook_fn + + def clear(self): + self.gate_records = [] + + def remove_hooks(self): + for h in self.hooks: + h.remove() + self.hooks = [] + + +# ======================== Data Collection ======================== + +def collect_routing_stats(model, cfg, hook_mgr, data_root, num_samples): + """Run inference on real test data and collect per-modality routing stats. + + Loads actual test images from the dataset to capture true data-driven + routing patterns (feature-dependent + modal_bias), not just modal_bias. + Falls back to random noise if no images are found for a modality. + """ + from mmseg.structures import SegDataSample + from mmengine.dataset import Compose + import glob + + device = next(model.parameters()).device + stats = defaultdict(lambda: defaultdict(list)) + + # Build a minimal inference pipeline: load image, resize, normalize, pack. + # Skip LoadAnnotations (no labels needed for routing analysis). + # Use 256x256 (matching training crop) for efficient single-pass inference. + pipeline = Compose([ + dict(type='LoadMultiModalImageFromFile', to_float32=True), + dict(type='Resize', scale=(256, 256), keep_ratio=False), + dict(type='MultiModalNormalize'), + dict(type='PackMultiModalSegInputs', + meta_keys=('img_path', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'modal_type', 'actual_channels', 'dataset_name', + 'img_norm_cfg', 'reduce_zero_label')), + ]) + + # Dataset directory + data_prefix = cfg.test_dataloader.dataset.get( + 'data_prefix', dict(img_path='test/images')) + img_dir = osp.join(data_root, data_prefix['img_path']) + + for modal in ['sar', 'rgb', 'GF']: + hook_mgr.clear() + ch = MODAL_CHANNELS[modal] + + # Find images for this modality by filename pattern + all_files = sorted(glob.glob(osp.join(img_dir, '*'))) + modal_key = modal.lower() + modal_files = [f for f in all_files + if modal_key in osp.basename(f).lower()] + + if modal_files: + n = min(num_samples, len(modal_files)) + print(f' [{modal}] Using {n} real test images ' + f'(found {len(modal_files)} total)') + + count = 0 + for img_path in modal_files[:num_samples]: + fname = osp.basename(img_path) + + data = dict( + img_path=img_path, + modal_type=modal, + actual_channels=ch, + dataset_name=modal, + reduce_zero_label=False, + ori_filename=fname, + flip=False, + flip_direction=None, + ) + + try: + data = pipeline(data) + except Exception as e: + if count == 0: + print(f' [DEBUG] Pipeline error: {e}') + print(f' [DEBUG] img_path: {img_path}') + continue + + # Pipeline outputs: {'inputs': Tensor, 'data_samples': SegDataSample} + img = data['inputs'].float().to(device) + data_sample = data['data_samples'] + + with torch.no_grad(): + model([img], [data_sample], mode='predict') + count += 1 + + print(f' [{modal}] Successfully processed {count} images') + + else: + print(f' [WARN] No images for {modal} in {img_dir}, ' + f'using random noise') + for _ in range(num_samples): + img = torch.randn(ch, 256, 256, device=device) + ds = SegDataSample() + ds.set_metainfo(dict( + img_shape=(256, 256), ori_shape=(256, 256), + pad_shape=(256, 256), scale_factor=(1.0, 1.0), + flip=False, flip_direction=None, + modal_type=modal, actual_channels=ch, + dataset_name=modal, reduce_zero_label=False, + )) + with torch.no_grad(): + model([img], [ds], mode='predict') + + # Aggregate per (stage, block) + for record in hook_mgr.gate_records: + key = (record['stage'], record['block']) + stats[modal][key].append(record['gates']) + + return stats + + +def collect_spatial_gates(model, hook_mgr, modal='sar', shape=(256, 256)): + """Collect per-token spatial gating for one image.""" + from mmseg.structures import SegDataSample + + device = next(model.parameters()).device + ch = MODAL_CHANNELS[modal] + h, w = shape + + # We need per-token gates, not per-sample gates. + # Modify hook to capture per-token routing. + backbone = hook_mgr._get_backbone(model) + spatial_records = [] + + hooks = [] + for stage_idx, stage_dict in enumerate(backbone.stages): + for block_idx, block in enumerate(stage_dict['blocks']): + if block.use_moe: + moe_layer = block.mlp + + def make_spatial_hook(s_idx, b_idx): + def hook_fn(module, input_args, output): + x = input_args[0] # [B, N, C] + modal_types = (input_args[1] + if len(input_args) > 1 else None) + + with torch.no_grad(): + # Per-sample gating (same as normal) + x_pooled = x.mean(dim=1) + logits = module.gating(x_pooled, modal_types) + top_logits, top_indices = logits.topk( + min(module.top_k, module.num_experts), dim=-1) + top_k_gates = F.softmax(top_logits, dim=-1) + + # The dominant expert for this sample + dominant_expert = top_indices[0, 0].item() + gate_weights = torch.zeros(module.num_experts) + gate_weights.scatter_( + 0, top_indices[0], + top_k_gates[0].cpu()) + + spatial_records.append({ + 'stage': s_idx, + 'block': b_idx, + 'dominant_expert': dominant_expert, + 'gate_weights': gate_weights, + }) + return hook_fn + + h_handle = moe_layer.register_forward_hook( + make_spatial_hook(stage_idx, block_idx)) + hooks.append(h_handle) + + img = torch.randn(ch, h, w, device=device) + data_sample = SegDataSample() + data_sample.set_metainfo(dict( + img_shape=(h, w), ori_shape=(h, w), pad_shape=(h, w), + scale_factor=(1.0, 1.0), flip=False, flip_direction=None, + modal_type=modal, actual_channels=ch, + dataset_name=modal, reduce_zero_label=False, + )) + + with torch.no_grad(): + model([img], [data_sample], mode='predict') + + for h_handle in hooks: + h_handle.remove() + + return spatial_records + + +# ======================== Figure (a): Activation Heatmap ======================== + +def plot_activation_heatmap(stats, num_experts, output_dir): + """Plot expert activation probability heatmap per modality per stage.""" + + # Aggregate: for each modal and each (stage, block), compute mean gate + modals = ['sar', 'rgb', 'GF'] + + # Collect all unique (stage, block) keys sorted + all_keys = set() + for modal in modals: + all_keys.update(stats[modal].keys()) + all_keys = sorted(all_keys) + + # Group by stage + stage_keys = defaultdict(list) + for s, b in all_keys: + stage_keys[s].append((s, b)) + + stages_with_moe = sorted(stage_keys.keys()) + num_stages = len(stages_with_moe) + + fig, axes = plt.subplots( + 1, num_stages, + figsize=(3.2 * num_stages + 0.8, 2.8), + gridspec_kw={'wspace': 0.35} + ) + if num_stages == 1: + axes = [axes] + + cmap = plt.cm.YlOrRd + + for ax_idx, stage in enumerate(stages_with_moe): + keys = stage_keys[stage] + # Build matrix: [num_modals, num_experts], averaged over all blocks + heatmap = np.zeros((len(modals), num_experts)) + + for m_idx, modal in enumerate(modals): + all_gates = [] + for key in keys: + if key in stats[modal]: + gate_list = stats[modal][key] + # Each is [B, num_experts], concat and mean + stacked = torch.cat(gate_list, dim=0) # [N, E] + all_gates.append(stacked) + if all_gates: + combined = torch.cat(all_gates, dim=0) + # Activation probability = fraction of times expert is selected + activation_prob = (combined > 0).float().mean(dim=0).numpy() + heatmap[m_idx] = activation_prob + + im = axes[ax_idx].imshow( + heatmap, cmap=cmap, aspect='auto', vmin=0, vmax=1.0) + + axes[ax_idx].set_xticks(range(num_experts)) + axes[ax_idx].set_xticklabels( + [f'E{i}' for i in range(num_experts)], fontsize=8) + axes[ax_idx].set_yticks(range(len(modals))) + if ax_idx == 0: + axes[ax_idx].set_yticklabels( + [MODAL_DISPLAY[m] for m in modals], fontsize=9) + else: + axes[ax_idx].set_yticklabels([]) + axes[ax_idx].set_xlabel('Expert Index', fontsize=9) + axes[ax_idx].set_title(f'Stage {stage}', fontsize=11, fontweight='bold') + + # Annotate cells + for i in range(len(modals)): + for j in range(num_experts): + val = heatmap[i, j] + color = 'white' if val > 0.5 else 'black' + axes[ax_idx].text( + j, i, f'{val:.2f}', + ha='center', va='center', fontsize=7, color=color) + + # Colorbar + cbar = fig.colorbar( + im, ax=axes, shrink=0.85, aspect=25, pad=0.03) + cbar.set_label('Activation Probability', fontsize=9) + + fig.savefig( + osp.join(output_dir, 'fig4a_expert_activation_heatmap.pdf'), + format='pdf') + fig.savefig( + osp.join(output_dir, 'fig4a_expert_activation_heatmap.png'), + format='png') + plt.close(fig) + print(f'[Saved] fig4a_expert_activation_heatmap.pdf/png') + + +# ======================== Figure (b): Modal Bias ======================== + +def plot_modal_bias(model, output_dir): + """Visualize the learned modal_bias parameters from all MoE layers.""" + + backbone = (model.module.backbone + if hasattr(model, 'module') else model.backbone) + + bias_per_stage = defaultdict(list) + modal_names = None + + for stage_idx, stage_dict in enumerate(backbone.stages): + for block_idx, block in enumerate(stage_dict['blocks']): + if block.use_moe and hasattr(block.mlp, 'gating'): + gate = block.mlp.gating + if (hasattr(gate, 'modal_bias') + and gate.modal_bias is not None): + bias = gate.modal_bias.detach().cpu().numpy() + bias_per_stage[stage_idx].append(bias) + if modal_names is None: + modal_names = gate.modal_name_to_idx + + if not bias_per_stage: + print('[WARN] No modal_bias parameters found.') + return + + # Sort modal names by index + sorted_modals = sorted(modal_names.items(), key=lambda x: x[1]) + modal_order = [m[0] for m in sorted_modals] + + stages = sorted(bias_per_stage.keys()) + num_stages = len(stages) + num_experts = bias_per_stage[stages[0]][0].shape[1] + + fig, axes = plt.subplots( + 1, num_stages, + figsize=(3.2 * num_stages + 0.8, 2.8), + gridspec_kw={'wspace': 0.35} + ) + if num_stages == 1: + axes = [axes] + + # Diverging colormap centered at 0 + cmap = plt.cm.RdBu_r + all_biases = np.concatenate( + [np.stack(v).mean(axis=0) for v in bias_per_stage.values()]) + vmax = max(abs(all_biases.min()), abs(all_biases.max())) + if vmax < 1e-6: + vmax = 1.0 + + for ax_idx, stage in enumerate(stages): + # Average across all blocks in this stage + avg_bias = np.stack(bias_per_stage[stage]).mean(axis=0) + # avg_bias shape: [num_modals, num_experts] + + im = axes[ax_idx].imshow( + avg_bias, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax) + + axes[ax_idx].set_xticks(range(num_experts)) + axes[ax_idx].set_xticklabels( + [f'E{i}' for i in range(num_experts)], fontsize=8) + axes[ax_idx].set_yticks(range(len(modal_order))) + if ax_idx == 0: + axes[ax_idx].set_yticklabels( + [MODAL_DISPLAY.get(m, m) for m in modal_order], fontsize=9) + else: + axes[ax_idx].set_yticklabels([]) + axes[ax_idx].set_xlabel('Expert Index', fontsize=9) + axes[ax_idx].set_title(f'Stage {stage}', fontsize=11, fontweight='bold') + + # Annotate + for i in range(avg_bias.shape[0]): + for j in range(avg_bias.shape[1]): + val = avg_bias[i, j] + color = 'white' if abs(val) > vmax * 0.6 else 'black' + axes[ax_idx].text( + j, i, f'{val:.2f}', + ha='center', va='center', fontsize=7, color=color) + + cbar = fig.colorbar(im, ax=axes, shrink=0.85, aspect=25, pad=0.03) + cbar.set_label('Bias Value', fontsize=9) + + fig.savefig( + osp.join(output_dir, 'fig4b_modal_bias_matrix.pdf'), format='pdf') + fig.savefig( + osp.join(output_dir, 'fig4b_modal_bias_matrix.png'), format='png') + plt.close(fig) + print(f'[Saved] fig4b_modal_bias_matrix.pdf/png') + + +# ======================== Figure (c): Expert Assignment Bar ======================== + +def plot_expert_assignment_comparison(stats, num_experts, output_dir): + """Grouped bar chart: expert selection frequency per modality. + + Aggregated across all MoE layers, showing which experts are preferred + by each modality. More intuitive than heatmap for presentations. + """ + modals = ['sar', 'rgb', 'GF'] + + # Aggregate across all stages/blocks + freq = {} + for modal in modals: + all_gates = [] + for key, gate_list in stats[modal].items(): + stacked = torch.cat(gate_list, dim=0) + all_gates.append(stacked) + if all_gates: + combined = torch.cat(all_gates, dim=0) + # Selection frequency + freq[modal] = (combined > 0).float().mean(dim=0).numpy() + else: + freq[modal] = np.zeros(num_experts) + + x = np.arange(num_experts) + width = 0.25 + fig, ax = plt.subplots(figsize=(5.5, 3.2)) + + for i, modal in enumerate(modals): + ax.bar( + x + (i - 1) * width, freq[modal], width, + label=MODAL_DISPLAY[modal], + color=MODAL_COLORS[modal], edgecolor='white', linewidth=0.5) + + ax.set_xlabel('Expert Index') + ax.set_ylabel('Selection Frequency') + ax.set_xticks(x) + ax.set_xticklabels([f'E{i}' for i in range(num_experts)]) + ax.legend(frameon=True, edgecolor='gray', fancybox=False) + ax.set_ylim(0, 1.0) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + fig.savefig( + osp.join(output_dir, 'fig4c_expert_selection_frequency.pdf'), + format='pdf') + fig.savefig( + osp.join(output_dir, 'fig4c_expert_selection_frequency.png'), + format='png') + plt.close(fig) + print(f'[Saved] fig4c_expert_selection_frequency.pdf/png') + + +# ======================== Figure (d): Gate Weight Distribution ======================== + +def plot_gate_weight_distribution(stats, num_experts, output_dir): + """Violin/box plot of gate weight distribution per modality per expert.""" + modals = ['sar', 'rgb', 'GF'] + + fig, axes = plt.subplots( + 1, len(modals), figsize=(4.0 * len(modals), 3.0), + sharey=True, gridspec_kw={'wspace': 0.1}) + + for m_idx, modal in enumerate(modals): + all_gates = [] + for key, gate_list in stats[modal].items(): + stacked = torch.cat(gate_list, dim=0) + all_gates.append(stacked) + + if not all_gates: + continue + + combined = torch.cat(all_gates, dim=0).numpy() # [N, E] + + # Only keep nonzero weights for box plot + data_per_expert = [] + for e in range(num_experts): + weights = combined[:, e] + nonzero = weights[weights > 0] + data_per_expert.append(nonzero if len(nonzero) > 0 + else np.array([0.0])) + + bp = axes[m_idx].boxplot( + data_per_expert, patch_artist=True, widths=0.6, + showfliers=False, medianprops=dict(color='black', linewidth=1.2)) + + color = MODAL_COLORS[modal] + for patch in bp['boxes']: + patch.set_facecolor(color) + patch.set_alpha(0.7) + + axes[m_idx].set_xlabel('Expert Index') + axes[m_idx].set_xticklabels( + [f'E{i}' for i in range(num_experts)], fontsize=8) + axes[m_idx].set_title(MODAL_DISPLAY[modal], fontsize=11, + fontweight='bold', color=color) + axes[m_idx].spines['top'].set_visible(False) + axes[m_idx].spines['right'].set_visible(False) + + axes[0].set_ylabel('Gate Weight (non-zero)') + + fig.savefig( + osp.join(output_dir, 'fig4d_gate_weight_distribution.pdf'), + format='pdf') + fig.savefig( + osp.join(output_dir, 'fig4d_gate_weight_distribution.png'), + format='png') + plt.close(fig) + print(f'[Saved] fig4d_gate_weight_distribution.pdf/png') + + +# ======================== Main ======================== + +def main(): + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + + cfg = Config.fromfile(args.config) + init_default_scope(cfg.get('default_scope', 'mmseg')) + + print('=' * 60) + print('Expert Routing Visualization') + print(f'Config: {args.config}') + print(f'Checkpoint: {args.checkpoint}') + print(f'Output: {args.output_dir}') + print(f'Samples: {args.num_samples} per modality') + print('=' * 60) + + # Use 'whole' mode to avoid slide_inference complexity for routing analysis + cfg.model.test_cfg = dict(mode='whole') + + # Build model + print('\nBuilding model...') + model = build_model(cfg, args.checkpoint) + + # Read num_experts from config + num_experts = cfg.model.backbone.get('num_experts', 8) + + # ---- Figure (b): Modal Bias (no inference needed) ---- + print('\n--- Figure (b): Modal Bias Matrix ---') + plot_modal_bias(model, args.output_dir) + + # ---- Collect routing stats via hooks ---- + print('\n--- Collecting routing statistics ---') + hook_mgr = GateHookManager(model) + stats = collect_routing_stats( + model, cfg, hook_mgr, args.data_root, args.num_samples) + hook_mgr.remove_hooks() + + # ---- Figure (a): Activation Heatmap ---- + print('\n--- Figure (a): Expert Activation Heatmap ---') + plot_activation_heatmap(stats, num_experts, args.output_dir) + + # ---- Figure (c): Expert Selection Frequency Bar ---- + print('\n--- Figure (c): Expert Selection Frequency ---') + plot_expert_assignment_comparison(stats, num_experts, args.output_dir) + + # ---- Figure (d): Gate Weight Distribution ---- + print('\n--- Figure (d): Gate Weight Distribution ---') + plot_gate_weight_distribution(stats, num_experts, args.output_dir) + + print('\n' + '=' * 60) + print(f'All figures saved to: {args.output_dir}/') + print(' fig4a_expert_activation_heatmap.pdf/png') + print(' fig4b_modal_bias_matrix.pdf/png') + print(' fig4c_expert_selection_frequency.pdf/png') + print(' fig4d_gate_weight_distribution.pdf/png') + print('=' * 60) + + +if __name__ == '__main__': + main() diff --git a/tools/compute_sen1floods11_stats.py b/tools/compute_sen1floods11_stats.py new file mode 100644 index 0000000000..d007490216 --- /dev/null +++ b/tools/compute_sen1floods11_stats.py @@ -0,0 +1,192 @@ +"""Compute channel-wise mean / std for Sen1Floods11 S1Hand or S2Hand. + +The tool walks ``//*.tif`` (default: S1Hand or S2Hand) +and ignores pixels that are flagged as nodata in the matching +``LabelHand`` file (label value == -1). Pass the resulting mean / std +arrays into ``MultiModalNormalize.NORM_CONFIGS`` under the ``s1`` / ``s2`` +entries in ``mmseg/datasets/transforms/multimodal_pipelines.py``. + +Usage:: + + python tools/compute_sen1floods11_stats.py \ + --data-root data/Sen1Floods11 --modality s1 + + python tools/compute_sen1floods11_stats.py \ + --data-root data/Sen1Floods11 --modality s2 --split splits/train.txt +""" +import argparse +import os +import os.path as osp +from typing import List + +import numpy as np + +try: + import tifffile +except ImportError as e: + raise SystemExit( + 'tifffile is required. Install with `pip install tifffile`.') from e + + +MODALITIES = { + 's1': { + 'subdir': 'S1Hand', + 'suffix': '_S1Hand.tif', + 'channels': 2, + }, + 's2': { + 'subdir': 'S2Hand', + 'suffix': '_S2Hand.tif', + 'channels': 13, + }, +} + +LABEL_SUBDIR = 'LabelHand' +LABEL_SUFFIX = '_LabelHand.tif' + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument('--data-root', required=True, + help='Sen1Floods11 root containing S1Hand/S2Hand/LabelHand') + p.add_argument('--modality', choices=list(MODALITIES), required=True) + p.add_argument('--split', default=None, + help='Optional txt file listing sample base-names ' + '(one per line). Paths are resolved against ' + 'data-root.') + p.add_argument('--max-samples', type=int, default=None, + help='Optional cap on the number of images to scan.') + p.add_argument('--mask-nodata', action='store_true', default=True, + help='Exclude pixels where LabelHand == -1 ' + '(default: True).') + p.add_argument('--no-mask-nodata', dest='mask_nodata', + action='store_false') + return p.parse_args() + + +def load_image_list(data_root: str, modality_cfg: dict, + split: str) -> List[str]: + img_dir = osp.join(data_root, modality_cfg['subdir']) + suffix = modality_cfg['suffix'] + + if split: + split_path = split if osp.isabs(split) else osp.join(data_root, split) + if not osp.isfile(split_path): + raise FileNotFoundError(f'split file not found: {split_path}') + bases = [] + with open(split_path, 'r') as f: + for line in f: + base = line.strip() + if not base: + continue + for s in (suffix, LABEL_SUFFIX): + if base.endswith(s): + base = base[:-len(s)] + break + bases.append(base) + img_names = [b + suffix for b in bases] + else: + img_names = sorted(f for f in os.listdir(img_dir) + if f.endswith(suffix)) + + return [osp.join(img_dir, n) for n in img_names] + + +def label_path_for(img_path: str, modality_cfg: dict, data_root: str) -> str: + base = osp.basename(img_path)[:-len(modality_cfg['suffix'])] + return osp.join(data_root, LABEL_SUBDIR, base + LABEL_SUFFIX) + + +def _hwc(img: np.ndarray) -> np.ndarray: + """Ensure (H, W, C).""" + if img.ndim == 2: + return img[:, :, None] + if img.ndim == 3 and img.shape[0] < img.shape[-1]: + return np.transpose(img, (1, 2, 0)) + return img + + +def compute_stats(img_paths, data_root, modality_cfg, mask_nodata): + num_channels = modality_cfg['channels'] + + # Running sums for per-channel mean / std using Welford-style + # accumulators (numerically stable and streaming). + total_count = np.zeros(num_channels, dtype=np.float64) + total_sum = np.zeros(num_channels, dtype=np.float64) + total_sqsum = np.zeros(num_channels, dtype=np.float64) + + for i, img_path in enumerate(img_paths): + img = tifffile.imread(img_path) + img = _hwc(img).astype(np.float64) + if img.shape[-1] != num_channels: + print(f'[warn] skip {img_path}: got {img.shape[-1]} ch ' + f'(expected {num_channels})') + continue + + valid_mask = None + if mask_nodata: + lbl_path = label_path_for(img_path, modality_cfg, data_root) + if osp.isfile(lbl_path): + lbl = tifffile.imread(lbl_path) + lbl = np.squeeze(lbl) + valid_mask = (lbl != -1) + # also drop NaN / Inf pixels that crop up in SAR dB + finite_mask = np.all(np.isfinite(img), axis=-1) + if valid_mask is None: + valid_mask = finite_mask + else: + valid_mask &= finite_mask + + if not valid_mask.any(): + continue + + valid = img[valid_mask] # (N, C) + total_count += valid.shape[0] + total_sum += valid.sum(axis=0) + total_sqsum += (valid ** 2).sum(axis=0) + + if (i + 1) % 50 == 0: + print(f' processed {i + 1}/{len(img_paths)} images') + + # Avoid divide-by-zero + total_count = np.where(total_count == 0, 1, total_count) + mean = total_sum / total_count + var = total_sqsum / total_count - mean ** 2 + var = np.clip(var, a_min=0.0, a_max=None) + std = np.sqrt(var) + return mean, std, total_count + + +def main(): + args = parse_args() + modality_cfg = MODALITIES[args.modality] + + img_paths = load_image_list(args.data_root, modality_cfg, args.split) + if args.max_samples: + img_paths = img_paths[:args.max_samples] + + print(f'Found {len(img_paths)} images for modality={args.modality}') + if not img_paths: + raise SystemExit('No images matched - check --data-root / --split') + + mean, std, count = compute_stats( + img_paths, args.data_root, modality_cfg, args.mask_nodata) + + print() + print('=' * 60) + print(f'Sen1Floods11 {args.modality} statistics ' + f'(valid pixels: {int(count[0])})') + print('=' * 60) + print(f'mean = {mean.tolist()}') + print(f'std = {std.tolist()}') + print() + print('Paste into NORM_CONFIGS in ' + 'mmseg/datasets/transforms/multimodal_pipelines.py:') + print(f" '{args.modality}': {{") + print(f" 'mean': {mean.tolist()},") + print(f" 'std': {std.tolist()},") + print(' },') + + +if __name__ == '__main__': + main() diff --git a/tools/predict_large_tif.py b/tools/predict_large_tif.py new file mode 100644 index 0000000000..2ecc2f5125 --- /dev/null +++ b/tools/predict_large_tif.py @@ -0,0 +1,408 @@ +""" +Large TIF Inference for Multi-Modal Segmentor. + +Reads a large GeoTIFF (e.g. 20803x36986), tiles it into patches, +runs inference with the multi-modal model, and stitches results +back into a full-size GeoTIFF. Flood pixels are colored red (255,0,0), +non-flood pixels are black (0,0,0). + +Usage: + python tools/predict_large_tif.py \ + configs/floodnet/finetune_single_modal.py \ + work_dirs/generalization/LY-train-station/best_mIoU_epoch_30.pth \ + --input data/luoyuan/result.tif \ + --output data/luoyuan/prediction.tif \ + --tile-size 512 \ + --overlap 64 \ + --modal rgb \ + --bands 0 1 2 \ + --batch-size 4 +""" + +import argparse +import math +import os +import time + +import numpy as np +import torch +from mmengine import Config +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + +try: + from osgeo import gdal + HAS_GDAL = True +except ImportError: + HAS_GDAL = False + +try: + import rasterio + HAS_RASTERIO = True +except ImportError: + HAS_RASTERIO = False + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Large TIF inference with tile stitching') + parser.add_argument('config', help='config file path') + parser.add_argument('checkpoint', help='checkpoint file path') + parser.add_argument('--input', required=True, help='input TIF file') + parser.add_argument('--output', default=None, + help='output TIF file (default: input_pred.tif)') + parser.add_argument('--tile-size', type=int, default=512, + help='tile size for inference (default: 512)') + parser.add_argument('--overlap', type=int, default=64, + help='overlap between tiles (default: 64)') + parser.add_argument('--batch-size', type=int, default=4, + help='batch size for inference (default: 4)') + parser.add_argument('--modal', default='rgb', + help='modality type for model (default: rgb)') + parser.add_argument('--bands', type=int, nargs='+', default=[0, 1, 2], + help='band indices to read (0-indexed, default: 0 1 2)') + parser.add_argument('--flood-class', type=int, default=1, + help='class index for flood (default: 1)') + parser.add_argument('--device', default='cuda:0', + help='device for inference (default: cuda:0)') + return parser.parse_args() + + +def build_model(cfg, checkpoint_path, device): + """Build model and load checkpoint.""" + cfg.model.train_cfg = None + model = MODELS.build(cfg.model) + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + model = revert_sync_batchnorm(model) + model.eval() + return model + + +def read_tif_info(tif_path): + """Read TIF metadata.""" + if HAS_RASTERIO: + with rasterio.open(tif_path) as src: + return { + 'height': src.height, + 'width': src.width, + 'bands': src.count, + 'dtype': src.dtypes[0], + 'crs': src.crs, + 'transform': src.transform, + 'profile': src.profile.copy(), + } + elif HAS_GDAL: + ds = gdal.Open(tif_path, gdal.GA_ReadOnly) + return { + 'height': ds.RasterYSize, + 'width': ds.RasterXSize, + 'bands': ds.RasterCount, + 'geo_transform': ds.GetGeoTransform(), + 'projection': ds.GetProjection(), + } + else: + raise ImportError('Neither rasterio nor GDAL available. ' + 'Install with: pip install rasterio') + + +def read_tile_rasterio(tif_path, x, y, w, h, band_indices): + """Read a tile from TIF using rasterio.""" + with rasterio.open(tif_path) as src: + window = rasterio.windows.Window(x, y, w, h) + # rasterio bands are 1-indexed + bands = [b + 1 for b in band_indices] + data = src.read(bands, window=window) # (C, H, W) + return data.astype(np.float32) + + +def read_tile_gdal(tif_path, x, y, w, h, band_indices): + """Read a tile from TIF using GDAL.""" + ds = gdal.Open(tif_path, gdal.GA_ReadOnly) + data = [] + for bi in band_indices: + band = ds.GetRasterBand(bi + 1) # GDAL is 1-indexed + arr = band.ReadAsArray(x, y, w, h) + data.append(arr.astype(np.float32)) + return np.stack(data, axis=0) # (C, H, W) + + +def generate_tiles(img_h, img_w, tile_size, overlap): + """Generate tile coordinates with overlap.""" + stride = tile_size - overlap + tiles = [] + + for y in range(0, img_h, stride): + for x in range(0, img_w, stride): + # Clamp to image boundary + x1 = min(x, img_w - tile_size) + y1 = min(y, img_h - tile_size) + x1 = max(x1, 0) + y1 = max(y1, 0) + w = min(tile_size, img_w - x1) + h = min(tile_size, img_h - y1) + tiles.append((x1, y1, w, h)) + + # Deduplicate (edge tiles may repeat) + seen = set() + unique_tiles = [] + for t in tiles: + if t not in seen: + seen.add(t) + unique_tiles.append(t) + + return unique_tiles + + +def normalize_tile(tile_data, modal='rgb'): + """Normalize tile using the same mean/std as training pipeline. + + Must match MultiModalNormalize in multimodal_pipelines.py exactly. + tile_data: (C, H, W) float32, raw pixel values (e.g. 0-255 for RGB). + """ + NORM_CONFIGS = { + 'rgb': { + 'mean': [123.675, 116.28, 103.53], + 'std': [58.395, 57.12, 57.375], + }, + 'sar': { + 'mean': [0.23651549, 0.31761484, 0.18514981, 0.26901252, + -14.57879175, -8.6098158, -14.2907338, -8.33534564], + 'std': [0.16280619, 0.20849304, 0.14008107, 0.19767644, + 4.07141682, 3.94773216, 4.21006244, 4.05494136], + }, + 'GF': { + 'mean': [432.02181, 315.92948, 246.468659, + 310.61462, 360.267789], + 'std': [97.73313111900238, 85.78646917160748, + 95.78015824658593, + 124.84677067613467, 251.73965882246978], + } + } + + channels = tile_data.shape[0] + if modal in NORM_CONFIGS: + mean = np.array(NORM_CONFIGS[modal]['mean'][:channels], + dtype=np.float32) + std = np.array(NORM_CONFIGS[modal]['std'][:channels], + dtype=np.float32) + else: + mean = np.array([128.0] * channels, dtype=np.float32) + std = np.array([50.0] * channels, dtype=np.float32) + + # tile_data shape: (C, H, W), normalize per channel + for c in range(channels): + tile_data[c] = (tile_data[c] - mean[c]) / std[c] + + return tile_data + + +def predict_batch(model, batch_imgs, modal, device): + """Run inference on a batch of tile images. + + Args: + model: segmentor model + batch_imgs: list of (C, H, W) numpy arrays + modal: modality string + device: torch device + + Returns: + list of (H, W) numpy prediction masks + """ + imgs = [] + data_samples = [] + + for img_np in batch_imgs: + img_tensor = torch.from_numpy(img_np).float().to(device) + imgs.append(img_tensor) + + ds = SegDataSample() + h, w = img_np.shape[1], img_np.shape[2] + ds.set_metainfo(dict( + img_shape=(h, w), + ori_shape=(h, w), + pad_shape=(h, w), + scale_factor=(1.0, 1.0), + flip=False, + flip_direction=None, + modal_type=modal, + actual_channels=img_np.shape[0], + dataset_name=modal, + reduce_zero_label=False, + )) + data_samples.append(ds) + + with torch.no_grad(): + results = model(imgs, data_samples, mode='predict') + + preds = [] + for r in results: + pred = r.pred_sem_seg.data.cpu().numpy()[0] # (H, W) + preds.append(pred) + + return preds + + +def main(): + args = parse_args() + + if not HAS_RASTERIO and not HAS_GDAL: + raise ImportError('Install rasterio or GDAL: pip install rasterio') + + read_tile = read_tile_rasterio if HAS_RASTERIO else read_tile_gdal + + # Output path + if args.output is None: + base, ext = os.path.splitext(args.input) + args.output = f'{base}_pred{ext}' + + # Load config and model + cfg = Config.fromfile(args.config) + init_default_scope(cfg.get('default_scope', 'mmseg')) + + # Override test_cfg to use 'whole' mode (we handle tiling ourselves) + cfg.model.test_cfg = dict(mode='whole') + + print('=' * 60) + print('Large TIF Inference') + print(f'Input: {args.input}') + print(f'Output: {args.output}') + print(f'Tile size: {args.tile_size}') + print(f'Overlap: {args.overlap}') + print(f'Modality: {args.modal}') + print(f'Bands: {args.bands}') + print(f'Batch size: {args.batch_size}') + print('=' * 60) + + # Read TIF info + info = read_tif_info(args.input) + img_h, img_w = info['height'], info['width'] + print(f'\nImage size: {img_w} x {img_h} (W x H)') + print(f'Bands: {info["bands"]}') + + # Build model + print('\nBuilding model...') + model = build_model(cfg, args.checkpoint, args.device) + print('Model loaded.') + + # Generate tiles + tiles = generate_tiles(img_h, img_w, args.tile_size, args.overlap) + num_tiles = len(tiles) + num_batches = math.ceil(num_tiles / args.batch_size) + print(f'\nTotal tiles: {num_tiles}') + print(f'Batches: {num_batches}') + + # Verify normalization + print(f'\nNormalization check (modal={args.modal}):') + test_tile = read_tile(args.input, 0, 0, + min(args.tile_size, img_w), + min(args.tile_size, img_h), + args.bands) + print(f' Raw pixel range: [{test_tile.min():.2f}, {test_tile.max():.2f}]') + test_normed = normalize_tile(test_tile.copy(), modal=args.modal) + print(f' After normalize: [{test_normed.min():.2f}, {test_normed.max():.2f}]') + del test_tile, test_normed + + # Allocate output: prediction mask + count for overlap voting + pred_sum = np.zeros((img_h, img_w), dtype=np.float32) + count_map = np.zeros((img_h, img_w), dtype=np.float32) + + # Inference + print('\nStarting inference...') + t_start = time.time() + + for batch_idx in range(num_batches): + start_i = batch_idx * args.batch_size + end_i = min(start_i + args.batch_size, num_tiles) + batch_tiles = tiles[start_i:end_i] + + # Read tiles + batch_imgs = [] + for (x, y, w, h) in batch_tiles: + tile_data = read_tile(args.input, x, y, w, h, args.bands) + # Handle edge tiles smaller than tile_size: pad + if h < args.tile_size or w < args.tile_size: + padded = np.zeros( + (len(args.bands), args.tile_size, args.tile_size), + dtype=np.float32) + padded[:, :h, :w] = tile_data + tile_data = padded + tile_data = normalize_tile(tile_data, modal=args.modal) + batch_imgs.append(tile_data) + + # Predict + preds = predict_batch(model, batch_imgs, args.modal, args.device) + + # Stitch predictions (vote-based: accumulate class labels) + for (x, y, w, h), pred in zip(batch_tiles, preds): + pred_sum[y:y+h, x:x+w] += pred[:h, :w].astype(np.float32) + count_map[y:y+h, x:x+w] += 1.0 + + # Progress + done = end_i + elapsed = time.time() - t_start + eta = elapsed / done * (num_tiles - done) if done > 0 else 0 + print(f'\r [{done}/{num_tiles}] ' + f'{done/num_tiles*100:.1f}% | ' + f'Elapsed: {elapsed:.0f}s | ETA: {eta:.0f}s', + end='', flush=True) + + print() + + # Average overlapping predictions + count_map = np.maximum(count_map, 1.0) + pred_avg = pred_sum / count_map + pred_mask = (pred_avg >= 0.5).astype(np.uint8) # threshold + + # Count statistics + flood_pixels = (pred_mask == 1).sum() + total_pixels = pred_mask.size + print(f'\nFlood pixels: {flood_pixels:,} ' + f'({flood_pixels/total_pixels*100:.2f}%)') + print(f'Non-flood pixels: {total_pixels - flood_pixels:,}') + + # Write output TIF (3-band RGB: flood=red, non-flood=black) + print(f'\nWriting output: {args.output}') + os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True) + + if HAS_RASTERIO: + profile = info['profile'].copy() + profile.update( + count=3, + dtype='uint8', + compress='lzw', + ) + with rasterio.open(args.output, 'w', **profile) as dst: + # Red channel: 255 for flood + red = (pred_mask * 255).astype(np.uint8) + # Green and Blue: 0 + green = np.zeros_like(red) + blue = np.zeros_like(red) + dst.write(red, 1) + dst.write(green, 2) + dst.write(blue, 3) + elif HAS_GDAL: + driver = gdal.GetDriverByName('GTiff') + out_ds = driver.Create( + args.output, img_w, img_h, 3, gdal.GDT_Byte, + options=['COMPRESS=LZW']) + out_ds.SetGeoTransform(info['geo_transform']) + out_ds.SetProjection(info['projection']) + red = (pred_mask * 255).astype(np.uint8) + out_ds.GetRasterBand(1).WriteArray(red) + out_ds.GetRasterBand(2).WriteArray(np.zeros_like(red)) + out_ds.GetRasterBand(3).WriteArray(np.zeros_like(red)) + out_ds.FlushCache() + out_ds = None + + elapsed_total = time.time() - t_start + print(f'\nDone! Total time: {elapsed_total:.1f}s') + print(f'Output: {args.output}') + + +if __name__ == '__main__': + main() diff --git a/tools/remap_pred_colors.py b/tools/remap_pred_colors.py new file mode 100644 index 0000000000..6f5e6a30c8 --- /dev/null +++ b/tools/remap_pred_colors.py @@ -0,0 +1,97 @@ +""" +Remap prediction image colors. + +Converts the two-color palette output from SegVisualizationHook: + - Black [ 0, 0, 0] (Background / Nodata) -> #7c7c7c [124, 124, 124] + - Red [255, 0, 0] (Flood) -> #000bc5 [ 0, 11, 197] + +Nodata pixels are treated the same as Background (both render as black +after SegVisualizationHook._mask_nodata, then both remap to #7c7c7c). + +Usage: + # In-place: remap all PNGs under a directory + python tools/remap_pred_colors.py --src vis_pred/vis_data/vis_image/ + + # Save to a new directory + python tools/remap_pred_colors.py \ + --src vis_pred/vis_data/vis_image/ \ + --dst vis_pred/vis_data/vis_image_remapped/ + + # Process a single file + python tools/remap_pred_colors.py --src path/to/image.png +""" + +import argparse +import sys +from pathlib import Path + +import numpy as np +from PIL import Image + +COLOR_MAP = { + (0, 0, 0): (124, 124, 124), # Background / Nodata -> #7c7c7c + (255, 0, 0): (0, 11, 197), # Flood -> #000bc5 +} + + +def remap_image(img_array: np.ndarray) -> np.ndarray: + out = img_array.copy() + for src_rgb, dst_rgb in COLOR_MAP.items(): + mask = ( + (img_array[:, :, 0] == src_rgb[0]) & + (img_array[:, :, 1] == src_rgb[1]) & + (img_array[:, :, 2] == src_rgb[2]) + ) + out[mask] = dst_rgb + return out + + +def process_file(src_path: Path, dst_path: Path) -> None: + img = Image.open(src_path).convert('RGB') + arr = np.array(img, dtype=np.uint8) + arr_out = remap_image(arr) + dst_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(arr_out).save(dst_path) + + +def collect_png(root: Path): + return sorted(root.rglob('*.png')) + + +def main(): + parser = argparse.ArgumentParser( + description='Remap prediction image colors ' + '(black→#7c7c7c, red→#000bc5)') + parser.add_argument('--src', required=True, + help='Source directory (or single PNG).') + parser.add_argument('--dst', default=None, + help='Destination directory. Default: in-place.') + args = parser.parse_args() + + src = Path(args.src) + if not src.exists(): + print(f'ERROR: --src "{src}" does not exist.', file=sys.stderr) + sys.exit(1) + + if src.is_file(): + pairs = [(src, Path(args.dst) if args.dst else src)] + else: + files = collect_png(src) + if not files: + print(f'No PNG files found under "{src}".', file=sys.stderr) + sys.exit(0) + if args.dst is None: + pairs = [(f, f) for f in files] + else: + dst_root = Path(args.dst) + pairs = [(f, dst_root / f.relative_to(src)) for f in files] + + for i, (src_f, dst_f) in enumerate(pairs, 1): + process_file(src_f, dst_f) + print(f'[{i}/{len(pairs)}] {src_f} -> {dst_f}') + + print(f'\nDone. {len(pairs)} image(s) remapped.') + + +if __name__ == '__main__': + main() diff --git a/tools/setup_sen1floods11.py b/tools/setup_sen1floods11.py new file mode 100644 index 0000000000..129e816f0d --- /dev/null +++ b/tools/setup_sen1floods11.py @@ -0,0 +1,337 @@ +""" +One-shot Sen1Floods11 setup for fine-tuning. + +This script bridges the gap between a freshly-downloaded Sen1Floods11 +folder and a training run of +``configs/floodnet/finetune_sen1floods11_{s1,s2}.py``. + +Expected on-disk layout:: + + data/Sen1Floods11/ + S1Hand/_S1Hand.tif # 2-band SAR (VV, VH in dB) + S2Hand/_S2Hand.tif # 13-band Sentinel-2 MSI + LabelHand/_LabelHand.tif # 1-band label, -1=nodata, 0/1 classes + +What the script does (all four steps are independent and can be +re-run): + + 1. Scan ``/LabelHand`` for every available base-name. + 2. Write deterministic train/val/test splits to + ``/splits/{train,val,test}.txt`` based on an MD5 hash + of the base-name (so the same tile always lands in the same + split on re-run). + 3. Compute per-channel mean / std for S1Hand (2 ch) and S2Hand + (13 ch) *using only training-split tiles*, with NaN / Inf pixels + and label == -1 pixels masked out. This is important because + the shipped NORM_CONFIGS values in + ``mmseg/datasets/transforms/multimodal_pipelines.py`` were + computed on a different split and may not match your copy of + the data. + 4. Print the NORM_CONFIGS snippet so you can paste it into + ``multimodal_pipelines.py``. + +Usage:: + + # full setup (splits + stats for both modalities) + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 + + # only regenerate splits (e.g. after adding more tiles) + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\ + --skip-stats + + # only recompute stats (splits already exist) + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\ + --skip-splits + + # only one modality + python tools/setup_sen1floods11.py --data-root data/Sen1Floods11 \\ + --modalities s1 + +The split ratio defaults to 70 / 15 / 15 train / val / test. Pass +``--train-ratio`` / ``--val-ratio`` to change it. +""" +import argparse +import hashlib +import os +import os.path as osp +from typing import List, Tuple + +import numpy as np + +try: + import tifffile +except ImportError as e: + raise SystemExit( + 'tifffile is required. Install with `pip install tifffile`.') from e + + +# --------------------------------------------------------------------------- +# Layout constants - keep in sync with Sen1Floods11Dataset.MODAL_CONFIG and +# multimodal_pipelines.MultiModalNormalize.NORM_CONFIGS. +# --------------------------------------------------------------------------- +LABEL_SUBDIR = 'LabelHand' +LABEL_SUFFIX = '_LabelHand.tif' + +MODALITIES = { + 's1': { + 'subdir': 'S1Hand', + 'suffix': '_S1Hand.tif', + 'channels': 2, + }, + 's2': { + 'subdir': 'S2Hand', + 'suffix': '_S2Hand.tif', + 'channels': 13, + }, +} + + +def parse_args(): + p = argparse.ArgumentParser( + description='Sen1Floods11 setup: splits + normalization stats.') + p.add_argument('--data-root', required=True, + help='Sen1Floods11 root with S1Hand/S2Hand/LabelHand.') + p.add_argument('--train-ratio', type=float, default=0.70, + help='Train fraction (default: 0.70).') + p.add_argument('--val-ratio', type=float, default=0.15, + help='Val fraction (default: 0.15). ' + 'Test fraction = 1 - train - val.') + p.add_argument('--modalities', nargs='+', default=['s1', 's2'], + choices=list(MODALITIES), + help='Which modalities to compute stats for.') + p.add_argument('--skip-splits', action='store_true', + help='Reuse existing splits/{train,val,test}.txt.') + p.add_argument('--skip-stats', action='store_true', + help='Only generate splits; do not compute stats.') + p.add_argument('--seed', type=int, default=42, + help='Salt prepended to base-names when hashing. ' + 'Change to get a different split assignment.') + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Step 1: scan +# --------------------------------------------------------------------------- +def scan_basenames(data_root: str) -> List[str]: + label_dir = osp.join(data_root, LABEL_SUBDIR) + if not osp.isdir(label_dir): + raise SystemExit(f'LabelHand dir not found: {label_dir}') + + bases = [] + for fn in sorted(os.listdir(label_dir)): + if fn.endswith(LABEL_SUFFIX): + bases.append(fn[:-len(LABEL_SUFFIX)]) + + if not bases: + raise SystemExit( + f'No *{LABEL_SUFFIX} files found in {label_dir}. ' + 'Check --data-root.') + return bases + + +# --------------------------------------------------------------------------- +# Step 2: splits +# --------------------------------------------------------------------------- +def assign_split(base: str, seed: int, + train_ratio: float, val_ratio: float) -> str: + """Hash-based deterministic split assignment. + + Using a hash (instead of shuffling + slicing) means the assignment + is stable if the set of base-names grows - a tile that was in + ``train`` before stays in ``train``. + """ + h = hashlib.md5(f'{seed}:{base}'.encode('utf-8')).hexdigest() + # first 8 hex chars -> 32-bit uint in [0, 2^32), normalized to [0, 1) + v = int(h[:8], 16) / 0x100000000 + if v < train_ratio: + return 'train' + if v < train_ratio + val_ratio: + return 'val' + return 'test' + + +def write_splits(data_root: str, bases: List[str], + train_ratio: float, val_ratio: float, seed: int + ) -> Tuple[List[str], List[str], List[str]]: + if train_ratio <= 0 or val_ratio < 0 or train_ratio + val_ratio >= 1.0: + raise SystemExit( + 'Invalid split ratios: need train_ratio > 0, val_ratio >= 0, ' + 'train + val < 1.') + + splits = {'train': [], 'val': [], 'test': []} + for base in bases: + splits[assign_split(base, seed, train_ratio, val_ratio)].append(base) + + splits_dir = osp.join(data_root, 'splits') + os.makedirs(splits_dir, exist_ok=True) + for name in ('train', 'val', 'test'): + items = sorted(splits[name]) + out_path = osp.join(splits_dir, f'{name}.txt') + with open(out_path, 'w') as f: + if items: + f.write('\n'.join(items) + '\n') + print(f'[splits] {name:5s}: {len(items):4d} -> {out_path}') + + total = sum(len(v) for v in splits.values()) + print(f'[splits] total: {total}') + return splits['train'], splits['val'], splits['test'] + + +def load_existing_splits(data_root: str + ) -> Tuple[List[str], List[str], List[str]]: + def read(name: str) -> List[str]: + p = osp.join(data_root, 'splits', f'{name}.txt') + if not osp.isfile(p): + raise SystemExit( + f'--skip-splits was set but {p} is missing. ' + 'Run without --skip-splits first.') + with open(p) as f: + return [ln.strip() for ln in f if ln.strip()] + return read('train'), read('val'), read('test') + + +# --------------------------------------------------------------------------- +# Step 3: stats +# --------------------------------------------------------------------------- +def _hwc(img: np.ndarray) -> np.ndarray: + """Normalize raster shape to (H, W, C).""" + if img.ndim == 2: + return img[:, :, None] + # tifffile typically returns (C, H, W) for multi-band TIFFs - + # transpose when the leading axis is the smallest. + if img.ndim == 3 and img.shape[0] < img.shape[-1]: + return np.transpose(img, (1, 2, 0)) + return img + + +def compute_stats_for_bases(data_root: str, modality: str, + bases: List[str] + ) -> Tuple[np.ndarray, np.ndarray, int]: + modal_cfg = MODALITIES[modality] + img_dir = osp.join(data_root, modal_cfg['subdir']) + label_dir = osp.join(data_root, LABEL_SUBDIR) + num_channels = modal_cfg['channels'] + + if not osp.isdir(img_dir): + raise SystemExit( + f'[{modality}] image dir not found: {img_dir}') + + # Welford-style streaming accumulators (float64 for precision). + total_count = 0 + total_sum = np.zeros(num_channels, dtype=np.float64) + total_sqsum = np.zeros(num_channels, dtype=np.float64) + + n_missing_img = 0 + n_missing_label = 0 + n_channel_mismatch = 0 + + for i, base in enumerate(bases): + img_path = osp.join(img_dir, base + modal_cfg['suffix']) + lbl_path = osp.join(label_dir, base + LABEL_SUFFIX) + + if not osp.isfile(img_path): + n_missing_img += 1 + continue + + img = _hwc(tifffile.imread(img_path)).astype(np.float64) + if img.shape[-1] != num_channels: + print(f' [warn] skip {base}: got {img.shape[-1]} channels ' + f'(expected {num_channels})') + n_channel_mismatch += 1 + continue + + # Valid = finite everywhere AND not label nodata. + valid_mask = np.all(np.isfinite(img), axis=-1) + if osp.isfile(lbl_path): + lbl = np.squeeze(tifffile.imread(lbl_path)) + if lbl.shape == img.shape[:2]: + valid_mask &= (lbl != -1) + else: + n_missing_label += 1 + + if not valid_mask.any(): + continue + + valid = img[valid_mask] # (N, C) + total_count += int(valid.shape[0]) + total_sum += valid.sum(axis=0) + total_sqsum += (valid ** 2).sum(axis=0) + + if (i + 1) % 50 == 0: + print(f' [{modality}] processed {i + 1}/{len(bases)} tiles, ' + f'valid pixels so far: {total_count}') + + if total_count == 0: + raise SystemExit( + f'[{modality}] no valid pixels across {len(bases)} tiles ' + f'(missing imgs: {n_missing_img}, ' + f'channel mismatches: {n_channel_mismatch}).') + + mean = total_sum / total_count + var = np.clip(total_sqsum / total_count - mean ** 2, 0.0, None) + std = np.sqrt(var) + + print(f' [{modality}] done. valid pixels={total_count}, ' + f'missing imgs={n_missing_img}, ' + f'missing labels={n_missing_label}, ' + f'channel mismatches={n_channel_mismatch}') + return mean, std, total_count + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + args = parse_args() + data_root = args.data_root.rstrip('/') + + if not osp.isdir(data_root): + raise SystemExit(f'--data-root not found: {data_root}') + + # Step 1: scan ----------------------------------------------------- + bases = scan_basenames(data_root) + print(f'Found {len(bases)} tiles under {data_root}/{LABEL_SUBDIR}/') + + # Step 2: splits --------------------------------------------------- + if args.skip_splits: + train_bases, val_bases, test_bases = load_existing_splits(data_root) + print(f'[splits] reusing existing splits: ' + f'{len(train_bases)}/{len(val_bases)}/{len(test_bases)}') + else: + train_bases, val_bases, test_bases = write_splits( + data_root, bases, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + seed=args.seed, + ) + + # Step 3 & 4: stats ------------------------------------------------ + if args.skip_stats: + print('\n[stats] skipped (--skip-stats)') + return + + all_stats = {} + for modality in args.modalities: + print() + print(f'[stats] computing {modality} mean/std over ' + f'{len(train_bases)} train tiles (NaN/Inf/label==-1 masked)') + mean, std, count = compute_stats_for_bases( + data_root, modality, train_bases) + all_stats[modality] = (mean, std, count) + + print() + print('=' * 72) + print('Paste the following into NORM_CONFIGS in') + print(' mmseg/datasets/transforms/multimodal_pipelines.py') + print('(replacing the existing \'s1\' / \'s2\' entries)') + print('=' * 72) + for modality in args.modalities: + mean, std, count = all_stats[modality] + print(f' \'{modality}\': {{ # {count} valid train pixels') + print(f' \'mean\': {mean.tolist()},') + print(f' \'std\': {std.tolist()},') + print(' },') + + +if __name__ == '__main__': + main() diff --git a/tools/test.py b/tools/test.py index 0d7f39b3a8..36c35d90f0 100644 --- a/tools/test.py +++ b/tools/test.py @@ -6,6 +6,11 @@ from mmengine.config import Config, DictAction from mmengine.runner import Runner +from mmseg.utils import register_all_modules + +# register all modules in mmseg into the registries +register_all_modules() + # TODO: support fuse_conv_bn, visualization, and format_only def parse_args(): diff --git a/tools/test_full_metrics.py b/tools/test_full_metrics.py new file mode 100644 index 0000000000..09a9299d75 --- /dev/null +++ b/tools/test_full_metrics.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Test script with full metrics: Precision, Recall, OA, F1 Score, mIoU. + +This script extends the default test.py by automatically configuring the +evaluator to compute all five metrics: + - mIoU (Mean Intersection over Union) + - mPrecision (Mean Precision) + - mRecall (Mean Recall) + - mFscore (F1 Score) + - aAcc (Overall Accuracy, OA) + +Usage: + python tools/test_full_metrics.py [options] + +Example: + python tools/test_full_metrics.py \\ + ./configs/deeplabv3plus/Deeplabv3+UAVflood.py \\ + work_dirs/SAR/Deeplabv3+/best_mIoU_epoch_100.pth \\ + --work-dir ./Result/SAR/Deeplabv3+ \\ + --show-dir ./Result/SAR/Deeplabv3+/vis \\ + --cfg-options visualizer.alpha=1.0 +""" +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMSeg test with full metrics ' + '(Precision, Recall, OA, F1 Score, mIoU)') + parser.add_argument('config', help='train config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='if specified, the evaluation metric results will be dumped ' + 'into the directory as json') + parser.add_argument( + '--out', + type=str, + help='The directory to save output prediction for offline evaluation') + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualizer = cfg.visualizer + visualizer['save_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks. ' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + + # add output_dir in metric + if args.out is not None: + cfg.test_evaluator['output_dir'] = args.out + cfg.test_evaluator['keep_results'] = True + + # Override test_evaluator to compute full metrics: + # mIoU, mFscore (which includes Precision, Recall, F1), and aAcc (OA) + cfg.test_evaluator['type'] = 'IoUMetric' + cfg.test_evaluator['iou_metrics'] = ['mIoU', 'mFscore'] + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py index 10fdaa1874..08ce040c08 100644 --- a/tools/train.py +++ b/tools/train.py @@ -9,6 +9,7 @@ from mmengine.runner import Runner from mmseg.registry import RUNNERS +from mmseg.utils import register_all_modules def parse_args(): @@ -54,6 +55,9 @@ def parse_args(): def main(): args = parse_args() + # register all modules + register_all_modules(init_default_scope=True) + # load config cfg = Config.fromfile(args.config) cfg.launcher = args.launcher