Skip to content

Commit 7bfee3f

Browse files
committed
Fix dereferencing in Function and AsyncFunction
We were accidentally stripping out all components in the schemas due to the Input/Output being referenced by PredictionRequest and PredictionResponse and those in turn being referenced by the various `path` entries. This commit ensures we retain Input and Output and cleans up the object to only contain relevant fields.
1 parent b76ff18 commit 7bfee3f

File tree

2 files changed

+101
-1
lines changed

2 files changed

+101
-1
lines changed

replicate/use.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,17 @@ def _resolve_ref(obj: Any) -> Any:
190190

191191
result = _resolve_ref(dereferenced)
192192

193-
# Filter out any references that have now been referenced.
193+
# Remove "paths" as these aren't relevant to models.
194+
result["paths"] = {}
195+
196+
# Retain Input and Output schemas as these are important.
197+
dereferenced_refs.discard("Input")
198+
dereferenced_refs.discard("Output")
199+
200+
dereferenced_refs.discard("TrainingInput")
201+
dereferenced_refs.discard("TrainingOutput")
202+
203+
# Filter out any remaining references that have been inlined.
194204
result["components"]["schemas"] = {
195205
k: v
196206
for k, v in result["components"]["schemas"].items()

tests/test_use.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,82 @@ def create_mock_version(version_overrides=None, version_id="xyz123"):
7878
"required": ["prompt"],
7979
},
8080
"Output": {"type": "string", "title": "Output"},
81+
"PredictionResponse": {
82+
"type": "object",
83+
"title": "PredictionResponse",
84+
"properties": {
85+
"id": {"type": "string", "title": "Id"},
86+
"logs": {"type": "string", "title": "Logs", "default": ""},
87+
"error": {"type": "string", "title": "Error"},
88+
"input": {"$ref": "#/components/schemas/Input"},
89+
"output": {"$ref": "#/components/schemas/Output"},
90+
"status": {"$ref": "#/components/schemas/Status"},
91+
"metrics": {"type": "object", "title": "Metrics"},
92+
"version": {"type": "string", "title": "Version"},
93+
"created_at": {
94+
"type": "string",
95+
"title": "Created At",
96+
"format": "date-time",
97+
},
98+
"started_at": {
99+
"type": "string",
100+
"title": "Started At",
101+
"format": "date-time",
102+
},
103+
"completed_at": {
104+
"type": "string",
105+
"title": "Completed At",
106+
"format": "date-time",
107+
},
108+
},
109+
},
110+
"PredictionRequest": {
111+
"type": "object",
112+
"title": "PredictionRequest",
113+
"properties": {
114+
"id": {"type": "string", "title": "Id"},
115+
"input": {"$ref": "#/components/schemas/Input"},
116+
"webhook": {
117+
"type": "string",
118+
"title": "Webhook",
119+
"format": "uri",
120+
"maxLength": 65536,
121+
"minLength": 1,
122+
},
123+
"created_at": {
124+
"type": "string",
125+
"title": "Created At",
126+
"format": "date-time",
127+
},
128+
"output_file_prefix": {
129+
"type": "string",
130+
"title": "Output File Prefix",
131+
},
132+
"webhook_events_filter": {
133+
"type": "array",
134+
"items": {"$ref": "#/components/schemas/WebhookEvent"},
135+
"default": ["start", "output", "logs", "completed"],
136+
},
137+
},
138+
},
139+
"Status": {
140+
"enum": [
141+
"starting",
142+
"processing",
143+
"succeeded",
144+
"canceled",
145+
"failed",
146+
],
147+
"type": "string",
148+
"title": "Status",
149+
"description": "An enumeration.",
150+
},
151+
"WebhookEvent": {
152+
"enum": ["start", "output", "logs", "completed"],
153+
"type": "string",
154+
"title": "WebhookEvent",
155+
"description": "An enumeration.",
156+
},
81157
}
82158
},
83159
},
@@ -345,6 +421,7 @@ async def test_use_function_openapi_schema_dereferenced(client_mode):
345421
"openapi_schema": {
346422
"components": {
347423
"schemas": {
424+
"Extra": {"type": "object"},
348425
"Output": {"$ref": "#/components/schemas/ModelOutput"},
349426
"ModelOutput": {
350427
"type": "object",
@@ -374,6 +451,12 @@ async def test_use_function_openapi_schema_dereferenced(client_mode):
374451
else:
375452
schema = hotdog_detector.openapi_schema()
376453

454+
assert schema["components"]["schemas"]["Extra"] == {"type": "object"}
455+
assert schema["components"]["schemas"]["Input"] == {
456+
"type": "object",
457+
"properties": {"prompt": {"type": "string", "title": "Prompt"}},
458+
"required": ["prompt"],
459+
}
377460
assert schema["components"]["schemas"]["Output"] == {
378461
"type": "object",
379462
"properties": {
@@ -386,7 +469,14 @@ async def test_use_function_openapi_schema_dereferenced(client_mode):
386469
},
387470
}
388471

472+
# Assert everything else is stripped out
473+
assert schema["paths"] == {}
474+
475+
assert "PredictionRequest" not in schema["components"]["schemas"]
476+
assert "PredictionResponse" not in schema["components"]["schemas"]
389477
assert "ModelOutput" not in schema["components"]["schemas"]
478+
assert "Status" not in schema["components"]["schemas"]
479+
assert "WebhookEvent" not in schema["components"]["schemas"]
390480

391481

392482
@pytest.mark.asyncio

0 commit comments

Comments
 (0)