From 1e094470b97ee6f671808890671b9c53f454b4fc Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 8 Jan 2021 22:38:10 +0800 Subject: [PATCH 1/7] 1. update dependencies 2. update onnx version --- onnxruntime-sys/Cargo.toml | 4 ++-- onnxruntime-sys/build.rs | 19 ++++++++++++------- onnxruntime/Cargo.toml | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/onnxruntime-sys/Cargo.toml b/onnxruntime-sys/Cargo.toml index 92cf0911..e9a89de0 100644 --- a/onnxruntime-sys/Cargo.toml +++ b/onnxruntime-sys/Cargo.toml @@ -17,8 +17,8 @@ keywords = ["neuralnetworks", "onnx", "bindings"] [dependencies] [build-dependencies] -bindgen = {version = "0.55", optional = true} -ureq = "1.5.1" +bindgen = { version = "0.56", optional = true } +ureq = "2" # Used on Windows zip = "0.5" diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index a3eafb7b..1ff87808 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -11,10 +11,11 @@ use std::{ /// WARNING: If version is changed, bindings for all platforms will have to be re-generated. /// To do so, run this: /// cargo build --package onnxruntime-sys --features generate-bindings -const ORT_VERSION: &str = "1.5.2"; +const ORT_VERSION: &str = "1.6.0"; /// Base Url from which to download pre-built releases/ -const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download"; +/// https://github.wuyanzheshui.workers.dev/microsoft/onnxruntime/releases/download/v1.6.0/onnxruntime-osx-x64-1.6.0.tgz +const ORT_RELEASE_BASE_URL: &str = "https://github.wuyanzheshui.workers.dev/microsoft/onnxruntime/releases/download"; /// Environment variable selecting which strategy to use for finding the library /// Possibilities: @@ -106,15 +107,19 @@ fn generate_bindings(include_dir: &Path) { } fn download>(source_url: &str, target_file: P) { - let resp = ureq::get(source_url) - .timeout_connect(1_000) // 1 second + let agent = ureq::AgentBuilder::new() + .timeout_read(std::time::Duration::from_secs(1)) // 1 second .timeout(std::time::Duration::from_secs(300)) - .call(); + .build(); - if resp.error() { + let resp = agent.get(source_url).call(); + + if resp.is_err() { panic!("ERROR: Failed to download {}: {:#?}", source_url, resp); } + let resp = resp.unwrap(); + let len = resp .header("Content-Length") .and_then(|s| s.parse::().ok()) @@ -154,7 +159,7 @@ fn extract_zip(filename: &Path, outpath: &Path) { for i in 0..archive.len() { let mut file = archive.by_index(i).unwrap(); #[allow(deprecated)] - let outpath = outpath.join(file.sanitized_name()); + let outpath = outpath.join(file.sanitized_name()); if !(&*file.name()).ends_with('/') { println!( "File {} extracted to \"{}\" ({} bytes)", diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 8f8837d5..f1fe4371 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -22,7 +22,7 @@ required-features = ["model-fetching"] onnxruntime-sys = {version = "0.0.10", path = "../onnxruntime-sys"} lazy_static = "1.4" -ndarray = "0.13" +ndarray = "0.14" thiserror = "1.0" tracing = "0.1" From 1b79671e08ce2a669798bb487e419c547887917b Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 8 Jan 2021 22:41:15 +0800 Subject: [PATCH 2/7] download url --- onnxruntime-sys/build.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 1ff87808..973691b2 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -14,8 +14,7 @@ use std::{ const ORT_VERSION: &str = "1.6.0"; /// Base Url from which to download pre-built releases/ -/// https://github.wuyanzheshui.workers.dev/microsoft/onnxruntime/releases/download/v1.6.0/onnxruntime-osx-x64-1.6.0.tgz -const ORT_RELEASE_BASE_URL: &str = "https://github.wuyanzheshui.workers.dev/microsoft/onnxruntime/releases/download"; +const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download"; /// Environment variable selecting which strategy to use for finding the library /// Possibilities: From 2c11f1444b6d7967b0402d47d32275e8ac78eef7 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 8 Jan 2021 22:46:43 +0800 Subject: [PATCH 3/7] reformat --- onnxruntime-sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 973691b2..47d0573e 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -158,7 +158,7 @@ fn extract_zip(filename: &Path, outpath: &Path) { for i in 0..archive.len() { let mut file = archive.by_index(i).unwrap(); #[allow(deprecated)] - let outpath = outpath.join(file.sanitized_name()); + let outpath = outpath.join(file.sanitized_name()); if !(&*file.name()).ends_with('/') { println!( "File {} extracted to \"{}\" ({} bytes)", From 4f273919b695d3082392ee38e2e7b127d6b6df7d Mon Sep 17 00:00:00 2001 From: ylfeng Date: Tue, 12 Jan 2021 10:21:18 +0800 Subject: [PATCH 4/7] 1. update onnxruntime ureq version 2. from capi to get onnx output shapes --- onnxruntime-sys/src/lib.rs | 7 ++----- onnxruntime/Cargo.toml | 2 +- onnxruntime/src/download.rs | 13 ++++++++---- onnxruntime/src/error.rs | 3 +++ onnxruntime/src/session.rs | 41 +++++++++++++++++++++++++++---------- 5 files changed, 45 insertions(+), 21 deletions(-) diff --git a/onnxruntime-sys/src/lib.rs b/onnxruntime-sys/src/lib.rs index 62941a6b..7eb08844 100644 --- a/onnxruntime-sys/src/lib.rs +++ b/onnxruntime-sys/src/lib.rs @@ -6,11 +6,8 @@ #![allow(improper_ctypes)] #[allow(clippy::all)] - -include!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/generated/bindings.rs" -)); +mod bindings; +pub use bindings::*; #[cfg(target_os = "windows")] pub type OnnxEnumInt = i32; diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index f1fe4371..b135b8a8 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -33,7 +33,7 @@ ureq = {version = "1.5.1", optional = true} image = "0.23" test-env-log = {version = "0.2", default-features = false, features = ["trace"]} tracing-subscriber = "0.2" -ureq = "1.5.1" +ureq = "2.0" [features] # Fetch model from ONNX Model Zoo (https://github.com/onnx/models) diff --git a/onnxruntime/src/download.rs b/onnxruntime/src/download.rs index 9b3cd50f..3494bebf 100644 --- a/onnxruntime/src/download.rs +++ b/onnxruntime/src/download.rs @@ -78,10 +78,15 @@ impl AvailableOnnxModel { "Downloading file, please wait....", ); - let resp = ureq::get(url) - .timeout_connect(1_000) // 1 second - .timeout(Duration::from_secs(180)) // 3 minutes - .call(); + let agent = ureq::AgentBuilder::new() + .timeout_connect(1_000) // 1 second .timeout_read(std::time::Duration::from_secs(1)) // 1 second + .timeout(Duration::from_secs(180)) // 3 minutes .timeout(std::time::Duration::from_secs(180))// 3 minutes + .build(); + + let resp = agent + .get(url) + .call() + .map_err(OrtDownloadError::DownloadError)?; assert!(resp.has("Content-Length")); let len = resp diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index b63072ee..8b3e9e80 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -128,6 +128,9 @@ pub enum OrtApiError { #[non_exhaustive] #[derive(Error, Debug)] pub enum OrtDownloadError { + /// Generic download error + #[error("Error downloading data")] + DownloadError(#[from] ureq::Error), /// Generic input/output error #[error("Error downloading data to file: {0}")] IoError(#[from] io::Error), diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 39cb558a..208a7f30 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -405,18 +405,37 @@ impl<'a> Session<'a> { .map(|n| n.as_ptr() as *const i8) .collect(); - let output_shapes: Vec> = { + let output_shapes: Vec> = unsafe { let mut tmp = Vec::new(); - for (idx, output) in self.outputs.iter().enumerate() { - let v: Vec<_> = output - .dimensions - .iter() - .enumerate() - .map(|(jdx, dim)| match dim { - None => input_arrays[idx].shape()[jdx], - Some(d) => *d as usize, - }) - .collect(); + for output in output_tensor_extractors_ptrs.iter() { + let mut tensor_info: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + let status = g_ort().GetTensorTypeAndShape.unwrap()( + *output, + &mut tensor_info as *mut *mut sys::OrtTensorTypeAndShapeInfo, + ); + status_to_result(status).map_err(OrtError::Run)?; + + let mut dim_size: sys::size_t = 0; + let status = g_ort().GetDimensionsCount.unwrap()( + tensor_info, + &mut dim_size as *mut sys::size_t, + ); + if !status.is_null() { + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info); + } + status_to_result(status).map_err(OrtError::Run)?; + + let mut v = vec![0usize; dim_size as usize]; + let status = g_ort().GetDimensions.unwrap()( + tensor_info, + v.as_mut_ptr() as *mut i64, + dim_size, + ); + if !status.is_null() { + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info); + } + status_to_result(status).map_err(OrtError::Run)?; + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info); tmp.push(v); } tmp From 7e6d07730da705ee3b6a8f541fe11147f980221e Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 24 Jan 2021 15:58:40 +0800 Subject: [PATCH 5/7] 1. fix about ureq 2.0 2. fix inference output shapes --- onnxruntime/Cargo.toml | 6 ++-- onnxruntime/src/download.rs | 2 +- onnxruntime/src/session.rs | 58 ++++++++++++++++++------------------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index b135b8a8..50abf6b5 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -19,7 +19,7 @@ name = "integration_tests" required-features = ["model-fetching"] [dependencies] -onnxruntime-sys = {version = "0.0.10", path = "../onnxruntime-sys"} +onnxruntime-sys = { version = "0.0.10", path = "../onnxruntime-sys" } lazy_static = "1.4" ndarray = "0.14" @@ -27,11 +27,11 @@ thiserror = "1.0" tracing = "0.1" # Enabled with 'model-fetching' feature -ureq = {version = "1.5.1", optional = true} +ureq = { version = "2", optional = true } [dev-dependencies] image = "0.23" -test-env-log = {version = "0.2", default-features = false, features = ["trace"]} +test-env-log = { version = "0.2", default-features = false, features = ["trace"] } tracing-subscriber = "0.2" ureq = "2.0" diff --git a/onnxruntime/src/download.rs b/onnxruntime/src/download.rs index 3494bebf..00e2be22 100644 --- a/onnxruntime/src/download.rs +++ b/onnxruntime/src/download.rs @@ -79,7 +79,7 @@ impl AvailableOnnxModel { ); let agent = ureq::AgentBuilder::new() - .timeout_connect(1_000) // 1 second .timeout_read(std::time::Duration::from_secs(1)) // 1 second + .timeout_connect(Duration::from_secs(1)) // 1 second .timeout_read(std::time::Duration::from_secs(1)) // 1 second .timeout(Duration::from_secs(180)) // 3 minutes .timeout(std::time::Duration::from_secs(180))// 3 minutes .build(); diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 208a7f30..7bfb0340 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -405,6 +405,35 @@ impl<'a> Session<'a> { .map(|n| n.as_ptr() as *const i8) .collect(); + let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + vec![std::ptr::null_mut(); output_names_ptr.len()]; + + // The C API expects pointers for the arrays (pointers to C-arrays) + let input_ort_tensors: Vec> = input_arrays + .into_iter() + .map(|input_array| OrtTensor::from_array(&self.memory_info, input_array)) + .collect::>>>()?; + let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors + .iter() + .map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue) + .collect(); + + let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); + + let status = unsafe { + g_ort().Run.unwrap()( + self.session_ptr, + run_options_ptr, + input_names_ptr.as_ptr(), + input_ort_values.as_ptr(), + input_ort_values.len() as u64, // C API expects a u64, not isize + output_names_ptr.as_ptr(), + output_names_ptr.len() as u64, // C API expects a u64, not isize + output_tensor_extractors_ptrs.as_mut_ptr(), + ) + }; + status_to_result(status).map_err(OrtError::Run)?; + let output_shapes: Vec> = unsafe { let mut tmp = Vec::new(); for output in output_tensor_extractors_ptrs.iter() { @@ -448,35 +477,6 @@ impl<'a> Session<'a> { }) .collect(); - let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = - vec![std::ptr::null_mut(); output_tensor_extractors.len()]; - - // The C API expects pointers for the arrays (pointers to C-arrays) - let input_ort_tensors: Vec> = input_arrays - .into_iter() - .map(|input_array| OrtTensor::from_array(&self.memory_info, input_array)) - .collect::>>>()?; - let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors - .iter() - .map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue) - .collect(); - - let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); - - let status = unsafe { - g_ort().Run.unwrap()( - self.session_ptr, - run_options_ptr, - input_names_ptr.as_ptr(), - input_ort_values.as_ptr(), - input_ort_values.len() as u64, // C API expects a u64, not isize - output_names_ptr.as_ptr(), - output_names_ptr.len() as u64, // C API expects a u64, not isize - output_tensor_extractors_ptrs.as_mut_ptr(), - ) - }; - status_to_result(status).map_err(OrtError::Run)?; - let outputs: Result>>> = output_tensor_extractors .into_iter() From 125933f3a028b318a6f79d23535c146548c48891 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 24 Jan 2021 16:14:28 +0800 Subject: [PATCH 6/7] fix onnxruntime-sys lib --- onnxruntime-sys/src/lib.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime-sys/src/lib.rs b/onnxruntime-sys/src/lib.rs index 7eb08844..62941a6b 100644 --- a/onnxruntime-sys/src/lib.rs +++ b/onnxruntime-sys/src/lib.rs @@ -6,8 +6,11 @@ #![allow(improper_ctypes)] #[allow(clippy::all)] -mod bindings; -pub use bindings::*; + +include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/generated/bindings.rs" +)); #[cfg(target_os = "windows")] pub type OnnxEnumInt = i32; From 3a3815fdae4873fec5cb0f9f8ccc4131895cf411 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 24 Jan 2021 16:19:14 +0800 Subject: [PATCH 7/7] fix download error --- onnxruntime/src/error.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index c5dceed7..0e919c48 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -132,6 +132,7 @@ pub enum OrtApiError { #[derive(Error, Debug)] pub enum OrtDownloadError { /// Generic download error + #[cfg(feature = "model-fetching")] #[error("Error downloading data")] DownloadError(#[from] ureq::Error), /// Generic input/output error