From e119a452d4204fc74b89d22e85abbab1298ae6d3 Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Thu, 10 Jul 2025 15:22:18 +0200 Subject: [PATCH 1/7] Start to rewrite --- mix.exs | 6 +- native/ortex/Cargo.lock | 758 +++++++++++++++++++++++++++++-------- native/ortex/Cargo.toml | 10 +- native/ortex/src/lib.rs | 12 +- native/ortex/src/model.rs | 20 +- native/ortex/src/tensor.rs | 62 +-- native/ortex/src/utils.rs | 37 +- 7 files changed, 682 insertions(+), 223 deletions(-) diff --git a/mix.exs b/mix.exs index 0ff357e..2d62309 100644 --- a/mix.exs +++ b/mix.exs @@ -34,9 +34,9 @@ defmodule Ortex.MixProject do {:rustler, "~> 0.27"}, {:nx, "~> 0.6"}, {:tokenizers, "~> 0.4", only: :dev}, - {:ex_doc, "0.29.4", only: :dev, runtime: false}, - {:exla, "~> 0.6", only: :dev}, - {:torchx, "~> 0.6", only: :dev} + {:ex_doc, "0.29.4", only: :dev, runtime: false} + # {:exla, "~> 0.6", only: :dev}, + # {:torchx, "~> 0.6", only: :dev} ] end diff --git a/native/ortex/Cargo.lock b/native/ortex/Cargo.lock index cd74a84..cd18047 100644 --- a/native/ortex/Cargo.lock +++ b/native/ortex/Cargo.lock @@ -9,25 +9,68 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] -name = "aho-corasick" -version = "0.7.20" +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "aws-lc-rs" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" dependencies = [ - "memchr", + "aws-lc-sys", + "zeroize", ] [[package]] -name = "autocfg" -version = "1.1.0" +name = "aws-lc-sys" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] [[package]] name = "base64" -version = "0.21.7" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.4.2", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] [[package]] name = "bitflags" @@ -56,13 +99,30 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + [[package]] name = "cc" -version = "1.0.83" +version = "1.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" dependencies = [ + "jobserver", "libc", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", ] [[package]] @@ -71,6 +131,42 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -105,6 +201,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -115,6 +221,18 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "errno" version = "0.3.8" @@ -122,9 +240,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "filetime" version = "0.2.23" @@ -134,7 +258,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -148,14 +272,32 @@ dependencies = [ ] [[package]] -name = "form_urlencoded" -version = "1.2.1" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "percent-encoding", + "foreign-types-shared", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "generic-array" version = "0.14.7" @@ -177,6 +319,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "half" version = "2.3.1" @@ -189,18 +337,67 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] -name = "idna" -version = "0.5.0" +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "inventory" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" +dependencies = [ + "rustversion", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jobserver" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "libc", ] [[package]] @@ -209,12 +406,28 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.53.2", +] + [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -223,12 +436,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "log" -version = "0.4.17" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "matchers" @@ -254,6 +464,12 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.2" @@ -263,6 +479,23 @@ dependencies = [ "adler", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -278,6 +511,16 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -322,23 +565,67 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl" +version = "0.10.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +dependencies = [ + "bitflags 2.4.2", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ort" -version = "2.0.0-rc.8" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11826e6118cc42fea0cb2b102f7d006c1bb339cb167f8badb5fb568616438234" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ - "half", "ndarray", "ort-sys", + "smallvec 2.0.0-alpha.10", "tracing", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.8" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4780a8b8681e653b2bed85c7f0e2c6e8547224c3e983e5ad27bf0457e012407" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" dependencies = [ "flate2", "pkg-config", @@ -366,6 +653,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -399,20 +695,30 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "prettyplease" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.27" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4f29d145265ec1c483c7c654450edde0bfe043d3938d6972630663356d9500" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -438,8 +744,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ - "aho-corasick", - "memchr", "regex-syntax", ] @@ -452,6 +756,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.28" @@ -470,9 +780,15 @@ dependencies = [ "libc", "spin", "untrusted", - "windows-sys", + "windows-sys 0.52.0", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "0.38.31" @@ -483,73 +799,117 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] name = "rustler" -version = "0.29.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0884cb623b9f43d3e2c51f9071c5e96a5acf3e6e6007866812884ff0cb983f1e" +checksum = "e3fe55230a9c379733dd38ee67d4072fa5c558b2e22b76b0e7f924390456e003" dependencies = [ - "lazy_static", + "inventory", + "libloading", + "regex-lite", "rustler_codegen", - "rustler_sys", ] [[package]] name = "rustler_codegen" -version = "0.29.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50e277af754f2560cf4c4ebedb68c1a735292fb354505c6133e47ec406e699cf" +checksum = "eb3b8de901ae61418e2036245d28e41ef58080d04f40b68430471ae36a4e84ed" dependencies = [ "heck", + "inventory", "proc-macro2", "quote", "syn", ] -[[package]] -name = "rustler_sys" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff76ba8524729d7c9db2b3e80f2269d1fdef39b5a60624c33fd794797e69b558" -dependencies = [ - "regex", - "unreachable", -] - [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ + "aws-lc-rs", "log", - "ring", + "once_cell", "rustls-pki-types", "rustls-webpki", "subtle", "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" -version = "1.3.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] [[package]] name = "rustls-webpki" -version = "0.102.2" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" + +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "security-framework" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "sha2" version = "0.10.8" @@ -570,12 +930,24 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "smallvec" version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "socks" version = "0.3.4" @@ -601,9 +973,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.16" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6f671d4b5ffdb8eadec19c0ae67fe2639df8684bd7bc4b83d986b8db549cf01" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -622,30 +994,28 @@ dependencies = [ ] [[package]] -name = "thread_local" -version = "1.1.7" +name = "tempfile" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", + "fastrand", "once_cell", + "rustix", + "windows-sys 0.59.0", ] [[package]] -name = "tinyvec" -version = "1.6.0" +name = "thread_local" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ - "tinyvec_macros", + "cfg-if", + "once_cell", ] -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "tracing" version = "0.1.37" @@ -689,7 +1059,7 @@ dependencies = [ "once_cell", "regex", "sharded-slab", - "smallvec", + "smallvec 1.13.1", "thread_local", "tracing", "tracing-core", @@ -702,36 +1072,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - [[package]] name = "unicode-ident" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" -[[package]] -name = "unicode-normalization" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "unreachable" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" -dependencies = [ - "void", -] - [[package]] name = "untrusted" version = "0.9.0" @@ -740,32 +1086,41 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.6" +version = "3.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35" +checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39" dependencies = [ "base64", + "der", "log", - "once_cell", - "rustls", + "native-tls", + "percent-encoding", + "rustls-pemfile", "rustls-pki-types", - "rustls-webpki", "socks", - "url", - "webpki-roots", + "ureq-proto", + "utf-8", + "webpki-root-certs 0.26.11", ] [[package]] -name = "url" -version = "2.5.0" +name = "ureq-proto" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7" dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", + "base64", + "http", + "httparse", + "log", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "valuable" version = "0.1.0" @@ -773,16 +1128,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" [[package]] -name = "version_check" -version = "0.9.4" +name = "vcpkg" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] -name = "void" -version = "1.0.2" +name = "version_check" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "wasi" @@ -791,14 +1146,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "webpki-roots" -version = "0.26.1" +name = "webpki-root-certs" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e" +dependencies = [ + "webpki-root-certs 1.0.1", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +checksum = "86138b15b2b7d561bc4469e77027b8dd005a43dc502e9031d1f5afc8ce1f280e" dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -827,65 +1203,145 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "xattr" diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index 4052470..bc5d409 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -9,14 +9,18 @@ name = "ortex" path = "src/lib.rs" crate-type = ["cdylib"] +[workspace] +resolver = "2" + [dependencies] -rustler = "0.29.0" -ort = { version = "2.0.0-rc.8" } +rustler = "0.36.2" +ort-sys = { version = "=2.0.0-rc.10", default-features = false } +ort = { version = "2.0.0-rc.10" } ndarray = "0.16.1" half = "2.2.1" tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } num-traits = "0.2.15" -rustls = "0.22.4" +rustls = "0.23.28" [features] # ONNXRuntime Execution providers diff --git a/native/ortex/src/lib.rs b/native/ortex/src/lib.rs index 0523105..abdab1b 100644 --- a/native/ortex/src/lib.rs +++ b/native/ortex/src/lib.rs @@ -12,7 +12,7 @@ mod utils; use model::OrtexModel; use tensor::OrtexTensor; -use rustler::resource::ResourceArc; +use rustler::ResourceArc; use rustler::types::Binary; use rustler::{Atom, Env, NifResult, Term}; @@ -106,16 +106,6 @@ pub fn concatenate<'a>( rustler::init!( "Elixir.Ortex.Native", - [ - run, - init, - from_binary, - to_binary, - show_session, - slice, - reshape, - concatenate - ], load = |env: Env, _| { rustler::resource!(OrtexModel, env); rustler::resource!(OrtexTensor, env); diff --git a/native/ortex/src/model.rs b/native/ortex/src/model.rs index f5e04b3..62702ca 100644 --- a/native/ortex/src/model.rs +++ b/native/ortex/src/model.rs @@ -13,14 +13,16 @@ use crate::utils::{is_bool_input, map_opt_level}; use std::convert::TryInto; use std::iter::zip; -use ort::{Error, ExecutionProviderDispatch, Session}; -use rustler::resource::ResourceArc; +use ort::execution_providers::ExecutionProviderDispatch; +use ort::session::Session; +use ort::Error; use rustler::Atom; +use rustler::ResourceArc; /// Holds the model state which include onnxruntime session and environment. All /// are threadsafe so this can be called concurrently from the beam. pub struct OrtexModel { - pub session: ort::Session, + pub session: ort::session::Session, } // Since we're only using the session for inference and @@ -63,7 +65,7 @@ pub fn show( for input in model.session.inputs.iter() { let name = input.name.to_string(); let repr = format!("{:#?}", input.input_type); - let dims = Option::<&Vec>::cloned(input.input_type.tensor_dimensions()); + let dims: Option> = input.input_type.tensor_shape().map(|s| s.to_vec()); inputs.push((name, repr, dims)); } @@ -71,7 +73,7 @@ pub fn show( for output in model.session.outputs.iter() { let name = output.name.to_string(); let repr = format!("{:#?}", output.output_type); - let dims = Option::<&Vec>::cloned(output.output_type.tensor_dimensions()); + let dims: Option> = output.output_type.tensor_shape().map(|s| s.to_vec()); outputs.push((name, repr, dims)); } @@ -85,9 +87,9 @@ pub fn run( inputs: Vec>, ) -> Result, Vec, Atom, usize)>, Error> { // Grab the session and run a forward pass with it - let session: &ort::Session = &model.session; + let session: &ort::session::Session = &model.session; - let mut ortified_inputs: Vec = Vec::new(); + let mut ortified_inputs: Vec = Vec::new(); for (elixir_input, onnx_input) in zip(inputs, &session.inputs) { let derefed_input: &OrtexTensor = &elixir_input; @@ -95,10 +97,10 @@ pub fn run( // this assumes that the boolean input isn't huge -- we're cloning it twice; // once below, once in the try_into() let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool(); - let v: ort::SessionInputValue = boolified_input.try_into()?; + let v: ort::session::SessionInputValue = boolified_input.try_into()?; ortified_inputs.push(v); } else { - let v: ort::SessionInputValue = derefed_input.try_into()?; + let v: ort::session::SessionInputValue = derefed_input.try_into()?; ortified_inputs.push(v); } } diff --git a/native/ortex/src/tensor.rs b/native/ortex/src/tensor.rs index d4aecf1..92c6cda 100644 --- a/native/ortex/src/tensor.rs +++ b/native/ortex/src/tensor.rs @@ -2,9 +2,10 @@ use core::convert::TryFrom; use ndarray::prelude::*; use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; -use ort::{DynValue, Error, Value}; -use rustler::resource::ResourceArc; +use ort::value::{DynValue, Value}; +use ort::Error; use rustler::Atom; +use rustler::ResourceArc; use std::convert::TryInto; use crate::constants::ortex_atoms; @@ -230,58 +231,59 @@ where impl TryFrom<&Value> for OrtexTensor { type Error = Error; fn try_from(e: &Value) -> Result { - let dtype: ort::ValueType = e.dtype(); + let dtype: ort::value::ValueType = e.dtype().clone(); let ty = match dtype { - ort::ValueType::Tensor { + ort::value::ValueType::Tensor { ty: t, - dimensions: _, + shape: _, + dimension_symbols: _, } => t, _ => panic!("can't decode non tensor, got {}", dtype), }; let tensor = match ty { - ort::TensorElementType::Bfloat16 => { + ort::tensor::TensorElementType::Bfloat16 => { OrtexTensor::bf16(e.try_extract_tensor::()?.into_owned()) } - ort::TensorElementType::Float16 => { + ort::tensor::TensorElementType::Float16 => { OrtexTensor::f16(e.try_extract_tensor::()?.into_owned()) } - ort::TensorElementType::Float32 => { + ort::tensor::TensorElementType::Float32 => { OrtexTensor::f32(e.try_extract_tensor::()?.into_owned()) } - ort::TensorElementType::Float64 => { - OrtexTensor::f64(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Float64 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[f64])>::to_owned()) } - ort::TensorElementType::Uint8 => { - OrtexTensor::u8(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Uint8 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[u8])>::to_owned()) } - ort::TensorElementType::Uint16 => { - OrtexTensor::u16(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Uint16 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) } - ort::TensorElementType::Uint32 => { - OrtexTensor::u32(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Uint32 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[u32])>::to_owned()) } - ort::TensorElementType::Uint64 => { - OrtexTensor::u64(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Uint64 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[u64])>::to_owned()) } - ort::TensorElementType::Int8 => { - OrtexTensor::s8(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Int8 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[i8])>::to_owned()) } - ort::TensorElementType::Int16 => { - OrtexTensor::s16(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Int16 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) } - ort::TensorElementType::Int32 => { - OrtexTensor::s32(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Int32 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[i32])>::to_owned()) } - ort::TensorElementType::Int64 => { - OrtexTensor::s64(e.try_extract_tensor::()?.into_owned()) + ort::tensor::TensorElementType::Int64 => { + OrtexTensor::u8(<(&ort::tensor::Shape, &[i64])>::to_owned()) } - ort::TensorElementType::String => { + ort::tensor::TensorElementType::String => { todo!("Can't return string tensors") } // map the output into u8 space - ort::TensorElementType::Bool => { - let nd_array = e.try_extract_tensor::()?.into_owned(); + ort::tensor::TensorElementType::Bool => { + let nd_array = e.try_extract_tensor::()?.to_owned(); OrtexTensor::u8(nd_array.mapv(|x| x as u8)) } }; @@ -290,7 +292,7 @@ impl TryFrom<&Value> for OrtexTensor { } } -impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> { +impl TryFrom<&OrtexTensor> for ort::session::SessionInputValue<'_> { type Error = Error; fn try_from(ort_tensor: &OrtexTensor) -> Result { let r: DynValue = match ort_tensor { diff --git a/native/ortex/src/utils.rs b/native/ortex/src/utils.rs index 8d96117..d3fcb47 100644 --- a/native/ortex/src/utils.rs +++ b/native/ortex/src/utils.rs @@ -7,11 +7,12 @@ use ndarray::{ArrayViewMut, Ix, IxDyn}; use ndarray::ShapeError; -use rustler::resource::ResourceArc; +use rustler::ResourceArc; use rustler::types::Binary; use rustler::{Atom, Env, NifResult}; -use ort::{ExecutionProviderDispatch, GraphOptimizationLevel}; +use ort::execution_providers::ExecutionProviderDispatch; +use ort::session::builder::GraphOptimizationLevel; /// A faster (unsafe) way of creating an Array from an Erlang binary fn initialize_from_raw_ptr(ptr: *const T, shape: &[Ix]) -> ArrayViewMut { @@ -94,15 +95,19 @@ pub fn to_binary<'a>( pub fn map_eps(env: rustler::env::Env, eps: Vec) -> Vec { eps.iter() .map(|e| match &e.to_term(env).atom_to_string().unwrap()[..] { - CPU => ort::CPUExecutionProvider::default().build(), - CUDA => ort::CUDAExecutionProvider::default().build(), - TENSORRT => ort::TensorRTExecutionProvider::default().build(), - ACL => ort::ACLExecutionProvider::default().build(), - ONEDNN => ort::OneDNNExecutionProvider::default().build(), - COREML => ort::CoreMLExecutionProvider::default().build(), - DIRECTML => ort::DirectMLExecutionProvider::default().build(), - ROCM => ort::ROCmExecutionProvider::default().build(), - _ => ort::CPUExecutionProvider::default().build(), + CPU => ort::execution_providers::cpu::CPUExecutionProvider::default().build(), + CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build(), + TENSORRT => { + ort::execution_providers::tensorrt::TensorRTExecutionProvider::default().build() + } + ACL => ort::execution_providers::acl::ACLExecutionProvider::default().build(), + ONEDNN => ort::execution_providers::onednn::OneDNNExecutionProvider::default().build(), + COREML => ort::execution_providers::coreml::CoreMLExecutionProvider::default().build(), + DIRECTML => { + ort::execution_providers::directml::DirectMLExecutionProvider::default().build() + } + ROCM => ort::execution_providers::rocm::ROCmExecutionProvider::default().build(), + _ => ort::execution_providers::cpu::CPUExecutionProvider::default().build(), }) .collect() } @@ -117,11 +122,11 @@ pub fn map_opt_level(opt: i32) -> GraphOptimizationLevel { } } -pub fn is_bool_input(inp: &ort::ValueType) -> bool { +pub fn is_bool_input(inp: &ort::value::ValueType) -> bool { match inp { - ort::ValueType::Tensor { ty, .. } => ty == &ort::TensorElementType::Bool, - ort::ValueType::Map { value, .. } => value == &ort::TensorElementType::Bool, - ort::ValueType::Sequence(boxed_input) => is_bool_input(boxed_input), - ort::ValueType::Optional(boxed_input) => is_bool_input(boxed_input), + ort::value::ValueType::Tensor { ty, .. } => ty == &ort::tensor::TensorElementType::Bool, + ort::value::ValueType::Map { value, .. } => value == &ort::tensor::TensorElementType::Bool, + ort::value::ValueType::Sequence(boxed_input) => is_bool_input(boxed_input), + ort::value::ValueType::Optional(boxed_input) => is_bool_input(boxed_input), } } From 3189db87110c60a403c1ced07982d7261e13dc0c Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Fri, 25 Jul 2025 16:02:51 +0200 Subject: [PATCH 2/7] More rewrite --- native/ortex/Cargo.lock | 20 ++--- native/ortex/Cargo.toml | 10 +-- native/ortex/src/lib.rs | 4 +- native/ortex/src/model.rs | 27 +++---- native/ortex/src/tensor.rs | 158 +++++++++++++++++++++++++++++-------- 5 files changed, 157 insertions(+), 62 deletions(-) diff --git a/native/ortex/Cargo.lock b/native/ortex/Cargo.lock index cd18047..00647ce 100644 --- a/native/ortex/Cargo.lock +++ b/native/ortex/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "adler" @@ -327,9 +327,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" -version = "2.3.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -552,9 +552,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -615,6 +615,7 @@ version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ + "half", "ndarray", "ort-sys", "smallvec 2.0.0-alpha.10", @@ -642,6 +643,7 @@ dependencies = [ "ndarray", "num-traits", "ort", + "ort-sys", "rustler", "rustls", "tracing-subscriber", @@ -829,9 +831,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.28" +version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ "aws-lc-rs", "log", @@ -862,9 +864,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "aws-lc-rs", "ring", diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index bc5d409..4ca6e04 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -15,12 +15,12 @@ resolver = "2" [dependencies] rustler = "0.36.2" ort-sys = { version = "=2.0.0-rc.10", default-features = false } -ort = { version = "2.0.0-rc.10" } +ort = { version = "2.0.0-rc.10", features = ["half"] } ndarray = "0.16.1" -half = "2.2.1" -tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } -num-traits = "0.2.15" -rustls = "0.23.28" +half = "2.6.0" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +num-traits = "0.2.19" +rustls = "0.23.29" [features] # ONNXRuntime Execution providers diff --git a/native/ortex/src/lib.rs b/native/ortex/src/lib.rs index abdab1b..8dbcd79 100644 --- a/native/ortex/src/lib.rs +++ b/native/ortex/src/lib.rs @@ -12,8 +12,8 @@ mod utils; use model::OrtexModel; use tensor::OrtexTensor; -use rustler::ResourceArc; use rustler::types::Binary; +use rustler::ResourceArc; use rustler::{Atom, Env, NifResult, Term}; #[rustler::nif(schedule = "DirtyIo")] @@ -106,7 +106,7 @@ pub fn concatenate<'a>( rustler::init!( "Elixir.Ortex.Native", - load = |env: Env, _| { + load = |env: Env, _term: Term| -> bool { rustler::resource!(OrtexModel, env); rustler::resource!(OrtexTensor, env); true diff --git a/native/ortex/src/model.rs b/native/ortex/src/model.rs index 62702ca..febafbe 100644 --- a/native/ortex/src/model.rs +++ b/native/ortex/src/model.rs @@ -85,17 +85,16 @@ pub fn show( pub fn run( model: ResourceArc, inputs: Vec>, -) -> Result, Vec, Atom, usize)>, Error> { - // Grab the session and run a forward pass with it - let session: &ort::session::Session = &model.session; +) -> Result, Vec, Atom, usize)>, Box> { + // Lock the session for mutable access + let mut session = model.session.lock().map_err(|e| Box::new(e) as Box)?; + let session_ref: &mut ort::session::Session = &mut session; let mut ortified_inputs: Vec = Vec::new(); - for (elixir_input, onnx_input) in zip(inputs, &session.inputs) { + for (elixir_input, onnx_input) in zip(inputs, &session_ref.inputs) { let derefed_input: &OrtexTensor = &elixir_input; if is_bool_input(&onnx_input.input_type) { - // this assumes that the boolean input isn't huge -- we're cloning it twice; - // once below, once in the try_into() let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool(); let v: ort::session::SessionInputValue = boolified_input.try_into()?; ortified_inputs.push(v); @@ -105,26 +104,24 @@ pub fn run( } } - // Construct a Vec of ModelOutput enums based on the DynOrtTensor data type - let outputs = session.run(&ortified_inputs[..])?; + let outputs = session_ref.run(&ortified_inputs[..])?; let mut collected_outputs = Vec::new(); - for output_descriptor in &session.outputs { + for output_descriptor in &session_ref.outputs { let output_name: &str = &output_descriptor.name; - let val = outputs.get(output_name).expect( - &format!( + let val = outputs + .get(output_name) + .expect(&format!( "Expected {} to be in the outputs, but didn't find it", output_name - )[..], - ); + )[..]); - // NOTE: try_into impl here will implicitly map bool outputs to u8 outputs let ortextensor: OrtexTensor = val.try_into()?; let shape = ortextensor.shape(); let (dtype, bits) = ortextensor.dtype(); let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits); - collected_outputs.push(collected_output) + collected_outputs.push(collected_output); } Ok(collected_outputs) diff --git a/native/ortex/src/tensor.rs b/native/ortex/src/tensor.rs index 92c6cda..0e1ff99 100644 --- a/native/ortex/src/tensor.rs +++ b/native/ortex/src/tensor.rs @@ -2,11 +2,11 @@ use core::convert::TryFrom; use ndarray::prelude::*; use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; -use ort::value::{DynValue, Value}; +use ort::value::Value; use ort::Error; use rustler::Atom; use rustler::ResourceArc; -use std::convert::TryInto; +use std::error::Error as StdError; use crate::constants::ortex_atoms; @@ -243,48 +243,144 @@ impl TryFrom<&Value> for OrtexTensor { let tensor = match ty { ort::tensor::TensorElementType::Bfloat16 => { - OrtexTensor::bf16(e.try_extract_tensor::()?.into_owned()) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::bf16(array) } ort::tensor::TensorElementType::Float16 => { - OrtexTensor::f16(e.try_extract_tensor::()?.into_owned()) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::f16(array) } ort::tensor::TensorElementType::Float32 => { - OrtexTensor::f32(e.try_extract_tensor::()?.into_owned()) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::f32(array) } ort::tensor::TensorElementType::Float64 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[f64])>::to_owned()) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::f64(array) } ort::tensor::TensorElementType::Uint8 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[u8])>::to_owned()) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Uint16 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert u16 data to u8 by truncating or clamping + let data_vec: Vec = data.iter().map(|&x| x.min(255) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Uint32 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[u32])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[u32])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert u16 data to u8 by truncating or clamping + let data_vec: Vec = data.iter().map(|&x| x.min(255) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Uint64 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[u64])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[u64])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert u16 data to u8 by truncating or clamping + let data_vec: Vec = data.iter().map(|&x| x.min(255) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Int8 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[i8])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[i8])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert i8 data to u8 (e.g., shift range by adding 128) + let data_vec: Vec = data.iter().map(|&x| (x as i16 + 128) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Int16 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[u16])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert i8 data to u8 (e.g., shift range by adding 128) + let data_vec: Vec = data.iter().map(|&x| (x as i16 + 128) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Int32 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[i32])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[i16])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert i8 data to u8 (e.g., shift range by adding 128) + let data_vec: Vec = data.iter().map(|&x| (x as i16 + 128) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::Int64 => { - OrtexTensor::u8(<(&ort::tensor::Shape, &[i64])>::to_owned()) + //OrtexTensor::u8(<(&ort::tensor::Shape, &[i64])>::to_owned()) + + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + // Convert i8 data to u8 (e.g., shift range by adding 128) + let data_vec: Vec = data.iter().map(|&x| (x as i16 + 128) as u8).collect(); + let array = Array::from_shape_vec(shape, data_vec) + .map_err(|e| Error::from(Box::new(e) as Box))?; + OrtexTensor::u8(array) } ort::tensor::TensorElementType::String => { todo!("Can't return string tensors") } // map the output into u8 space ort::tensor::TensorElementType::Bool => { - let nd_array = e.try_extract_tensor::()?.to_owned(); - OrtexTensor::u8(nd_array.mapv(|x| x as u8)) + let (shape, data) = e.try_extract_tensor::()?; + let shape_vec: Vec = shape.iter().map(|&dim| dim as usize).collect(); + let shape = IxDyn(&shape_vec); + let bool_array = Array::from_shape_vec(shape, data.to_vec()) + .map_err(|e| Error::from(Box::new(e) as Box))?; + let u8_array = bool_array.mapv(|x| x as u8); + OrtexTensor::u8(u8_array) + } + _ => { + todo!("Complex types") } }; @@ -295,22 +391,22 @@ impl TryFrom<&Value> for OrtexTensor { impl TryFrom<&OrtexTensor> for ort::session::SessionInputValue<'_> { type Error = Error; fn try_from(ort_tensor: &OrtexTensor) -> Result { - let r: DynValue = match ort_tensor { - OrtexTensor::s8(arr) => arr.to_owned().try_into()?, - OrtexTensor::s16(arr) => arr.clone().try_into()?, - OrtexTensor::s32(arr) => arr.clone().try_into()?, - OrtexTensor::s64(arr) => arr.clone().try_into()?, - OrtexTensor::f16(arr) => arr.clone().try_into()?, - OrtexTensor::f32(arr) => arr.clone().try_into()?, - OrtexTensor::f64(arr) => arr.clone().try_into()?, - OrtexTensor::bf16(arr) => arr.clone().try_into()?, - OrtexTensor::u8(arr) => arr.clone().try_into()?, - OrtexTensor::u16(arr) => arr.clone().try_into()?, - OrtexTensor::u32(arr) => arr.clone().try_into()?, - OrtexTensor::u64(arr) => arr.clone().try_into()?, - OrtexTensor::bool(arr) => arr.clone().try_into()?, + let value: Value = match ort_tensor { + OrtexTensor::s8(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::s16(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::s32(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::s64(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::f16(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::f32(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::f64(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::bf16(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::u8(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::u16(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::u32(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::u64(arr) => Value::from_array(arr.to_owned())?.into(), + OrtexTensor::bool(arr) => Value::from_array(arr.to_owned())?.into(), }; - Ok(r.into()) + Ok(ort::session::SessionInputValue::from(value)) } } From de0dae6ea56887328e8f3b50fdc74cc68c44d3b4 Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Sun, 27 Jul 2025 13:14:58 +0200 Subject: [PATCH 3/7] Try try try --- mix.exs | 8 ++--- mix.lock | 28 ++++++++-------- native/ortex/Cargo.lock | 70 +++++++++++++++++++++++---------------- native/ortex/Cargo.toml | 2 +- native/ortex/src/lib.rs | 10 ++++++ native/ortex/src/model.rs | 35 +++++++++++--------- 6 files changed, 90 insertions(+), 63 deletions(-) diff --git a/mix.exs b/mix.exs index 2d62309..4936d3a 100644 --- a/mix.exs +++ b/mix.exs @@ -31,10 +31,10 @@ defmodule Ortex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:rustler, "~> 0.27"}, - {:nx, "~> 0.6"}, - {:tokenizers, "~> 0.4", only: :dev}, - {:ex_doc, "0.29.4", only: :dev, runtime: false} + {:rustler, "~> 0.33"}, + {:nx, "~> 0.10"}, + {:tokenizers, "~> 0.5", only: :dev}, + {:ex_doc, "~> 0.38", only: :dev, runtime: false} # {:exla, "~> 0.6", only: :dev}, # {:torchx, "~> 0.6", only: :dev} ] diff --git a/mix.lock b/mix.lock index 8b9284d..d5d8101 100644 --- a/mix.lock +++ b/mix.lock @@ -1,28 +1,28 @@ %{ "axon": {:hex, :axon, "0.5.1", "1ae3a2193df45e51fca912158320b2ca87cb7fba4df242bd3ebe245504d0ea1a", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "d36f2a11c34c6c2b458f54df5c71ffdb7ed91c6a9ccd908faba909c84cc6a38e"}, "axon_onnx": {:hex, :axon_onnx, "0.4.0", "7be4b5ac7a44340ec65eb59c24122a8fe2aa8105da33b3321a378b455a6cd9c6", [:mix], [{:axon, "~> 0.5", [hex: :axon, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}, {:protox, "~> 1.6.10", [hex: :protox, repo: "hexpm", optional: false]}], "hexpm", "b98c84e5656caf156ef8998296836349a62bc35598f05cc21eececbbef022d09"}, - "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, + "castore": {:hex, :castore, "1.0.14", "4582dd7d630b48cf5e1ca8d3d42494db51e406b7ba704e81fbd401866366896a", [:mix], [], "hexpm", "7bc1b65249d31701393edaaac18ec8398d8974d52c647b7904d01b964137b9f4"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, - "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.2.0", "557c43befb8e3b119b718da302adccde3bd855acdb999498a14a2a8d2814b8b9", [:rebar3], [], "hexpm", "a2115d4bf1cca488a7b33f3c648847f64019b32c0382d10286d84dd5c3cbc0e5"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"}, - "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "ex_doc": {:hex, :ex_doc, "0.38.2", "504d25eef296b4dec3b8e33e810bc8b5344d565998cd83914ffe1b8503737c02", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "732f2d972e42c116a70802f9898c51b54916e542cc50968ac6980512ec90f42b"}, "exla": {:hex, :exla, "0.6.1", "a4400933a04d018c5fb508c75a080c73c3c1986f6c16a79bbfee93ba22830d4d", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "f0e95b0f91a937030cf9fcbe900c9d26933cb31db2a26dfc8569aa239679e6d4"}, - "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.6.2", "f1d137f477b1a6f84f8db638f7a6d5a0f8266caea63c9918aa4583db38ebe1d6", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ac913b68d53f25f6eb39bddcf2d2cd6ea2e9bcb6f25cf86a79e35d0411ba96ad"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, + "makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, + "nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"}, "protox": {:hex, :protox, "1.6.10", "41d0b0c5b9190e7d5e6a2b1a03a09257ead6f3d95e6a0cf8b81430b526126908", [:mix], [{:decimal, "~> 1.9 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "9769fca26ae7abfc5cc61308a1e8d9e2400ff89a799599cee7930d21132832d9"}, - "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, - "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, - "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, + "rustler": {:hex, :rustler, "0.36.2", "6c2142f912166dfd364017ab2bf61242d4a5a3c88e7b872744642ae004b82501", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.7", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "93832a6dbc1166739a19cd0c25e110e4cf891f16795deb9361dfcae95f6c88fe"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, + "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "tokenizers": {:hex, :tokenizers, "0.5.1", "b0975d92b4ee5b18e8f47b5d65b9d5f1e583d9130189b1a2620401af4e7d4b35", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "5f08d97cc7f2ed3d71d370d68120da6d3de010948ccf676c9c0eb591ba4bacc9"}, "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, "torchx": {:hex, :torchx, "0.6.1", "2a9862ebc4b397f42c51f0fa3f9f4e3451a83df6fba42882f8523cbc925c8ae1", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "99b3fc73b52d6cfbe5cad8bdd74277ddc99297ce8fc6765b1dabec80681e8d9d"}, "useful": {:hex, :useful, "1.11.0", "b2d89223563c3354fd56f4da75b63f07f52cb32b243289a7f1fcc37869bcf9c2", [:mix], [], "hexpm", "2e5b2a47acc191bfb38e936f5f1bc57dad3b11133e0defe59a32fda10ebafcff"}, diff --git a/native/ortex/Cargo.lock b/native/ortex/Cargo.lock index 00647ce..e8b4089 100644 --- a/native/ortex/Cargo.lock +++ b/native/ortex/Cargo.lock @@ -8,6 +8,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +dependencies = [ + "memchr", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -367,15 +376,6 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" -[[package]] -name = "inventory" -version = "0.3.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" -dependencies = [ - "rustversion", -] - [[package]] name = "itertools" version = "0.12.1" @@ -746,6 +746,8 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ + "aho-corasick", + "memchr", "regex-syntax", ] @@ -758,12 +760,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "regex-lite" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" - [[package]] name = "regex-syntax" version = "0.6.28" @@ -806,29 +802,36 @@ dependencies = [ [[package]] name = "rustler" -version = "0.36.2" +version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3fe55230a9c379733dd38ee67d4072fa5c558b2e22b76b0e7f924390456e003" +checksum = "45d51ae0239c57c3a3e603dd855ace6795078ef33c95c85d397a100ac62ed352" dependencies = [ - "inventory", - "libloading", - "regex-lite", "rustler_codegen", + "rustler_sys", ] [[package]] name = "rustler_codegen" -version = "0.36.2" +version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3b8de901ae61418e2036245d28e41ef58080d04f40b68430471ae36a4e84ed" +checksum = "27061f1a2150ad64717dca73902678c124b0619b0d06563294df265bc84759e1" dependencies = [ "heck", - "inventory", "proc-macro2", "quote", "syn", ] +[[package]] +name = "rustler_sys" +version = "2.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd0e2c955cfc86ea4680067e1d5e711427b43f7befcb6e23c7807cf3dd90e97" +dependencies = [ + "regex", + "unreachable", +] + [[package]] name = "rustls" version = "0.23.29" @@ -874,12 +877,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "rustversion" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" - [[package]] name = "schannel" version = "0.1.27" @@ -1080,6 +1077,15 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + [[package]] name = "untrusted" version = "0.9.0" @@ -1141,6 +1147,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index 4ca6e04..636ff13 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -13,7 +13,7 @@ crate-type = ["cdylib"] resolver = "2" [dependencies] -rustler = "0.36.2" +rustler = "0.33" ort-sys = { version = "=2.0.0-rc.10", default-features = false } ort = { version = "2.0.0-rc.10", features = ["half"] } ndarray = "0.16.1" diff --git a/native/ortex/src/lib.rs b/native/ortex/src/lib.rs index 8dbcd79..812b252 100644 --- a/native/ortex/src/lib.rs +++ b/native/ortex/src/lib.rs @@ -106,6 +106,16 @@ pub fn concatenate<'a>( rustler::init!( "Elixir.Ortex.Native", + [ + run, + init, + from_binary, + to_binary, + show_session, + slice, + reshape, + concatenate + ], load = |env: Env, _term: Term| -> bool { rustler::resource!(OrtexModel, env); rustler::resource!(OrtexTensor, env); diff --git a/native/ortex/src/model.rs b/native/ortex/src/model.rs index febafbe..543db32 100644 --- a/native/ortex/src/model.rs +++ b/native/ortex/src/model.rs @@ -15,14 +15,17 @@ use std::iter::zip; use ort::execution_providers::ExecutionProviderDispatch; use ort::session::Session; +use ort::util::Mutex; use ort::Error; use rustler::Atom; use rustler::ResourceArc; +use std::error::Error as StdError; +use std::sync::Arc; /// Holds the model state which include onnxruntime session and environment. All /// are threadsafe so this can be called concurrently from the beam. pub struct OrtexModel { - pub session: ort::session::Session, + pub session: Arc>, } // Since we're only using the session for inference and @@ -46,7 +49,9 @@ pub fn init( .with_execution_providers(eps)? .commit_from_file(model_path)?; - let state = OrtexModel { session }; + let state = OrtexModel { + session: Arc::new(Mutex::new(session)), + }; Ok(state) } @@ -60,9 +65,10 @@ pub fn show( Vec<(String, String, Option>)>, ) { let model: &OrtexModel = &*model; + let session = model.session.lock(); let mut inputs = Vec::new(); - for input in model.session.inputs.iter() { + for input in session.inputs.iter() { let name = input.name.to_string(); let repr = format!("{:#?}", input.input_type); let dims: Option> = input.input_type.tensor_shape().map(|s| s.to_vec()); @@ -70,7 +76,7 @@ pub fn show( } let mut outputs = Vec::new(); - for output in model.session.outputs.iter() { + for output in session.outputs.iter() { let name = output.name.to_string(); let repr = format!("{:#?}", output.output_type); let dims: Option> = output.output_type.tensor_shape().map(|s| s.to_vec()); @@ -85,14 +91,12 @@ pub fn show( pub fn run( model: ResourceArc, inputs: Vec>, -) -> Result, Vec, Atom, usize)>, Box> { - // Lock the session for mutable access - let mut session = model.session.lock().map_err(|e| Box::new(e) as Box)?; - let session_ref: &mut ort::session::Session = &mut session; +) -> Result, Vec, Atom, usize)>, Box> { + let mut session = model.session.lock(); let mut ortified_inputs: Vec = Vec::new(); - for (elixir_input, onnx_input) in zip(inputs, &session_ref.inputs) { + for (elixir_input, onnx_input) in zip(inputs, &session.inputs) { let derefed_input: &OrtexTensor = &elixir_input; if is_bool_input(&onnx_input.input_type) { let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool(); @@ -104,17 +108,18 @@ pub fn run( } } - let outputs = session_ref.run(&ortified_inputs[..])?; + let output_descriptors = session.outputs.clone(); + let outputs = session.run(&ortified_inputs[..])?; let mut collected_outputs = Vec::new(); - for output_descriptor in &session_ref.outputs { + for output_descriptor in output_descriptors { let output_name: &str = &output_descriptor.name; - let val = outputs - .get(output_name) - .expect(&format!( + let val = outputs.get(output_name).expect( + &format!( "Expected {} to be in the outputs, but didn't find it", output_name - )[..]); + )[..], + ); let ortextensor: OrtexTensor = val.try_into()?; let shape = ortextensor.shape(); From a4493c374e510c78f5fe79393583b1ade148a968 Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Mon, 28 Jul 2025 09:38:18 +0200 Subject: [PATCH 4/7] Another try --- native/ortex/src/model.rs | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/native/ortex/src/model.rs b/native/ortex/src/model.rs index 543db32..df3e49e 100644 --- a/native/ortex/src/model.rs +++ b/native/ortex/src/model.rs @@ -15,17 +15,16 @@ use std::iter::zip; use ort::execution_providers::ExecutionProviderDispatch; use ort::session::Session; -use ort::util::Mutex; use ort::Error; use rustler::Atom; use rustler::ResourceArc; use std::error::Error as StdError; -use std::sync::Arc; +use std::sync::Mutex; /// Holds the model state which include onnxruntime session and environment. All /// are threadsafe so this can be called concurrently from the beam. pub struct OrtexModel { - pub session: Arc>, + pub session: Mutex, } // Since we're only using the session for inference and @@ -50,7 +49,7 @@ pub fn init( .commit_from_file(model_path)?; let state = OrtexModel { - session: Arc::new(Mutex::new(session)), + session: session.into(), }; Ok(state) } @@ -64,8 +63,7 @@ pub fn show( Vec<(String, String, Option>)>, Vec<(String, String, Option>)>, ) { - let model: &OrtexModel = &*model; - let session = model.session.lock(); + let session: &mut ort::session::Session = &mut model.session.lock().unwrap(); let mut inputs = Vec::new(); for input in session.inputs.iter() { @@ -92,7 +90,7 @@ pub fn run( model: ResourceArc, inputs: Vec>, ) -> Result, Vec, Atom, usize)>, Box> { - let mut session = model.session.lock(); + let session: &mut ort::session::Session = &mut model.session.lock().unwrap(); let mut ortified_inputs: Vec = Vec::new(); @@ -108,12 +106,10 @@ pub fn run( } } - let output_descriptors = session.outputs.clone(); let outputs = session.run(&ortified_inputs[..])?; let mut collected_outputs = Vec::new(); - for output_descriptor in output_descriptors { - let output_name: &str = &output_descriptor.name; + for output_name in outputs.keys() { let val = outputs.get(output_name).expect( &format!( "Expected {} to be in the outputs, but didn't find it", @@ -124,10 +120,26 @@ pub fn run( let ortextensor: OrtexTensor = val.try_into()?; let shape = ortextensor.shape(); let (dtype, bits) = ortextensor.dtype(); - let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits); collected_outputs.push(collected_output); } + // for output_descriptor in &session.outputs { + // let output_name: &str = &output_descriptor.name; + // let val = outputs.get(output_name).expect( + // &format!( + // "Expected {} to be in the outputs, but didn't find it", + // output_name + // )[..], + // ); + + // let ortextensor: OrtexTensor = val.try_into()?; + // let shape = ortextensor.shape(); + // let (dtype, bits) = ortextensor.dtype(); + + // let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits); + // collected_outputs.push(collected_output); + // } + Ok(collected_outputs) } From 2da4c1a3013794722ed15a856e20d519f8fd0c0b Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Mon, 28 Jul 2025 10:35:39 +0200 Subject: [PATCH 5/7] Make sure EPs are supported --- mix.exs | 2 +- native/ortex/Cargo.lock | 70 ++++++++++++++++---------------------- native/ortex/Cargo.toml | 4 +-- native/ortex/src/lib.rs | 18 +++------- native/ortex/src/model.rs | 2 ++ native/ortex/src/tensor.rs | 5 ++- native/ortex/src/utils.rs | 4 +-- 7 files changed, 45 insertions(+), 60 deletions(-) diff --git a/mix.exs b/mix.exs index 4936d3a..0c44fb7 100644 --- a/mix.exs +++ b/mix.exs @@ -31,7 +31,7 @@ defmodule Ortex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:rustler, "~> 0.33"}, + {:rustler, "~> 0.36.2"}, {:nx, "~> 0.10"}, {:tokenizers, "~> 0.5", only: :dev}, {:ex_doc, "~> 0.38", only: :dev, runtime: false} diff --git a/native/ortex/Cargo.lock b/native/ortex/Cargo.lock index e8b4089..00647ce 100644 --- a/native/ortex/Cargo.lock +++ b/native/ortex/Cargo.lock @@ -8,15 +8,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -376,6 +367,15 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "inventory" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" +dependencies = [ + "rustversion", +] + [[package]] name = "itertools" version = "0.12.1" @@ -746,8 +746,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ - "aho-corasick", - "memchr", "regex-syntax", ] @@ -760,6 +758,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.28" @@ -802,36 +806,29 @@ dependencies = [ [[package]] name = "rustler" -version = "0.33.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45d51ae0239c57c3a3e603dd855ace6795078ef33c95c85d397a100ac62ed352" +checksum = "e3fe55230a9c379733dd38ee67d4072fa5c558b2e22b76b0e7f924390456e003" dependencies = [ + "inventory", + "libloading", + "regex-lite", "rustler_codegen", - "rustler_sys", ] [[package]] name = "rustler_codegen" -version = "0.33.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27061f1a2150ad64717dca73902678c124b0619b0d06563294df265bc84759e1" +checksum = "eb3b8de901ae61418e2036245d28e41ef58080d04f40b68430471ae36a4e84ed" dependencies = [ "heck", + "inventory", "proc-macro2", "quote", "syn", ] -[[package]] -name = "rustler_sys" -version = "2.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd0e2c955cfc86ea4680067e1d5e711427b43f7befcb6e23c7807cf3dd90e97" -dependencies = [ - "regex", - "unreachable", -] - [[package]] name = "rustls" version = "0.23.29" @@ -877,6 +874,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" + [[package]] name = "schannel" version = "0.1.27" @@ -1077,15 +1080,6 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" -[[package]] -name = "unreachable" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" -dependencies = [ - "void", -] - [[package]] name = "untrusted" version = "0.9.0" @@ -1147,12 +1141,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "void" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index 636ff13..f36eac2 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -13,9 +13,9 @@ crate-type = ["cdylib"] resolver = "2" [dependencies] -rustler = "0.33" +rustler = "0.36.2" ort-sys = { version = "=2.0.0-rc.10", default-features = false } -ort = { version = "2.0.0-rc.10", features = ["half"] } +ort = { version = "2.0.0-rc.10", features = ["half", "cuda"] } ndarray = "0.16.1" half = "2.6.0" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } diff --git a/native/ortex/src/lib.rs b/native/ortex/src/lib.rs index 812b252..f9a6fae 100644 --- a/native/ortex/src/lib.rs +++ b/native/ortex/src/lib.rs @@ -104,21 +104,13 @@ pub fn concatenate<'a>( Ok(ResourceArc::new(concatted)) } +pub fn on_load(env: Env) -> bool { + env.register::().is_ok() && env.register::().is_ok() +} + rustler::init!( "Elixir.Ortex.Native", - [ - run, - init, - from_binary, - to_binary, - show_session, - slice, - reshape, - concatenate - ], load = |env: Env, _term: Term| -> bool { - rustler::resource!(OrtexModel, env); - rustler::resource!(OrtexTensor, env); - true + on_load(env) } ); diff --git a/native/ortex/src/model.rs b/native/ortex/src/model.rs index df3e49e..c849d5f 100644 --- a/native/ortex/src/model.rs +++ b/native/ortex/src/model.rs @@ -17,6 +17,7 @@ use ort::execution_providers::ExecutionProviderDispatch; use ort::session::Session; use ort::Error; use rustler::Atom; +use rustler::Resource; use rustler::ResourceArc; use std::error::Error as StdError; use std::sync::Mutex; @@ -26,6 +27,7 @@ use std::sync::Mutex; pub struct OrtexModel { pub session: Mutex, } +impl Resource for OrtexModel {} // Since we're only using the session for inference and // inference is threadsafe, this Sync is safe. Additionally, diff --git a/native/ortex/src/tensor.rs b/native/ortex/src/tensor.rs index 0e1ff99..d96d6fd 100644 --- a/native/ortex/src/tensor.rs +++ b/native/ortex/src/tensor.rs @@ -5,6 +5,7 @@ use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; use ort::value::Value; use ort::Error; use rustler::Atom; +use rustler::Resource; use rustler::ResourceArc; use std::error::Error as StdError; @@ -204,6 +205,8 @@ impl OrtexTensor { } } +impl Resource for OrtexTensor {} + fn slice_array<'a, T, D>( array: &'a Array, slice_specs: &'a Vec<(isize, Option, isize)>, @@ -441,7 +444,7 @@ macro_rules! concatenate { // `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant ($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{ type ArrayType<'a> = ArrayBase, Dim>; - fn filter(tensor: &OrtexTensor) -> Option { + fn filter<'a>(tensor: &'a OrtexTensor) -> Option> { match tensor { OrtexTensor::$ort_tensor_kind(x) => Some(x.view()), _ => None, diff --git a/native/ortex/src/utils.rs b/native/ortex/src/utils.rs index d3fcb47..8b370a9 100644 --- a/native/ortex/src/utils.rs +++ b/native/ortex/src/utils.rs @@ -15,7 +15,7 @@ use ort::execution_providers::ExecutionProviderDispatch; use ort::session::builder::GraphOptimizationLevel; /// A faster (unsafe) way of creating an Array from an Erlang binary -fn initialize_from_raw_ptr(ptr: *const T, shape: &[Ix]) -> ArrayViewMut { +fn initialize_from_raw_ptr(ptr: *const T, shape: &[Ix]) -> ArrayViewMut<'_, T, IxDyn> { let array = unsafe { ArrayViewMut::from_shape_ptr(shape, ptr as *mut T) }; array } @@ -96,7 +96,7 @@ pub fn map_eps(env: rustler::env::Env, eps: Vec) -> Vec ort::execution_providers::cpu::CPUExecutionProvider::default().build(), - CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build(), + CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build().error_on_failure(), TENSORRT => { ort::execution_providers::tensorrt::TensorRTExecutionProvider::default().build() } From 53ab51d3ce2a5d6333af3bc1bcb0300c1f6d406f Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Mon, 28 Jul 2025 11:01:13 +0200 Subject: [PATCH 6/7] Do not error if EP is not avaliable, include more eps in cargo.toml --- native/ortex/Cargo.toml | 8 +++++++- native/ortex/src/utils.rs | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index f36eac2..e228700 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -15,7 +15,13 @@ resolver = "2" [dependencies] rustler = "0.36.2" ort-sys = { version = "=2.0.0-rc.10", default-features = false } -ort = { version = "2.0.0-rc.10", features = ["half", "cuda"] } +ort = { version = "2.0.0-rc.10", features = [ + "half", + "cuda", + "tensorrt", + "directml", + "coreml" +] } ndarray = "0.16.1" half = "2.6.0" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } diff --git a/native/ortex/src/utils.rs b/native/ortex/src/utils.rs index 8b370a9..b36662c 100644 --- a/native/ortex/src/utils.rs +++ b/native/ortex/src/utils.rs @@ -7,8 +7,8 @@ use ndarray::{ArrayViewMut, Ix, IxDyn}; use ndarray::ShapeError; -use rustler::ResourceArc; use rustler::types::Binary; +use rustler::ResourceArc; use rustler::{Atom, Env, NifResult}; use ort::execution_providers::ExecutionProviderDispatch; @@ -96,7 +96,7 @@ pub fn map_eps(env: rustler::env::Env, eps: Vec) -> Vec ort::execution_providers::cpu::CPUExecutionProvider::default().build(), - CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build().error_on_failure(), + CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build(), TENSORRT => { ort::execution_providers::tensorrt::TensorRTExecutionProvider::default().build() } From dceff4118c335183e94ef0f821fba6ca6910682c Mon Sep 17 00:00:00 2001 From: "morten.lund@maskon.no" Date: Wed, 6 Aug 2025 11:10:33 +0200 Subject: [PATCH 7/7] Disable unsupported execution providers for now --- native/ortex/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native/ortex/Cargo.toml b/native/ortex/Cargo.toml index e228700..b09f355 100644 --- a/native/ortex/Cargo.toml +++ b/native/ortex/Cargo.toml @@ -18,9 +18,9 @@ ort-sys = { version = "=2.0.0-rc.10", default-features = false } ort = { version = "2.0.0-rc.10", features = [ "half", "cuda", - "tensorrt", - "directml", - "coreml" + # "tensorrt", + # "directml", + # "coreml" ] } ndarray = "0.16.1" half = "2.6.0"