34
34
35
35
import pytest
36
36
import yaml
37
- from pymilvus import CollectionSchema
38
- from pymilvus import DataType
39
- from pymilvus import FieldSchema
40
- from pymilvus import Function
41
- from pymilvus import FunctionType
42
- from pymilvus import MilvusClient
43
- from pymilvus import RRFRanker
44
- from pymilvus .milvus_client import IndexParams
45
- from testcontainers .core .config import MAX_TRIES as TC_MAX_TRIES
46
- from testcontainers .core .config import testcontainers_config
47
- from testcontainers .core .generic import DbContainer
48
- from testcontainers .milvus import MilvusContainer
49
37
50
38
import apache_beam as beam
51
39
from apache_beam .ml .rag .types import Chunk
54
42
from apache_beam .testing .test_pipeline import TestPipeline
55
43
from apache_beam .testing .util import assert_that
56
44
45
+ # pylint: disable=ungrouped-imports
57
46
try :
47
+ from pymilvus import (
48
+ CollectionSchema ,
49
+ DataType ,
50
+ FieldSchema ,
51
+ Function ,
52
+ FunctionType ,
53
+ MilvusClient ,
54
+ RRFRanker )
55
+ from pymilvus .milvus_client import IndexParams
56
+ from testcontainers .core .config import MAX_TRIES as TC_MAX_TRIES
57
+ from testcontainers .core .config import testcontainers_config
58
+ from testcontainers .core .generic import DbContainer
59
+ from testcontainers .milvus import MilvusContainer
58
60
from apache_beam .transforms .enrichment import Enrichment
59
61
from apache_beam .ml .rag .enrichment .milvus_search import (
60
62
MilvusSearchEnrichmentHandler ,
@@ -467,7 +469,7 @@ def create_user_yaml(service_port: int, max_vector_field_num=5):
467
469
os .remove (path )
468
470
469
471
470
- @pytest .mark .uses_testcontainer
472
+ @pytest .mark .require_docker_in_docker
471
473
@unittest .skipUnless (
472
474
platform .system () == "Linux" ,
473
475
"Test runs only on Linux due to lack of support, as yet, for nested "
@@ -483,22 +485,16 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
483
485
484
486
@classmethod
485
487
def setUpClass (cls ):
486
- try :
487
- cls ._db = MilvusEnrichmentTestHelper .start_db_container (
488
- cls ._version , vector_client_max_retries = 1 , tc_max_retries = 1 )
489
- cls ._connection_params = MilvusConnectionParameters (
490
- uri = cls ._db .uri ,
491
- user = cls ._db .user ,
492
- password = cls ._db .password ,
493
- db_id = cls ._db .id ,
494
- token = cls ._db .token )
495
- cls ._collection_load_params = MilvusCollectionLoadParameters ()
496
- cls ._collection_name = MilvusEnrichmentTestHelper .initialize_db_with_data (
497
- cls ._connection_params )
498
- except Exception as e :
499
- pytest .skip (
500
- f"Skipping all tests in { cls .__name__ } due to DB startup failure: { e } "
501
- )
488
+ cls ._db = MilvusEnrichmentTestHelper .start_db_container (cls ._version )
489
+ cls ._connection_params = MilvusConnectionParameters (
490
+ uri = cls ._db .uri ,
491
+ user = cls ._db .user ,
492
+ password = cls ._db .password ,
493
+ db_id = cls ._db .id ,
494
+ token = cls ._db .token )
495
+ cls ._collection_load_params = MilvusCollectionLoadParameters ()
496
+ cls ._collection_name = MilvusEnrichmentTestHelper .initialize_db_with_data (
497
+ cls ._connection_params )
502
498
503
499
@classmethod
504
500
def tearDownClass (cls ):
@@ -1368,4 +1364,4 @@ def assert_chunks_equivalent(
1368
1364
1369
1365
1370
1366
if __name__ == '__main__' :
1371
- unittest .main ()
1367
+ unittest .main ()
0 commit comments