Skip to content

Commit fc2c4e0

Browse files
authored
Merge pull request #375 from allenai/favyen/20251125-examples-update
Update examples to use new dataset config and checkpoint management
2 parents cec4819 + 3f9eda0 commit fc2c4e0

File tree

15 files changed

+301
-194
lines changed

15 files changed

+301
-194
lines changed

README.md

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ directory `/path/to/dataset` and corresponding configuration file at
7979
"bands": ["R", "G", "B"]
8080
}],
8181
"data_source": {
82-
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
83-
"index_cache_dir": "cache/sentinel2/",
84-
"sort_by": "cloud_cover",
85-
"use_rtree_index": false
82+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
83+
"init_args": {
84+
"index_cache_dir": "cache/sentinel2/",
85+
"sort_by": "cloud_cover",
86+
"use_rtree_index": false
87+
}
8688
}
8789
}
8890
}
@@ -189,8 +191,10 @@ automate this process. Update the dataset `config.json` with a new layer:
189191
}],
190192
"resampling_method": "nearest",
191193
"data_source": {
192-
"name": "rslearn.data_sources.local_files.LocalFiles",
193-
"src_dir": "file:///path/to/world_cover_tifs/"
194+
"class_path": "rslearn.data_sources.local_files.LocalFiles",
195+
"init_args": {
196+
"src_dir": "file:///path/to/world_cover_tifs/"
197+
}
194198
}
195199
}
196200
},
@@ -252,8 +256,7 @@ model:
252256
data:
253257
class_path: rslearn.train.data_module.RslearnDataModule
254258
init_args:
255-
# Replace this with the dataset path.
256-
path: /path/to/dataset/
259+
path: ${DATASET_PATH}
257260
# This defines the layers that should be read for each window.
258261
# The key ("image" / "targets") is what the data will be called in the model,
259262
# while the layers option specifies which layers will be read.
@@ -351,7 +354,9 @@ trainer:
351354
...
352355
- class_path: rslearn.train.prediction_writer.RslearnWriter
353356
init_args:
354-
path: /path/to/dataset/
357+
# We need to include this argument, but it will be overridden with the dataset
358+
# path from data.init_args.path.
359+
path: placeholder
355360
output_layer: output
356361
```
357362
@@ -504,24 +509,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
504509
SegmentationTask and overriding the visualize function.
505510

506511

507-
### Logging to Weights & Biases
512+
### Checkpoint and Logging Management
513+
514+
Above, we needed to configure the checkpoint directory in the model config (the
515+
`dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
516+
specify the checkpoint path when applying the model. Additionally, metrics are logged
517+
to the local filesystem and not well organized.
508518

509-
We can log to W&B by setting the logger under trainer in the model configuration file:
519+
We can instead let rslearn automatically manage checkpoints, along with logging to
520+
Weights & Biases. To do so, we add project_name, run_name, and management_dir options
521+
to the model config. The project_name corresponds to the W&B project, and the run name
522+
corresponds to the W&B name. The management_dir is a directory to store project data;
523+
rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
524+
and uses it to store checkpoints.
510525

511526
```yaml
527+
model:
528+
# ...
529+
data:
530+
# ...
512531
trainer:
513532
# ...
514-
logger:
515-
class_path: lightning.pytorch.loggers.WandbLogger
516-
init_args:
517-
project: land_cover_model
518-
name: version_00
533+
project_name: land_cover_model
534+
run_name: version_00
535+
# This sets the option via the MANAGEMENT_DIR environment variable.
536+
management_dir: ${MANAGEMENT_DIR}
519537
```
520538

521-
Now, runs with this model configuration should show on W&B. For `model fit` runs,
522-
the training and validation loss and accuracy metric will be logged. The accuracy
523-
metric is provided by SegmentationTask, and additional metrics can be enabled by
524-
passing the relevant init_args to the task, e.g. mean IoU and F1:
539+
Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
540+
541+
```
542+
export MANAGEMENT_DIR=./project_data
543+
rslearn model fit --config land_cover_model.yaml
544+
```
545+
546+
The training and validation loss and accuracy metric should now be logged to W&B. The
547+
accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
548+
by passing the relevant init_args to the task, e.g. mean IoU and F1:
525549

526550
```yaml
527551
class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -532,6 +556,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
532556
enable_f1_metric: true
533557
```
534558

559+
When calling `model test` and `model predict` with management_dir set, rslearn will
560+
automatically load the best checkpoint from the project directory, or raise an error if
561+
no existing checkpoint exists. This behavior can be overridden with the
562+
`--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
563+
details). Logging will be enabled during fit but not test/predict, and this can also
564+
be overridden, using `--log_mode`.
565+
535566

536567
### Inputting Multiple Sentinel-2 Images
537568

@@ -554,10 +585,12 @@ query_config section. This can replace the sentinel2 layer:
554585
"bands": ["R", "G", "B"]
555586
}],
556587
"data_source": {
557-
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
558-
"index_cache_dir": "cache/sentinel2/",
559-
"sort_by": "cloud_cover",
560-
"use_rtree_index": false,
588+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
589+
"init_args": {
590+
"index_cache_dir": "cache/sentinel2/",
591+
"sort_by": "cloud_cover",
592+
"use_rtree_index": false
593+
},
561594
"query_config": {
562595
"max_matches": 3
563596
}

docs/DatasetConfig.md

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,12 @@ duration of the layers is controlled by the duration of the window's time range.
8383
"bands": ["R", "G", "B"]
8484
}],
8585
"data_source": {
86-
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
87-
"index_cache_dir": "cache/sentinel2/",
88-
"sort_by": "cloud_cover",
89-
"use_rtree_index": false
86+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
87+
"init_args": {
88+
"index_cache_dir": "cache/sentinel2/",
89+
"sort_by": "cloud_cover",
90+
"use_rtree_index": false
91+
}
9092
},
9193
"alias": "sentinel2"
9294
},
@@ -97,10 +99,12 @@ duration of the layers is controlled by the duration of the window's time range.
9799
"bands": ["R", "G", "B"]
98100
}],
99101
"data_source": {
100-
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
101-
"index_cache_dir": "cache/sentinel2/",
102-
"sort_by": "cloud_cover",
103-
"use_rtree_index": false,
102+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
103+
"init_args": {
104+
"index_cache_dir": "cache/sentinel2/",
105+
"sort_by": "cloud_cover",
106+
"use_rtree_index": false
107+
},
104108
// The time offset is documented later.
105109
"time_offset": "60d"
106110
},
@@ -297,7 +301,7 @@ The data source specification looks like this:
297301
```jsonc
298302
{
299303
// The class path of the data source.
300-
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
304+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
301305
// The query configuration specifies how items should be matched to windows. It is
302306
// optional, and the values below are defaults.
303307
"query_config": {
@@ -314,9 +318,12 @@ The data source specification looks like this:
314318
"duration": null,
315319
// The ingest flag is optional, and defaults to true.
316320
"ingest": true,
317-
// Data sources may expose additional configuration options. These would also be
318-
// configured in this section.
319-
// ...
321+
// Data sources may expose additional configuration options, passed via init_args.
322+
// class_path and init_args are handled by jsonargparse to instantiate the data
323+
// source class.
324+
"init_args": {
325+
// ...
326+
}
320327
}
321328
```
322329

@@ -886,29 +893,31 @@ attribute is "IW".
886893
}
887894
],
888895
"data_source": {
889-
"collection_name": "COPERNICUS/S1_GRD",
890-
"dtype": "float32",
891-
"filters": [
892-
[
893-
"transmitterReceiverPolarisation",
896+
"class_path": "rslearn.data_sources.google_earth_engine.GEE",
897+
"init_args": {
898+
"collection_name": "COPERNICUS/S1_GRD",
899+
"dtype": "float32",
900+
"filters": [
901+
[
902+
"transmitterReceiverPolarisation",
903+
[
904+
"VV",
905+
"VH"
906+
]
907+
],
894908
[
895-
"VV",
896-
"VH"
909+
"instrumentMode",
910+
"IW"
897911
]
898912
],
899-
[
900-
"instrumentMode",
901-
"IW"
902-
]
903-
],
904-
"gcs_bucket_name": "YOUR_BUCKET_NAME",
905-
"index_fname": "cache/sentinel1_index",
906-
"name": "rslearn.data_sources.google_earth_engine.GEE",
913+
"gcs_bucket_name": "YOUR_BUCKET_NAME",
914+
"index_fname": "cache/sentinel1_index",
915+
"service_account_credentials": "/etc/credentials/gee_credentials.json",
916+
"service_account_name": "YOUR_SERVICE_ACCOUNT_NAME"
917+
},
907918
"query_config": {
908919
"max_matches": 1
909-
},
910-
"service_account_credentials": "/etc/credentials/gee_credentials.json",
911-
"service_account_name": "YOUR_SERVICE_ACCOUNT_NAME"
920+
}
912921
},
913922
"type": "raster"
914923
}

docs/DatasetFormat.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ be merged/mosaicked together to form one raster or vector file for the window. I
131131
are multiple sub-lists, it typically corresponds to multi-temporal data, and each one
132132
will result in a different raster or vector file after the data is materialized.
133133

134+
Materialization will use the first item group in `item_groups` to populate
135+
`layers/LAYER_NAME`, the second to populate `layers/LAYER_NAME.1`, and so on.
136+
134137
For example, consider this query configuration for a data source
135138
(see [DatasetConfig](DatasetConfig.md) for details):
136139

docs/ModelConfig.md

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ data:
3030
# other data related options
3131
trainer:
3232
# Lightning trainer options and callbacks.
33+
# Model management options.
34+
run_name: # ...
35+
project_name: # ...
36+
management_dir: ${MANAGEMENT_DIR}
3337
```
3438
3539
The YAML is parsed by jsonargparse, so each section directly configures a Python class
@@ -693,16 +697,51 @@ trainer:
693697
mode: max
694698
# We also keep the latest checkpoint.
695699
save_last: true
696-
# The logger can be set to log to something other than the local filesystem, like
697-
# W&B.
698-
logger:
699-
class_path: lightning.pytorch.loggers.WandbLogger
700-
init_args:
701-
# This is the W&B project name, and run name.
702-
# You could set entity here as well, otherwise it will use the default based on
703-
# the API key being used.
704-
project: land_cover_model
705-
name: version_00
700+
```
701+
702+
## Model Management Options
703+
704+
rslearn provides functionality to automatically manage checkpoints and logging. Without
705+
it, when running `model test` and `model predict`, the checkpoint needs to be
706+
explicitly specified using `--ckpt_path`.
707+
708+
If enabled, model management will:
709+
1. Adjust the `dirpath` of any `ModelCheckpoint` callbacks to save checkpoints in
710+
a project directory at `{management_dir}/{project_name}/{run_name}/`.
711+
2. If training is restarted, resume from the last checkpoint.
712+
3. During test/predict, automatically load the best checkpoint.
713+
4. Enable W&B logging and save the W&B run ID to the save project directory (so it can
714+
be reused when resuming training).
715+
5. Save the model config with the W&B run.
716+
717+
Common options are summarized below:
718+
719+
```yaml
720+
# The management directory. Setting this (default null) enables model management. We
721+
# recommend setting it to ${MANAGEMENT_DIR} so that it can easily be changed in
722+
# different environments.
723+
management_dir: ${MANAGEMENT_DIR}
724+
# The project name; corresponds to the W&B project.
725+
project_name: my_project
726+
# The run name (a name for this experiment); corresponds to the W&B run.
727+
run_name: my_first_experiment
728+
# Optional description that will be added to the W&B run.
729+
run_description: this is my first experiment
730+
# Which checkpoint to load, if any (default 'auto').
731+
# 'none' never loads any checkpoint.
732+
# 'last' loads the most recent checkpoint.
733+
# 'best' loads the best checkpoint.
734+
# 'auto' will use 'last' during fit and 'best' during val/test/predict.
735+
load_checkpoint_mode: auto
736+
# Whether to fail if the expected checkpoint based on load_checkpoint_mode does not exist (default 'auto').
737+
# 'yes' will fail while 'no' won't.
738+
# 'auto' will use 'no' during fit and 'yes' during val/test/predict.
739+
load_checkpoint_required: auto
740+
# Whether to log to W&B (default 'auto').
741+
# 'yes' will enable logging.
742+
# 'no' will disable logging.
743+
# 'auto' will use 'yes' during fit and 'no' during val/test/predict.
744+
log_mode: auto
706745
```
707746

708747
## Using Custom Classes

0 commit comments

Comments
 (0)