Skip to content

Commit 6e544d0

Browse files
sitatecaron
andauthored
Fix url key of createFileOutput options for streaming (#350)
Fix bug with broken `FileOutput` object being created as part of the streaming API as the wrong value for URL was passed to the constructor. This also extends the `replicate.stream()` interface to accept a `useFileOutput` configuration object. Lastly, we now use the stream URL itself to detect if we should convert URL objects into `FileOutput`. --------- Co-authored-by: Aron Carroll <[email protected]>
1 parent 0eac811 commit 6e544d0

File tree

5 files changed

+121
-13
lines changed

5 files changed

+121
-13
lines changed

biome.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
{
22
"$schema": "https://biomejs.dev/schemas/1.0.0/schema.json",
33
"files": {
4-
"ignore": [
5-
".wrangler",
6-
"node_modules",
7-
"vendor/*"
8-
]
4+
"ignore": [".wrangler", "node_modules", "vendor/*"]
95
},
106
"formatter": {
117
"indentStyle": "space",

index.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ declare module "replicate" {
208208
webhook?: string;
209209
webhook_events_filter?: WebhookEventType[];
210210
signal?: AbortSignal;
211+
useFileOutput?: boolean;
211212
}
212213
): AsyncGenerator<ServerSentEvent>;
213214

index.js

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,12 @@ class Replicate {
315315
* @yields {ServerSentEvent} Each streamed event from the prediction
316316
*/
317317
async *stream(ref, options) {
318-
const { wait, signal, ...data } = options;
318+
const {
319+
wait,
320+
signal,
321+
useFileOutput = this.useFileOutput,
322+
...data
323+
} = options;
319324

320325
const identifier = ModelVersionIdentifier.parse(ref);
321326

@@ -338,7 +343,10 @@ class Replicate {
338343
const stream = createReadableStream({
339344
url: prediction.urls.stream,
340345
fetch: this.fetch,
341-
...(signal ? { options: { signal } } : {}),
346+
options: {
347+
useFileOutput,
348+
...(signal ? { signal } : {}),
349+
},
342350
});
343351

344352
yield* streamAsyncIterator(stream);

index.test.ts

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,8 +1905,12 @@ describe("Replicate client", () => {
19051905
// Continue with tests for other methods
19061906

19071907
describe("createReadableStream", () => {
1908-
function createStream(body: string | ReadableStream, status = 200) {
1909-
const streamEndpoint = "https://stream.replicate.com/fake_stream";
1908+
function createStream(
1909+
body: string | ReadableStream,
1910+
status = 200,
1911+
streamEndpoint = "https://stream.replicate.com/fake_stream",
1912+
options: { useFileOutput?: boolean } = {}
1913+
) {
19101914
const fetch = jest.fn((url) => {
19111915
if (url !== streamEndpoint) {
19121916
throw new Error(`Unmocked call to fetch() with url: ${url}`);
@@ -1916,6 +1920,7 @@ describe("Replicate client", () => {
19161920
return createReadableStream({
19171921
url: streamEndpoint,
19181922
fetch: fetch as any,
1923+
options,
19191924
});
19201925
}
19211926

@@ -2192,5 +2197,95 @@ describe("Replicate client", () => {
21922197
);
21932198
expect(await iterator.next()).toEqual({ done: true });
21942199
});
2200+
2201+
describe("file streams", () => {
2202+
test("emits FileOutput objects", async () => {
2203+
const stream = createStream(
2204+
`
2205+
event: output
2206+
id: EVENT_1
2207+
data: data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==
2208+
2209+
event: output
2210+
id: EVENT_2
2211+
data: https://delivery.replicate.com/my_file.png
2212+
2213+
event: done
2214+
id: EVENT_3
2215+
data: {}
2216+
2217+
`.replace(/^[ ]+/gm, ""),
2218+
200,
2219+
"https://stream.replicate.com/v1/files/abcd"
2220+
);
2221+
2222+
const iterator = stream[Symbol.asyncIterator]();
2223+
const { value: event1 } = await iterator.next();
2224+
expect(event1.data).toBeInstanceOf(ReadableStream);
2225+
expect(event1.data.url().href).toEqual(
2226+
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
2227+
);
2228+
2229+
const { value: event2 } = await iterator.next();
2230+
expect(event2.data).toBeInstanceOf(ReadableStream);
2231+
expect(event2.data.url().href).toEqual(
2232+
"https://delivery.replicate.com/my_file.png"
2233+
);
2234+
2235+
expect(await iterator.next()).toEqual({
2236+
done: false,
2237+
value: { event: "done", id: "EVENT_3", data: "{}" },
2238+
});
2239+
2240+
expect(await iterator.next()).toEqual({ done: true });
2241+
});
2242+
2243+
test("emits strings when useFileOutput is false", async () => {
2244+
const stream = createStream(
2245+
`
2246+
event: output
2247+
id: EVENT_1
2248+
data: data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==
2249+
2250+
event: output
2251+
id: EVENT_2
2252+
data: https://delivery.replicate.com/my_file.png
2253+
2254+
event: done
2255+
id: EVENT_3
2256+
data: {}
2257+
2258+
`.replace(/^[ ]+/gm, ""),
2259+
200,
2260+
"https://stream.replicate.com/v1/files/abcd",
2261+
{ useFileOutput: false }
2262+
);
2263+
2264+
const iterator = stream[Symbol.asyncIterator]();
2265+
2266+
expect(await iterator.next()).toEqual({
2267+
done: false,
2268+
value: {
2269+
event: "output",
2270+
id: "EVENT_1",
2271+
data: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
2272+
},
2273+
});
2274+
expect(await iterator.next()).toEqual({
2275+
done: false,
2276+
value: {
2277+
event: "output",
2278+
id: "EVENT_2",
2279+
data: "https://delivery.replicate.com/my_file.png",
2280+
},
2281+
});
2282+
expect(await iterator.next()).toEqual({
2283+
done: false,
2284+
value: { event: "done", id: "EVENT_3", data: "{}" },
2285+
});
2286+
2287+
expect(await iterator.next()).toEqual({ done: true });
2288+
});
2289+
});
21952290
});
21962291
});

lib/stream.js

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ServerSentEvent {
5353
*/
5454
function createReadableStream({ url, fetch, options = {} }) {
5555
const { useFileOutput = true, headers = {}, ...initOptions } = options;
56+
const shouldProcessFileOutput = useFileOutput && isFileStream(url);
5657

5758
return new ReadableStream({
5859
async start(controller) {
@@ -89,11 +90,11 @@ function createReadableStream({ url, fetch, options = {} }) {
8990

9091
let data = event.data;
9192
if (
92-
useFileOutput &&
93-
typeof data === "string" &&
94-
(data.startsWith("https:") || data.startsWith("data:"))
93+
event.event === "output" &&
94+
shouldProcessFileOutput &&
95+
typeof data === "string"
9596
) {
96-
data = createFileOutput({ data, fetch });
97+
data = createFileOutput({ url: data, fetch });
9798
}
9899
controller.enqueue(new ServerSentEvent(event.event, data, event.id));
99100

@@ -169,6 +170,13 @@ function createFileOutput({ url, fetch }) {
169170
});
170171
}
171172

173+
function isFileStream(url) {
174+
try {
175+
return new URL(url).pathname.startsWith("/v1/files/");
176+
} catch {}
177+
return false;
178+
}
179+
172180
module.exports = {
173181
createFileOutput,
174182
createReadableStream,

0 commit comments

Comments
 (0)