Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ myimglist.txt
newcat.jpg

.pexing
*.pex
*.pex*.index
*.json
15 changes: 14 additions & 1 deletion autofaiss/external/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@
logger = logging.getLogger("autofaiss")


class NumpyEncoder(json.JSONEncoder):
"""Special json encoder for numpy types"""

def default(self, o): # pylint: disable=E0202
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.ndarray):
return o.tolist()
return super().default(o)


def check_if_index_needs_training(index_key: str) -> bool:
"""
Function that checks if the index needs to be trained
Expand Down Expand Up @@ -569,6 +582,6 @@ def optimize_and_measure_index(
with fsspec.open(index_path, "wb").open() as f:
faiss.write_index(index, faiss.PyCallbackIOWriter(f.write))
with fsspec.open(index_infos_path, "w").open() as f:
json.dump(metric_infos, f)
json.dump(metric_infos, f, cls=NumpyEncoder)

return metric_infos
16 changes: 8 additions & 8 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
black==23.12.1
mypy==1.8.0
pylint==3.0.3
pytest-cov==4.1.0
pytest-xdist==3.5.0
pytest==8.0.1
pyspark==3.2.2; python_version < "3.11"
pyspark<3.6.0; python_version >= "3.11"
black>=23.12.1
mypy>=1.8.0
pylint>=3.0.3
pytest-cov>=4.1.0
pytest-xdist>=3.5.0
pytest>=8.0.1
pyspark>=3.2.2; python_version < "3.11"
pyspark<3.6.0; python_version >= "3.11"
16 changes: 7 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
dataclasses>=0.6,<1.0.0; python_version < "3.7"
fire>=0.4.0,<0.6.0
numpy>=1.19.5,<2
pandas>=1.1.5,<3
pyarrow>=6.0.1,<16
tqdm>=4.62.3,<5
faiss-cpu<1.7.3; python_version < "3.7"
faiss-cpu>=1,<2; python_version >= "3.7"
fire>=0.4.0
numpy>=1.19.5
pandas>=1.1.5
pyarrow>=6.0.1
tqdm>=4.62.3
faiss-cpu>=1
fsspec>=2022.1.0
embedding_reader>=1.5.1,<2
embedding_reader>=1.5.1
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def _read_reqs(relpath):
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Intended Audience :: Developers",
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def test_index_correctness_in_distributed_mode(tmpdir):
)
temporary_indices_folder = os.path.join(tmpdir.strpath, "distributed_autofaiss_indices")
ids_path = os.path.join(tmpdir.strpath, "ids")
index_path = os.path.join(tmpdir.strpath, "distributed_knn.index")
index_infos_path = os.path.join(tmpdir.strpath, "distributed_knn_infos.json")
index, _ = build_index(
embeddings=tmp_dir,
index_path=index_path,
index_infos_path=index_infos_path,
distributed="pyspark",
file_format="parquet",
temporary_indices_folder=temporary_indices_folder,
Expand Down Expand Up @@ -322,8 +326,12 @@ def test_index_correctness_in_distributed_mode(tmpdir):
tmp_dir, _, _, expected_array, _ = build_test_collection_numpy(
tmpdir, min_size=min_size, max_size=max_size, dim=dim, nb_files=nb_files
)
index_path = os.path.join(tmpdir.strpath, "distributed_knn.index")
index_infos_path = os.path.join(tmpdir.strpath, "distributed_knn_infos.json")
index, _ = build_index(
embeddings=tmp_dir,
index_path=index_path,
index_infos_path=index_infos_path,
distributed="pyspark",
file_format="npy",
temporary_indices_folder=temporary_indices_folder,
Expand Down Expand Up @@ -380,8 +388,12 @@ def test_index_correctness_in_distributed_mode_with_multiple_indices(tmpdir):
)
temporary_indices_folder = os.path.join(tmpdir.strpath, "distributed_autofaiss_indices")
ids_path = os.path.join(tmpdir.strpath, "ids")
index_path = os.path.join(tmpdir.strpath, "distributed_knn.index")
index_infos_path = os.path.join(tmpdir.strpath, "distributed_knn_infos.json")
_, index_path2_metric_infos = build_index(
embeddings=tmp_dir,
index_path=index_path,
index_infos_path=index_infos_path,
distributed="pyspark",
file_format="parquet",
temporary_indices_folder=temporary_indices_folder,
Expand Down Expand Up @@ -421,8 +433,12 @@ def test_index_correctness_in_distributed_mode_with_multiple_indices(tmpdir):
)

temporary_indices_folder = os.path.join(tmpdir.strpath, "distributed_autofaiss_indices")
index_path = os.path.join(tmpdir.strpath, "distributed_knn.index")
index_infos_path = os.path.join(tmpdir.strpath, "distributed_knn_infos.json")
_, index_path2_metric_infos = build_index(
embeddings=tmp_dir,
index_path=index_path,
index_infos_path=index_infos_path,
distributed="pyspark",
file_format="npy",
temporary_indices_folder=temporary_indices_folder,
Expand Down
Loading