@@ -176,17 +176,20 @@ def train(env, cmd_args):
176176 env_vars = env .to_env_vars ()
177177
178178 # Setup
179- if parameter_server_enabled :
179+ if env .current_instance_group in env .distribution_instance_groups :
180+ if parameter_server_enabled :
180181
181- tf_config = _build_tf_config_for_ps (hosts = env .hosts , current_host = env .current_host )
182- logger .info ("Running distributed training job with parameter servers" )
182+ tf_config = _build_tf_config_for_ps (hosts = env .distribution_hosts , current_host = env .current_host )
183+ logger .info ("Running distributed training job with parameter servers" )
183184
184- elif multi_worker_mirrored_strategy_enabled :
185+ elif multi_worker_mirrored_strategy_enabled :
185186
186- env_vars ["TF_CONFIG" ] = json .dumps (
187- _build_tf_config_for_mwms (hosts = env .hosts , current_host = env .current_host )
188- )
189- logger .info ("Running distributed training job with multi_worker_mirrored_strategy setup" )
187+ env_vars ["TF_CONFIG" ] = json .dumps (
188+ _build_tf_config_for_mwms (hosts = env .distribution_hosts , current_host = env .current_host )
189+ )
190+ logger .info ("Running distributed training job with multi_worker_mirrored_strategy setup" )
191+
192+ runner_type = runner .ProcessRunnerType
190193
191194 # Run
192195 if parameter_server_enabled :
@@ -200,15 +203,13 @@ def train(env, cmd_args):
200203 _wait_until_master_is_down (env .hosts [0 ])
201204
202205 else :
206+ if env .current_instance_group in env .distribution_instance_groups :
207+ mpi_enabled = env .additional_framework_parameters .get ("sagemaker_mpi_enabled" )
203208
204- mpi_enabled = env .additional_framework_parameters .get ("sagemaker_mpi_enabled" )
205-
206- if mpi_enabled :
207- runner_type = runner .MPIRunnerType
208- elif sagemaker_distributed_dataparallel_enabled :
209- runner_type = runner .SMDataParallelRunnerType
210- else :
211- runner_type = runner .ProcessRunnerType
209+ if mpi_enabled :
210+ runner_type = runner .MPIRunnerType
211+ elif sagemaker_distributed_dataparallel_enabled :
212+ runner_type = runner .SMDataParallelRunnerType
212213
213214 entry_point .run (
214215 uri = env .module_dir ,
0 commit comments