1818import boto3
1919import pytest
2020from sagemaker import LocalSession , Session
21- from sagemaker .tensorflow import TensorFlow
2221
23- from test .integration import NO_P2_REGIONS , NO_P3_REGIONS
22+ from integration import image_utils
23+ from integration import NO_P2_REGIONS , NO_P3_REGIONS
24+
2425
2526logger = logging .getLogger (__name__ )
2627logging .getLogger ('boto' ).setLevel (logging .INFO )
2930logging .getLogger ('auth.py' ).setLevel (logging .INFO )
3031logging .getLogger ('connectionpool.py' ).setLevel (logging .INFO )
3132
32- SCRIPT_PATH = os .path .dirname (os .path .realpath (__file__ ))
33+ DIR_PATH = os .path .dirname (os .path .realpath (__file__ ))
3334
3435
3536def pytest_addoption (parser ):
36- parser .addoption ('--docker-base-name' , default = 'sagemaker-tensorflow-scriptmode' )
37+ parser .addoption ('--build-image' , '-B' , action = 'store_true' )
38+ parser .addoption ('--push-image' , '-P' , action = 'store_true' )
39+ parser .addoption ('--dockerfile-type' , '-T' , choices = ['dlc.cpu' , 'dlc.gpu' , 'tf' ],
40+ default = 'tf' )
41+ parser .addoption ('--dockerfile' , '-D' , default = None )
42+ parser .addoption ('--docker-base-name' , default = 'sagemaker-tensorflow-training' )
3743 parser .addoption ('--tag' , default = None )
3844 parser .addoption ('--region' , default = 'us-west-2' )
39- parser .addoption ('--framework-version' , default = TensorFlow . LATEST_VERSION )
45+ parser .addoption ('--framework-version' , default = '2.1.0' )
4046 parser .addoption ('--processor' , default = 'cpu' , choices = ['cpu' , 'gpu' , 'cpu,gpu' ])
4147 parser .addoption ('--py-version' , default = '3' , choices = ['2' , '3' , '2,3' ])
4248 parser .addoption ('--account-id' , default = '142577830533' )
@@ -48,6 +54,38 @@ def pytest_configure(config):
4854 os .environ ['TEST_PROCESSORS' ] = config .getoption ('--processor' )
4955
5056
57+ @pytest .fixture (scope = 'session' , name = 'dockerfile_type' )
58+ def fixture_dockerfile_type (request ):
59+ return request .config .getoption ('--dockerfile-type' )
60+
61+
62+ @pytest .fixture (scope = 'session' , name = 'dockerfile' )
63+ def fixture_dockerfile (request , dockerfile_type ):
64+ dockerfile = request .config .getoption ('--dockerfile' )
65+ return dockerfile if dockerfile else 'Dockerfile.{}' .format (dockerfile_type )
66+
67+
68+ @pytest .fixture (scope = 'session' , name = 'build_image' , autouse = True )
69+ def fixture_build_image (request , framework_version , dockerfile , image_uri , region ):
70+ build_image = request .config .getoption ('--build-image' )
71+ if build_image :
72+ return image_utils .build_image (framework_version = framework_version ,
73+ dockerfile = dockerfile ,
74+ image_uri = image_uri ,
75+ region = region ,
76+ cwd = os .path .join (DIR_PATH , '..' , '..' ))
77+
78+ return image_uri
79+
80+
81+ @pytest .fixture (scope = 'session' , name = 'push_image' , autouse = True )
82+ def fixture_push_image (request , image_uri , region , account_id ):
83+ push_image = request .config .getoption ('--push-image' )
84+ if push_image :
85+ return image_utils .push_image (image_uri , region , account_id )
86+ return None
87+
88+
5189@pytest .fixture (scope = 'session' )
5290def docker_base_name (request ):
5391 return request .config .getoption ('--docker-base-name' )
@@ -63,7 +101,7 @@ def framework_version(request):
63101 return request .config .getoption ('--framework-version' )
64102
65103
66- @pytest .fixture
104+ @pytest .fixture ( scope = 'session' )
67105def tag (request , framework_version , processor , py_version ):
68106 provided_tag = request .config .getoption ('--tag' )
69107 default_tag = '{}-{}-py{}' .format (framework_version , processor , py_version )
@@ -92,20 +130,6 @@ def instance_type(request, processor):
92130 return provided_instance_type if provided_instance_type is not None else default_instance_type
93131
94132
95- @pytest .fixture ()
96- def py_version ():
97- if 'TEST_PY_VERSIONS' in os .environ :
98- return os .environ ['TEST_PY_VERSIONS' ].split (',' )
99- return None
100-
101-
102- @pytest .fixture ()
103- def processor ():
104- if 'TEST_PROCESSORS' in os .environ :
105- return os .environ ['TEST_PROCESSORS' ].split (',' )
106- return None
107-
108-
109133@pytest .fixture (autouse = True )
110134def skip_by_device_type (request , processor ):
111135 is_gpu = (processor == 'gpu' )
@@ -121,19 +145,27 @@ def skip_gpu_instance_restricted_regions(region, instance_type):
121145 pytest .skip ('Skipping GPU test in region {}' .format (region ))
122146
123147
124- @pytest .fixture
125- def docker_image (docker_base_name , tag ):
126- return '{}:{}' .format (docker_base_name , tag )
127-
128-
129- @pytest .fixture
130- def ecr_image (account_id , docker_base_name , tag , region ):
131- return '{}.dkr.ecr.{}.amazonaws.com/{}:{}' .format (
132- account_id , region , docker_base_name , tag )
133-
134-
135148@pytest .fixture (autouse = True )
136149def skip_py2_containers (request , tag ):
137150 if request .node .get_closest_marker ('skip_py2_containers' ):
138151 if 'py2' in tag :
139152 pytest .skip ('Skipping python2 container with tag {}' .format (tag ))
153+
154+
155+ @pytest .fixture (autouse = True )
156+ def skip_by_dockerfile_type (request , dockerfile_type ):
157+ is_generic = (dockerfile_type == 'tf' )
158+ if request .node .get_closest_marker ('skip_generic' ) and is_generic :
159+ pytest .skip ('Skipping because running generic image without mpi and horovod' )
160+
161+
162+ @pytest .fixture (name = 'docker_registry' , scope = 'session' )
163+ def fixture_docker_registry (account_id , region ):
164+ return '{}.dkr.ecr.{}.amazonaws.com' .format (account_id , region ) if account_id else None
165+
166+
167+ @pytest .fixture (name = 'image_uri' , scope = 'session' )
168+ def fixture_image_uri (docker_registry , docker_base_name , tag ):
169+ if docker_registry :
170+ return '{}/{}:{}' .format (docker_registry , docker_base_name , tag )
171+ return '{}:{}' .format (docker_base_name , tag )
0 commit comments