diff --git a/olmoearth_projects/projects/mozambique_lulc/__init__.py b/olmoearth_projects/projects/mozambique_lulc/__init__.py new file mode 100644 index 0000000..7f55ac2 --- /dev/null +++ b/olmoearth_projects/projects/mozambique_lulc/__init__.py @@ -0,0 +1 @@ +"""Mozambique Land Use Land Cover (LULC) mapping project.""" diff --git a/olmoearth_projects/projects/mozambique_lulc/create_label_raster.py b/olmoearth_projects/projects/mozambique_lulc/create_label_raster.py new file mode 100644 index 0000000..cf983e3 --- /dev/null +++ b/olmoearth_projects/projects/mozambique_lulc/create_label_raster.py @@ -0,0 +1,102 @@ +"""Create label_raster from label. + +If you run this, you will need to update the config.json for the dataset +to include the following entry: + +"label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, +""" + +import argparse +import multiprocessing + +import numpy as np +import tqdm +from rslearn.dataset.dataset import Dataset +from rslearn.dataset.window import Window +from rslearn.utils.raster_format import GeotiffRasterFormat +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +LULC_CLASS_NAMES = [ + "invalid", + "Water", + "Bare Ground", + "Rangeland", + "Flooded Vegetation", + "Trees", + "Cropland", + "Buildings", +] +CROPTYPE_CLASS_NAMES = [ + "invalid", + "corn", + "cassava", + "rice", + "sesame", + "beans", + "millet", + "sorghum", +] +PROPERTY_NAME = "category" +BAND_NAME = "label" + + +def create_label_raster(window: Window) -> None: + """Create label raster for the given window.""" + label_dir = window.get_layer_dir("label") + features = GeojsonVectorFormat().decode_vector( + label_dir, window.projection, window.bounds + ) + class_name = features[0].properties[PROPERTY_NAME] + try: + class_id = LULC_CLASS_NAMES.index(class_name) + except ValueError: + class_id = CROPTYPE_CLASS_NAMES.index(class_name) + + # Draw the class_id in the middle 1x1 of the raster. + raster = np.zeros( + (1, window.bounds[3] - window.bounds[1], window.bounds[2] - window.bounds[0]), + dtype=np.uint8, + ) + raster[:, raster.shape[1] // 2, raster.shape[2] // 2] = class_id + raster_dir = window.get_raster_dir("label_raster", [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_dir, window.projection, window.bounds, raster + ) + window.mark_layer_completed("label_raster") + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser() + parser.add_argument( + "--ds_path", + type=str, + required=True, + help="Path to the dataset", + ) + parser.add_argument( + "--workers", + type=int, + default=64, + help="Number of worker processes to use", + ) + args = parser.parse_args() + + dataset = Dataset(UPath(args.ds_path)) + windows = dataset.load_windows(workers=args.workers, show_progress=True) + p = multiprocessing.Pool(args.workers) + outputs = p.imap_unordered(create_label_raster, windows) + for _ in tqdm.tqdm(outputs, total=len(windows)): + pass + p.close() diff --git a/olmoearth_projects/projects/mozambique_lulc/create_windows_for_lulc.py b/olmoearth_projects/projects/mozambique_lulc/create_windows_for_lulc.py new file mode 100644 index 0000000..4268196 --- /dev/null +++ b/olmoearth_projects/projects/mozambique_lulc/create_windows_for_lulc.py @@ -0,0 +1,334 @@ +"""Create windows for crop type mapping from GPKG files (fixed splits).""" + +import argparse +import multiprocessing +from collections.abc import Iterable +from datetime import UTC, datetime +from pathlib import Path + +import geopandas as gpd +import shapely +import tqdm +from olmoearth_run.runner.tools.data_splitters.spatial_data_splitter import ( + SpatialDataSplitter, +) +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.feature import Feature +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label" + +CLASS_MAP = { + 0: "Water", + 1: "Bare Ground", + 2: "Rangeland", + 3: "Flooded Vegetation", + 4: "Trees", + 5: "Cropland", + 6: "Buildings", +} + +# reversed; in the crop type gpkg +# the strings are saved and we want to +# get the ints back for the category id +CROP_TYPE_MAP = { + "corn": 0, + "cassava": 1, + "rice": 2, + "sesame": 3, + "beans": 4, + "millet": 5, + "sorghum": 6, +} + + +# Per-province temporal coverage (UTC) +GROUP_TIME = { + "gaza": ( + datetime(2024, 10, 23, tzinfo=UTC), + datetime(2025, 6, 7, tzinfo=UTC), + ), + "manica": ( + datetime(2024, 10, 23, tzinfo=UTC), + datetime(2025, 6, 7, tzinfo=UTC), + ), + "zambezia": ( + datetime(2024, 10, 23, tzinfo=UTC), + datetime(2025, 6, 7, tzinfo=UTC), + ), + # for crop type, we will train a single model + # for all 3 provinces since there are too few labels + # so let's take the union of the ranges. + "crop_type": ( + datetime(2024, 10, 23, tzinfo=UTC), + datetime(2025, 6, 7, tzinfo=UTC), + ), +} + + +def calculate_bounds( + geometry: STGeometry, window_size: int +) -> tuple[int, int, int, int]: + """Calculate the bounds of a window around a geometry. + + Args: + geometry: the geometry to calculate the bounds of. + window_size: the size of the window. + + Copied from + https://github.com/allenai/rslearn_projects/blob/master/rslp/utils/windows.py + """ + if window_size <= 0: + raise ValueError("Window size must be greater than 0") + + if window_size % 2 == 0: + bounds = ( + int(geometry.shp.x) - window_size // 2, + int(geometry.shp.y) - window_size // 2, + int(geometry.shp.x) + window_size // 2, + int(geometry.shp.y) + window_size // 2, + ) + else: + bounds = ( + int(geometry.shp.x) - window_size // 2, + int(geometry.shp.y) - window_size // 2 - 1, + int(geometry.shp.x) + window_size // 2 + 1, + int(geometry.shp.y) + window_size // 2, + ) + + return bounds + + +def process_gpkg(gpkg_path: UPath, crop_type: bool) -> gpd.GeoDataFrame: + """Load a GPKG and ensure lon/lat in WGS84; expect 'fid' and 'class' columns.""" + gdf = gpd.read_file(str(gpkg_path)) + + # Normalize CRS to WGS84 + if gdf.crs is None: + gdf = gdf.set_crs("EPSG:4326", allow_override=True) + else: + gdf = gdf.to_crs("EPSG:4326") + + required_cols = {"crop1" if crop_type else "class", "geometry"} + missing = [c for c in required_cols if c not in gdf.columns] + if missing: + raise ValueError(f"{gpkg_path}: missing required column(s): {missing}") + + return gdf + + +def iter_points( + gdf: gpd.GeoDataFrame, crop_type: bool +) -> Iterable[tuple[int, float, float, int | str]]: + """Yield (fid, latitude, longitude, category) per feature using centroid for polygons.""" + for fid, row in gdf.iterrows(): + geom = row.geometry + if geom is None or geom.is_empty: + continue + if isinstance(geom, shapely.Point): + pt = geom + else: + pt = geom.centroid + lon, lat = float(pt.x), float(pt.y) + # the crop type labels are strings, the lulc labels are ints which + # map to classes + category = row["crop1"] if crop_type else int(row["class"]) + yield fid, lat, lon, category + + +def create_window( + rec: tuple[int, float, float, int | str], + ds_path: UPath, + group_name: str, + split: str, + window_size: int, + start_time: datetime, + end_time: datetime, + crop_type: bool, +) -> None: + """Create a single window and write label layer.""" + fid, latitude, longitude, category_id = rec + if crop_type: + if not isinstance(category_id, str): + raise ValueError(f"{category_id} should be str in the crop-type case.") + category_label = category_id + category_id = CROP_TYPE_MAP[category_label] + else: + if not isinstance(category_id, int): + raise ValueError(f"{category_id} should be int in the non crop-type case.") + category_label = CLASS_MAP.get(category_id, f"Unknown_{category_id}") + + # Geometry/projection + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, window_size) + + # Group = province name; split is taken from file name (train/test) + group = group_name + window_name = f"{fid}_{latitude:.6f}_{longitude:.6f}" + + window = Window( + path=Window.get_window_root(ds_path, group, window_name), + group=group, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={ + "split": split, # 'train' or 'test' as provided + "category_id": category_id, + "category": category_label, + "fid": fid, + "source": "gpkg", + }, + ) + + if split == "train": + # split into a train and val set using the spatial data + # splitter, keep the test set as it was originally + splitter = SpatialDataSplitter( + train_prop=0.8, val_prop=0.2, test_prop=0.0, grid_size=32 + ) + split = splitter.choose_split_for_window(window) + window.options["split"] = split + window.save() + + # Label layer (same as before, using window geometry) + feature = Feature( + window.get_geometry(), + { + "category_id": category_id, + "category": category_label, + "fid": fid, + "split": split, + }, + ) + layer_dir = window.get_layer_dir(LABEL_LAYER) + GeojsonVectorFormat().encode_vector(layer_dir, [feature]) + window.mark_layer_completed(LABEL_LAYER) + + +def create_windows_from_gpkg( + gpkg_path: UPath, + ds_path: UPath, + group_name: str, + split: str, + window_size: int, + max_workers: int, + start_time: datetime, + end_time: datetime, + crop_type: bool, +) -> None: + """Create windows from a single GPKG file.""" + gdf = process_gpkg(gpkg_path, crop_type) + records = list(iter_points(gdf, crop_type)) + + jobs = [ + dict( + rec=rec, + ds_path=ds_path, + group_name=group_name, + split=split, + window_size=window_size, + start_time=start_time, + end_time=end_time, + crop_type=crop_type, + ) + for rec in records + ] + + print( + f"[{group_name}:{split}] file={gpkg_path.name} features={len(jobs)} " + f"time={start_time.date()}→{end_time.date()}" + ) + + if max_workers <= 1: + for kw in tqdm.tqdm(jobs): + create_window(**kw) + else: + p = multiprocessing.Pool(max_workers) + outputs = star_imap_unordered(p, create_window, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs)): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver", force=True) + + parser = argparse.ArgumentParser(description="Create windows from GPKG files") + parser.add_argument( + "--gpkg_dir", + type=str, + required=True, + help="Directory containing gaza_[train|test].gpkg, manica_[train|test].gpkg, zambezia_[train|test].gpkg", + ) + parser.add_argument( + "--ds_path", + type=str, + required=True, + help="Path to the dataset root", + ) + parser.add_argument( + "--window_size", + type=int, + default=1, + help="Window size (pixels per side in projected grid)", + ) + parser.add_argument( + "--max_workers", + type=int, + default=32, + help="Worker processes (set 1 for single-process)", + ) + parser.add_argument("--crop_type", action="store_true", default=False) + args = parser.parse_args() + + gpkg_dir = Path(args.gpkg_dir) + ds_path = UPath(args.ds_path) + if not args.crop_type: + expected = [ + ("gaza", "train", gpkg_dir / "gaza_train.gpkg"), + ("gaza", "test", gpkg_dir / "gaza_test.gpkg"), + ("manica", "train", gpkg_dir / "manica_train.gpkg"), + ("manica", "test", gpkg_dir / "manica_test.gpkg"), + ("zambezia", "train", gpkg_dir / "zambezia_train.gpkg"), + ("zambezia", "test", gpkg_dir / "zambezia_test.gpkg"), + ] + else: + expected = [ + ("crop_type", "train", gpkg_dir / "training_gaza_zambezia_manica.gpkg"), + ("crop_type", "test", gpkg_dir / "test_gaza_zambezia_manica.gpkg"), + ] + + # Basic checks + for group_or_province, _, path in expected: + if group_or_province not in GROUP_TIME: + raise ValueError(f"Unknown province or group '{group_or_province}'") + if not path.exists(): + raise FileNotFoundError(f"Missing expected file: {path}") + + # Run per file + for group_or_province, split, path in expected: + start_time, end_time = GROUP_TIME[group_or_province] + create_windows_from_gpkg( + gpkg_path=UPath(path), + ds_path=ds_path, + group_name=group_or_province, # group == province + split=split, # honor provided split + window_size=args.window_size, + max_workers=args.max_workers, + start_time=start_time, + end_time=end_time, + crop_type=args.crop_type, + ) + + print("Done.") diff --git a/olmoearth_projects/projects/mozambique_lulc/points_per_class.py b/olmoearth_projects/projects/mozambique_lulc/points_per_class.py new file mode 100644 index 0000000..4ab513a --- /dev/null +++ b/olmoearth_projects/projects/mozambique_lulc/points_per_class.py @@ -0,0 +1,30 @@ +"""Count how many classes are in each split.""" + +import argparse + +from rslearn.dataset.dataset import Dataset +from upath import UPath + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ds_path", + type=str, + required=True, + help="Path to the dataset", + ) + args = parser.parse_args() + + dataset = Dataset(UPath(args.ds_path)) + windows = dataset.load_windows(show_progress=True) + output_dict: dict[str, dict[str, int]] = {} + for window in windows: + split = window.options["split"] + category = window.options["category"] + + if category not in output_dict: + output_dict[category] = {"train": 0, "val": 0, "test": 0} + + output_dict[category][split] += 1 + + print(output_dict) diff --git a/olmoearth_run_data/mozambique_lulc/README.md b/olmoearth_run_data/mozambique_lulc/README.md new file mode 100644 index 0000000..790c4f2 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/README.md @@ -0,0 +1,67 @@ +# Mozambique LULC and Crop Type Classification + +This project has two main tasks: + 1. Land Use/Land Cover (LULC) and cropland classification + 2. Crop type classification + +The annotations come from field surveys across three provinces in Mozambique: Gaza, Zambezia, and Manica. + +For LULC classification, the train/test splits are: +- Gaza: 2,262 / 970 +- Manica: 1,917 / 822 +- Zambezia: 1,225 / 525 + +### Generating the data +``` +export DATASET_PATH=/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251113 + +python /weka/dfive-default/gabrielt/olmoearth_projects/olmoearth_projects/projects/mozambique_lulc/create_windows_for_lulc.py --gpkg_dir /weka/dfive-default/yawenz/datasets/mozambique/train_test_samples --ds_path $DATASET_PATH --window_size 32 + +python /weka/dfive-default/gabrielt/olmoearth_projects/olmoearth_projects/projects/mozambique_lulc/create_windows_for_lulc.py --gpkg_dir /weka/dfive-default/yawenz/datasets/mozambique/train_test_samples --ds_path $DATASET_PATH --window_size 32 --crop_type +``` +You will then need to copy a `config.json` into `$DATASET_PATH`. + +The config being used is available in [config.json](config.json). This config requires [rslearn_projects](https://github.com/allenai/rslearn_projects) in your environment. + +Once the config is copied into the dataset root, the following commands can be run: + +``` +rslearn dataset prepare --root $DATASET_PATH --workers 64 --no-use-initial-job --retry-max-attempts 8 --retry-backoff-seconds 60 + +python -m rslp.main common launch_data_materialization_jobs --image yawenzzzz/rslp20251112h --ds_path $DATASET_PATH --clusters+=ai2/neptune-cirrascale --num_jobs 5 +``` + +Within `/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc` there are two versions of the data: +- `/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251023`, which only has the train and test split as defined in the gpkg files +- `/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251113`, which splits the training data into train and val data using a spatial split (introduced in [this commit](https://github.com/allenai/olmoearth_projects/pull/28/commits/1cfb86d40c8e2ccba830eb80410d1248544877c9)). This leads to the following train / val / test splits (with `val_ratio = 0.2`): + - Gaza: 1,802 / 460 / 970 + - Manica: 1,564 / 353 / 822 + - Zambezia: 949 / 276 / 525 + - For crop type mapping, the following train / val / test splits, per class: `'corn': {'train': 917, 'val': 191, 'test': 3709}, 'sesame': {'train': 384, 'val': 0, 'test': 383}, 'beans': {'train': 932, 'val': 224, 'test': 417}, 'rice': {'train': 648, 'val': 512, 'test': 863}, 'millet': {'train': 36, 'val': 0, 'test': 57}, 'cassava': {'train': 685, 'val': 133, 'test': 201}, 'sorghum': {'train': 52, 'val': 0, 'test': 41},` +- `/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251114` which aligns the dates for all provinces (as in [this commit](https://github.com/allenai/olmoearth_projects/pull/28/commits/07ee7ef22a383b2c71ef6acab3171df8387924bd)). + +Finally - we treat this as a segmentation task, not as a classification task (this makes inference faster, without hurting performance). This means the point labels need to be transformed into rasters: + +``` +python olmoearth_projects/projects/mozambique_lulc/create_label_raster.py --ds_path $DATASET_PATH +``` + +### Finetuning + +Currently, we use [rslearn_projects](github.com/allenai/rslearn_projects) for finetuning, using [rslp_finetuning.yaml](rslp_finetuning.yaml) and [rslp_finetuning_croptype.yaml](rslp_finetuning_croptype.yaml). With `rslean_projects` installed (and access to Beaker), finetuning can then be run with the following command: + +``` +python -m rslp.main olmoearth_pretrain launch_finetune --image_name yawenzzzz/rslp20251112h --config_paths+=olmoearth_run_data/mozambique_lulc/rslp_finetuning.yaml --cluster+=ai2/saturn --rslp_project --experiment_id +``` + +### Testing + +Obtaining test results consisted of the following: +1. Spin up an interactive beaker session with a GPU: `beaker session create --remote --bare --budget ai2/es-platform --cluster ai2/saturn --mount src=weka,ref=dfive-default,dst=/weka/dfive-default --image beaker://yawenzzzz/rslp20251112h --gpus 1` +2. Go to the olmoearth projects folder on weka (to easily `git pull`) changes: `cd /weka/dfive-default/gabrielt/olmoearth_projects` +3. Add the `RSLP_PREFIX` to the environment, `export RSLP_PREFIX=/weka/dfive-default/rslearn-eai` +4. Run testing: `python -m rslp.rslearn_main model test --config olmoearth_run_data/mozambique_lulc/rslp_finetuning.yaml --rslp_experiment --rslp_project --force_log=true --load_best=true --verbose true` + +### Inference + +All inference is done on [OlmoEarth Studio](https://olmoearth.allenai.org/). Polygons around the provinces were manually drawn (within Studio). diff --git a/olmoearth_run_data/mozambique_lulc/config.json b/olmoearth_run_data/mozambique_lulc/config.json new file mode 100644 index 0000000..96d35a8 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/config.json @@ -0,0 +1,141 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, + "output": { + "format": { + "name": "geojson" + }, + "type": "vector" + }, + "sentinel1_ascending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + }, + "sat:orbit_state": { + "eq": "ascending" + } + }, + "query_config": { + "max_matches": 6 + } + }, + "type": "raster" + }, + "sentinel1_descending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + }, + "sat:orbit_state": { + "eq": "descending" + } + }, + "query_config": { + "max_matches": 6 + } + }, + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "ingest": false, + "max_cloud_cover": 50, + "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", + "query_config": { + "max_matches": 6 + }, + "sort_by": "eo:cloud_cover" + }, + "type": "raster" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/olmoearth_run_data/mozambique_lulc/dataset.json b/olmoearth_run_data/mozambique_lulc/dataset.json new file mode 100644 index 0000000..e939fd5 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/dataset.json @@ -0,0 +1,72 @@ +{ + "layers": { + "label": { + "band_sets": [ + { + "bands": [ + "category" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, + "output": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "float32" + } + ], + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "ingest": false, + "name": "rslearn.data_sources.planetary_computer.Sentinel2", + "query_config": { + "max_matches": 12, + "period_duration": "30d", + "space_mode": "PER_PERIOD_MOSAIC" + }, + "sort_by": "eo:cloud_cover" + }, + "type": "raster" + } + } +} diff --git a/olmoearth_run_data/mozambique_lulc/model.yaml b/olmoearth_run_data/mozambique_lulc/model.yaml new file mode 100644 index 0000000..18b1527 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/model.yaml @@ -0,0 +1,281 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: ${EXTRA_FILES_PATH} + selector: ["encoder"] + patch_size: 1 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 1 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 8 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + segment: + label: "targets" + batch_size: 2 + num_workers: ${NUM_WORKERS} + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 4 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/crop_type_classification/classes", "target/crop_type_classification/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 31 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.crop.Crop + init_args: + crop_size: 16 + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + groups: ["gaza", "manica", "zambezia"] + tags: + split: "train" + val_config: + patch_size: 16 + groups: ["gaza", "manica", "zambezia"] + tags: + split: "test" + test_config: + patch_size: 16 + groups: ["gaza", "manica", "zambezia"] + tags: + split: "test" + predict_config: + load_all_patches: true + patch_size: 16 + overlap_ratio: 0.25 # 4 / 16 + skip_targets: true + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] +trainer: + max_epochs: 100 + default_root_dir: ${TRAINER_DATA_PATH} + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: ${WANDB_PROJECT} + name: ${WANDB_NAME} + entity: ${WANDB_ENTITY} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: ${PREDICTION_OUTPUT_LAYER} + selector: ["segment"] + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + padding: 2 diff --git a/olmoearth_run_data/mozambique_lulc/olmoearth_run.yaml b/olmoearth_run_data/mozambique_lulc/olmoearth_run.yaml new file mode 100644 index 0000000..964bec7 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/olmoearth_run.yaml @@ -0,0 +1,80 @@ +inference_results_config: + data_type: RASTER + classification_fields: + - property_name: crop_type_classification + band_index: 1 + allowed_values: + - value: 1 + label: Water + color: [136, 33, 233] + - value: 2 + label: Bare Ground + color: [124, 67, 18] + - value: 3 + label: Rangeland + color: [51, 160, 44] + - value: 4 + label: Flooded Vegetation + color: [169, 214, 146] + - value: 5 + label: Trees + color: [240, 159, 28] + - value: 6 + label: Cropland + color: [251, 154, 153] + - value: 7 + label: Buildings + color: [31, 234, 146] + +window_prep: + sampler: + class_path: olmoearth_run.runner.tools.samplers.noop_sampler.NoopSampler + labeled_window_preparer: + class_path: olmoearth_run.runner.tools.labeled_window_preparers.point_to_raster_window_preparer.PointToRasterWindowPreparer + init_args: + window_buffer: 31 # in pixels + window_resolution: 10.0 + dtype: INT32 + nodata_value: 255 + data_splitter: + class_path: olmoearth_run.runner.tools.data_splitters.spatial_data_splitter.SpatialDataSplitter + init_args: + train_prop: 0.75 + val_prop: 0.25 + test_prop: 0.0 + grid_size: 128 # in pixels + label_layer: "label" + group_name: "spatial_split" + split_property: "split" + +partition_strategies: + partition_request_geometry: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 0.25 + + prepare_window_geometries: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 1024 + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + +postprocessing_strategies: + process_dataset: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 255 + + process_partition: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 255 + + process_window: + class_path: olmoearth_run.runner.tools.postprocessors.noop_raster.NoopRaster diff --git a/olmoearth_run_data/mozambique_lulc/olmoearth_run_croptype.yaml b/olmoearth_run_data/mozambique_lulc/olmoearth_run_croptype.yaml new file mode 100644 index 0000000..7f07308 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/olmoearth_run_croptype.yaml @@ -0,0 +1,80 @@ +inference_results_config: + data_type: RASTER + classification_fields: + - property_name: crop_type_classification + band_index: 1 + allowed_values: + - value: 1 + label: Corn + color: [136, 33, 233] + - value: 2 + label: Cassava + color: [124, 67, 18] + - value: 3 + label: Rice + color: [51, 160, 44] + - value: 4 + label: Sesame + color: [169, 214, 146] + - value: 5 + label: Beans + color: [240, 159, 28] + - value: 6 + label: Millet + color: [251, 154, 153] + - value: 7 + label: Sorghum + color: [31, 234, 146] + +window_prep: + sampler: + class_path: olmoearth_run.runner.tools.samplers.noop_sampler.NoopSampler + labeled_window_preparer: + class_path: olmoearth_run.runner.tools.labeled_window_preparers.point_to_raster_window_preparer.PointToRasterWindowPreparer + init_args: + window_buffer: 31 # in pixels + window_resolution: 10.0 + dtype: INT32 + nodata_value: 255 + data_splitter: + class_path: olmoearth_run.runner.tools.data_splitters.spatial_data_splitter.SpatialDataSplitter + init_args: + train_prop: 0.75 + val_prop: 0.25 + test_prop: 0.0 + grid_size: 128 # in pixels + label_layer: "label" + group_name: "spatial_split" + split_property: "split" + +partition_strategies: + partition_request_geometry: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 0.25 + + prepare_window_geometries: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 1024 + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + +postprocessing_strategies: + process_dataset: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 255 + + process_partition: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 255 + + process_window: + class_path: olmoearth_run.runner.tools.postprocessors.noop_raster.NoopRaster diff --git a/olmoearth_run_data/mozambique_lulc/prediction_request_geometry.geojson b/olmoearth_run_data/mozambique_lulc/prediction_request_geometry.geojson new file mode 100644 index 0000000..b92201a --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/prediction_request_geometry.geojson @@ -0,0 +1,39 @@ +{ + "features": [ + { + "geometry": { + "coordinates": [ + [ + [ + 33.66473812697805, + -25.14037394430332 + ], + [ + 33.66473812697805, + -25.145788014726122 + ], + [ + 33.67009152455876, + -25.145788014726122 + ], + [ + 33.67009152455876, + -25.14037394430332 + ], + [ + 33.66473812697805, + -25.14037394430332 + ] + ] + ], + "type": "Polygon" + }, + "properties": { + "oe_end_time": "2025-05-07T00:00:00+00:00", + "oe_start_time": "2024-10-23T00:00:00+00:00" + }, + "type": "Feature" + } + ], + "type": "FeatureCollection" +} diff --git a/olmoearth_run_data/mozambique_lulc/rslp_finetuning.yaml b/olmoearth_run_data/mozambique_lulc/rslp_finetuning.yaml new file mode 100644 index 0000000..0b67b16 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/rslp_finetuning.yaml @@ -0,0 +1,267 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/joer/phase2.0_base_lr0.0001_wd0.02/step667200 + selector: + - encoder + patch_size: 1 + random_initialization: false + embedding_size: null + autocast_dtype: bfloat16 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 1 + mode: bilinear + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 8 + kernel_size: 1 + padding: same + stride: 1 + activation: + class_path: torch.nn.Identity + init_args: {} + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: + - 0 + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lazy_decode: false + loss_weights: null + trunk: null + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251114 + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + bareground_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + bareground_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rangeland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rangeland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + floodedvegetation_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + floodedvegetation_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + trees_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + trees_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + buildings_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + buildings_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + segment: + label: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 31 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.crop.Crop + init_args: + crop_size: 16 + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + groups: ["gaza", "manica", "zambezia"] + tags: + split: "train" + val_config: + patch_size: 16 + groups: ["gaza", "manica", "zambezia"] + tags: + split: "val" + test_config: + patch_size: 16 + groups: ["gaza", "manica", "zambezia"] + tags: + split: "test" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_segment/accuracy + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 +rslp_project: 2025_09_18_mozambique_lulc +rslp_experiment: mozambique_lulc_helios_base_S2_ts_ws16_ps1_gaza diff --git a/olmoearth_run_data/mozambique_lulc/rslp_finetuning_croptype.yaml b/olmoearth_run_data/mozambique_lulc/rslp_finetuning_croptype.yaml new file mode 100644 index 0000000..0969c84 --- /dev/null +++ b/olmoearth_run_data/mozambique_lulc/rslp_finetuning_croptype.yaml @@ -0,0 +1,267 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/joer/phase2.0_base_lr0.0001_wd0.02/step667200 + selector: + - encoder + patch_size: 1 + random_initialization: false + embedding_size: null + autocast_dtype: bfloat16 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 1 + mode: bilinear + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 8 + kernel_size: 1 + padding: same + stride: 1 + activation: + class_path: torch.nn.Identity + init_args: {} + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: + - 0 + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lazy_decode: false + loss_weights: null + trunk: null + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc/20251114 + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + label: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + is_target: true + dtype: INT32 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 8 + zero_is_invalid: true + metric_kwargs: + average: "micro" + other_metrics: + corn_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 1 + corn_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 1 + cassava_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 2 + cassava_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 2 + rice_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 3 + rice_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 3 + sesame_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 4 + sesame_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 4 + beans_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 5 + beans_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 5 + millet_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 6 + millet_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 6 + sorghum_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 8 + average: null + class_idx: 7 + sorghum_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 8 + average: null + class_idx: 7 + input_mapping: + segment: + label: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 31 + mode: "center" + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.crop.Crop + init_args: + crop_size: 16 + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + groups: ["crop_type"] + tags: + split: "train" + val_config: + patch_size: 16 + groups: ["crop_type"] + tags: + split: "val" + test_config: + patch_size: 16 + groups: ["crop_type"] + tags: + split: "test" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_segment/accuracy + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 +rslp_project: 2025_09_18_mozambique_lulc +rslp_experiment: mozambique_lulc_helios_base_S2_ts_ws16_ps1_gaza