diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e47f959..c29fc1b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,9 +90,9 @@ Go into `tests/benchmarks/test_benchmarks.py` and adjust the connection details This implies you're having a running database. Then run the benchmarks with: ```bash -python -m tox -e py312-test -- --benchmark-only --benchmark-autosave --benchmark-group-by=fullname +python -m tox -e py312-test -- --benchmark-only --benchmark-autosave # or to compare the results with the previous run -python -m tox -e py312-test -- --benchmark-only --benchmark-autosave --benchmark-group-by=fullname --benchmark-compare +python -m tox -e py312-test -- --benchmark-only --benchmark-autosave --benchmark-compare ``` ### Changelog Entry diff --git a/bin/target_driver.sh b/bin/target_driver.sh index 4b95d42..191fd93 100755 --- a/bin/target_driver.sh +++ b/bin/target_driver.sh @@ -11,7 +11,8 @@ git fetch origin git checkout "$version" git pull origin "$version" cd .. -cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/v1/from_driver/test_packstream.py +cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py +cp -r driver/tests/unit/common/vector/* tests/vector/from_driver towncrier create -c "Target driver version ${version}." "+.feature" echo "=== Please rename the changelog file to match the PR number. ===" diff --git a/changelog.d/45.feature.md b/changelog.d/45.feature.md new file mode 100644 index 0000000..5dc48a6 --- /dev/null +++ b/changelog.d/45.feature.md @@ -0,0 +1,3 @@ +Add extension for the `Vector` type. +* Speed up endian conversion (byte flipping). +* Speed up conversion from and to native python types. diff --git a/neo4j/_codec/packstream/.keep b/neo4j/.keep similarity index 100% rename from neo4j/_codec/packstream/.keep rename to neo4j/.keep diff --git a/pyproject.toml b/pyproject.toml index 9276377..5f7ad5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ build-backend = "maturin" [tool.maturin] features = ["pyo3/extension-module", "pyo3/generate-import-lib"] -module-name = "neo4j._codec.packstream._rust" +module-name = "neo4j._rust" exclude = [ "/.editorconfig", ".gitignore", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6f80a8c..f6717b4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ isort>=6.0.1 tox>=4.25.0 pytest>=8.3.5 pytest-benchmark>=5.1.0 +pytest-mock>=3.14.1 # for Python driver's TestKit backend freezegun>=1.5.1 diff --git a/src/codec.rs b/src/codec.rs new file mode 100644 index 0000000..276fb44 --- /dev/null +++ b/src/codec.rs @@ -0,0 +1,33 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod packstream; + +use pyo3::prelude::*; + +use crate::register_package; + +pub(super) fn init_module(m: &Bound, name: &str) -> PyResult<()> { + let py = m.py(); + + m.gil_used(false)?; + register_package(m, name)?; + + let mod_packstream = PyModule::new(py, "packstream")?; + m.add_submodule(&mod_packstream)?; + packstream::init_module(&mod_packstream, format!("{name}.packstream").as_str())?; + + Ok(()) +} diff --git a/src/codec/packstream.rs b/src/codec/packstream.rs new file mode 100644 index 0000000..bcfe97d --- /dev/null +++ b/src/codec/packstream.rs @@ -0,0 +1,104 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod v1; + +use pyo3::basic::CompareOp; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyTuple}; +use pyo3::IntoPyObjectExt; + +use crate::register_package; + +pub(super) fn init_module(m: &Bound, name: &str) -> PyResult<()> { + let py = m.py(); + + m.gil_used(false)?; + register_package(m, name)?; + + let mod_v1 = PyModule::new(py, "v1")?; + m.add_submodule(&mod_v1)?; + v1::init_module(&mod_v1, format!("{name}.v1").as_str())?; + + m.add_class::()?; + + Ok(()) +} + +#[pyclass] +#[derive(Debug)] +pub struct Structure { + tag: u8, + #[pyo3(get)] + fields: Vec, +} + +#[pymethods] +impl Structure { + #[new] + #[pyo3(signature = (tag, *fields))] + #[pyo3(text_signature = "(tag, *fields)")] + fn new(tag: &[u8], fields: Vec) -> PyResult { + if tag.len() != 1 { + return Err(PyErr::new::("tag must be a single byte")); + } + let tag = tag[0]; + Ok(Self { tag, fields }) + } + + #[getter(tag)] + fn read_tag<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + PyBytes::new(py, &[self.tag]) + } + + #[getter(fields)] + fn read_fields<'py>(&self, py: Python<'py>) -> PyResult> { + PyTuple::new(py, &self.fields) + } + + fn eq(&self, other: &Self, py: Python<'_>) -> PyResult { + if self.tag != other.tag || self.fields.len() != other.fields.len() { + return Ok(false); + } + for (a, b) in self + .fields + .iter() + .map(|e| e.bind(py)) + .zip(other.fields.iter().map(|e| e.bind(py))) + { + if !a.eq(b)? { + return Ok(false); + } + } + Ok(true) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult { + Ok(match op { + CompareOp::Eq => self.eq(other, py)?.into_py_any(py)?, + CompareOp::Ne => (!self.eq(other, py)?).into_py_any(py)?, + _ => py.NotImplemented(), + }) + } + + fn __hash__(&self, py: Python<'_>) -> PyResult { + let mut fields_hash = 0; + for field in &self.fields { + fields_hash += field.bind(py).hash()?; + } + Ok(fields_hash.wrapping_add(self.tag.into())) + } +} diff --git a/src/v1.rs b/src/codec/packstream/v1.rs similarity index 89% rename from src/v1.rs rename to src/codec/packstream/v1.rs index 674e3fd..327f332 100644 --- a/src/v1.rs +++ b/src/codec/packstream/v1.rs @@ -19,6 +19,8 @@ mod unpack; use pyo3::prelude::*; use pyo3::wrap_pyfunction; +use crate::register_package; + const TINY_STRING: u8 = 0x80; const TINY_LIST: u8 = 0x90; const TINY_MAP: u8 = 0xA0; @@ -44,7 +46,10 @@ const BYTES_8: u8 = 0xCC; const BYTES_16: u8 = 0xCD; const BYTES_32: u8 = 0xCE; -pub(crate) fn register(m: &Bound) -> PyResult<()> { +pub(crate) fn init_module(m: &Bound, name: &str) -> PyResult<()> { + m.gil_used(false)?; + register_package(m, name)?; + m.add_function(wrap_pyfunction!(unpack::unpack, m)?)?; m.add_function(wrap_pyfunction!(pack::pack, m)?)?; diff --git a/src/v1/pack.rs b/src/codec/packstream/v1/pack.rs similarity index 99% rename from src/v1/pack.rs rename to src/codec/packstream/v1/pack.rs index 98d709e..5f89b6a 100644 --- a/src/v1/pack.rs +++ b/src/codec/packstream/v1/pack.rs @@ -22,12 +22,12 @@ use pyo3::sync::GILOnceCell; use pyo3::types::{PyBytes, PyDict, PyString, PyType}; use pyo3::{intern, IntoPyObjectExt}; +use super::super::Structure; use super::{ BYTES_16, BYTES_32, BYTES_8, FALSE, FLOAT_64, INT_16, INT_32, INT_64, INT_8, LIST_16, LIST_32, LIST_8, MAP_16, MAP_32, MAP_8, NULL, STRING_16, STRING_32, STRING_8, TINY_LIST, TINY_MAP, TINY_STRING, TINY_STRUCT, TRUE, }; -use crate::Structure; #[derive(Debug)] struct TypeMappings { diff --git a/src/v1/unpack.rs b/src/codec/packstream/v1/unpack.rs similarity index 99% rename from src/v1/unpack.rs rename to src/codec/packstream/v1/unpack.rs index a258ed9..d92cbcc 100644 --- a/src/v1/unpack.rs +++ b/src/codec/packstream/v1/unpack.rs @@ -19,12 +19,12 @@ use pyo3::sync::with_critical_section; use pyo3::types::{IntoPyDict, PyByteArray, PyBytes, PyDict, PyList, PyTuple}; use pyo3::{intern, IntoPyObjectExt}; +use super::super::Structure; use super::{ BYTES_16, BYTES_32, BYTES_8, FALSE, FLOAT_64, INT_16, INT_32, INT_64, INT_8, LIST_16, LIST_32, LIST_8, MAP_16, MAP_32, MAP_8, NULL, STRING_16, STRING_32, STRING_8, TINY_LIST, TINY_MAP, TINY_STRING, TINY_STRUCT, TRUE, }; -use crate::Structure; #[pyfunction] #[pyo3(signature = (bytes, idx, hydration_hooks=None))] diff --git a/src/lib.rs b/src/lib.rs index 5030228..8a10621 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,26 +13,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod v1; +mod codec; +mod vector; -use pyo3::basic::CompareOp; -use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyTuple}; -use pyo3::IntoPyObjectExt; #[pymodule(gil_used = false)] #[pyo3(name = "_rust")] -fn packstream(m: &Bound) -> PyResult<()> { +fn init_module(m: &Bound) -> PyResult<()> { let py = m.py(); - m.add_class::()?; + let mod_codec = PyModule::new(py, "codec")?; + m.add_submodule(&mod_codec)?; + codec::init_module(&mod_codec, "codec")?; - let mod_v1 = PyModule::new(py, "v1")?; - mod_v1.gil_used(false)?; - v1::register(&mod_v1)?; - m.add_submodule(&mod_v1)?; - register_package(&mod_v1, "v1")?; + let mod_vector = PyModule::new(py, "vector")?; + m.add_submodule(&mod_vector)?; + vector::init_module(&mod_vector, "vector")?; Ok(()) } @@ -41,7 +38,7 @@ fn packstream(m: &Bound) -> PyResult<()> { // https://github.com/PyO3/pyo3/issues/1517#issuecomment-808664021 fn register_package(m: &Bound, name: &str) -> PyResult<()> { let py = m.py(); - let module_name = format!("neo4j._codec.packstream._rust.{name}").into_pyobject(py)?; + let module_name = format!("neo4j._rust.{name}").into_pyobject(py)?; py.import("sys")? .getattr("modules")? @@ -50,68 +47,3 @@ fn register_package(m: &Bound, name: &str) -> PyResult<()> { Ok(()) } - -#[pyclass] -#[derive(Debug)] -pub struct Structure { - tag: u8, - #[pyo3(get)] - fields: Vec, -} - -#[pymethods] -impl Structure { - #[new] - #[pyo3(signature = (tag, *fields))] - #[pyo3(text_signature = "(tag, *fields)")] - fn new(tag: &[u8], fields: Vec) -> PyResult { - if tag.len() != 1 { - return Err(PyErr::new::("tag must be a single byte")); - } - let tag = tag[0]; - Ok(Self { tag, fields }) - } - - #[getter(tag)] - fn read_tag<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { - PyBytes::new(py, &[self.tag]) - } - - #[getter(fields)] - fn read_fields<'py>(&self, py: Python<'py>) -> PyResult> { - PyTuple::new(py, &self.fields) - } - - fn eq(&self, other: &Self, py: Python<'_>) -> PyResult { - if self.tag != other.tag || self.fields.len() != other.fields.len() { - return Ok(false); - } - for (a, b) in self - .fields - .iter() - .map(|e| e.bind(py)) - .zip(other.fields.iter().map(|e| e.bind(py))) - { - if !a.eq(b)? { - return Ok(false); - } - } - Ok(true) - } - - fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult { - Ok(match op { - CompareOp::Eq => self.eq(other, py)?.into_py_any(py)?, - CompareOp::Ne => (!self.eq(other, py)?).into_py_any(py)?, - _ => py.NotImplemented(), - }) - } - - fn __hash__(&self, py: Python<'_>) -> PyResult { - let mut fields_hash = 0; - for field in &self.fields { - fields_hash += field.bind(py).hash()?; - } - Ok(fields_hash.wrapping_add(self.tag.into())) - } -} diff --git a/src/vector.rs b/src/vector.rs new file mode 100644 index 0000000..ee453a5 --- /dev/null +++ b/src/vector.rs @@ -0,0 +1,41 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod native_conversion; +mod swap_endian; + +use crate::register_package; +use pyo3::prelude::*; + +pub(super) fn init_module(m: &Bound, name: &str) -> PyResult<()> { + m.gil_used(false)?; + register_package(m, name)?; + + m.add_function(wrap_pyfunction!(swap_endian::swap_endian, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_f64_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_f64_to_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_f32_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_f32_to_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i64_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i64_to_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i32_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i32_to_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i16_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i16_to_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i8_from_native, m)?)?; + m.add_function(wrap_pyfunction!(native_conversion::vec_i8_to_native, m)?)?; + + Ok(()) +} diff --git a/src/vector/native_conversion.rs b/src/vector/native_conversion.rs new file mode 100644 index 0000000..278d347 --- /dev/null +++ b/src/vector/native_conversion.rs @@ -0,0 +1,332 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pyo3::exceptions::{PyOverflowError, PyTypeError}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyFloat, PyInt, PyList}; + +// ================= +// ====== F64 ====== +// ================= + +#[pyfunction] +pub(super) fn vec_f64_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_f64(value?)?; + bytes.extend(&f64::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_f64(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to f64, expected float, got {}.", + value.get_type().name()? + ))) + } + + value + .downcast::() + .or_else(|_| make_error(&value))? + .extract() + .or_else(|_| make_error(&value)) +} + +#[pyfunction] +pub(super) fn vec_f64_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = f64::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyFloat::new(py, value) + }), + ) +} + +// ================= +// ====== F32 ====== +// ================= + +#[pyfunction] +pub(super) fn vec_f32_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_f32(value?)?; + bytes.extend(&f32::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_f32(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to f32, expected float, got {}.", + value.get_type().name()? + ))) + } + + value + .downcast::() + .or_else(|_| make_error(&value))? + .extract() + .or_else(|_| make_error(&value)) +} + +#[pyfunction] +pub(super) fn vec_f32_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = f32::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyFloat::new(py, value.into()) + }), + ) +} + +// ================= +// ====== I64 ====== +// ================= + +#[pyfunction] +pub(super) fn vec_i64_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_i64(value?)?; + bytes.extend(&i64::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_i64(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to i64, expected int, got {}.", + value.get_type().name()? + ))) + } + + let py = value.py(); + + let value = value.downcast::().or_else(|_| make_error(&value))?; + if value.lt(PyInt::new(py, i64::MIN))? || value.gt(PyInt::new(py, i64::MAX))? { + return Err(PyErr::new::(format!( + "Value {} is out of range for i64: [-9223372036854775808, 9223372036854775807]", + value.str()? + ))); + } + value.extract().or_else(|_| make_error(value)) +} + +#[pyfunction] +pub(super) fn vec_i64_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = i64::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyInt::new(py, value) + }), + ) +} + +// ================= +// ====== I32 ====== +// ================= + +#[pyfunction] +pub(super) fn vec_i32_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_i32(value?)?; + bytes.extend(&i32::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_i32(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to i32, expected int, got {}.", + value.get_type().name()? + ))) + } + + let py = value.py(); + + let value = value.downcast::().or_else(|_| make_error(&value))?; + if value.lt(PyInt::new(py, i32::MIN))? || value.gt(PyInt::new(py, i32::MAX))? { + return Err(PyErr::new::(format!( + "Value {} is out of range for i32: [-2147483648, 2147483647]", + value.str()? + ))); + } + value.extract().or_else(|_| make_error(value)) +} + +#[pyfunction] +pub(super) fn vec_i32_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = i32::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyInt::new(py, value) + }), + ) +} + +// ================= +// ====== I16 ====== +// ================= + +#[pyfunction] +pub(super) fn vec_i16_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_i16(value?)?; + bytes.extend(&i16::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_i16(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to i16, expected int, got {}.", + value.get_type().name()? + ))) + } + + let py = value.py(); + + let value = value.downcast::().or_else(|_| make_error(&value))?; + if value.lt(PyInt::new(py, i16::MIN))? || value.gt(PyInt::new(py, i16::MAX))? { + return Err(PyErr::new::(format!( + "Value {} is out of range for i16: [-32768, 32767]", + value.str()? + ))); + } + value.extract().or_else(|_| make_error(value)) +} + +#[pyfunction] +pub(super) fn vec_i16_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = i16::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyInt::new(py, value) + }), + ) +} + +// ================ +// ====== I8 ====== +// ================ + +#[pyfunction] +pub(super) fn vec_i8_from_native<'py>(data: Bound<'py, PyAny>) -> PyResult> { + let py = data.py(); + + let data_iter = data.try_iter()?; + let mut bytes = Vec::with_capacity(data_iter.size_hint().0.saturating_mul(size_of::())); + for value in data_iter { + let value = vec_value_as_i8(value?)?; + bytes.extend(&i8::to_be_bytes(value)); + } + Ok(PyBytes::new(py, &bytes)) +} + +fn vec_value_as_i8(value: Bound) -> PyResult { + fn make_error(value: &Bound) -> PyResult { + Err(PyErr::new::(format!( + "Cannot convert value to i8, expected int, got {}.", + value.get_type().name()? + ))) + } + + let py = value.py(); + + let value = value.downcast::().or_else(|_| make_error(&value))?; + if value.lt(PyInt::new(py, i8::MIN))? || value.gt(PyInt::new(py, i8::MAX))? { + return Err(PyErr::new::(format!( + "Value {} is out of range for i8: [-128, 127]", + value.str()? + ))); + } + value.extract().or_else(|_| make_error(value)) +} + +#[pyfunction] +pub(super) fn vec_i8_to_native<'py>(data: Bound<'py, PyBytes>) -> PyResult> { + const DATA_SIZE: usize = size_of::(); + let py = data.py(); + PyList::new( + py, + data.as_bytes().chunks(DATA_SIZE).map(|chunk| { + let value = i8::from_be_bytes( + chunk + .try_into() + .expect("bytes size is not multiple of type size"), + ); + PyInt::new(py, value) + }), + ) +} diff --git a/src/vector/swap_endian.rs b/src/vector/swap_endian.rs new file mode 100644 index 0000000..511d652 --- /dev/null +++ b/src/vector/swap_endian.rs @@ -0,0 +1,66 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyInt}; +use pyo3::{pyfunction, Bound, PyErr, PyResult}; + +#[pyfunction] +pub(super) fn swap_endian<'py>( + type_size: Bound<'py, PyInt>, + data: Bound<'py, PyBytes>, +) -> PyResult> { + let py = type_size.py(); + + let type_size: usize = match type_size.extract::() { + Ok(type_size @ 2) | Ok(type_size @ 4) | Ok(type_size @ 8) => type_size, + _ => { + return Err(PyErr::new::(format!( + "Unsupported type size {type_size}", + ))) + } + }; + let bytes = &data.as_bytes(); + let len = bytes.len(); + if len % type_size != 0 { + return Err(PyErr::new::( + "Data length not a multiple of type_size", + )); + } + + PyBytes::new_with(py, bytes.len(), |out| { + match type_size { + 2 => swap_n::<2>(bytes, out), + 4 => swap_n::<4>(bytes, out), + 8 => swap_n::<8>(bytes, out), + _ => unreachable!(), + } + Ok(()) + }) +} + +#[inline] +fn swap_n(src: &[u8], dst: &mut [u8]) { + // Doesn't technically need to be a function with a const generic, but this + // allows the compiler to optimize the code better. + assert_eq!(src.len(), dst.len()); + assert_eq!(src.len() % N, 0); + for i in (0..src.len()).step_by(N) { + for j in 0..N { + dst[i + j] = src[i + N - j - 1]; + } + } +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_macro_benchmarks.py similarity index 100% rename from tests/benchmarks/test_benchmarks.py rename to tests/benchmarks/test_macro_benchmarks.py diff --git a/tests/benchmarks/test_vector_benchmarks.py b/tests/benchmarks/test_vector_benchmarks.py new file mode 100644 index 0000000..c2cd6cc --- /dev/null +++ b/tests/benchmarks/test_vector_benchmarks.py @@ -0,0 +1,68 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import pytest + +from ..vector.from_driver.test_vector import ( + _mock_mask_extensions, + _swap_endian, + Vector, +) + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("type_size", (2, 4, 8)) +@pytest.mark.parametrize("length", (1, 100_000)) +def test_bench_swap_endian(benchmark, mocker, ext, type_size, length): + data = bytes(i % 256 for i in range(8 * length)) + _mock_mask_extensions(mocker, ext) + rounds = max(min(1_000_000 // length, 100_000), 100) + + benchmark.pedantic(lambda: _swap_endian(type_size, data), rounds=rounds) + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("as_gen", (True, False)) +@pytest.mark.parametrize("length", (1, 1_000)) +def test_bench_from_native(benchmark, mocker, ext, dtype, as_gen, length): + raw_data = bytes(i % 256 for i in range(8 * length)) + data = Vector.from_bytes(raw_data, dtype).to_native() + rounds = max(min(1_000_000 // length, 100_000), 100) + _mock_mask_extensions(mocker, ext) + if as_gen: + + def work(): + Vector.from_native((x for x in data), dtype) + else: + + def work(): + Vector.from_native(data, dtype) + + benchmark.pedantic(work, rounds=rounds) + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("length", (1, 1_000)) +def test_bench_to_native(benchmark, mocker, ext, dtype, length): + data = Vector.from_bytes(bytes(i % 256 for i in range(8 * length)), dtype) + rounds = max(min(1_000_000 // length, 100_000), 100) + _mock_mask_extensions(mocker, ext) + + benchmark.pedantic(data.to_native, rounds=rounds) diff --git a/tests/codec/__init__.py b/tests/codec/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/__init__.py b/tests/codec/packstream/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/v1/__init__.py b/tests/codec/packstream/v1/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/v1/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/v1/from_driver/__init__.py b/tests/codec/packstream/v1/from_driver/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/v1/from_driver/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/v1/from_driver/test_packstream.py b/tests/codec/packstream/v1/from_driver/test_packstream.py similarity index 100% rename from tests/v1/from_driver/test_packstream.py rename to tests/codec/packstream/v1/from_driver/test_packstream.py diff --git a/tests/v1/test_injection.py b/tests/codec/packstream/v1/test_injection.py similarity index 88% rename from tests/v1/test_injection.py rename to tests/codec/packstream/v1/test_injection.py index bbb56d6..6869af9 100644 --- a/tests/v1/test_injection.py +++ b/tests/codec/packstream/v1/test_injection.py @@ -97,21 +97,24 @@ def raise_test_exception(*args, **kwargs): @pytest.mark.parametrize( - ("name", "package_names"), + ("name", "submodule_names"), ( - ("neo4j._codec.packstream._rust.v1", ()), - ("neo4j._codec.packstream._rust", ("v1",)), - ("neo4j._codec.packstream", ("_rust",)), + # packstream v1 + ("neo4j._rust.codec.packstream.v1", ()), + ("neo4j._rust.codec.packstream", ("v1",)), + ("neo4j._rust.codec", ("packstream",)), + ("neo4j._rust", ("codec",)), + ("neo4j", ("_rust",)), ), ) -def test_import_module(name, package_names): +def test_import_module(name, submodule_names): module = importlib.import_module(name) assert module.__name__ == name - for package_name in package_names: - package = getattr(module, package_name) - assert package.__name__ == f"{name}.{package_name}" + for submodule_name in submodule_names: + package = getattr(module, submodule_name) + assert package.__name__ == f"{name}.{submodule_name}" def test_rust_struct_access(): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ab6a346 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict + +import pytest + + +@pytest.hookimpl(wrapper=True) +def pytest_benchmark_group_stats(config, benchmarks, group_by): + outcome = yield + + if group_by != "group": + # not default grouping, so let the user have what they asked for + return outcome + + result = defaultdict(list) + for bench in benchmarks: + param_start = bench["fullname"].rfind("[") + if param_start < 0: + base_name = bench["fullname"] + else: + base_name = bench["fullname"][:param_start] + params = bench.get("params", None) + if params is None: + result[base_name].append(bench) + continue + ext = params.get("ext", None) + if ext is None: + result[base_name].append(bench) + continue + param_keys = sorted(params.keys()) + name_params = "-".join( + str(params[k]) for k in param_keys if k != "ext" + ) + group_name = f"{base_name}[{name_params}]" + result[group_name].append(bench) + + return result.items() diff --git a/tests/vector/__init__.py b/tests/vector/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/vector/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/vector/from_driver/__init__.py b/tests/vector/from_driver/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/vector/from_driver/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/vector/from_driver/test_import_vector.py b/tests/vector/from_driver/test_import_vector.py new file mode 100644 index 0000000..8be94f1 --- /dev/null +++ b/tests/vector/from_driver/test_import_vector.py @@ -0,0 +1,73 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib + +import pytest + + +MODULE_PATH = "neo4j.vector" +VECTOR_ATTRIBUTES = ( + # (name, warning) + ("Vector", None), + ("VectorDType", None), + ("VectorEndian", None), +) + + +def _get_module(): + module = importlib.__import__(MODULE_PATH) + for submodule in MODULE_PATH.split(".")[1:]: + module = getattr(module, submodule) + return module + + +@pytest.mark.parametrize(("name", "warning"), VECTOR_ATTRIBUTES) +def test_attribute_import(name, warning): + module = _get_module() + if warning: + with pytest.warns(warning): + getattr(module, name) + else: + getattr(module, name) + + +@pytest.mark.parametrize(("name", "warning"), VECTOR_ATTRIBUTES) +def test_attribute_from_import(name, warning): + if warning: + with pytest.warns(warning): + importlib.__import__(MODULE_PATH, fromlist=(name,)) + else: + importlib.__import__(MODULE_PATH, fromlist=(name,)) + + +def test_all(): + module = _get_module() + + assert sorted(module.__all__) == sorted([i[0] for i in VECTOR_ATTRIBUTES]) + + +def test_dir(): + module = _get_module() + + dir_attrs = (attr for attr in dir(module) if not attr.startswith("_")) + assert sorted(dir_attrs) == sorted([i[0] for i in VECTOR_ATTRIBUTES]) + + +def test_import_star(): + # ignore PT029: purposefully capturing all warnings to then apply further + # checks on them + importlib.__import__(MODULE_PATH, fromlist=("*",)) diff --git a/tests/vector/from_driver/test_vector.py b/tests/vector/from_driver/test_vector.py new file mode 100644 index 0000000..72077ae --- /dev/null +++ b/tests/vector/from_driver/test_vector.py @@ -0,0 +1,809 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import random +import struct +import sys +import typing as t + +import pytest + +from neo4j._optional_deps import ( + np, + pa, +) +from neo4j.vector import ( + _swap_endian, + Vector, +) + + +if t.TYPE_CHECKING: + import numpy + import pyarrow + + +def _max_value_be_bytes(size: t.Literal[1, 2, 4, 8], count: int = 1) -> bytes: + def generator(count_: int) -> t.Iterable[int]: + pack_format = { + 1: ">b", + 2: ">h", + 4: ">i", + 8: ">q", + }[size] + if count_ <= 0: + return + yield from struct.pack(pack_format, 0) + count_ -= 1 + i = 0 + min_value = -(2 ** (size * 8 - 1)) + max_value = 2 ** (size * 8 - 1) - 1 + while True: + if count_ <= 0: + return + yield from struct.pack(pack_format, min_value + i) + count_ -= 1 + if count_ == 0: + return + yield from struct.pack(pack_format, max_value - i) + count_ -= 1 + i += 1 + i %= 2 ** (size * 8) + + return bytes(generator(count)) + + +def _random_value_be_bytes( + size: t.Literal[1, 2, 4, 8], count: int = 1 +) -> bytes: + def generator(count_: int) -> t.Iterable[int]: + pack_format = { + 1: ">B", + 2: ">H", + 4: ">I", + 8: ">Q", + }[size] + while count_ > 0: + yield from struct.pack( + pack_format, random.randint(0, 2 ** (size * 8) - 1) + ) + count_ -= 1 + + return bytes(generator(count)) + + +def _get_type_size(dtype: str) -> t.Literal[1, 2, 4, 8]: + lookup: dict[str, t.Literal[1, 2, 4, 8]] = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f32": 4, + "f64": 8, + } + return lookup[dtype] + + +def _normalize_float_bytes(dtype: str, data: bytes) -> bytes: + if dtype not in {"f32", "f64"}: + raise ValueError(f"Invalid dtype {dtype}") + type_size = _get_type_size(dtype) + pack_format = _dtype_to_pack_format(dtype) + chunks = (data[i : i + type_size] for i in range(0, len(data), type_size)) + return bytes( + b + for chunk in chunks + for b in struct.pack(pack_format, struct.unpack(pack_format, chunk)[0]) + ) + + +def _dtype_to_pack_format(dtype: str) -> str: + return { + "i8": ">b", + "i16": ">h", + "i32": ">i", + "i64": ">q", + "f32": ">f", + "f64": ">d", + }[dtype] + + +def _mock_mask_extensions(mocker, used_ext): + from neo4j.vector import ( + _swap_endian_unchecked_np, + _swap_endian_unchecked_py, + _swap_endian_unchecked_rust, + _VecF32, + _VecF64, + _VecI8, + _VecI16, + _VecI32, + _VecI64, + ) + + vec_types = (_VecF64, _VecF32, _VecI64, _VecI32, _VecI16, _VecI8) + match used_ext: + case "numpy": + if _swap_endian_unchecked_np is None: + pytest.skip("numpy not installed") + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_np, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_np, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_np, + ) + case "rust": + if _swap_endian_unchecked_rust is None: + pytest.skip("rust extensions are not installed") + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_rust, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_rust, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_rust, + ) + case "python": + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_py, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_py, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_py, + ) + case _: + raise ValueError(f"Invalid ext value {used_ext}") + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_swap_endian(mocker, ext): + data = bytes(range(1, 17)) + _mock_mask_extensions(mocker, ext) + res = _swap_endian(2, data) + assert isinstance(res, bytes) + assert res == bytes( + (2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15) + ) + res = _swap_endian(4, data) + assert isinstance(res, bytes) + assert res == bytes( + (4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13) + ) + res = _swap_endian(8, data) + assert isinstance(res, bytes) + assert res == bytes( + (8, 7, 6, 5, 4, 3, 2, 1, 16, 15, 14, 13, 12, 11, 10, 9) + ) + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("type_size", (-1, 0, 3, 5, 7, 9, 16, 32)) +def test_swap_endian_unhandled_size(mocker, ext, type_size): + data = bytes(i % 256 for i in range(1, abs(type_size) * 4)) + _mock_mask_extensions(mocker, ext) + + with pytest.raises(ValueError, match=str(type_size)): + _swap_endian(type_size, data) + + +@pytest.mark.parametrize( + ("dtype", "data"), + ( + ("i8", b""), + ("i8", b"\x01"), + ("i8", b"\x01\x02\x03\x04"), + ("i8", _max_value_be_bytes(1, 4096)), + ("i16", b""), + ("i16", b"\x00\x01"), + ("i16", b"\x00\x01\x00\x02"), + ("i16", _max_value_be_bytes(2, 4096)), + ("i32", b""), + ("i32", b"\x00\x00\x00\x01"), + ("i32", b"\x00\x00\x00\x01\x00\x00\x00\x02"), + ("i32", _max_value_be_bytes(4, 4096)), + ("i64", b""), + ("i64", b"\x00\x00\x00\x00\x00\x00\x00\x01"), + ( + "i64", + ( + b"\x00\x00\x00\x00\x00\x00\x00\x01" + b"\x00\x00\x00\x00\x00\x00\x00\x02" + ), + ), + ("i64", _max_value_be_bytes(8, 4096)), + ("f32", b""), + ("f32", _random_value_be_bytes(4, 4096)), + ("f64", b""), + ("f64", _random_value_be_bytes(8, 4096)), + ), +) +@pytest.mark.parametrize("input_endian", (None, "big", "little")) +@pytest.mark.parametrize("as_bytearray", (False, True)) +def test_raw_data( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + data: bytes, + input_endian: t.Literal["big", "little"] | None, + as_bytearray: bool, +) -> None: + swapped_data = _swap_endian(_get_type_size(dtype), data) + if input_endian is None: + input_data = bytearray(data) if as_bytearray else data + v = Vector(input_data, dtype) + elif input_endian == "big": + input_data = bytearray(data) if as_bytearray else data + v = Vector(input_data, dtype, byteorder=input_endian) + elif input_endian == "little": + input_data = bytearray(swapped_data) if as_bytearray else swapped_data + v = Vector(input_data, dtype, byteorder=input_endian) + else: + raise ValueError(f"Invalid input_endian {input_endian}") + assert v.dtype == dtype + assert v.raw() == data + assert v.raw(byteorder="big") == data + assert v.raw(byteorder="little") == swapped_data + + +def nan_equals(a: list[object], b: list[object]) -> bool: + if len(a) != len(b): + return False + for i in range(len(a)): + ai = a[i] + bi = b[i] + if ai != bi and not ( + isinstance(ai, float) + and isinstance(bi, float) + and math.isnan(ai) + and math.isnan(bi) + ): + return False + return True + + +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + repeat: int, + size: int, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + type_size = _get_type_size(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + values = [ + struct.unpack( + _dtype_to_pack_format(dtype), data[i : i + type_size] + )[0] + for i in range(0, len(data), type_size) + ] + v = Vector.from_native(values, dtype) + expected_raw = data + if dtype.startswith("f"): + expected_raw = _normalize_float_bytes(dtype, data) + assert v.raw() == expected_raw + + +SPECIAL_VALUES = ( + # (dtype, value, packed_bytes_be) + # i8 + ("i8", -128, b"\x80"), + ("i8", 0, b"\x00"), + ("i8", 127, b"\x7f"), + # i16 + ("i16", -32768, b"\x80\x00"), + ("i16", 0, b"\x00\x00"), + ("i16", 32767, b"\x7f\xff"), + # i32 + ("i32", -2147483648, b"\x80\x00\x00\x00"), + ("i32", 0, b"\x00\x00\x00\x00"), + ("i32", 2147483647, b"\x7f\xff\xff\xff"), + # i64 + ("i64", -9223372036854775808, b"\x80\x00\x00\x00\x00\x00\x00\x00"), + ("i64", 0, b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ("i64", 9223372036854775807, b"\x7f\xff\xff\xff\xff\xff\xff\xff"), + # f32 + # NaN + ("f32", float("nan"), b"\x7f\xc0\x00\x00"), + ("f32", float("-nan"), b"\xff\xc0\x00\x00"), + ( + "f32", + struct.unpack(">f", b"\x7f\xc0\x00\x11")[0], + b"\x7f\xc0\x00\x11", + ), + ( + "f32", + struct.unpack(">f", b"\x7f\x80\x00\x01")[0], + # Python < 3.14 does not properly preserver all NaN payload + # when calling struct.pack + _normalize_float_bytes("f32", b"\x7f\x80\x00\x01"), + ), + # ±inf + ("f32", float("inf"), b"\x7f\x80\x00\x00"), + ("f32", float("-inf"), b"\xff\x80\x00\x00"), + # ±0.0 + ("f32", 0.0, b"\x00\x00\x00\x00"), + ("f32", -0.0, b"\x80\x00\x00\x00"), + # smallest normal + ( + "f32", + struct.unpack(">f", b"\x00\x80\x00\x00")[0], + b"\x00\x80\x00\x00", + ), + ( + "f32", + struct.unpack(">f", b"\x80\x80\x00\x00")[0], + b"\x80\x80\x00\x00", + ), + # subnormal + ( + "f32", + struct.unpack(">f", b"\x00\x00\x00\x01")[0], + b"\x00\x00\x00\x01", + ), + ( + "f32", + struct.unpack(">f", b"\x80\x00\x00\x01")[0], + b"\x80\x00\x00\x01", + ), + # largest normal + ( + "f32", + struct.unpack(">f", b"\x7f\x7f\xff\xff")[0], + b"\x7f\x7f\xff\xff", + ), + ( + "f32", + struct.unpack(">f", b"\xff\x7f\xff\xff")[0], + b"\xff\x7f\xff\xff", + ), + # f64 + # NaN + ("f64", float("nan"), b"\x7f\xf8\x00\x00\x00\x00\x00\x00"), + ("f64", float("-nan"), b"\xff\xf8\x00\x00\x00\x00\x00\x00"), + ( + "f64", + struct.unpack(">d", b"\x7f\xf8\x00\x00\x00\x00\x00\x11")[0], + b"\x7f\xf8\x00\x00\x00\x00\x00\x11", + ), + ( + "f64", + struct.unpack(">d", b"\x7f\xf0\x00\x01\x00\x00\x00\x01")[0], + b"\x7f\xf0\x00\x01\x00\x00\x00\x01", + ), + # ±inf + ("f64", float("inf"), b"\x7f\xf0\x00\x00\x00\x00\x00\x00"), + ("f64", float("-inf"), b"\xff\xf0\x00\x00\x00\x00\x00\x00"), + # ±0.0 + ("f64", 0.0, b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ("f64", -0.0, b"\x80\x00\x00\x00\x00\x00\x00\x00"), + # smallest normal + ( + "f64", + struct.unpack(">d", b"\x00\x10\x00\x00\x00\x00\x00\x00")[0], + b"\x00\x10\x00\x00\x00\x00\x00\x00", + ), + ( + "f64", + struct.unpack(">d", b"\x80\x10\x00\x00\x00\x00\x00\x00")[0], + b"\x80\x10\x00\x00\x00\x00\x00\x00", + ), + # subnormal + ( + "f64", + struct.unpack(">d", b"\x00\x00\x00\x00\x00\x00\x00\x01")[0], + b"\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + "f64", + struct.unpack(">d", b"\x80\x00\x00\x00\x00\x00\x00\x01")[0], + b"\x80\x00\x00\x00\x00\x00\x00\x01", + ), + # largest normal + ( + "f64", + struct.unpack(">d", b"\x7f\xef\xff\xff\xff\xff\xff\xff")[0], + b"\x7f\xef\xff\xff\xff\xff\xff\xff", + ), + ( + "f64", + struct.unpack(">d", b"\xff\xef\xff\xff\xff\xff\xff\xff")[0], + b"\xff\xef\xff\xff\xff\xff\xff\xff", + ), +) + + +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + if dtype in {"f32", "f64"}: + assert isinstance(value, float) + dtype_f = t.cast(t.Literal["f32", "f64"], dtype) + v = Vector.from_native([value], dtype_f) + elif dtype in {"i8", "i16", "i32", "i64"}: + assert isinstance(value, int) + dtype_i = t.cast(t.Literal["i8", "i16", "i32", "i64"], dtype) + v = Vector.from_native([value], dtype_i) + else: + raise ValueError(f"Invalid dtype {dtype}") + assert v.raw() == data_be + + +@pytest.mark.parametrize( + ("dtype", "value"), + ( + ("i8", "1"), + ("i8", None), + ("i8", 1.0), + ("i16", "1"), + ("i16", None), + ("i16", 1.0), + ("i32", "1"), + ("i32", None), + ("i32", 1.0), + ("i64", "1"), + ("i64", None), + ("i64", 1.0), + ("f32", "1.0"), + ("f32", None), + ("f32", 1), + ("f64", "1.0"), + ("f64", None), + ("f64", 1), + ), +) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_wrong_type( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + with pytest.raises(TypeError) as exc: + Vector.from_native([value], dtype) # type: ignore + + assert dtype in str(exc.value) + assert str(type(value).__name__) in str(exc.value) + + +@pytest.mark.parametrize( + ("dtype", "value"), + ( + ("i8", -129), + ("i8", 128), + ("i16", -32769), + ("i16", 32768), + ("i32", -2147483649), + ("i32", 2147483648), + ("i64", -9223372036854775809), + ("i64", 9223372036854775808), + ), +) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_overflow( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + with pytest.raises(OverflowError) as exc: + Vector.from_native([value], dtype) # type: ignore + + assert dtype in str(exc.value) + + +def _vector_from_data( + data: bytes, + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, +) -> Vector: + match endian: + case None: + return Vector(data, dtype) + case "big": + return Vector(data, dtype, byteorder=endian) + case "little": + type_size = _get_type_size(dtype) + data_le = _swap_endian(type_size, data) + return Vector(data_le, dtype, byteorder=endian) + case _: + raise ValueError(f"Invalid endian {endian}") + + +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_native_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + expected = [ + struct.unpack( + _dtype_to_pack_format(dtype), data[i : i + type_size] + )[0] + for i in range(0, len(data), type_size) + ] + v = _vector_from_data(data, dtype, endian) + assert nan_equals(v.to_native(), expected) + + +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +def test_to_native_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, +) -> None: + type_size = _get_type_size(dtype) + pack_format = _dtype_to_pack_format(dtype) + expected = [ + struct.unpack(pack_format, data_be[i : i + type_size])[0] + for i in range(0, len(data_be), type_size) + ] + v = Vector(data_be, dtype) + assert nan_equals(v.to_native(), expected) + + +def _get_numpy_dtype(dtype: str) -> str: + return { + "i8": "i1", + "i16": "i2", + "i32": "i4", + "i64": "i8", + "f32": "f4", + "f64": "f8", + }[dtype] + + +def _get_numpy_array( + data_be: bytes, dtype: str, endian: t.Literal["big", "little", "native"] +) -> numpy.ndarray: + np_type = _get_numpy_dtype(dtype) + type_size = _get_type_size(dtype) + data_in = data_be + match endian: + case "big": + data_in = data_be + np_type = f">{np_type}" + case "little": + data_in = _swap_endian(type_size, data_be) + np_type = f"<{np_type}" + case "native": + if sys.byteorder == "little": + data_in = _swap_endian(type_size, data_be) + np_type = f"={np_type}" + return np.frombuffer(data_in, dtype=np_type) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_from_numpy_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + array = _get_numpy_array(data_be, dtype, endian) + v = Vector.from_numpy(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +def test_from_numpy_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + value: object, + data_be: bytes, +) -> None: + array = _get_numpy_array(data_be, dtype, endian) + v = Vector.from_numpy(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_numpy_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + np_type = _get_numpy_dtype(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + v = _vector_from_data(data, dtype, endian) + array = v.to_numpy() + assert array.dtype == np.dtype(f">{np_type}") + assert array.size == len(data) // type_size + assert array.tobytes() == data + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", None)) +def test_to_numpy_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + value: object, + data_be: bytes, +) -> None: + np_type = _get_numpy_dtype(dtype) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_numpy() + assert array.dtype == np.dtype(f">{np_type}") + assert array.size == 1 + assert array.tobytes() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +def _get_pyarrow_dtype(dtype: str) -> pyarrow.DataType: + return { + "i8": pa.int8(), + "i16": pa.int16(), + "i32": pa.int32(), + "i64": pa.int64(), + "f32": pa.float32(), + "f64": pa.float64(), + }[dtype] + + +def _get_pyarrow_array(data_be: bytes, dtype: str) -> pyarrow.Array: + type_size = _get_type_size(dtype) + length = len(data_be) // type_size + data_in = data_be + if sys.byteorder == "little": + data_in = _swap_endian(type_size, data_be) + pa_type = _get_pyarrow_dtype(dtype) + buffers = [None, pa.py_buffer(data_in)] + return pa.Array.from_buffers(pa_type, length, buffers, 0) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_from_pyarrow_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + array = _get_pyarrow_array(data_be, dtype) + v = Vector.from_pyarrow(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.to_pylist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +def test_from_pyarrow_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, +) -> None: + array = _get_pyarrow_array(data_be, dtype) + v = Vector.from_pyarrow(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.to_pylist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_pyarrow_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + pa_type = _get_pyarrow_dtype(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + data_ne = data_be + if sys.byteorder == "little": + data_ne = _swap_endian(type_size, data_be) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_pyarrow() + assert array.type == pa_type + assert pa.compute.count(array, mode="only_null").as_py() == 0 + buffers = array.buffers() + assert len(buffers) == 2 + assert buffers[0] is None + assert buffers[1].to_pybytes() == data_ne + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", None)) +def test_to_pyarrow_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + value: object, + data_be: bytes, +) -> None: + type_size = _get_type_size(dtype) + data_ne = data_be + if sys.byteorder == "little": + data_ne = _swap_endian(type_size, data_be) + pa_type = _get_pyarrow_dtype(dtype) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_pyarrow() + assert array.type == pa_type + assert pa.compute.count(array, mode="only_null").as_py() == 0 + buffers = array.buffers() + assert len(buffers) == 2 + assert buffers[0] is None + assert buffers[1].to_pybytes() == data_ne + assert nan_equals(array.tolist(), v.to_native()) diff --git a/tests/vector/test_injection.py b/tests/vector/test_injection.py new file mode 100644 index 0000000..edcff32 --- /dev/null +++ b/tests/vector/test_injection.py @@ -0,0 +1,112 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import neo4j.vector + + +def test_endian_swap_was_imported(): + swap = neo4j.vector._swap_endian_unchecked_rust + assert swap is not None + assert swap is neo4j._rust.vector.swap_endian + assert neo4j.vector._swap_endian_unchecked is swap + + +def test_endian_swap_was_injected(mocker): + mock = mocker.patch("neo4j.vector._swap_endian_unchecked") + neo4j.vector._swap_endian(2, b"\x01\x02\x03\x04") + mock.assert_called_once_with(2, b"\x01\x02\x03\x04") + + +@pytest.mark.parametrize( + "vec_cls", + ( + neo4j.vector._VecF64, + neo4j.vector._VecF32, + neo4j.vector._VecI64, + neo4j.vector._VecI32, + neo4j.vector._VecI16, + neo4j.vector._VecI8, + ), +) +def test_vec_from_native_was_imported(vec_cls): + vec_rust = neo4j.vector._vec_rust + assert vec_rust is not None + assert vec_cls.from_native == vec_cls._from_native_rust + + +@pytest.mark.parametrize( + ("dtype", "value", "method"), + ( + ("f64", 1.0, "vec_f64_from_native"), + ("f32", 1.0, "vec_f32_from_native"), + ("i64", 1, "vec_i64_from_native"), + ("i32", 1, "vec_i32_from_native"), + ("i16", 1, "vec_i16_from_native"), + ("i8", 1, "vec_i8_from_native"), + ), +) +def test_vec_from_native_was_injected(dtype, value, method, mocker): + mock = mocker.patch("neo4j.vector._vec_rust") + rust_mock = getattr(mock, method) + rust_mock.return_value = b"" + + data = [value] + + neo4j.vector.Vector.from_native(data, dtype) + + getattr(mock, method).assert_called_once_with(data) + + +@pytest.mark.parametrize( + "vec_cls", + ( + neo4j.vector._VecF64, + neo4j.vector._VecF32, + neo4j.vector._VecI64, + neo4j.vector._VecI32, + neo4j.vector._VecI16, + neo4j.vector._VecI8, + ), +) +def test_vec_to_native_was_imported(vec_cls): + vec_rust = neo4j.vector._vec_rust + assert vec_rust is not None + assert vec_cls.to_native == vec_cls._to_native_rust + + +@pytest.mark.parametrize( + ("dtype", "method"), + ( + ("f64", "vec_f64_to_native"), + ("f32", "vec_f32_to_native"), + ("i64", "vec_i64_to_native"), + ("i32", "vec_i32_to_native"), + ("i16", "vec_i16_to_native"), + ("i8", "vec_i8_to_native"), + ), +) +def test_vec_to_native_was_injected(dtype, method, mocker): + mock = mocker.patch("neo4j.vector._vec_rust") + + data = bytes(range(8)) + vec = neo4j.vector.Vector.from_bytes(data, dtype) + getattr(mock, method).assert_not_called() + + vec.to_native() + + getattr(mock, method).assert_called_once_with(data)