diff --git a/examples/ort-raw-session/index.ts b/examples/ort-raw-session/index.ts new file mode 100644 index 000000000..a0b06c414 --- /dev/null +++ b/examples/ort-raw-session/index.ts @@ -0,0 +1,119 @@ +/* +const modelUrl = 'https://huggingface.co/kalleby/hp-to-miles/resolve/main/model.onnx?download=true'; +const modelConfigUrl = + 'https://huggingface.co/kalleby/hp-to-miles/resolve/main/config.json?download=true'; + +const model = await Supabase.ai.RawSession.fromUrl(modelUrl); +const modelConfig = await fetch(modelConfigUrl).then((r) => r.json()); + +Deno.serve(async (req: Request) => { + const params = new URL(req.url).searchParams; + const inputValue = parseInt(params.get('value')); + + const input = new Supabase.ai.RawTensor('float32', [inputValue], [1, 1]); + .minMaxNormalize(modelConfig.input.min, modelConfig.input.max); + + const output = await model.run({ + 'dense_dense1_input': input, + }); + + console.log('output', output); + + const outputTensor = output['dense_Dense4'] + .minMaxUnnormalize(modelConfig.label.min, modelConfig.label.max); + + return Response.json({ result: outputTensor.data }); +}); +*/ + +// transformers.js Compatible: +// import { Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.2'; +// const rawTensor = new Supabase.ai.RawTensor('string', urls, [urls.length]); +// console.log('raw tensor', rawTensor ); +// +// const tensor = new Tensor(rawTensor); +// console.log('hf tensor', tensor); +// +// 'hf tensor operations' +// tensor.min(); tensor.max(); tensor.norm() .... + +// const modelUrl = +// 'https://huggingface.co/pirocheto/phishing-url-detection/resolve/main/model.onnx?download=true'; + +/* +const { Tensor, RawSession } = Supabase.ai; + +const model = await RawSession.fromHuggingFace('pirocheto/phishing-url-detection', { + path: { + template: `{REPO_ID}/resolve/{REVISION}/{MODEL_FILE}?donwload=true`, + modelFile: 'model.onnx', + }, +}); + +console.log('session', model); + +Deno.serve(async (_req: Request) => { + const urls = [ + 'https://clubedemilhagem.com/home.php', + 'http://www.medicalnewstoday.com/articles/188939.php', + 'https://magalu-crediarioluiza.com/Produto_20203/produto.php?sku=1', + ]; + + const inputs = new Tensor('string', urls, [urls.length]); + console.log('tensor', inputs.data); + + const output = await model.run({ inputs }); + console.log(output); + + return Response.json({ result: output.probabilities }); +}); +*/ + +const { RawTensor, RawSession } = Supabase.ai; + +const session = await RawSession.fromHuggingFace( + "kallebysantos/vehicle-emission", + { + path: { + modelFile: "model.onnx", + }, + }, +); + +Deno.serve(async (_req: Request) => { + // sample data could be a JSON request + const carsBatchInput = [{ + "Model_Year": 2021, + "Engine_Size": 2.9, + "Cylinders": 6, + "Fuel_Consumption_in_City": 13.9, + "Fuel_Consumption_in_City_Hwy": 10.3, + "Fuel_Consumption_comb": 12.3, + "Smog_Level": 3, + }, { + "Model_Year": 2023, + "Engine_Size": 2.4, + "Cylinders": 4, + "Fuel_Consumption_in_City": 9.9, + "Fuel_Consumption_in_City_Hwy": 7.0, + "Fuel_Consumption_comb": 8.6, + "Smog_Level": 3, + }]; + + // Parsing objects to tensor input + const inputTensors: Record> = {}; + session.inputs.forEach((inputKey) => { + const values = carsBatchInput.map((item) => item[inputKey]); + + inputTensors[inputKey] = new RawTensor("float32", values, [ + values.length, + 1, + ]); + }); + + const { emissions } = await session.run(inputTensors); + console.log(emissions); + // [ 289.01, 199.53] + + return Response.json({ result: emissions }); +}); diff --git a/examples/text-to-audio/index.ts b/examples/text-to-audio/index.ts new file mode 100644 index 000000000..5014ac170 --- /dev/null +++ b/examples/text-to-audio/index.ts @@ -0,0 +1,111 @@ +import { PreTrainedTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.1"; + +// import 'phonemize' code from Kokoro.js repo +import { phonemize } from "./phonemizer.js"; + +const { RawTensor, RawSession } = Supabase.ai; + +/* NOTE: Reference [original paper](https://arxiv.org/pdf/2306.07691#Model%20Training): +> All datasets were resampled to 24 kHz to match LibriTTS, and the texts +> were converted into phonemes using phonemizer' +*/ +const SAMPLE_RATE = 24000; // 24 kHz + +/* NOTE: Reference [original paper](https://arxiv.org/pdf/2306.07691#Detailed%20Model%20Architectures): +> The size of s and c is 256 × 1 +*/ +const STYLE_DIM = 256; +const MODEL_ID = "onnx-community/Kokoro-82M-ONNX"; + +// https://huggingface.co/onnx-community/Kokoro-82M-ONNX#samples +const ALLOWED_VOICES = [ + "af_bella", + "af_nicole", + "af_sarah", + "af_sky", + "am_adam", + "am_michael", + "bf_emma", + "bf_isabella", + "bm_george", + "bm_lewis", +]; + +const session = await RawSession.fromHuggingFace(MODEL_ID); + +Deno.serve(async (req) => { + const params = new URL(req.url).searchParams; + const text = params.get("text") ?? "Hello from Supabase!"; + const voice = params.get("voice") ?? "af_bella"; + + if (!ALLOWED_VOICES.includes(voice)) { + return Response.json({ + error: `invalid voice '${voice}'`, + must_be_one_of: ALLOWED_VOICES, + }, { status: 400 }); + } + + const tokenizer = await loadTokenizer(); + const language = voice.at(0); // 'a'merican | 'b'ritish + const phonemes = await phonemize(text, language); + const { input_ids } = tokenizer(phonemes, { + truncation: true, + }); + + // Select voice style based on number of input tokens + const num_tokens = Math.max( + input_ids.dims.at(-1) - 2, // Without padding; + 0, + ); + + const voiceStyle = await loadVoiceStyle(voice, num_tokens); + + const { waveform } = await session.run({ + input_ids, + style: voiceStyle, + speed: new Tensor("float32", [1], [1]), + }); + + // Do `wave` encoding from rust backend + const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); + + return new Response(audio, { + headers: { + "Content-Type": "audio/wav", + }, + }); +}); + +async function loadVoiceStyle(voice: string, num_tokens: number) { + const voice_url = + `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; + + console.log("loading voice:", voice_url); + + const voiceBuffer = await fetch(voice_url).then(async (res) => + await res.arrayBuffer() + ); + + const offset = num_tokens * STYLE_DIM; + const voiceData = new Float32Array(voiceBuffer).slice( + offset, + offset + STYLE_DIM, + ); + + return new Tensor("float32", voiceData, [1, STYLE_DIM]); +} + +async function loadTokenizer() { + // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets + // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); + + const tokenizerData = await fetch( + "https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true", + ).then(async (res) => await res.json()); + + const tokenizerConfig = await fetch( + "https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true", + ).then(async (res) => await res.json()); + + return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); +} diff --git a/examples/text-to-audio/phonemizer.js b/examples/text-to-audio/phonemizer.js new file mode 100644 index 000000000..3a2b8028d --- /dev/null +++ b/examples/text-to-audio/phonemizer.js @@ -0,0 +1,208 @@ +// Source code from https://github.com/hexgrad/kokoro/blob/a09db51873211d76a3e49b058b18873a4e002b81/kokoro.js/src/phonemize.js +// BUG: Don't now why, but if import it from cdnjs will cause runtime stack overflow +import { phonemize as espeakng } from "npm:phonemizer@1.2.1"; + +/** + * Helper function to split a string on a regex, but keep the delimiters. + * This is required, because the JavaScript `.split()` method does not keep the delimiters, + * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). + * @param {string} text The text to split. + * @param {RegExp} regex The regex to split on. + * @returns {{match: boolean; text: string}[]} The split string. + */ +function split(text, regex) { + const result = []; + let prev = 0; + for (const match of text.matchAll(regex)) { + const fullMatch = match[0]; + if (prev < match.index) { + result.push({ match: false, text: text.slice(prev, match.index) }); + } + if (fullMatch.length > 0) { + result.push({ match: true, text: fullMatch }); + } + prev = match.index + fullMatch.length; + } + if (prev < text.length) { + result.push({ match: false, text: text.slice(prev) }); + } + return result; +} + +/** + * Helper function to split numbers into phonetic equivalents + * @param {string} match The matched number + * @returns {string} The phonetic equivalent + */ +function split_num(match) { + if (match.includes(".")) { + return match; + } else if (match.includes(":")) { + let [h, m] = match.split(":").map(Number); + if (m === 0) { + return `${h} o'clock`; + } else if (m < 10) { + return `${h} oh ${m}`; + } + return `${h} ${m}`; + } + let year = parseInt(match.slice(0, 4), 10); + if (year < 1100 || year % 1000 < 10) { + return match; + } + let left = match.slice(0, 2); + let right = parseInt(match.slice(2, 4), 10); + let suffix = match.endsWith("s") ? "s" : ""; + if (year % 1000 >= 100 && year % 1000 <= 999) { + if (right === 0) { + return `${left} hundred${suffix}`; + } else if (right < 10) { + return `${left} oh ${right}${suffix}`; + } + } + return `${left} ${right}${suffix}`; +} + +/** + * Helper function to format monetary values + * @param {string} match The matched currency + * @returns {string} The formatted currency + */ +function flip_money(match) { + const bill = match[0] === "$" ? "dollar" : "pound"; + if (isNaN(Number(match.slice(1)))) { + return `${match.slice(1)} ${bill}s`; + } else if (!match.includes(".")) { + let suffix = match.slice(1) === "1" ? "" : "s"; + return `${match.slice(1)} ${bill}${suffix}`; + } + const [b, c] = match.slice(1).split("."); + const d = parseInt(c.padEnd(2, "0"), 10); + let coins = match[0] === "$" + ? (d === 1 ? "cent" : "cents") + : d === 1 + ? "penny" + : "pence"; + return `${b} ${bill}${b === "1" ? "" : "s"} and ${d} ${coins}`; +} + +/** + * Helper function to process decimal numbers + * @param {string} match The matched number + * @returns {string} The formatted number + */ +function point_num(match) { + let [a, b] = match.split("."); + return `${a} point ${b.split("").join(" ")}`; +} + +/** + * Normalize text for phonemization + * @param {string} text The text to normalize + * @returns {string} The normalized text + */ +function normalize_text(text) { + return ( + text + // 1. Handle quotes and brackets + .replace(/[‘’]/g, "'") + .replace(/«/g, "“") + .replace(/»/g, "”") + .replace(/[“”]/g, '"') + .replace(/\(/g, "«") + .replace(/\)/g, "»") + // 2. Replace uncommon punctuation marks + .replace(/、/g, ", ") + .replace(/。/g, ". ") + .replace(/!/g, "! ") + .replace(/,/g, ", ") + .replace(/:/g, ": ") + .replace(/;/g, "; ") + .replace(/?/g, "? ") + // 3. Whitespace normalization + .replace(/[^\S \n]/g, " ") + .replace(/ +/, " ") + .replace(/(?<=\n) +(?=\n)/g, "") + // 4. Abbreviations + .replace(/\bD[Rr]\.(?= [A-Z])/g, "Doctor") + .replace(/\b(?:Mr\.|MR\.(?= [A-Z]))/g, "Mister") + .replace(/\b(?:Ms\.|MS\.(?= [A-Z]))/g, "Miss") + .replace(/\b(?:Mrs\.|MRS\.(?= [A-Z]))/g, "Mrs") + .replace(/\betc\.(?! [A-Z])/gi, "etc") + // 5. Normalize casual words + .replace(/\b(y)eah?\b/gi, "$1e'a") + // 5. Handle numbers and currencies + .replace( + /\d*\.\d+|\b\d{4}s?\b|(? m.replace(/\./g, "-")) + .replace(/(?<=[A-Z])\.(?=[A-Z])/gi, "-") + // 8. Strip leading and trailing whitespace + .trim() + ); +} + +/** + * Escapes regular expression special characters from a string by replacing them with their escaped counterparts. + * + * @param {string} string The string to escape. + * @returns {string} The escaped string. + */ +function escapeRegExp(string) { + return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string +} + +const PUNCTUATION = ';:,.!?¡¿—…"«»“”(){}[]'; +const PUNCTUATION_PATTERN = new RegExp( + `(\\s*[${escapeRegExp(PUNCTUATION)}]+\\s*)+`, + "g", +); + +export async function phonemize(text, language = "a", norm = true) { + // 1. Normalize text + if (norm) { + text = normalize_text(text); + } + + // 2. Split into chunks, to ensure we preserve punctuation + const sections = split(text, PUNCTUATION_PATTERN); + + // 3. Convert each section to phonemes + const lang = language === "a" ? "en-us" : "en"; + const ps = (await Promise.all( + sections.map(async ( + { match, text }, + ) => (match ? text : (await espeakng(text, lang)).join(" "))), + )).join(""); + + // 4. Post-process phonemes + let processed = ps + // https://en.wiktionary.org/wiki/kokoro#English + .replace(/kəkˈoːɹoʊ/g, "kˈoʊkəɹoʊ") + .replace(/kəkˈɔːɹəʊ/g, "kˈəʊkəɹəʊ") + .replace(/ʲ/g, "j") + .replace(/r/g, "ɹ") + .replace(/x/g, "k") + .replace(/ɬ/g, "l") + .replace(/(?<=[a-zɹː])(?=hˈʌndɹɪd)/g, " ") + .replace(/ z(?=[;:,.!?¡¿—…"«»“” ]|$)/g, "z"); + + // 5. Additional post-processing for American English + if (language === "a") { + processed = processed.replace(/(?<=nˈaɪn)ti(?!ː)/g, "di"); + } + return processed.trim(); +} diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index c2a306927..330095947 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,4 +1,5 @@ import "ext:ai/onnxruntime/onnx.js"; +import InferenceAPI from "ext:ai/onnxruntime/inference_api.js"; import EventSourceStream from "ext:ai/util/event_source_stream.mjs"; const core = globalThis.Deno.core; @@ -258,6 +259,7 @@ const MAIN_WORKER_API = { const USER_WORKER_API = { Session, + ...InferenceAPI, }; export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js new file mode 100644 index 000000000..f9b85b003 --- /dev/null +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -0,0 +1,125 @@ +const core = globalThis.Deno.core; +import { InferenceSession, Tensor } from "ext:ai/onnxruntime/onnx.js"; + +const DEFAULT_HUGGING_FACE_OPTIONS = { + hostname: "https://huggingface.co", + path: { + template: "{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true", + revision: "main", + modelFile: "model_quantized.onnx", + }, +}; + +const DEFAULT_STORAGE_OPTIONS = () => ({ + hostname: Deno.env.get("SUPABASE_URL"), + mode: { + authorization: Deno.env.get("SUPABASE_SERVICE_ROLE_KEY"), + }, +}); + +/** + * An user friendly API for onnx backend + */ +class UserInferenceSession { + inner; + + id; + inputs; + outputs; + + constructor(session) { + this.inner = session; + + this.id = session.sessionId; + this.inputs = session.inputNames; + this.outputs = session.outputNames; + } + + static async fromUrl(modelUrl, authorization) { + if (modelUrl instanceof URL) { + modelUrl = modelUrl.toString(); + } + + const encoder = new TextEncoder(); + const modelUrlBuffer = encoder.encode(modelUrl); + const session = await InferenceSession.fromRequest( + modelUrlBuffer, + authorization, + ); + + return new UserInferenceSession(session); + } + + static async fromHuggingFace(repoId, opts = {}) { + const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; + const pathOpts = { + ...DEFAULT_HUGGING_FACE_OPTIONS.path, + ...opts?.path, + }; + + const modelPath = pathOpts.template + .replaceAll("{REPO_ID}", repoId) + .replaceAll("{REVISION}", pathOpts.revision) + .replaceAll("{MODEL_FILE}", pathOpts.modelFile); + + if (!URL.canParse(modelPath, hostname)) { + throw Error( + `[Invalid URL] Couldn't parse the model path: "${modelPath}"`, + ); + } + + return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); + } + + static async fromStorage(modelPath, opts = {}) { + const defaultOpts = DEFAULT_STORAGE_OPTIONS(); + const hostname = opts?.hostname ?? defaultOpts.hostname; + const mode = opts?.mode ?? defaultOpts.mode; + + const assetPath = mode === "public" + ? `public/${modelPath}` + : `authenticated/${modelPath}`; + + const storageUrl = `/storage/v1/object/${assetPath}`; + + if (!URL.canParse(storageUrl, hostname)) { + throw Error( + `[Invalid URL] Couldn't parse the model path: "${storageUrl}"`, + ); + } + + return await UserInferenceSession.fromUrl( + new URL(storageUrl, hostname), + mode?.authorization, + ); + } + async run(inputs) { + const outputs = await core.ops.op_ai_ort_run_session(this.id, inputs); + + // Parse to Tensor + for (const key in outputs) { + if (Object.hasOwn(outputs, key)) { + const { type, data, dims } = outputs[key]; + + outputs[key] = new UserTensor(type, data.buffer, dims); + } + } + + return outputs; + } +} + +class UserTensor extends Tensor { + constructor(type, data, dim) { + super(type, data, dim); + } + + async tryEncodeAudio(sampleRate) { + return await core.ops.op_ai_ort_encode_tensor_audio(this.data, sampleRate); + } +} + +export default { + RawSession: UserInferenceSession, + RawTensor: UserTensor, +}; diff --git a/ext/ai/js/onnxruntime/onnx.js b/ext/ai/js/onnxruntime/onnx.js index 2e0e4548a..0f1138b1d 100644 --- a/ext/ai/js/onnxruntime/onnx.js +++ b/ext/ai/js/onnxruntime/onnx.js @@ -31,7 +31,7 @@ class TensorProxy { } } -class Tensor { +export class Tensor { /** @type {DataType} Type of the tensor. */ type; @@ -67,7 +67,7 @@ class Tensor { } } -class InferenceSession { +export class InferenceSession { sessionId; inputNames; outputNames; @@ -86,6 +86,15 @@ class InferenceSession { return new InferenceSession(id, inputs, outputs); } + static async fromRequest(modelUrl, authorization) { + const [id, inputs, outputs] = await core.ops.op_ai_ort_init_session( + modelUrl, + authorization, + ); + + return new InferenceSession(id, inputs, outputs); + } + async run(inputs) { const sessionInputs = {}; diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450f..266e92058 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -47,6 +47,7 @@ deno_core::extension!( op_ai_try_cleanup_unused_session, op_ai_ort_init_session, op_ai_ort_run_session, + op_ai_ort_encode_tensor_audio, ], esm_entry_point = "ext:ai/ai.js", esm = [ @@ -55,7 +56,8 @@ deno_core::extension!( "util/event_stream_parser.mjs", "util/event_source_stream.mjs", "onnxruntime/onnx.js", - "onnxruntime/cache_adapter.js" + "onnxruntime/cache_adapter.js", + "onnxruntime/inference_api.js" ] ); @@ -116,8 +118,11 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { let handle = handle.clone(); move || { handle.block_on(async move { - load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap()) - .await + load_session_from_url( + Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap(), + None, + ) + .await }) } }) @@ -141,6 +146,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { "tokenizer", Url::parse(consts::GTE_SMALL_TOKENIZER_URL).unwrap(), None, + None, ) .map_err(AnyError::from) .and_then(|it| { diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index b1dd136f8..c37de4f92 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -20,6 +20,7 @@ use deno_core::op2; use deno_core::JsBuffer; use deno_core::JsRuntime; use deno_core::OpState; +use deno_core::ToJsBuffer; use deno_core::V8CrossThreadTaskSpawner; use model::Model; @@ -29,6 +30,7 @@ use reqwest::Url; use tensor::JsTensor; use tensor::ToJsTensor; use tokio::sync::oneshot; +use tokio_util::bytes::BufMut; use tracing::debug; use tracing::trace; @@ -37,6 +39,8 @@ use tracing::trace; pub async fn op_ai_ort_init_session( state: Rc>, #[buffer] model_bytes: JsBuffer, + // Maybe improve the code style to enum payload or something else + #[string] req_authorization: Option, ) -> Result { let model_bytes = model_bytes.into_parts().to_boxed_slice(); let model_bytes_or_url = str::from_utf8(&model_bytes) @@ -46,7 +50,7 @@ pub async fn op_ai_ort_init_session( let model = match model_bytes_or_url { Ok(model_url) => { trace!(kind = "url", url = %model_url); - Model::from_url(model_url).await? + Model::from_url(model_url, req_authorization).await? } Err(_) => { trace!(kind = "bytes", len = model_bytes.len()); @@ -133,3 +137,44 @@ pub async fn op_ai_ort_run_session( rx.await.context("failed to get inference result")? } + +// REF: https://youtu.be/qqjvB_VxMRM?si=7lnYdgbhOC_K7P6S +// http://soundfile.sapp.org/doc/WaveFormat/ +#[op2] +#[serde] +pub fn op_ai_ort_encode_tensor_audio( + #[serde] tensor: JsBuffer, + sample_rate: u32, +) -> Result { + // let copy for now + let data_buffer = tensor.iter().as_slice(); + + let sample_size = 4; // f32 tensor + let data_chunk_size = data_buffer.len() as u32 * sample_size; + let total_riff_size = 36 + data_chunk_size; // 36 is the total of bytes until write data + + let mut audio_wav = Vec::new(); + + // RIFF HEADER + audio_wav.extend_from_slice(b"RIFF"); + audio_wav.put_u32_le(total_riff_size); + audio_wav.extend_from_slice(b"WAVE"); + + // FORMAT CHUNK + audio_wav.extend_from_slice(b"fmt "); // whitespace needed "fmt" + " " + audio_wav.put_u32_le(16); // PCM chunk size + audio_wav.put_u16_le(3); // RAW audio format + audio_wav.put_u16_le(1); // Number of channels + audio_wav.put_u32_le(sample_rate); + audio_wav.put_u32_le(sample_rate * sample_size); // Byte rate + audio_wav.put_u16_le(sample_size as u16); // Block align + audio_wav.put_u16_le(32); // f32 Bits per sample + + // DATA Chunk + audio_wav.extend_from_slice(b"data"); // chunk ID + audio_wav.put_u32_le(data_chunk_size); + + audio_wav.extend_from_slice(data_buffer); + + Ok(ToJsBuffer::from(audio_wav)) +} diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index f3a17e6e4..a8907b479 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -71,8 +71,13 @@ impl Model { .map(Self::new) } - pub async fn from_url(model_url: Url) -> Result { - load_session_from_url(model_url).await.map(Self::new) + pub async fn from_url( + model_url: Url, + authorization: Option, + ) -> Result { + load_session_from_url(model_url, authorization) + .await + .map(Self::new) } pub async fn from_bytes(model_bytes: &[u8]) -> Result { diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 6205e8550..e407ee948 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -155,6 +155,7 @@ pub(crate) async fn load_session_from_bytes( #[instrument(level = "debug", fields(%model_url), err)] pub(crate) async fn load_session_from_url( model_url: Url, + authorization: Option, ) -> Result { let session_id = fxhash::hash(model_url.as_str()).to_string(); @@ -169,6 +170,7 @@ pub(crate) async fn load_session_from_url( "model", model_url, Some(session_id.to_string()), + authorization, ) .await?; diff --git a/ext/ai/utils.rs b/ext/ai/utils.rs index f9e188ba3..f57eb0f16 100644 --- a/ext/ai/utils.rs +++ b/ext/ai/utils.rs @@ -20,6 +20,7 @@ pub async fn fetch_and_cache_from_url( kind: &'static str, url: Url, cache_id: Option, + authorization: Option, ) -> Result { let cache_id = cache_id.unwrap_or(fxhash::hash(url.as_str()).to_string()); let download_dir = std::env::var("EXT_AI_CACHE_DIR") @@ -91,13 +92,26 @@ pub async fn fetch_and_cache_from_url( use reqwest::*; + let mut headers = header::HeaderMap::new(); + + if let Some(authorization) = authorization { + let mut authorization = + header::HeaderValue::from_str(authorization.as_str())?; + authorization.set_sensitive(true); + + headers.insert(header::AUTHORIZATION, authorization); + }; + let resp = Client::builder() .http1_only() + .default_headers(headers) .build() .context("failed to create http client")? .get(url.clone()) .send() .await + .context("failed to download")? + .error_for_status() .context("failed to download")?; let file = tokio::fs::File::create(&filepath) diff --git a/types/global.d.ts b/types/global.d.ts index 7810e23a9..48aa5a302 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -12,6 +12,14 @@ declare interface WindowEventMap { "drain": Event; } +type DecoratorType = "tc39" | "typescript" | "typescript_with_metadata"; + +interface JsxImportBaseConfig { + defaultSpecifier?: string | null; + module?: string | null; + baseUrl?: string | null; +} + // TODO(Nyannyacha): These two type defs will be provided later. // deno-lint-ignore no-explicit-any @@ -179,6 +187,25 @@ declare namespace Supabase { signal?: AbortSignal; } + export type TensorDataTypeMap = { + float32: Float32Array | number[]; + float64: Float64Array | number[]; + string: string[]; + int8: Int8Array | number[]; + uint8: Uint8Array | number[]; + int16: Int16Array | number[]; + uint16: Uint16Array | number[]; + int32: Int32Array | number[]; + uint32: Uint32Array | number[]; + int64: BigInt64Array | number[]; + uint64: BigUint64Array | number[]; + bool: Uint8Array | number[]; + }; + + export type TensorMap = { + [key: string]: RawTensor; + }; + export class Session { /** * Create a new model session using given model @@ -198,6 +225,150 @@ declare namespace Supabase { modelOptions?: ModelOptions, ): unknown; } + + /** Provides an user friendly interface for the low level *onnx backend API*. + * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. + * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. + * + * **Example:** + * ```typescript + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + export class RawSession { + /** The underline session's ID. + * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. + */ + id: string; + + /** A list of all input keys the model expects. */ + inputs: string[]; + + /** A list of all output keys the model will result. */ + outputs: string[]; + + /** Loads a ONNX model session from source URL. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromUrl(source: string | URL): Promise; + + /** Loads a ONNX model session from **HuggingFace** repository. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromHuggingFace(repoId: string, opts?: { + /** + * @default 'https://huggingface.co' + */ + hostname?: string | URL; + path?: { + /** + * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' + */ + template?: string; + /** + * @default 'main' + */ + revision?: string; + /** + * @default 'model_quantized.onnx' + */ + modelFile?: string; + }; + }): Promise; + + /** Loads a ONNX model session from **Storage**. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromStorage(repoId: string, opts?: { + /** + * @default 'env SUPABASE_URL' + */ + hostname?: string | URL; + mode?: "public" | { + authorization: string; + }; + }): Promise; + + /** Run the current session with the given inputs. + * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. + * + * @param inputs The input tensors required by the model. + * @returns The output tensors generated by the model. + * + * @example + * ```typescript + * const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + run(inputs: TensorMap): Promise; + } + + /** A low level representation of model input/output. + * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... + * + * **Example: Generating embeddings from scratch** + * ```typescript + * import { Tensor as HFTensor } from "@huggingface/transformers.js"; + * const { Tensor, RawSession } = Supabase.ai; + * + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * + * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. + * const inputs = { + * input_ids: new Tensor('float32', [...], [n, 2]), + * attention_mask: new Tensor('float32', [...], [n, 2]), + * token_types_ids: new Tensor('float32', [...], [n, 2]) + * }; + * + * const { last_hidden_state } = await session.run(inputs); + * + * // Using `transformers.js` APIs + * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); + * + * return hfTensor.tolist(); + * + * ``` + */ + export class RawTensor { + /** Type of the tensor. */ + type: T; + + /** The data stored in the tensor. */ + data: TensorDataTypeMap[T]; + + /** Dimensions of the tensor. */ + dims: number[]; + + /** The total number of elements in the tensor. */ + size: number; + + constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + + tryEncodeAudio(sampleRate: number): Promise; + } } }