From 31bf7676c330a6d04f641e809e95044c4c25f139 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 11:39:27 +0100 Subject: [PATCH 1/5] Add new FileOutput type to stream.js This takes a url and creates a `ReadableStream` that has two additional methods `url()` and `blob()` for easily working with remote URLs or base64 encoded data. --- lib/stream.js | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/lib/stream.js b/lib/stream.js index 2e0bbde..875e020 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -98,7 +98,61 @@ function createReadableStream({ url, fetch, options = {} }) { }); } +function createFileOutput({ url, fetch }) { + let type = "application/octet-stream"; + + class FileOutput extends ReadableStream { + async blob() { + const chunks = []; + for await (const chunk of this) { + chunks.push(chunk); + } + return new Blob(chunks, { type }); + } + + url() { + return new URL(url); + } + + toString() { + return url; + } + } + + return new FileOutput({ + async start(controller) { + const response = await fetch(url); + + if (!response.ok) { + const text = await response.text(); + const request = new Request(url, init); + controller.error( + new ApiError( + `Request to ${url} failed with status ${response.status}: ${text}`, + request, + response + ) + ); + } + + if (response.headers.get("Content-Type")) { + type = response.headers.get("Content-Type"); + } + + try { + for await (const chunk of streamAsyncIterator(response.body)) { + controller.enqueue(chunk); + } + controller.close(); + } catch (err) { + controller.error(err); + } + }, + }); +} + module.exports = { + createFileOutput, createReadableStream, ServerSentEvent, }; From 16dae4dec1f9aea0bc3598dd04eafba6b68b849f Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 11:39:48 +0100 Subject: [PATCH 2/5] Export `transform` function from utils.js --- lib/util.js | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/util.js b/lib/util.js index daecfd1..27afd46 100644 --- a/lib/util.js +++ b/lib/util.js @@ -452,6 +452,7 @@ async function* streamAsyncIterator(stream) { } module.exports = { + transform, transformFileInputs, validateWebhook, withAutomaticRetries, From 0d133ca0e65da1f50a0503d74d0e6394047a1d5a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 11:42:42 +0100 Subject: [PATCH 3/5] Return FileOutput from `run()` function This is currently behind the `useFileOutput` flag provided to the Replicate constructor. This allows us to test the feature before rolling it out more widely. When enabled any URLs or data-uris will be converted into a FileOutput type. This is essentially a `ReadableStream` that has two additional methods `url()` to return the underlying URL and `blob()` which will return a `Blob()` object with the file data loaded into memory. The intention here is to make it easier to work with file outputs and allows us to optimize the delivery of file assets to the client in future iterations. --- index.d.ts | 7 ++ index.js | 17 ++++- index.test.ts | 199 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 220 insertions(+), 3 deletions(-) diff --git a/index.d.ts b/index.d.ts index 25cd51a..4801654 100644 --- a/index.d.ts +++ b/index.d.ts @@ -8,6 +8,12 @@ declare module "replicate" { response: Response; } + export interface FileOutput extends ReadableStream { + blob(): Promise; + url(): URL; + toString(): string; + } + export interface Account { type: "user" | "organization"; username: string; @@ -137,6 +143,7 @@ declare module "replicate" { init?: RequestInit ) => Promise; fileEncodingStrategy?: FileEncodingStrategy; + useFileOutput?: boolean; }); auth: string; diff --git a/index.js b/index.js index 6b35db6..4b45d25 100644 --- a/index.js +++ b/index.js @@ -1,7 +1,8 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); -const { createReadableStream } = require("./lib/stream"); +const { createReadableStream, createFileOutput } = require("./lib/stream"); const { + transform, withAutomaticRetries, validateWebhook, parseProgressFromLogs, @@ -47,6 +48,7 @@ class Replicate { * @param {string} options.userAgent - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + * @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false. * @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use */ constructor(options = {}) { @@ -58,6 +60,7 @@ class Replicate { this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; + this.useFileOutput = options.useFileOutput ?? false; this.accounts = { current: accounts.current.bind(this), @@ -196,7 +199,17 @@ class Replicate { throw new Error(`Prediction failed: ${prediction.error}`); } - return prediction.output; + return transform(prediction.output, (value) => { + if ( + typeof value === "string" && + (value.startsWith("https:") || value.startsWith("data:")) + ) { + return this.useFileOutput + ? createFileOutput({ url: value, fetch: this.fetch }) + : value; + } + return value; + }); } /** diff --git a/index.test.ts b/index.test.ts index 7f9fcf2..5ca3e54 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,13 +1,13 @@ import { expect, jest, test } from "@jest/globals"; import Replicate, { ApiError, + FileOutput, Model, Prediction, validateWebhook, parseProgressFromLogs, } from "replicate"; import nock from "nock"; -import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; let client: Replicate; @@ -1562,6 +1562,203 @@ describe("Replicate client", () => { scope.done(); }); + + test("returns FileOutput for URLs when useFileOutput is true", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "https://example.com", + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual(new URL("https://example.com")); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - acts like string", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "https://example.com", + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as unknown as string; + + expect(fetch(output).then((r) => r.text())).resolves.toEqual( + "hello world" + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - array output", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: ["https://example.com"], + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const [output] = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput[]; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual(new URL("https://example.com")); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - data uri", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", + logs: [].join("\n"), + }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual( + new URL("data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==") + ); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); }); describe("webhooks.default.secret.get", () => { From 5659b6c4ee033bf0198e136f7d2f4766982c09ba Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:45:47 -0700 Subject: [PATCH 4/5] Add documentation comments to createFileOutput declaration --- lib/stream.js | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/stream.js b/lib/stream.js index 875e020..2f72e2c 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -98,6 +98,15 @@ function createReadableStream({ url, fetch, options = {} }) { }); } +/** + * Create a new readable stream for an output file + * created by running a Replicate model. + * + * @param {object} config + * @param {string} config.url The URL to connect to. + * @param {typeof fetch} [config.fetch] The URL to connect to. + * @returns {ReadableStream} + */ function createFileOutput({ url, fetch }) { let type = "application/octet-stream"; From 4bb5b2884b17c1c1b509b7ac709dcf1672d5edd1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:48:39 -0700 Subject: [PATCH 5/5] Replace use of ?? with || --- index.js | 4 ++-- lib/util.js | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/index.js b/index.js index 4b45d25..f2c3e1e 100644 --- a/index.js +++ b/index.js @@ -59,8 +59,8 @@ class Replicate { options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; - this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; - this.useFileOutput = options.useFileOutput ?? false; + this.fileEncodingStrategy = options.fileEncodingStrategy || "default"; + this.useFileOutput = options.useFileOutput || false; this.accounts = { current: accounts.current.bind(this), diff --git a/lib/util.js b/lib/util.js index 27afd46..bd3c31e 100644 --- a/lib/util.js +++ b/lib/util.js @@ -318,7 +318,7 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { } const data = bytesToBase64(buffer); - mime = mime ?? "application/octet-stream"; + mime = mime || "application/octet-stream"; return `data:${mime};base64,${data}`; });