diff --git a/tutorials/radiant-mlhub-on-demand-training-data.ipynb b/tutorials/radiant-mlhub-on-demand-training-data.ipynb new file mode 100644 index 00000000..13ab8a02 --- /dev/null +++ b/tutorials/radiant-mlhub-on-demand-training-data.ipynb @@ -0,0 +1,947 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a7cf916c-0991-4af6-889f-a1ce86696d46", + "metadata": {}, + "source": [ + "## On Demand Training Data from Radiant MLHub and Planetary Computer\n", + "\n", + "Radiant MLHub Logo" + ] + }, + { + "cell_type": "markdown", + "id": "010b5b89-32c4-4f6b-81fb-f41d782d251f", + "metadata": {}, + "source": [ + "In this tutorial, we will walk through the process of requesting on-demand traning data from the [Planetary Computer Data Catalog](https://planetarycomputer.microsoft.com/catalog) to pair with the [BigEarthNet](https://mlhub.earth/data/bigearthnet_v1) dataset downloaded from Radiant MLHub. This is an important workflow for someone in the geospatial community who wants to train an ML model on a datasource outside of a prepackaged dataset, such as those found on MLHub. They can start with any dataset containing source image and label collections in STAC, obtain a random sample to work with, fetch source images from a different collection or satellite product, and then reproject and crop those images to match the spatial and temporal extent of the original dataset.\n", + "\n", + "**NOTE:** because the workflow documented below uses libraries like `pystac_client` and `stackstac`, the datasets queried need to be organized into STAC Collections." + ] + }, + { + "cell_type": "markdown", + "id": "f130b365-6fff-4d2a-86a9-39085ab13886", + "metadata": {}, + "source": [ + "Let's start by importing the Python libraries we'll use in this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e144460-5549-4ab4-ba98-10a1a7ebd236", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import tempfile\n", + "from pathlib import Path\n", + "import os\n", + "import json\n", + "from glob import glob\n", + "import requests\n", + "from typing import List, Tuple, Dict, Any\n", + "from datetime import datetime as dt\n", + "from datetime import timedelta as td\n", + "\n", + "import planetary_computer\n", + "import pystac_client\n", + "from pystac import ItemCollection, Item, Asset\n", + "import dask\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from stackstac import stack\n", + "from geopandas import GeoDataFrame\n", + "import rasterio as rio\n", + "from rasterio.plot import show\n", + "import rioxarray\n", + "from xarray import DataArray\n", + "from shapely.geometry import shape\n", + "from shapely.geometry import Polygon\n", + "from pyproj import CRS" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb87b-9d9c-4140-bac8-3e95f146c029", + "metadata": {}, + "source": [ + "### Define global variables" + ] + }, + { + "cell_type": "markdown", + "id": "2de2f657-3229-4037-bf9c-541c503cc269", + "metadata": {}, + "source": [ + "We will also need to define other initial global variables to get our workflow started, e.g. a temporary working directory to download and write data to, the STAC API endpoints, names of Collections, and other variables like the RGB bands for those collections. These are pretty flexible depending on your individual needs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3bf07a1-3a4a-4207-a449-7be766fa7e36", + "metadata": {}, + "outputs": [], + "source": [ + "# Temporary working directory on local machine or PC instance\n", + "TMP_DIR = tempfile.gettempdir()\n", + "\n", + "# API endpoints for MLHub and Planetary Computer catalogs\n", + "MLHUB_API_URL = \"https://api.radiant.earth/mlhub/v1\"\n", + "MSPC_API_URL = \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n", + "\n", + "# Names of Collections that will be queried against using pystac_client\n", + "BIGEARTHNET_SOURCE_COLLECTION = \"bigearthnet_v1_source\" # sentinel-2 source imagery\n", + "BIGEARTHNET_LABEL_COLLECTION = \"bigearthnet_v1_labels\" # geojson classification labels\n", + "PLANETARY_COMPUTER_LANDSAT_8 = \"landsat-8-c2-l2\" # landsat 8 source imagery on PC\n", + "OUTPUT_DIR = \"landsat_8_source\"\n", + "\n", + "# Default variables that will be used in the API queries\n", + "BIGEARTHNET_TIME_RANGE = \"2017-06-01/2018-05-31\" # full date range for BigEarthNet\n", + "LABEL_CRS = CRS(\"EPSG:4326\")\n", + "DATE_BUFFER = 60\n", + "LANDSAT_8_RGB_BANDS = [\"SR_B4\", \"SR_B3\", \"SR_B2\"] # names of RGB bands from BigEarthNet\n", + "BIGEARTHNET_RGB_BANDS = [\"B04\", \"B03\", \"B02\"] # names of RGB bands from PC Landsat 8\n", + "\n", + "# Bounding box for demonstration fetching Items over Luxembourg\n", + "LUXEMBOURG_AOI = [6.06, 49.58, 6.21, 49.66] # aoi around Luxembourg\n", + "SPAIN_AOI = [-9.73, 35.84, 3.43, 43.87]" + ] + }, + { + "cell_type": "markdown", + "id": "5d5e31b3-5cde-4a7b-af43-dc194b06d0a0", + "metadata": {}, + "source": [ + "### Authentication with Radiant MLHub" + ] + }, + { + "cell_type": "markdown", + "id": "5f11e821-b98b-4df1-a26c-826c9bdbec50", + "metadata": {}, + "source": [ + "Programmatic access to the Radiant MLHub API using the `pystac_client` library requires both the API end-point and an API key. You can obtain an API key for free by registering an account on [mlhub.earth](https://mlhub.earth/). This can be found under `Settings & API Key` from the drop-down once logged in." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4c9dd60-3abc-464d-af25-4b23c0d2783b", + "metadata": {}, + "outputs": [], + "source": [ + "MLHUB_API_KEY = getpass.getpass(prompt=\"MLHub API Key: \")" + ] + }, + { + "cell_type": "markdown", + "id": "5c77d029-191e-42a5-8250-bc451b80f247", + "metadata": {}, + "source": [ + "### Configure API connection to Radiant MLHub" + ] + }, + { + "cell_type": "markdown", + "id": "b0480635-eef5-4e3f-847a-5060c409ae4f", + "metadata": {}, + "source": [ + "This makes a connection to the Radiant MLHub Data Catalog using the API endpoint URL, and the API key from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf79c301-76df-4158-bf97-3da53552e143", + "metadata": {}, + "outputs": [], + "source": [ + "mlhub_catalog = pystac_client.Client.open(\n", + " url=MLHUB_API_URL, parameters={\"key\": MLHUB_API_KEY}, ignore_conformance=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "637734bc-af77-4c81-ab26-98e808de6415", + "metadata": {}, + "source": [ + "### Fetch label items from BigEarthNet over Luxembourg" + ] + }, + { + "cell_type": "markdown", + "id": "e9f19ce4-57fd-43d0-9263-7a5edd106ee8", + "metadata": {}, + "source": [ + "We will now use the `search` function from the API client to get label Items over Luxembourg as a simple use-case." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc6e40fb-9e96-4041-bb21-119539875caa", + "metadata": {}, + "outputs": [], + "source": [ + "origin_label_items = mlhub_catalog.search(\n", + " collections=BIGEARTHNET_LABEL_COLLECTION,\n", + " bbox=LUXEMBOURG_AOI,\n", + " datetime=BIGEARTHNET_TIME_RANGE,\n", + " max_items=100\n", + ").get_all_items()" + ] + }, + { + "cell_type": "markdown", + "id": "e9d121a9-e54b-4039-9cef-04277962c2ca", + "metadata": {}, + "source": [ + "This is a helper function that simply displays the geometry for labels from an ItemCollection overlayed on a map of the region." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f203e31-5b09-4f2e-bf27-9a3ef8e3fc4d", + "metadata": {}, + "outputs": [], + "source": [ + "def explore_search_extent(items: ItemCollection) -> None:\n", + " \"\"\"Extracts geometry from ItemCollection to display polygons on a map.\n", + "\n", + " Args:\n", + " items: ItemCollection of Items retrieved from pystac_client search\n", + "\n", + " Returns:\n", + " GeoDataFrame object with the .explore() method called\n", + " \"\"\"\n", + " item_feature_collection = items.to_dict()\n", + " geom_df = GeoDataFrame.from_features(item_feature_collection).set_crs(4326)\n", + " print(geom_df.bounds)\n", + " return geom_df[[\"geometry\", \"datetime\"]].explore(\n", + " column=\"datetime\", style_kwds={\"fillOpacity\": 0.2}, cmap=\"viridis\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "ab923855-3c20-4f10-af00-d4175a54fdd4", + "metadata": {}, + "source": [ + "Here are the BigEarthNet chips with their bounding boxes that matched the spatial parameters for the city of Luxembourg and surrounding areas." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b86bf5d8-8dc8-491d-b3cd-1ea1263774ca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "explore_search_extent(origin_label_items)" + ] + }, + { + "cell_type": "markdown", + "id": "13c6e607-c79b-4678-a483-9bacb0b3b1df", + "metadata": {}, + "source": [ + "### Download BigEarthNet Source Items from Radiant MLHub" + ] + }, + { + "cell_type": "markdown", + "id": "2e131160-50bb-487f-8fcb-96d37ce80167", + "metadata": {}, + "source": [ + "We could certainly use the method above to query all label and source Items directly from our connection to the Radiant MLHub API endpoint. However, on very large collections, such as in the case with BigEarthNet, pagination becomes a bottleneck issue in obtaining and resolving STAC item. \n", + "\n", + "Querying the entire Collection of nearly ~600,000 Items from a single collection alone would take almost an hour depending on your connection speed. This means it could possibly take a few hours to download all Items in the Catalog. \n", + "\n", + "One alternative is to download the `.tar.gz` of the collections directly from the Radiant MLHub dataset detail page. The filesize for the labels archive is not large, only 165 MB. However because there are over half a million objects, it takes a long time to decompress the entire download.\n", + "\n", + "Therefore, we can showcase this workflow by paginating over the source Item Collection to fetch the first 5,000 Items available (which only represents 1% of the entire collection)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b257cc4f-d77c-422b-b5ca-e81f768e32d5", + "metadata": {}, + "outputs": [], + "source": [ + "bigearthnet_source_search = mlhub_catalog.search(\n", + " collections=BIGEARTHNET_SOURCE_COLLECTION,\n", + " bbox=SPAIN_AOI,\n", + " # limit=100, # limit of items per page\n", + " max_items=5000 # total Item recall\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f729e7ea-689f-40ee-9048-4c37f3f2e859", + "metadata": {}, + "source": [ + "It should take less than a minute to fetch all the STAC Items for the 5000 sample we've queried." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d026603-2fc8-434b-9701-72aab86a6cd6", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "bigearthnet_source_items = bigearthnet_source_search.get_all_items()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01d88ffe-f530-46d6-87d9-4337c1ec0202", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "explore_search_extent(bigearthnet_source_items)" + ] + }, + { + "cell_type": "markdown", + "id": "77109ba6-f0a1-426e-8884-39a5fac2795f", + "metadata": {}, + "source": [ + "We can see from this map that the location of the source items fetched are concentrated in Portugal. This is merely a consequence of the fact we fetched the first 5,000 source Items out of the Catalog API with a bounding box search criteria over Spain. Had we downloaded the entire Catalog locally, or ran an unfiltered search, we could fetch a random sample that is more representative of the entire dataset." + ] + }, + { + "cell_type": "markdown", + "id": "3f09ed9c-66c5-4f4e-8230-1c62e71e2db3", + "metadata": {}, + "source": [ + "### Configure API connection to Planetary Computer" + ] + }, + { + "cell_type": "markdown", + "id": "9a320f3a-f6dd-487d-8cef-34b47f327f1d", + "metadata": {}, + "source": [ + "This makes a connection to the Planetary Computer Data Catalog using the API endpoint URL." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38280716-e130-4de4-9699-2c0bcbfc056d", + "metadata": {}, + "outputs": [], + "source": [ + "mspc_catalog = pystac_client.Client.open(MSPC_API_URL)" + ] + }, + { + "cell_type": "markdown", + "id": "a5e0cbba-32b8-47ad-9913-e7ba2a939922", + "metadata": { + "tags": [] + }, + "source": [ + "### Fetch Landsat 8 scenes based on source Item bbox and datetime" + ] + }, + { + "cell_type": "markdown", + "id": "69ad7ef1-143e-47f0-b6b9-26c73e5cc65d", + "metadata": {}, + "source": [ + "Since it is known that the BigEarthNet dataset from MLHub has a 1-to-1 pairing of source and labels, we can safely assume the first source item is the appropriate match for our label." + ] + }, + { + "cell_type": "markdown", + "id": "39f35d02-d2de-4be5-9317-c80214502c88", + "metadata": {}, + "source": [ + "We will now use the API client with the helper function above to fetch the best Landsat 8 match for the sampled label Item. This will find only the scenes where the label is completely within the scene, and there is minimal cloud cover." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eddda32-d50d-413d-add8-dedaa2f9a067", + "metadata": {}, + "outputs": [], + "source": [ + "def temporal_buffer(item_datetime: str, date_delta: int) -> str:\n", + " \"\"\"Takes a datetime string and returns a buffer around that date\n", + "\n", + " Args:\n", + " item_datetime: string of the datetime property from an Item\n", + " date_delta: integer for days to add before and after a date\n", + "\n", + " Returns:\n", + " a string range representing the full date buffer\n", + " \"\"\"\n", + " delta = td(days=date_delta)\n", + " item_dt = dt.strptime(item_datetime, \"%Y-%m-%dT%H:%M:%SZ\")\n", + "\n", + " dt_start = item_dt - delta\n", + " dt_start_str = dt_start.strftime(\"%Y-%m-%d\")\n", + "\n", + " dt_end = item_dt + delta\n", + " dt_end_str = dt_end.strftime(\"%Y-%m-%d\")\n", + "\n", + " return f\"{dt_start_str}/{dt_end_str}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f77e4ec0-a1b8-490c-b732-4c10d89b06ea", + "metadata": {}, + "outputs": [], + "source": [ + "def min_cloud_cover_scene(label_geom: Polygon, search_items: ItemCollection) -> Item:\n", + " \"\"\"Finds the Item with minimal cloud cover from an ItemCollection\n", + "\n", + " Args:\n", + " label_geom: Polygon geometry to ensure label completely within scene\n", + " search_items: ItemCollection of the Items found from pystac_client search\n", + "\n", + " Returns:\n", + " Item where label completely contained within, and minimal cloud cover\n", + " \"\"\"\n", + " min_cc = np.inf\n", + " min_cc_item = None\n", + " for item in search_items:\n", + " item_geom = shape(item.geometry)\n", + " item_cc = item.properties[\"eo:cloud_cover\"]\n", + " if item_cc < min_cc and label_geom.within(item_geom):\n", + " min_cc = item_cc\n", + " min_cc_item = item\n", + " return min_cc_item" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3b23fa5-d588-42d4-b913-fab7819e7ea3", + "metadata": {}, + "outputs": [], + "source": [ + "def get_landsat_8_match(bbox: List[float], geometry: Dict[str, Any], datetime: str) -> Item:\n", + " \"\"\"Finds the best Landsat 8 match using source Item datetime and bounding box.\n", + "\n", + " Args:\n", + " bbox: bounding box of the STAC source Item\n", + " datetime: datetime of the STAC source Item\n", + "\n", + " Returns:\n", + " best_l8_match: matching Landsat 8 source Item\n", + " \"\"\"\n", + "\n", + " # search PC Catalog for L8 Items\n", + " l8_items = mspc_catalog.search(\n", + " collections=PLANETARY_COMPUTER_LANDSAT_8,\n", + " bbox=bbox,\n", + " datetime=temporal_buffer(datetime, DATE_BUFFER),\n", + " ).get_all_items()\n", + "\n", + " # filter to best L8 Item match\n", + " signed_l8_items = planetary_computer.sign(l8_items)\n", + " best_l8_match = min_cloud_cover_scene(\n", + " shape(geometry), \n", + " signed_l8_items\n", + " )\n", + "\n", + " return best_l8_match" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "784d731a-4e04-41ba-99e1-476d828aa65e", + "metadata": {}, + "outputs": [], + "source": [ + "sample_source_item = bigearthnet_source_items[np.random.randint(0, len(bigearthnet_source_items))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53d3ab14-86c7-42a8-be96-8156f5e74d64", + "metadata": {}, + "outputs": [], + "source": [ + "best_l8_match = get_landsat_8_match(\n", + " sample_source_item.bbox,\n", + " sample_source_item.geometry,\n", + " sample_source_item.properties['datetime']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53556cf0-375b-4044-9262-d16de707a13d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if best_l8_match:\n", + " print(best_l8_match.id)\n", + " print(best_l8_match.bbox)\n", + " print(best_l8_match.geometry)\n", + " print(best_l8_match.properties)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f3b5d88-7f3c-4e74-8e68-2408ba762b0f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "explore_search_extent(ItemCollection([best_l8_match]))" + ] + }, + { + "cell_type": "markdown", + "id": "66c13807-50a3-4a8c-8a89-33002eafabe6", + "metadata": {}, + "source": [ + "If everything worked correctly, the geographic scope of the Landsat 8 scene should encompass a much larger surface area than the Sentinel-2 source and label chips. From here we need to crop the image down and make sure the chips from both products match." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "258a1d9e-55aa-4781-a0f5-e77ace273240", + "metadata": {}, + "outputs": [], + "source": [ + "def get_redirect_url(asset: Asset) -> str:\n", + " \"\"\"Returns the direct URL to an asset.\n", + "\n", + " Args:\n", + " asset: Asset object from an Item\n", + "\n", + " Returns:\n", + " string response URL direct to Asset\n", + " \"\"\"\n", + " response = requests.get(asset.href, allow_redirects=True)\n", + " if response.status_code == 200:\n", + " return response.url\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40a5cf10-0c5b-4e42-856a-51c0987f5523", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_rgb_chip(rgb_stack: DataArray, norm: int) -> None:\n", + " img_arr = rgb_stack[0].to_numpy().squeeze()\n", + " fig, ax = plt.subplots(figsize=(7,7))\n", + " show(img_arr/norm, ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e4e977d-eb14-4aa9-85df-adbe84b1732d", + "metadata": {}, + "outputs": [], + "source": [ + "s2_stack = stack(\n", + " items=ItemCollection([sample_source_item]),\n", + " assets=BIGEARTHNET_RGB_BANDS,\n", + " epsg=rio.open(get_redirect_url(sample_source_item.assets[\"B02\"])).crs.to_epsg(),\n", + " resolution=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fc36fee6-59ea-4554-82e0-20291d6d3004", + "metadata": {}, + "source": [ + "The `stackstac.stack` method returns a DataArray object with width and height for longitude and latitude, and a third dimension for the RGB bands." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "287f9b13-82a3-4a5d-82f2-9ab5066c6be1", + "metadata": {}, + "outputs": [], + "source": [ + "s2_stack" + ] + }, + { + "cell_type": "markdown", + "id": "1f3bfb6c-494c-47f1-8350-75b90faf9a5b", + "metadata": {}, + "source": [ + "This is the true color image representation of the Sentinel-2 chip we fetched RGB assets for." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7726df0-742e-48fe-b371-0a6844932702", + "metadata": {}, + "outputs": [], + "source": [ + "plot_rgb_chip(s2_stack, 4000)" + ] + }, + { + "cell_type": "markdown", + "id": "4a1edab0-bdd6-4347-b23e-58ed08a38cf1", + "metadata": {}, + "source": [ + "Here are the RGB bands all ploted on a subplot together in a row." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee95ad09-34ff-4be1-a5a6-7b870578768e", + "metadata": {}, + "outputs": [], + "source": [ + "s2_stack[0].plot(col=\"band\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc2326b7-d1cc-4474-82b6-95f7da05f897", + "metadata": {}, + "outputs": [], + "source": [ + "l8_original = stack(\n", + " items=ItemCollection([best_l8_match]), assets=LANDSAT_8_RGB_BANDS, resolution=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa276eef-73ea-4834-a6a1-7cbd205f3d46", + "metadata": {}, + "outputs": [], + "source": [ + "l8_original" + ] + }, + { + "cell_type": "markdown", + "id": "bd20472b-be7e-481a-b1ed-3bbcff044c7f", + "metadata": {}, + "source": [ + "As we can see from the metadata for the Xarray above, the Landsat 8 scene has a significantly larger geographic footprint, `~20,000 x ~20,000 pixels`, compared to `120 x 120 pixels` for the Sentinel-2 chips that were prepared for the dataset. We need to crop/mask the Landsat 8 images down so they represent the same geographical footprint.\n", + "\n", + "Luckily, the `bounds_latlon` parameter of `stackstac` makes it easy to crop the image to this size automatically for all bands/assets requested." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0045f54d-5076-4a14-9292-3dbdde339cdc", + "metadata": {}, + "outputs": [], + "source": [ + "l8_cropped = stack(\n", + " items=ItemCollection([best_l8_match]),\n", + " assets=LANDSAT_8_RGB_BANDS,\n", + " bounds_latlon=sample_source_item.bbox,\n", + " resolution=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f0600e6-3bb8-4da9-90b2-ee4ec45ea227", + "metadata": {}, + "outputs": [], + "source": [ + "l8_cropped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "355ee719-0365-40e6-969f-4193ca2a59ec", + "metadata": {}, + "outputs": [], + "source": [ + "l8_cropped[0].data.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95507bc8-521e-43d8-93d6-06ab62b78fa8", + "metadata": {}, + "outputs": [], + "source": [ + "plot_rgb_chip(l8_cropped, 23000)" + ] + }, + { + "cell_type": "markdown", + "id": "bc46c85b-daa9-44df-9be7-62dfa1234b25", + "metadata": {}, + "source": [ + "Now we have a cropped Landsat 8 chip that spatially and temporally matches our Sentinel-2 source imagery and label sample from the BigEarthNet dataset. The first observation is that the Landsat 8 image appears blurry compared to Sentinel-2. This is because Sentinel-2 RGB bands have a 10m resolution, while the same bands for Landsat 8 have a 30m resolution." + ] + }, + { + "cell_type": "markdown", + "id": "f3b09697-b6ab-4026-bfcb-2f5214b03f5c", + "metadata": {}, + "source": [ + "### Scale the workflow using Dask Delayed" + ] + }, + { + "cell_type": "markdown", + "id": "cb650d80-bf1a-4aad-8f8b-08a612e28aae", + "metadata": {}, + "source": [ + "We will now use Dask to optimize processing the Landsat-8 scenes by parallelizing the workflow with a delayed computation graph. The Dask Client schedules, runs the delayed computations, and gathers the results. With parallel processing, we can speed up the runtime of our image processing workflow by 10-20x." + ] + }, + { + "cell_type": "markdown", + "id": "5368c39f-94a1-41ba-acc0-fd18c8dc1c18", + "metadata": {}, + "source": [ + "These are some helper functions that we will use to encapsulate the process of creating the cropped Landsat 8 chips and write them to disk in parallel using the Dask Delayed decorator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c924acb1-092e-4f86-b73f-5b56ccdebe27", + "metadata": {}, + "outputs": [], + "source": [ + "def create_landsat_8_chip(source_item: Dict[str, any]) -> DataArray:\n", + " \"\"\"Creates a Landsat 8 chip from BigEarthNet label chip.\n", + "\n", + " Args:\n", + " source_item: JSON/dictionary representation of source Item\n", + "\n", + " Returns:\n", + " Landsat 8 DataArray that has been cropped to sentinel-2 bbox\n", + " \"\"\"\n", + "\n", + " # fetch the Landsat 8 scene that best matches the label\n", + " l8_match = get_landsat_8_match(\n", + " source_item['bbox'],\n", + " source_item['geometry'],\n", + " source_item['properties']['datetime']\n", + " )\n", + "\n", + " # crop L8 match to S2 dims and read image data\n", + " l8_stack = stack(\n", + " items=ItemCollection([l8_match]),\n", + " assets=LANDSAT_8_RGB_BANDS,\n", + " bounds_latlon=source_item['bbox'],\n", + " resolution=10,\n", + " )\n", + "\n", + " return l8_stack" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02b974fe-0c6f-40a3-ad63-2f4753e0236b", + "metadata": {}, + "outputs": [], + "source": [ + "def write_tif_bands(l8_array: DataArray, l8_item_id: str) -> None:\n", + " \"\"\"Writes to a GeoTiff for each band in Landsat 8 DataArray\n", + "\n", + " Args:\n", + " l8_array: the DataArray object created from the BigEarthNet label item\n", + " \"\"\"\n", + " # write cropped L8 DataArray to a tiff file for each band\n", + " for _band in LANDSAT_8_RGB_BANDS:\n", + " l8_band_img = l8_array.sel(band=_band)\n", + " l8_band_filename = os.path.join(\n", + " TMP_DIR, OUTPUT_DIR, l8_item_id, f\"{l8_item_id}_{_band}.tiff\"\n", + " )\n", + " Path(os.path.split(l8_band_filename)[0]).mkdir(parents=True, exist_ok=True)\n", + " l8_band_img[0].rio.to_raster(l8_band_filename)" + ] + }, + { + "cell_type": "markdown", + "id": "d5a97950-5b08-4e79-ab15-3736697d0584", + "metadata": {}, + "source": [ + "This sets the stage for the Dask Task Scheduler by mapping all label Items to the `create_landsat_8_dataarray` function. Nothing in the task graph will actually be executed until the `.compute()` command is ran." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29531759-6d19-4010-8401-eb947a32c515", + "metadata": {}, + "outputs": [], + "source": [ + "client = dask.distributed.Client() # you can configure Dask client parameters here\n", + "client" + ] + }, + { + "cell_type": "markdown", + "id": "6a30a2f9-a176-436b-a132-01903ab72fd3", + "metadata": {}, + "source": [ + "One quirky nature of combining DataArray objects returned from `stackstac.stack()` (leveraging the `rioxarray` library under the hood) is that the kernel will throw an error that the DataArrays don't have the method `rio.to_raster()`. Normally we could solve this problem by explicitly importing the `rioxarray` library, but we also need to import the module onto each worker in the client cluster. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16e0b91b-0d73-4335-8ba9-00f75850d409", + "metadata": {}, + "outputs": [], + "source": [ + "import importlib\n", + "client.run(lambda: importlib.import_module(\"rioxarray\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3c67a19-442d-484c-8643-93c4d00c60a0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%time\n", + "chunk_size = 125\n", + "\n", + "for i in range(0, len(bigearthnet_source_items[0:500]), chunk_size):\n", + " future_pool = []\n", + " item_chunk=bigearthnet_source_items[i:i+chunk_size]\n", + " for source_item in item_chunk:\n", + " item_dict = dask.delayed(Item.to_dict)(source_item)\n", + " l8_xarray = dask.delayed(create_landsat_8_chip)(item_dict)\n", + " image_writer = dask.delayed(write_tif_bands)(l8_xarray, item_dict['id'])\n", + " future_pool.append(image_writer)\n", + " future_pool = dask.persist(*future_pool)\n", + " dask.compute(*future_pool)" + ] + }, + { + "cell_type": "markdown", + "id": "2bde6ce7-9d7a-4114-88ba-56e4a4bea247", + "metadata": {}, + "source": [ + "Now that our parallelized workflow has completed, let's confirm that folders with images were written to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec64a009-9a87-4fb2-bf5b-fbc41a3f8021", + "metadata": {}, + "outputs": [], + "source": [ + "landsat_chip_dir = os.path.join(TMP_DIR, OUTPUT_DIR)\n", + "len(os.listdir(landsat_chip_dir))" + ] + }, + { + "cell_type": "markdown", + "id": "e0884796-bfab-4905-aa75-2f8dd43a5a13", + "metadata": {}, + "source": [ + "We can also open one of the new Landsat 8 chips to inspect what it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "beb3d5a8-2677-43dc-867a-153eb1e087f9", + "metadata": {}, + "outputs": [], + "source": [ + "landsat_images = glob(f\"{landsat_chip_dir}/**/*.tiff\", recursive=True)\n", + "first_l8_img = rioxarray.open_rasterio(landsat_images[0])\n", + "first_l8_img.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "37581811-0f21-4838-9541-db84688032f6", + "metadata": {}, + "source": [ + "Lastly, we will shutdown the Dask client to cleanup cluster resources." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d643476-917b-484b-a041-8f3c94d12c06", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "client.shutdown()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}