66import pytorch_pfn_extras as ppe
77
88
9+ def torch_testing_assert_close (* args , ** kwargs ):
10+ if ppe .requires ("1.10.0" ):
11+ torch .testing .assert_close (* args , ** kwargs )
12+ else :
13+ torch .testing .assert_allclose (* args , ** kwargs )
14+
15+
916class MockRuntime (ppe .runtime .BaseRuntime ):
1017 def __init__ (self , device , options ):
1118 super ().__init__ (device , options )
@@ -361,8 +368,8 @@ def test_train_step(self):
361368 model = models ['main' ]
362369 assert input .grad is not None
363370 # The gradient of a linear layer is its transposed weight
364- torch . testing . assert_allclose (input .grad , model .weight .T )
365- torch . testing . assert_allclose (out , model (input ))
371+ torch_testing_assert_close (input .grad , model .weight .T )
372+ torch_testing_assert_close (out , model (input ))
366373
367374 @pytest .mark .parametrize (
368375 'to_backprop' ,
@@ -403,11 +410,11 @@ def forward(self, x):
403410 grad = torch .zeros (1 )
404411 for val in to_backprop :
405412 grad = grad + getattr (model , f'l{ val } ' ).weight .T
406- torch . testing . assert_allclose (input .grad , grad )
413+ torch_testing_assert_close (input .grad , grad )
407414
408415 # Check that logic step does not change the value of weight
409416 for val in original_parameters :
410- torch . testing . assert_allclose (
417+ torch_testing_assert_close (
411418 original_parameters [val ], getattr (model , f'l{ val } ' ).weight )
412419
413420 def test_train_step_backward_nograd (self ):
@@ -461,7 +468,7 @@ def test_train_step_optimizers(self):
461468 w_grad = model .weight .grad .clone ().detach ()
462469 logic .train_step_optimizers (model , optimizers , 0 )
463470 # Checks that the value was correctly updated
464- torch . testing . assert_allclose (m_weight - w_grad , model .weight .T )
471+ torch_testing_assert_close (m_weight - w_grad , model .weight .T )
465472
466473 @pytest .mark .gpu
467474 def test_grad_scaler (self ):
@@ -473,12 +480,12 @@ def test_grad_scaler(self):
473480 m_weight = model .weight .clone ().detach ()
474481 w_grad = model .weight .grad .clone ().detach ()
475482 # The gradient of a linear layer is its transposed weight
476- torch . testing . assert_allclose (input .grad , scaler .scale (model .weight .T ))
477- torch . testing . assert_allclose (out , model (input ))
483+ torch_testing_assert_close (input .grad , scaler .scale (model .weight .T ))
484+ torch_testing_assert_close (out , model (input ))
478485 logic .train_step_optimizers (model , optimizers , 0 )
479486 # Checks that the value was correctly updated and gradients deescaled
480487 # before the update
481- torch . testing . assert_allclose (
488+ torch_testing_assert_close (
482489 scaler .scale (m_weight ) - w_grad , scaler .scale (model .weight .T ))
483490
484491 @pytest .mark .gpu
@@ -513,4 +520,4 @@ def test_eval_step(self):
513520 models = {'main' : model }
514521 models ['main' ].eval ()
515522 out = logic .eval_step (models , 0 , input )
516- torch . testing . assert_allclose (out , model (input ))
523+ torch_testing_assert_close (out , model (input ))
0 commit comments