Skip to content

Commit 60c292d

Browse files
mfisher87bibeputtSubhositRay
committed
Add tiler route and TiTilerServer class
Steal more code from jupytergis-tiler Co-authored-by: bibeputt <[email protected]> Co-authored-by: Subhosit <[email protected]>
1 parent faf07f7 commit 60c292d

File tree

6 files changed

+172
-49
lines changed

6 files changed

+172
-49
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
cog.tif
2+
13
*.bundle.*
24
lib/
35
node_modules/

jupyter_server_titiler/api.py

Lines changed: 152 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,174 @@
11
import uuid
2+
from asyncio import Event, Lock, Task, create_task
3+
from functools import partial
4+
from urllib.parse import urlencode
25

36
import rioxarray
7+
from anycorn import Config, serve
8+
from anyio import connect_tcp, create_task_group
49
from fastapi import FastAPI
510
from starlette.middleware.cors import CORSMiddleware
611
from xarray import DataArray, Dataset
712
from geopandas import GeoDataFrame
813
from rio_tiler.io.xarray import XarrayReader
914
from titiler.core.factory import TilerFactory
15+
from titiler.core.algorithm import algorithms as default_algorithms
16+
from titiler.core.algorithm import Algorithms, BaseAlgorithm
1017
from titiler.core.dependencies import DefaultDependency
1118

12-
def _setup_app() -> FastAPI:
13-
app = FastAPI()
19+
from constants import ENDPOINT_BASE
1420

15-
# Add CORS middleware
16-
app.add_middleware(
17-
CORSMiddleware,
18-
allow_origins=["*"], # Allows all origins (for development - be more specific in production)
19-
allow_credentials=True,
20-
allow_methods=["*"],
21-
allow_headers=["*"],
22-
)
2321

24-
return app
22+
# def _setup_app() -> FastAPI:
23+
# app = FastAPI()
24+
#
25+
# # Add CORS middleware
26+
# app.add_middleware(
27+
# CORSMiddleware,
28+
# allow_origins=["*"], # Allows all origins (for development - be more specific in production)
29+
# allow_credentials=True,
30+
# allow_methods=["*"],
31+
# allow_headers=["*"],
32+
# )
33+
#
34+
# return app
35+
#
36+
#
37+
# def _create_xarray_id_lookup(*args: list[DataArray | Dataset]):
38+
# return {uuid.uuid4(): ds for ds in args}
2539

2640

27-
def _create_xarray_id_lookup(*args: list[DataArray | Dataset]):
28-
return {uuid.uuid4(): ds for ds in args}
41+
class TiTilerServer:
42+
"""Shamelessly stolen from jupytergis-tiler.
2943
44+
https://github.com/geojupyter/jupytergis-tiler/blob/main/src/jupytergis/tiler/gis_document.py
45+
"""
46+
def __init__(self, *args, **kwargs) -> None:
47+
super().__init__(*args, **kwargs)
48+
self._tile_server_task: Task | None = None
49+
self._tile_server_started = Event()
50+
self._tile_server_shutdown = Event()
51+
self._tile_server_lock = Lock()
3052

31-
def explore(*args: list[DataArray | Dataset | GeoDataFrame]):
32-
app = _setup_app()
53+
async def start_tile_server(self):
54+
async with self._tile_server_lock:
55+
if not self._tile_server_started.is_set():
56+
self._tile_server_task = create_task(self._start_tile_server())
57+
await self._tile_server_started.wait()
3358

34-
xarray_objs = tuple(arg for arg in args if isinstance(arg, (DataArray, Dataset)))
35-
if xarray_objs:
36-
xarray_id_lookup = _create_xarray_id_lookup(xarray_objs)
59+
async def add_data_array(
60+
self,
61+
data_array: DataArray,
62+
name: str,
63+
colormap_name: str = "viridis",
64+
rescale: tuple[float, float] | None = None,
65+
scale: int = 1,
66+
opacity: float = 1,
67+
algorithm: BaseAlgorithm | None = None,
68+
**params,
69+
):
70+
await self.start_tile_server()
3771

38-
tiler_factory = TilerFactory(
39-
path_dependency=xarray_id_lookup.get,
72+
_params = {
73+
"server_url": self._tile_server_url,
74+
"scale": str(scale),
75+
"colormap_name": colormap_name,
76+
"reproject": "max",
77+
**params,
78+
}
79+
if rescale is not None:
80+
_params["rescale"] = f"{rescale[0]},{rescale[1]}"
81+
if algorithm is not None:
82+
_params["algorithm"] = "algorithm"
83+
source_id = str(uuid.uuid4())
84+
url = (
85+
f"/{ENDPOINT_BASE}/{source_id}/tiles/WebMercatorQuad/"
86+
+ "{z}/{x}/{y}.png?"
87+
+ urlencode(_params)
88+
)
89+
return url
90+
91+
async def stop_tile_server(self):
92+
async with self._tile_server_lock:
93+
if self._tile_server_started.is_set():
94+
self._tile_server_shutdown.set()
95+
96+
async def _start_tile_server(self):
97+
self._app = FastAPI()
98+
99+
config = Config()
100+
config.bind = "127.0.0.1:0"
101+
102+
async with create_task_group() as tg:
103+
binds = await tg.start(
104+
partial(
105+
serve,
106+
self._app,
107+
config,
108+
shutdown_trigger=self._tile_server_shutdown.wait,
109+
mode="asgi",
110+
)
111+
)
112+
113+
self._tile_server_url = binds[0]
114+
115+
host, _port = binds[0][len("http://") :].split(":")
116+
port = int(_port)
117+
while True:
118+
try:
119+
await connect_tcp(host, port)
120+
except OSError:
121+
pass
122+
else:
123+
self._tile_server_started.set()
124+
break
125+
126+
def _include_tile_server_router(
127+
self,
128+
source_id: str,
129+
data_array: DataArray,
130+
algorithm: BaseAlgorithm | None = None,
131+
):
132+
algorithms = default_algorithms
133+
if algorithm is not None:
134+
algorithms = default_algorithms.register({"algorithm": algorithm})
135+
136+
tiler = TilerFactory(
137+
router_prefix=f"/{source_id}",
40138
reader=XarrayReader,
139+
path_dependency=lambda:data_array,
41140
reader_dependency=DefaultDependency,
141+
process_dependency=algorithms.dependency,
42142
)
43-
app.include_router(tiler_factory.router)
44-
45-
import uvicorn
46-
uvicorn.run(app=app, host="127.0.0.1", port=8080, log_level="info")
47-
# return app, xarray_id_lookup
48-
49-
# Display a widget
50-
# TODO: What if there are multiple widgets?
51-
# TODO: Clean up when widgets clean up
52-
raise NotImplementedError("Only xarray.Dataset and xarray.DataArray are supported for now.")
53-
54-
55-
def test() -> Dataset | DataArray:
56-
ds = rioxarray.open_rasterio(
57-
"https://s2downloads.eox.at/demo/EOxCloudless/2020/rgbnir/s2cloudless2020-16bits_sinlge-file_z0-4.tif"
58-
)
59-
return explore(ds)
143+
self._tile_server_app.include_router(tiler.router, prefix=f"/{source_id}")
144+
145+
146+
# def explore(*args: list[DataArray | Dataset | GeoDataFrame]):
147+
# app = _setup_app()
148+
149+
# xarray_objs = tuple(arg for arg in args if isinstance(arg, (DataArray, Dataset)))
150+
# if xarray_objs:
151+
# xarray_id_lookup = _create_xarray_id_lookup(xarray_objs)
152+
153+
# tiler_factory = TilerFactory(
154+
# path_dependency=xarray_id_lookup.get,
155+
# reader=XarrayReader,
156+
# reader_dependency=DefaultDependency,
157+
# )
158+
# app.include_router(tiler_factory.router)
159+
#
160+
# import uvicorn
161+
# uvicorn.run(app=app, host="127.0.0.1", port=8080, log_level="info")
162+
# # return app, xarray_id_lookup
163+
164+
# # Display a widget
165+
# # TODO: What if there are multiple widgets?
166+
# # TODO: Clean up when widgets clean up
167+
# raise NotImplementedError("Only xarray.Dataset and xarray.DataArray are supported for now.")
168+
169+
170+
# def test() -> Dataset | DataArray:
171+
# ds = rioxarray.open_rasterio(
172+
# "https://s2downloads.eox.at/demo/EOxCloudless/2020/rgbnir/s2cloudless2020-16bits_sinlge-file_z0-4.tif"
173+
# )
174+
# return explore(ds)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ENDPOINT_BASE = "titiler"

jupyter_server_titiler/routes.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
import json
22

3-
from jupyter_server.base.handlers import APIHandler
3+
from jupyter_server.base.handlers import JupyterHandler
44
from jupyter_server.utils import url_path_join
55
import tornado
66

7-
class HelloRouteHandler(APIHandler):
8-
# The following decorator should be present on all verb methods (head, get, post,
9-
# patch, put, delete, options) to ensure only authorized user can request the
10-
# Jupyter server
7+
from constants import ENDPOINT_BASE
8+
9+
10+
class TiTilerRouteHandler(JupyterHandler):
11+
"""How does this handler work?"""
12+
1113
@tornado.web.authenticated
12-
def get(self):
13-
self.finish(json.dumps({
14-
"data": "This is /titiler/get-example endpoint!"
15-
}))
14+
async def get(self, path):
15+
params = {key: val[0].decode() for key, val in self.request.arguments.items()}
16+
server_url = params.pop("server_url")
17+
async with httpx.AsyncClient() as client:
18+
r = await client.get(f"{server_url}/{path}", params=params)
19+
self.write(r.content)
1620

1721

1822
def setup_handlers(web_app):
1923
host_pattern = ".*$"
2024

2125
base_url = web_app.settings["base_url"]
22-
route_pattern = url_path_join(base_url, "titiler", "get-example")
23-
handlers = [(route_pattern, RouteHandler)]
26+
titiler_route_pattern = url_path_join(base_url, ENDPOINT_BASE)
27+
handlers = [(titiler_route_pattern, TiTilerRouteHandler)]
2428
web_app.add_handlers(host_pattern, handlers)

jupyter_server_titiler/server.py

Whitespace-only changes.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ classifiers = [
2525
dependencies = [
2626
"jupyter_server>=2.4.0,<3",
2727
"titiler",
28-
"uvicorn",
28+
"anyio",
29+
"anycorn",
2930
"rioxarray",
3031
"geopandas",
3132
]

0 commit comments

Comments
 (0)