Skip to content

Commit 5702501

Browse files
Move numpy dense vector to a subclass
1 parent d4187c2 commit 5702501

File tree

6 files changed

+64
-44
lines changed

6 files changed

+64
-44
lines changed

elasticsearch/dsl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
MatchOnlyText,
7474
Murmur3,
7575
Nested,
76+
NumpyDenseVector,
7677
Object,
7778
Passthrough,
7879
Percolator,
@@ -189,6 +190,7 @@
189190
"Murmur3",
190191
"Nested",
191192
"NestedFacet",
193+
"NumpyDenseVector",
192194
"Object",
193195
"Passthrough",
194196
"Percolator",

elasticsearch/dsl/field.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,8 +1555,6 @@ class DenseVector(Field):
15551555
:arg dynamic:
15561556
:arg fields:
15571557
:arg synthetic_source_keep:
1558-
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
1559-
:arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
15601558
"""
15611559

15621560
name = "dense_vector"
@@ -1589,8 +1587,6 @@ def __init__(
15891587
synthetic_source_keep: Union[
15901588
Literal["none", "arrays", "all"], "DefaultType"
15911589
] = DEFAULT,
1592-
use_numpy: bool = False,
1593-
dtype: str = "float32",
15941590
**kwargs: Any,
15951591
):
15961592
if dims is not DEFAULT:
@@ -1618,20 +1614,31 @@ def __init__(
16181614
self._element_type = kwargs.get("element_type", "float")
16191615
if self._element_type in ["float", "byte"]:
16201616
kwargs["multi"] = True
1621-
self._use_numpy = use_numpy
1622-
self._dtype = dtype
16231617
super().__init__(*args, **kwargs)
16241618

1619+
1620+
class NumpyDenseVector(DenseVector):
1621+
"""A dense vector field that uses numpy arrays.
1622+
1623+
Accepts the same arguments as class ``DenseVector`` plus:
1624+
1625+
:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
1626+
"""
1627+
1628+
def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
1629+
super().__init__(*args, **kwargs)
1630+
self._dtype = dtype
1631+
16251632
def deserialize(self, data: Any) -> Any:
1626-
if self._use_numpy and isinstance(data, list):
1633+
if isinstance(data, list):
16271634
import numpy as np
16281635

1629-
return np.array(data, dtype=getattr(np, self._dtype))
1636+
return np.array(data, dtype=self._dtype)
16301637
return super().deserialize(data)
16311638

16321639
def clean(self, data: Any) -> Any:
1633-
# this method does the same as the one in the parent class, but it
1634-
# avoids comparisons that break when data is a numpy array
1640+
# this method does the same as the one in the parent classes, but it
1641+
# avoids comparisons that do not work for numpy arrays
16351642
if data is not None:
16361643
data = self.deserialize(data)
16371644
if (data is None or len(data) == 0) and self._required:

examples/quotes/backend/quotes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class Quote(AsyncBaseESModel):
2424
embedding: Annotated[
2525
np.ndarray,
2626
PlainSerializer(lambda v: v.tolist()),
27-
dsl.DenseVector(use_numpy=True)
28-
] = Field(init=False, default_factory=lambda: np.array([]))
27+
dsl.NumpyDenseVector(dtype=np.float32)
28+
] = Field(init=False, default_factory=lambda: np.array([], dtype=np.float32))
2929

3030
class Config:
3131
arbitrary_types_allowed = True

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
Mapping,
4949
MetaField,
5050
Nested,
51+
NumpyDenseVector,
5152
Object,
5253
Q,
5354
RankFeatures,
@@ -866,29 +867,33 @@ class Doc(AsyncDocument):
866867
float_vector: List[float] = mapped_field(DenseVector())
867868
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
868869
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
869-
numpy_float_vector: np.ndarray = mapped_field(DenseVector(use_numpy=True))
870+
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
870871

871872
class Index:
872873
name = "vectors"
873874

874875
await Doc._index.delete(ignore_unavailable=True)
875876
await Doc.init()
876877

878+
test_float_vector = [1.0, 1.2, 2.3]
879+
test_byte_vector = [12, 23, 34, 45]
880+
test_bit_vector = [18, -43, -112]
881+
877882
doc = Doc(
878-
float_vector=[1.0, 1.2, 2.3],
879-
byte_vector=[12, 23, 34, 45],
880-
bit_vector=[18, -43, -112],
881-
numpy_float_vector=np.array([3.1, 2.25, 1.0]),
883+
float_vector=test_float_vector,
884+
byte_vector=test_byte_vector,
885+
bit_vector=test_bit_vector,
886+
numpy_float_vector=np.array(test_float_vector),
882887
)
883888
await doc.save(refresh=True)
884889

885890
docs = await Doc.search().execute()
886891
assert len(docs) == 1
887-
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
888-
assert docs[0].byte_vector == doc.byte_vector
889-
assert docs[0].bit_vector == doc.bit_vector
892+
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
893+
assert docs[0].byte_vector == test_byte_vector
894+
assert docs[0].bit_vector == test_bit_vector
890895
assert type(docs[0].numpy_float_vector) is np.ndarray
891-
assert np.array_equal(docs[0].numpy_float_vector, doc.numpy_float_vector)
896+
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
892897

893898

894899
@pytest.mark.anyio

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Mapping,
4848
MetaField,
4949
Nested,
50+
NumpyDenseVector,
5051
Object,
5152
Q,
5253
RankFeatures,
@@ -854,29 +855,33 @@ class Doc(Document):
854855
float_vector: List[float] = mapped_field(DenseVector())
855856
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
856857
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
857-
numpy_float_vector: np.ndarray = mapped_field(DenseVector(use_numpy=True))
858+
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
858859

859860
class Index:
860861
name = "vectors"
861862

862863
Doc._index.delete(ignore_unavailable=True)
863864
Doc.init()
864865

866+
test_float_vector = [1.0, 1.2, 2.3]
867+
test_byte_vector = [12, 23, 34, 45]
868+
test_bit_vector = [18, -43, -112]
869+
865870
doc = Doc(
866-
float_vector=[1.0, 1.2, 2.3],
867-
byte_vector=[12, 23, 34, 45],
868-
bit_vector=[18, -43, -112],
869-
numpy_float_vector=np.array([3.1, 2.25, 1.0]),
871+
float_vector=test_float_vector,
872+
byte_vector=test_byte_vector,
873+
bit_vector=test_bit_vector,
874+
numpy_float_vector=np.array(test_float_vector),
870875
)
871876
doc.save(refresh=True)
872877

873878
docs = Doc.search().execute()
874879
assert len(docs) == 1
875-
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
876-
assert docs[0].byte_vector == doc.byte_vector
877-
assert docs[0].bit_vector == doc.bit_vector
880+
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
881+
assert docs[0].byte_vector == test_byte_vector
882+
assert docs[0].bit_vector == test_bit_vector
878883
assert type(docs[0].numpy_float_vector) is np.ndarray
879-
assert np.array_equal(docs[0].numpy_float_vector, doc.numpy_float_vector)
884+
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
880885

881886

882887
@pytest.mark.sync

utils/templates/field.py.tpl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,6 @@ class {{ k.name }}({{ k.parent }}):
217217
{% endfor %}
218218
{% endfor %}
219219
{% endif %}
220-
{% if k.field == "dense_vector" %}
221-
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
222-
:arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
223-
{% endif %}
224220
"""
225221
name = "{{ k.field }}"
226222
{% if k.coerced %}
@@ -250,10 +246,6 @@ class {{ k.name }}({{ k.parent }}):
250246
{{ arg.name }}: {{ arg.type }} = DEFAULT,
251247
{% endif %}
252248
{% endfor %}
253-
{% if k.field == "dense_vector" %}
254-
use_numpy: bool = False,
255-
dtype: str = "float32",
256-
{% endif %}
257249
**kwargs: Any
258250
):
259251
{% for arg in k.args %}
@@ -424,19 +416,28 @@ class {{ k.name }}({{ k.parent }}):
424416
self._element_type = kwargs.get("element_type", "float")
425417
if self._element_type in ["float", "byte"]:
426418
kwargs["multi"] = True
427-
self._use_numpy = use_numpy
428-
self._dtype = dtype
429419
super().__init__(*args, **kwargs)
430420

421+
class NumpyDenseVector(DenseVector):
422+
"""A dense vector field that uses numpy arrays.
423+
424+
Accepts the same arguments as class ``DenseVector`` plus:
425+
426+
:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
427+
"""
428+
def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
429+
super().__init__(*args, **kwargs)
430+
self._dtype = dtype
431+
431432
def deserialize(self, data: Any) -> Any:
432-
if self._use_numpy and isinstance(data, list):
433+
if isinstance(data, list):
433434
import numpy as np
434-
return np.array(data, dtype=getattr(np, self._dtype))
435+
return np.array(data, dtype=self._dtype)
435436
return super().deserialize(data)
436437

437438
def clean(self, data: Any) -> Any:
438-
# this method does the same as the one in the parent class, but it
439-
# avoids comparisons that break when data is a numpy array
439+
# this method does the same as the one in the parent classes, but it
440+
# avoids comparisons that do not work for numpy arrays
440441
if data is not None:
441442
data = self.deserialize(data)
442443
if (data is None or len(data) == 0) and self._required:

0 commit comments

Comments
 (0)