1717
1818from mock import MagicMock , patch
1919import pytest
20- from sagemaker_containers . beta . framework import runner
20+ from sagemaker_training import runner
2121import tensorflow as tf
2222
2323from sagemaker_tensorflow_container import training
@@ -81,30 +81,30 @@ def test_is_host_master():
8181 assert training ._is_host_master (HOST_LIST , 'somehost' ) is False
8282
8383
84- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
84+ @patch ('sagemaker_training .entry_point.run' )
8585def test_single_machine (run_module , single_machine_training_env ):
8686 training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
8787 run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
8888 single_machine_training_env .to_env_vars (),
89- runner = runner .ProcessRunnerType )
89+ runner_type = runner .ProcessRunnerType )
9090
9191
92- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
92+ @patch ('sagemaker_training .entry_point.run' )
9393def test_train_horovod (run_module , single_machine_training_env ):
9494 single_machine_training_env .additional_framework_parameters ['sagemaker_mpi_enabled' ] = True
9595
9696 training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
9797 run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
9898 single_machine_training_env .to_env_vars (),
99- runner = runner .MPIRunnerType )
99+ runner_type = runner .MPIRunnerType )
100100
101101
102102@pytest .mark .skip_on_pipeline
103103@pytest .mark .skipif (sys .version_info .major != 3 ,
104104 reason = "Skip this for python 2 because of dict key order mismatch" )
105105@patch ('tensorflow.train.ClusterSpec' )
106106@patch ('tensorflow.distribute.Server' )
107- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
107+ @patch ('sagemaker_training .entry_point.run' )
108108@patch ('multiprocessing.Process' , lambda target : target ())
109109@patch ('time.sleep' , MagicMock ())
110110def test_train_distributed_master (run , tf_server , cluster_spec , distributed_training_env ):
@@ -135,7 +135,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
135135 reason = "Skip this for python 2 because of dict key order mismatch" )
136136@patch ('tensorflow.train.ClusterSpec' )
137137@patch ('tensorflow.distribute.Server' )
138- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
138+ @patch ('sagemaker_training .entry_point.run' )
139139@patch ('multiprocessing.Process' , lambda target : target ())
140140@patch ('time.sleep' , MagicMock ())
141141def test_train_distributed_worker (run , tf_server , cluster_spec , distributed_training_env ):
@@ -163,15 +163,15 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
163163 {'TF_CONFIG' : tf_config })
164164
165165
166- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
166+ @patch ('sagemaker_training .entry_point.run' )
167167def test_train_distributed_no_ps (run , distributed_training_env ):
168168 distributed_training_env .additional_framework_parameters [
169169 training .SAGEMAKER_PARAMETER_SERVER_ENABLED ] = False
170170 distributed_training_env .current_host = HOST2
171171 training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
172172
173173 run .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
174- distributed_training_env .to_env_vars (), runner = runner .ProcessRunnerType )
174+ distributed_training_env .to_env_vars (), runner_type = runner .ProcessRunnerType )
175175
176176
177177def test_build_tf_config ():
@@ -241,8 +241,8 @@ def test_log_model_missing_warning_correct(logger):
241241@patch ('sagemaker_tensorflow_container.training.logger' )
242242@patch ('sagemaker_tensorflow_container.training.train' )
243243@patch ('logging.Logger.setLevel' )
244- @patch ('sagemaker_containers.beta.framework.training_env ' )
245- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {})
244+ @patch ('sagemaker_training.environment.Environment ' )
245+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {})
246246@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
247247def test_main (configure_s3_env , read_hyperparameters , training_env ,
248248 set_level , train , logger , single_machine_training_env ):
@@ -258,8 +258,8 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
258258@patch ('sagemaker_tensorflow_container.training.logger' )
259259@patch ('sagemaker_tensorflow_container.training.train' )
260260@patch ('logging.Logger.setLevel' )
261- @patch ('sagemaker_containers.beta.framework.training_env ' )
262- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR })
261+ @patch ('sagemaker_training.environment.Environment ' )
262+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR })
263263@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
264264def test_main_simple_training_model_dir (configure_s3_env , read_hyperparameters , training_env ,
265265 set_level , train , logger , single_machine_training_env ):
@@ -272,9 +272,9 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
272272@patch ('sagemaker_tensorflow_container.training.logger' )
273273@patch ('sagemaker_tensorflow_container.training.train' )
274274@patch ('logging.Logger.setLevel' )
275- @patch ('sagemaker_containers.beta.framework.training_env ' )
276- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277- '_tuning_objective_metric' : 'auc' })
275+ @patch ('sagemaker_training.environment.Environment ' )
276+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277+ '_tuning_objective_metric' : 'auc' })
278278@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
279279def test_main_tuning_model_dir (configure_s3_env , read_hyperparameters , training_env ,
280280 set_level , train , logger , single_machine_training_env ):
@@ -288,9 +288,9 @@ def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_
288288@patch ('sagemaker_tensorflow_container.training.logger' )
289289@patch ('sagemaker_tensorflow_container.training.train' )
290290@patch ('logging.Logger.setLevel' )
291- @patch ('sagemaker_containers.beta.framework.training_env ' )
292- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : '/opt/ml/model' ,
293- '_tuning_objective_metric' : 'auc' })
291+ @patch ('sagemaker_training.environment.Environment ' )
292+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : '/opt/ml/model' ,
293+ '_tuning_objective_metric' : 'auc' })
294294@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
295295def test_main_tuning_mpi_model_dir (configure_s3_env , read_hyperparameters , training_env ,
296296 set_level , train , logger , single_machine_training_env ):
0 commit comments