Skip to content

maikebauer/CME_ML

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

74 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DOI

Solar Transient Recognition Using Deep Learning (STRUDL)

Code for the various components of the STRUDL model. A list of Python packages used to generate the results in the STRUDL paper can be found in requirements.txt.

The main file is model_3dconv_onec_slide.py, which is used to run the model for training. All model configurations are defined in the sample_config.yaml file. It is recommended to clone this file and rename the copy to config.yaml, as sample_config.yaml only serves as an example. An explanation of each keyword is given in Section: Configuration.

Running STRUDL

After installing the required packages, you can run STRUDL from the CME_ML directory using the following command:

python model_3dconv_onec_slide.py

This command will use all the settings defined in the config.yaml file and run the base model as described in the STRUDL paper. The results will be saved in a new directory within CME_ML/Model_Train/, named run_ddmmyyyy_HHMMSS_model_3donv_onec_slide/, where ddmmyyyy_HHMMSS corresponds to the start time of the model run. The trained model will be saved as a .pth file, and the config.yaml file used for the run will be copied into the results folder for reference.

Evaluating STRUDL

All evaluation settings are defined in the sample_config_evaluation.yaml file. As with the training config, it is recommended to copy and rename it to config_evaluation.yaml. All configuration options are explained in Section: Evaluation Configuration. To reproduce the evaluation results shown in the STRUDL paper, run:

python run_evaluation_total.py

Configuration

dataset:

data_path: "/path/to/data/" (Path to folder in which .pkl files of ST-A HI-1 running difference images in full 1024x1024 resolution are saved)

annotation_path: "/path/to/annotation.json" (Path to .json file containing the annotations for the ST-A HI-1 images)

height: 128 (desired height for ST-A HI-1 images used for training/validation/testing)

width: 128 (desired width for ST-A HI-1 images used for training/validation/testing)

win_size: 16 (desired length of image series)

stride: 2 (desired stride when splitting image seriws into sequences)

quick_run: false (uses only the first few image sequences as input to check if code is running without errors)

model:

model_path: "/path/to/code/" (path to where this repository is saved)

name: "cnn3d" (name of backbone to be used, it is recommended to stick to cnn3d)

model_parameters:

  • input_channels: 1
  • output_channels: 1
  • final_layer: Sigmoid (final layer of model)

final_layer: Sigmoid

device: cuda (which device to run the model on)

seed: 42 (seed to use for sequence generation if pre-determined sequences are not used)

optimizer:

name: Adam (name of optimizer to use)

optimizer_parameters:

  • lr: 1.0e-05
  • weight_decay: 0.0

scheduler:

use_scheduler: false

name: ReduceLROnPlateau (if use_scheduler is true, which scheduler to use)

scheduler_parameters:

  • mode: min
  • factor: 0.1
  • patience: 10

loss:

name: AsymmetricUnifiedFocalLoss (name of loss function to use)

loss_parameters:

  • weight: 0.5
  • delta: 0.6
  • gamma: 0.5

train:

batch_size: 4

num_workers: 2

epochs: 150

include_potential: true (include CMEs flagged as potential CMEs)

include_potential_gt: true (include CMEs flagged as potential CMEs)

data_parallel: true (train the model on parallel GPUs)

shuffle: true (shuffle the training dataset each epoch)

binary_gt: true (read in ground truth as binary array)

threshold_iou: 0.9 (IoU threshold to use for computing training IoU after each epoch)

cross_validation:

  • use_cross_validation: true

  • fold_file: "/path/to/folds_dictionary.npy" (path of .npy file containing definition of folds to use for cross-validation)

  • fold_definition: (which folds to use for what set)

    • train:

      • fold_1

      • fold_2

      • fold_3

      • fold_4

      • fold_5

      • fold_6

      • fold_7

      • fold_8

    • test:

      • fold_9
    • val:

      • fold_10

data_augmentation: (which data augmentations to use and which probability to apply them with)

  • name: ToTensor

  • name: RandomHorizontalFlip

    • p: 0.5
  • name: RandomVerticalFlip

    • p: 0.5

load_checkpoint:

  • load_model: False (whether to re-start training from a previous checkpoint)

  • load_optimizer: False

  • checkpoint_path: "/path/to/checkpoint.pth"

  • updated_lr: 1.0e-05

evaluate: (parameters useed only in evaluation step)

include_potential: true

include_potential_gt: true

shuffle: false

binary_gt: true

test: (currently not used, parameters used only in testing step)

model_path: "/path/to/model.pth"

include_potential: true

include_potential_gt: true

shuffle: false

binary_gt: true

threshold_iou: 0.9

Evaluation Configuration

mode: test (choose either 'val' for validation, or 'test' for testing)

method: (which methods to use for aggregating the predicted segmentation masks)

  • median
  • mean
  • max

mdls_event_based: (names of folders containing .pth files to evaluate as part of event-based tracking)

  • run_27062025_150321_model_cnn3d
  • run_27062025_150708_model_cnn3d
  • run_27062025_171628_model_cnn3d
  • run_30062025_105242_model_cnn3d
  • run_27062025_154802_model_cnn3d

mdls_operational: run_25062025_120013_model_cnn3d (name of folder containing .pth file to evaluate for continuous tracking)

paths:

  • ml_path: /path/to/Model_Train/ (path to the Model_Train folder)

  • data_paths: (path to running difference images for continuous tracking)

    • /path/to/2009_rdifs/
    • /path/to/2011_rdifs/
  • rdif_path: /path/to/event_based_rdifs/ (path to running difference images for event-based tracking)

  • annotation_path: /path/to/instances_default.json (path to .json file containing annotations)

  • helcats_path: /path/to/helcats/HCME_WP3_V06_TE_PROFILES/ (path to HELCATS WP3, aka HIGeoCAT, profiles. Can be downloaded here)

  • corrected_helcats_path: /path/to/time_corrected/helcats/HCME_WP3_V06_TE_PROFILES_CORRECTED_CSV/ (path to corrected versions of HELCATS WP3 profiles. Delete this keyword for program to create corrected files)

  • fits_path: /path/to/STEREO_A/fits/ (path to ST-A HI-1 .fits files, with or without data reduction applied)

  • wp2_path: /path/to/helcats/HCME_WP2_V06.json (path to HELCATS WP2, aka HICAT, profiles. Can be downloaded here)

time_pairs: (start and end times for continuous evaluation)

  • start:

    • '2009_01_01'
    • '2011_01_01'
  • end:

    • '2009_12_31'
    • '2011_12_31'

get_segmentation_masks: True (get segmentation masks from .pth file. Set to True if evaluation for model is done for the first time)

plotting: True

dates_plotting_operational: (years and months to be plotted for operational evaluation)

  • '2009':

    • '01'
    • '02'
  • '2011':

    • '01'
    • '02'

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •