Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ declare module "replicate" {
response: Response;
}

export interface FileOutput extends ReadableStream {
blob(): Promise<Blob>;
url(): URL;
toString(): string;
}

export interface Account {
type: "user" | "organization";
username: string;
Expand Down Expand Up @@ -137,6 +143,7 @@ declare module "replicate" {
init?: RequestInit
) => Promise<Response>;
fileEncodingStrategy?: FileEncodingStrategy;
useFileOutput?: boolean;
});

auth: string;
Expand Down
17 changes: 15 additions & 2 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const ApiError = require("./lib/error");
const ModelVersionIdentifier = require("./lib/identifier");
const { createReadableStream } = require("./lib/stream");
const { createReadableStream, createFileOutput } = require("./lib/stream");
const {
transform,
withAutomaticRetries,
validateWebhook,
parseProgressFromLogs,
Expand Down Expand Up @@ -47,6 +48,7 @@ class Replicate {
* @param {string} options.userAgent - Identifier of your app
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
* @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false.
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
*/
constructor(options = {}) {
Expand All @@ -58,6 +60,7 @@ class Replicate {
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
this.fetch = options.fetch || globalThis.fetch;
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
this.useFileOutput = options.useFileOutput ?? false;

this.accounts = {
current: accounts.current.bind(this),
Expand Down Expand Up @@ -196,7 +199,17 @@ class Replicate {
throw new Error(`Prediction failed: ${prediction.error}`);
}

return prediction.output;
return transform(prediction.output, (value) => {
if (
typeof value === "string" &&
(value.startsWith("https:") || value.startsWith("data:"))
) {
return this.useFileOutput
? createFileOutput({ url: value, fetch: this.fetch })
: value;
}
return value;
});
}

/**
Expand Down
199 changes: 198 additions & 1 deletion index.test.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { expect, jest, test } from "@jest/globals";
import Replicate, {
ApiError,
FileOutput,
Model,
Prediction,
validateWebhook,
parseProgressFromLogs,
} from "replicate";
import nock from "nock";
import { Readable } from "node:stream";
import { createReadableStream } from "./lib/stream";

let client: Replicate;
Expand Down Expand Up @@ -1562,6 +1562,203 @@ describe("Replicate client", () => {

scope.done();
});

test("returns FileOutput for URLs when useFileOutput is true", async () => {
client = new Replicate({ auth: "foo", useFileOutput: true });

nock(BASE_URL)
.post("/predictions")
.reply(201, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "starting",
logs: null,
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "succeeded",
output: "https://example.com",
logs: [].join("\n"),
});

nock("https://example.com")
.get("/")
.reply(200, "hello world", { "Content-Type": "text/plain" });

const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
input: { text: "Hello, world!" },
}
)) as FileOutput;

expect(output).toBeInstanceOf(ReadableStream);
expect(output.url()).toEqual(new URL("https://example.com"));

const blob = await output.blob();
expect(blob.type).toEqual("text/plain");
expect(blob.arrayBuffer()).toEqual(
new Blob(["Hello, world!"]).arrayBuffer()
);
});

test("returns FileOutput for URLs when useFileOutput is true - acts like string", async () => {
client = new Replicate({ auth: "foo", useFileOutput: true });

nock(BASE_URL)
.post("/predictions")
.reply(201, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "starting",
logs: null,
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "succeeded",
output: "https://example.com",
logs: [].join("\n"),
});

nock("https://example.com")
.get("/")
.reply(200, "hello world", { "Content-Type": "text/plain" });

const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
input: { text: "Hello, world!" },
}
)) as unknown as string;

expect(fetch(output).then((r) => r.text())).resolves.toEqual(
"hello world"
);
});

test("returns FileOutput for URLs when useFileOutput is true - array output", async () => {
client = new Replicate({ auth: "foo", useFileOutput: true });

nock(BASE_URL)
.post("/predictions")
.reply(201, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "starting",
logs: null,
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "succeeded",
output: ["https://example.com"],
logs: [].join("\n"),
});

nock("https://example.com")
.get("/")
.reply(200, "hello world", { "Content-Type": "text/plain" });

const [output] = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
input: { text: "Hello, world!" },
}
)) as FileOutput[];

expect(output).toBeInstanceOf(ReadableStream);
expect(output.url()).toEqual(new URL("https://example.com"));

const blob = await output.blob();
expect(blob.type).toEqual("text/plain");
expect(blob.arrayBuffer()).toEqual(
new Blob(["Hello, world!"]).arrayBuffer()
);
});

test("returns FileOutput for URLs when useFileOutput is true - data uri", async () => {
client = new Replicate({ auth: "foo", useFileOutput: true });

nock(BASE_URL)
.post("/predictions")
.reply(201, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "starting",
logs: null,
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
logs: [].join("\n"),
})
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "succeeded",
output: "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==",
logs: [].join("\n"),
});

const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
input: { text: "Hello, world!" },
}
)) as FileOutput;

expect(output).toBeInstanceOf(ReadableStream);
expect(output.url()).toEqual(
new URL("data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==")
);

const blob = await output.blob();
expect(blob.type).toEqual("text/plain");
expect(blob.arrayBuffer()).toEqual(
new Blob(["Hello, world!"]).arrayBuffer()
);
});
});

describe("webhooks.default.secret.get", () => {
Expand Down
54 changes: 54 additions & 0 deletions lib/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,61 @@ function createReadableStream({ url, fetch, options = {} }) {
});
}

function createFileOutput({ url, fetch }) {
let type = "application/octet-stream";

class FileOutput extends ReadableStream {
async blob() {
const chunks = [];
for await (const chunk of this) {
chunks.push(chunk);
}
return new Blob(chunks, { type });
}

url() {
return new URL(url);
}

toString() {
return url;
}
}

return new FileOutput({
async start(controller) {
const response = await fetch(url);

if (!response.ok) {
const text = await response.text();
const request = new Request(url, init);
controller.error(
new ApiError(
`Request to ${url} failed with status ${response.status}: ${text}`,
request,
response
)
);
}

if (response.headers.get("Content-Type")) {
type = response.headers.get("Content-Type");
}

try {
for await (const chunk of streamAsyncIterator(response.body)) {
controller.enqueue(chunk);
}
controller.close();
} catch (err) {
controller.error(err);
}
},
});
}

module.exports = {
createFileOutput,
createReadableStream,
ServerSentEvent,
};
1 change: 1 addition & 0 deletions lib/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ async function* streamAsyncIterator(stream) {
}

module.exports = {
transform,
transformFileInputs,
validateWebhook,
withAutomaticRetries,
Expand Down
Loading