@@ -79,10 +79,12 @@ directory `/path/to/dataset` and corresponding configuration file at
7979 "bands" : [" R" , " G" , " B" ]
8080 }],
8181 "data_source" : {
82- "name" : " rslearn.data_sources.gcp_public_data.Sentinel2" ,
83- "index_cache_dir" : " cache/sentinel2/" ,
84- "sort_by" : " cloud_cover" ,
85- "use_rtree_index" : false
82+ "class_path" : " rslearn.data_sources.gcp_public_data.Sentinel2" ,
83+ "init_args" : {
84+ "index_cache_dir" : " cache/sentinel2/" ,
85+ "sort_by" : " cloud_cover" ,
86+ "use_rtree_index" : false
87+ }
8688 }
8789 }
8890 }
@@ -189,8 +191,10 @@ automate this process. Update the dataset `config.json` with a new layer:
189191 }],
190192 " resampling_method" : " nearest" ,
191193 " data_source" : {
192- " name" : " rslearn.data_sources.local_files.LocalFiles" ,
193- " src_dir" : " file:///path/to/world_cover_tifs/"
194+ " class_path" : " rslearn.data_sources.local_files.LocalFiles" ,
195+ " init_args" : {
196+ " src_dir" : " file:///path/to/world_cover_tifs/"
197+ }
194198 }
195199 }
196200},
@@ -252,8 +256,7 @@ model:
252256data :
253257 class_path : rslearn.train.data_module.RslearnDataModule
254258 init_args :
255- # Replace this with the dataset path.
256- path : /path/to/dataset/
259+ path : ${DATASET_PATH}
257260 # This defines the layers that should be read for each window.
258261 # The key ("image" / "targets") is what the data will be called in the model,
259262 # while the layers option specifies which layers will be read.
@@ -351,7 +354,9 @@ trainer:
351354 ...
352355 - class_path : rslearn.train.prediction_writer.RslearnWriter
353356 init_args :
354- path : /path/to/dataset/
357+ # We need to include this argument, but it will be overridden with the dataset
358+ # path from data.init_args.path.
359+ path : placeholder
355360 output_layer : output
356361` ` `
357362
@@ -504,24 +509,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
504509SegmentationTask and overriding the visualize function.
505510
506511
507- # ## Logging to Weights & Biases
512+ # ## Checkpoint and Logging Management
513+
514+ Above, we needed to configure the checkpoint directory in the model config (the
515+ ` dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
516+ specify the checkpoint path when applying the model. Additionally, metrics are logged
517+ to the local filesystem and not well organized.
508518
509- We can log to W&B by setting the logger under trainer in the model configuration file :
519+ We can instead let rslearn automatically manage checkpoints, along with logging to
520+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
521+ to the model config. The project_name corresponds to the W&B project, and the run name
522+ corresponds to the W&B name. The management_dir is a directory to store project data;
523+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
524+ and uses it to store checkpoints.
510525
511526` ` ` yaml
527+ model:
528+ # ...
529+ data:
530+ # ...
512531trainer:
513532 # ...
514- logger:
515- class_path: lightning.pytorch.loggers.WandbLogger
516- init_args:
517- project: land_cover_model
518- name: version_00
533+ project_name: land_cover_model
534+ run_name: version_00
535+ # This sets the option via the MANAGEMENT_DIR environment variable.
536+ management_dir: ${MANAGEMENT_DIR}
519537` ` `
520538
521- Now, runs with this model configuration should show on W&B. For `model fit` runs,
522- the training and validation loss and accuracy metric will be logged. The accuracy
523- metric is provided by SegmentationTask, and additional metrics can be enabled by
524- passing the relevant init_args to the task, e.g. mean IoU and F1 :
539+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit` :
540+
541+ ` ` `
542+ export MANAGEMENT_DIR=./project_data
543+ rslearn model fit --config land_cover_model.yaml
544+ ` ` `
545+
546+ The training and validation loss and accuracy metric should now be logged to W&B. The
547+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
548+ by passing the relevant init_args to the task, e.g. mean IoU and F1 :
525549
526550` ` ` yaml
527551 class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -532,6 +556,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
532556 enable_f1_metric: true
533557` ` `
534558
559+ When calling `model test` and `model predict` with management_dir set, rslearn will
560+ automatically load the best checkpoint from the project directory, or raise an error if
561+ no existing checkpoint exists. This behavior can be overridden with the
562+ ` --load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
563+ details). Logging will be enabled during fit but not test/predict, and this can also
564+ be overridden, using `--log_mode`.
565+
535566
536567# ## Inputting Multiple Sentinel-2 Images
537568
@@ -554,10 +585,12 @@ query_config section. This can replace the sentinel2 layer:
554585 "bands": ["R", "G", "B"]
555586 }],
556587 "data_source": {
557- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
558- "index_cache_dir": "cache/sentinel2/",
559- "sort_by": "cloud_cover",
560- "use_rtree_index": false,
588+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
589+ "init_args": {
590+ "index_cache_dir": "cache/sentinel2/",
591+ "sort_by": "cloud_cover",
592+ "use_rtree_index": false
593+ },
561594 "query_config": {
562595 "max_matches": 3
563596 }
0 commit comments