Skip to content

Commit 4d8973c

Browse files
committed
Initial UDF server implementation
1 parent c9b6e69 commit 4d8973c

20 files changed

+795
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.venv/
2+
__pycache__/
3+
.pytest_cache/
4+
*.egg-info/

README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,41 @@
1-
# ai-server
1+
# ai-server
2+
3+
Stage aware Databend UDF server providing storage centric helpers used by the AI
4+
extensions. The initial release ships with:
5+
6+
- `list_stage_files`: enumerate objects inside an external S3 stage via
7+
[Apache OpenDAL](https://opendal.apache.org/docs/python).
8+
- `read_pdf`: pull down a PDF object from a stage and return its extracted text
9+
as a `STRING`.
10+
- `read_docx`: fetch Microsoft Word files (`.docx`) and expose their textual
11+
content as a `STRING`.
12+
13+
## Getting started
14+
15+
```bash
16+
python3 -m venv .venv
17+
source .venv/bin/activate
18+
pip install -e .[dev]
19+
ai-udf-server --port 8815 --metrics-port 9091
20+
```
21+
22+
Databend can now connect to the running Flight server and call the registered
23+
functions. `list_stage_files` expects a `STAGE_LOCATION` argument supplied by
24+
Databend plus a numeric limit. Example:
25+
26+
```sql
27+
SELECT
28+
list_stage_files(
29+
@stage_location,
30+
50
31+
);
32+
```
33+
34+
To retrieve file contents call the document readers with an explicit path. The
35+
PDF reader optionally accepts `NULL` for the `max_pages` argument to stream the
36+
entire document:
37+
38+
```sql
39+
SELECT read_pdf(@stage_location, 'inbox/manual.pdf', NULL);
40+
SELECT read_docx(@stage_location, 'reports/summary.docx');
41+
```

ai_server/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""AI-enhanced Databend UDF server utilities."""
2+
3+
from .server import create_server
4+
5+
__all__ = ["create_server"]

ai_server/main.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Command line entrypoint for the Databend AI UDF server."""
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
import logging
7+
import signal
8+
import sys
9+
from contextlib import suppress
10+
from typing import Optional
11+
12+
from prometheus_client import start_http_server as start_prometheus_server
13+
14+
from ai_server.server import create_server
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def _parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(description="Databend AI UDF server")
21+
parser.add_argument("--host", default="0.0.0.0", help="Bind address for the gRPC server")
22+
parser.add_argument("--port", type=int, default=8815, help="Port for the gRPC server")
23+
parser.add_argument(
24+
"--metrics-port",
25+
type=int,
26+
default=None,
27+
help="Port for Prometheus metrics exporter (disabled when omitted).",
28+
)
29+
parser.add_argument(
30+
"--log-level",
31+
default="INFO",
32+
help="Python logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
33+
)
34+
return parser.parse_args(argv)
35+
36+
37+
def _configure_logging(level: str) -> None:
38+
logging.basicConfig(
39+
level=getattr(logging, level.upper(), logging.INFO),
40+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
41+
)
42+
43+
44+
def main(argv: Optional[list[str]] = None) -> int:
45+
args = _parse_args(argv)
46+
_configure_logging(args.log_level)
47+
48+
if args.metrics_port is not None:
49+
start_prometheus_server(args.metrics_port)
50+
logger.info("Prometheus metrics server started on port %s", args.metrics_port)
51+
52+
server = create_server(host=args.host, port=args.port, metric_port=args.metrics_port)
53+
logger.info("Starting Databend AI UDF server on %s:%s", args.host, args.port)
54+
55+
# Handle shutdown gracefully to ensure we stop serving when receiving termination signals.
56+
stop_event = getattr(server, "stopped", None)
57+
58+
def _handle_signal(signum, frame): # noqa: ANN001 - signature dictated by signal library
59+
logger.info("Received signal %s, shutting down.", signum)
60+
with suppress(Exception):
61+
server.shutdown()
62+
if stop_event is not None:
63+
stop_event.set()
64+
65+
signal.signal(signal.SIGINT, _handle_signal)
66+
signal.signal(signal.SIGTERM, _handle_signal)
67+
68+
try:
69+
server.serve()
70+
except KeyboardInterrupt:
71+
logger.info("Interrupted, stopping server.")
72+
return 0
73+
74+
75+
if __name__ == "__main__":
76+
sys.exit(main())

ai_server/server.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Entrypoint for the Databend AI UDF server."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Optional
6+
7+
from databend_udf import UDFServer
8+
9+
from ai_server.udfs import list_stage_files, read_docx, read_pdf
10+
11+
12+
def create_server(
13+
host: str = "0.0.0.0", port: int = 8815, metric_port: Optional[int] = None
14+
) -> UDFServer:
15+
"""
16+
Create a configured UDF server instance.
17+
18+
Parameters
19+
----------
20+
host:
21+
Bind address for the Flight server.
22+
port:
23+
Bind port for the Flight server.
24+
metric_port:
25+
Optional metrics port for Prometheus exporter. When provided the
26+
databend-udf server will expose metrics via Prometheus.
27+
"""
28+
location = f"{host}:{port}"
29+
metric_location = (
30+
f"{host}:{metric_port}" if metric_port is not None else None
31+
)
32+
server = UDFServer(location, metric_location=metric_location)
33+
server.add_function(list_stage_files)
34+
server.add_function(read_pdf)
35+
server.add_function(read_docx)
36+
return server

ai_server/stages/operator.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Helpers to translate Databend stage metadata into OpenDAL operators."""
2+
3+
from __future__ import annotations
4+
5+
import hashlib
6+
import json
7+
import logging
8+
import threading
9+
from typing import Any, Dict, Mapping, Tuple
10+
11+
from databend_udf import StageLocation
12+
from opendal import Operator, exceptions as opendal_exceptions
13+
14+
logger = logging.getLogger(__name__)
15+
16+
_OPERATOR_CACHE: Dict[str, Operator] = {}
17+
_CACHE_LOCK = threading.Lock()
18+
19+
20+
class StageConfigurationError(RuntimeError):
21+
"""Raised when an unsupported or invalid stage configuration is encountered."""
22+
23+
24+
def _normalize_bool(value: Any) -> bool:
25+
if isinstance(value, bool):
26+
return value
27+
if value is None:
28+
return False
29+
if isinstance(value, (int, float)):
30+
return bool(value)
31+
if isinstance(value, str):
32+
normalized = value.strip().lower()
33+
if normalized in {"1", "true", "t", "yes", "y", "on"}:
34+
return True
35+
if normalized in {"0", "false", "f", "no", "n", "off"}:
36+
return False
37+
return bool(value)
38+
39+
40+
def _first_present(storage: Mapping[str, Any], *keys: str) -> Any:
41+
for key in keys:
42+
if key in storage:
43+
value = storage[key]
44+
if value not in (None, "", {}):
45+
return value
46+
return None
47+
48+
49+
def _build_s3_options(storage: Mapping[str, Any]) -> Dict[str, Any]:
50+
bucket = _first_present(storage, "bucket", "name")
51+
if not bucket:
52+
raise StageConfigurationError("S3 stage is missing bucket configuration")
53+
54+
region = _first_present(storage, "region")
55+
endpoint = _first_present(storage, "endpoint", "endpoint_url")
56+
access_key = _first_present(storage, "access_key_id", "aws_key_id")
57+
secret_key = _first_present(storage, "secret_access_key", "aws_secret_key")
58+
security_token = _first_present(storage, "security_token", "session_token", "aws_token")
59+
master_key = _first_present(storage, "master_key")
60+
root = _first_present(storage, "root")
61+
role_arn = _first_present(storage, "role_arn", "aws_role_arn")
62+
external_id = _first_present(storage, "external_id", "aws_external_id")
63+
virtual_host_style = storage.get("enable_virtual_host_style")
64+
disable_loader = storage.get("disable_credential_loader")
65+
66+
options: Dict[str, Any] = {"bucket": bucket}
67+
68+
if region:
69+
options["region"] = region
70+
else:
71+
# Databend stages may skip region when working with S3 compatible endpoints.
72+
# OpenDAL requires a region, default to us-east-1 if not provided.
73+
options["region"] = "us-east-1"
74+
75+
if endpoint:
76+
options["endpoint"] = endpoint
77+
if access_key:
78+
options["access_key_id"] = access_key
79+
if secret_key:
80+
options["secret_access_key"] = secret_key
81+
if security_token:
82+
options["security_token"] = security_token
83+
if master_key:
84+
options["master_key"] = master_key
85+
if root:
86+
options["root"] = root
87+
if role_arn:
88+
options["role_arn"] = role_arn
89+
if external_id:
90+
options["external_id"] = external_id
91+
if virtual_host_style is not None:
92+
options["enable_virtual_host_style"] = _normalize_bool(virtual_host_style)
93+
if disable_loader is not None:
94+
options["disable_credential_loader"] = _normalize_bool(disable_loader)
95+
96+
return options
97+
98+
99+
def _build_memory_options(_: Mapping[str, Any]) -> Dict[str, Any]:
100+
# Useful for local testing; not a Databend production configuration.
101+
return {}
102+
103+
104+
_STORAGE_BUILDERS: Dict[str, Any] = {"s3": _build_s3_options, "memory": _build_memory_options}
105+
106+
107+
def _cache_key(stage: StageLocation) -> str:
108+
payload = {
109+
"stage_name": stage.stage_name,
110+
"stage_type": stage.stage_type,
111+
"storage": stage.storage,
112+
}
113+
encoded = json.dumps(payload, sort_keys=True, default=str)
114+
return hashlib.sha256(encoded.encode("utf-8")).hexdigest()
115+
116+
117+
def _build_operator(stage: StageLocation) -> Operator:
118+
storage = stage.storage or {}
119+
storage_type = str(storage.get("type", "")).lower()
120+
121+
if storage_type not in _STORAGE_BUILDERS:
122+
raise StageConfigurationError(
123+
f"Unsupported stage storage type '{storage_type or 'unknown'}'"
124+
)
125+
126+
builder = _STORAGE_BUILDERS[storage_type]
127+
options = builder(storage)
128+
logger.debug(
129+
"Creating OpenDAL operator for stage '%s' with backend '%s'",
130+
stage.stage_name,
131+
storage_type,
132+
)
133+
try:
134+
return Operator(storage_type, **options)
135+
except opendal_exceptions.Error as exc:
136+
raise StageConfigurationError(
137+
f"Failed to construct operator for stage '{stage.stage_name}': {exc}"
138+
) from exc
139+
140+
141+
def get_operator(stage: StageLocation) -> Operator:
142+
"""Return a cached OpenDAL operator for the given stage."""
143+
144+
cache_key = _cache_key(stage)
145+
with _CACHE_LOCK:
146+
operator = _OPERATOR_CACHE.get(cache_key)
147+
if operator is None:
148+
operator = _build_operator(stage)
149+
_OPERATOR_CACHE[cache_key] = operator
150+
return operator
151+
152+
153+
def clear_operator_cache() -> None:
154+
"""Utility that clears cached operators, primarily for testing."""
155+
156+
with _CACHE_LOCK:
157+
_OPERATOR_CACHE.clear()
158+
159+
160+
def resolve_stage_subpath(stage: StageLocation, path: str | None = None) -> str:
161+
"""
162+
Combine the stage's relative path with a user-provided path.
163+
164+
The resulting string is relative to the OpenDAL operator's configured root.
165+
"""
166+
167+
def _normalize(component: str | None) -> Tuple[str, ...]:
168+
if not component:
169+
return ()
170+
parts = []
171+
for part in component.split("/"):
172+
chunk = part.strip()
173+
if not chunk or chunk == ".":
174+
continue
175+
if chunk == "..":
176+
raise ValueError("Stage paths must not contain '..'")
177+
parts.append(chunk)
178+
return tuple(parts)
179+
180+
base_parts = _normalize(stage.relative_path)
181+
extra_parts = _normalize(path)
182+
full_parts = base_parts + extra_parts
183+
if not full_parts:
184+
return ""
185+
return "/".join(full_parts)
186+
187+
188+
def as_directory_path(path: str) -> str:
189+
"""Ensure the provided path represents a directory for list operations."""
190+
if not path:
191+
return ""
192+
return path if path.endswith("/") else f"{path}/"

ai_server/udfs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Collection of UDF implementations exposed by the AI server."""
2+
3+
from .stage import list_stage_files
4+
from .files import read_docx, read_pdf
5+
6+
__all__ = ["list_stage_files", "read_pdf", "read_docx"]

0 commit comments

Comments
 (0)