4
4
5
5
from unittest .mock import MagicMock , patch
6
6
7
+ import numpy as np
7
8
import pytest
8
9
import torch
9
10
@@ -350,6 +351,24 @@ def test_score_threshold(self):
350
351
out = ranker .run (query = "test" , documents = documents )
351
352
assert len (out ["documents" ]) == 1
352
353
354
+ def test_scores_cast_to_python_float_when_numpy_scalars_returned (self ):
355
+ mock_cross_encoder = MagicMock ()
356
+ ranker = SentenceTransformersSimilarityRanker (model = "model" )
357
+ ranker ._cross_encoder = mock_cross_encoder
358
+
359
+ # Simulate backend returning numpy scalar types
360
+ mock_cross_encoder .rank .return_value = [
361
+ {"score" : np .float32 (0.123 ), "corpus_id" : 0 },
362
+ {"score" : np .float64 (0.456 ), "corpus_id" : 1 },
363
+ ]
364
+
365
+ documents = [Document (content = "doc 0" ), Document (content = "doc 1" )]
366
+ out = ranker .run (query = "test" , documents = documents )
367
+
368
+ assert len (out ["documents" ]) == 2
369
+ for d in out ["documents" ]:
370
+ assert isinstance (d .score , float )
371
+
353
372
@pytest .mark .integration
354
373
@pytest .mark .slow
355
374
def test_run (self ):
@@ -373,6 +392,9 @@ def test_run(self):
373
392
assert docs_after [1 ].score == pytest .approx (sorted_scores [1 ], abs = 1e-6 )
374
393
assert docs_after [2 ].score == pytest .approx (sorted_scores [2 ], abs = 1e-6 )
375
394
395
+ for doc in docs_after :
396
+ assert isinstance (doc .score , float )
397
+
376
398
@pytest .mark .integration
377
399
@pytest .mark .slow
378
400
def test_run_top_k (self ):
@@ -393,6 +415,9 @@ def test_run_top_k(self):
393
415
sorted_scores = sorted ([doc .score for doc in docs_after ], reverse = True )
394
416
assert [doc .score for doc in docs_after ] == sorted_scores
395
417
418
+ for doc in docs_after :
419
+ assert isinstance (doc .score , float )
420
+
396
421
@pytest .mark .integration
397
422
@pytest .mark .slow
398
423
def test_run_single_document (self ):
@@ -403,3 +428,4 @@ def test_run_single_document(self):
403
428
docs_after = output ["documents" ]
404
429
405
430
assert len (docs_after ) == 1
431
+ assert isinstance (docs_after [0 ].score , float )
0 commit comments