Skip to content

Commit 213840d

Browse files
committed
Use torch.testing.assert_close
1 parent 4dd935f commit 213840d

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

pytorch_pfn_extras/utils/comparer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn
1313
import torch.testing
1414

15+
import pytorch_pfn_extras
1516
from pytorch_pfn_extras import handler as _handler_module
1617
from pytorch_pfn_extras.handler import _logic
1718
from pytorch_pfn_extras.training import _trainer
@@ -191,8 +192,13 @@ def compare_fn(
191192
val1 = val1.cpu().detach()
192193
if isinstance(val2, torch.Tensor):
193194
val2 = val2.cpu().detach()
194-
torch.testing.assert_allclose(
195-
val1, val2, rtol=rtol, atol=atol, equal_nan=equal_nan)
195+
196+
if pytorch_pfn_extras.requires("1.9.0"):
197+
assert_close = torch.testing.assert_close # type: ignore[attr-defined]
198+
else:
199+
assert_close = torch.testing.assert_allclose # type: ignore[assignment]
200+
201+
assert_close(val1, val2, rtol=rtol, atol=atol, equal_nan=equal_nan)
196202

197203
return compare_fn
198204

tests/pytorch_pfn_extras_tests/test_handler.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
import pytorch_pfn_extras as ppe
77

88

9+
def torch_testing_assert_close(*args, **kwargs):
10+
if ppe.requires("1.10.0"):
11+
torch.testing.assert_close(*args, **kwargs)
12+
else:
13+
torch.testing.assert_allclose(*args, **kwargs)
14+
15+
916
class MockRuntime(ppe.runtime.BaseRuntime):
1017
def __init__(self, device, options):
1118
super().__init__(device, options)
@@ -361,8 +368,8 @@ def test_train_step(self):
361368
model = models['main']
362369
assert input.grad is not None
363370
# The gradient of a linear layer is its transposed weight
364-
torch.testing.assert_allclose(input.grad, model.weight.T)
365-
torch.testing.assert_allclose(out, model(input))
371+
torch_testing_assert_close(input.grad, model.weight.T)
372+
torch_testing_assert_close(out, model(input))
366373

367374
@pytest.mark.parametrize(
368375
'to_backprop',
@@ -403,11 +410,11 @@ def forward(self, x):
403410
grad = torch.zeros(1)
404411
for val in to_backprop:
405412
grad = grad + getattr(model, f'l{val}').weight.T
406-
torch.testing.assert_allclose(input.grad, grad)
413+
torch_testing_assert_close(input.grad, grad)
407414

408415
# Check that logic step does not change the value of weight
409416
for val in original_parameters:
410-
torch.testing.assert_allclose(
417+
torch_testing_assert_close(
411418
original_parameters[val], getattr(model, f'l{val}').weight)
412419

413420
def test_train_step_backward_nograd(self):
@@ -461,7 +468,7 @@ def test_train_step_optimizers(self):
461468
w_grad = model.weight.grad.clone().detach()
462469
logic.train_step_optimizers(model, optimizers, 0)
463470
# Checks that the value was correctly updated
464-
torch.testing.assert_allclose(m_weight - w_grad, model.weight.T)
471+
torch_testing_assert_close(m_weight - w_grad, model.weight.T)
465472

466473
@pytest.mark.gpu
467474
def test_grad_scaler(self):
@@ -473,12 +480,12 @@ def test_grad_scaler(self):
473480
m_weight = model.weight.clone().detach()
474481
w_grad = model.weight.grad.clone().detach()
475482
# The gradient of a linear layer is its transposed weight
476-
torch.testing.assert_allclose(input.grad, scaler.scale(model.weight.T))
477-
torch.testing.assert_allclose(out, model(input))
483+
torch_testing_assert_close(input.grad, scaler.scale(model.weight.T))
484+
torch_testing_assert_close(out, model(input))
478485
logic.train_step_optimizers(model, optimizers, 0)
479486
# Checks that the value was correctly updated and gradients deescaled
480487
# before the update
481-
torch.testing.assert_allclose(
488+
torch_testing_assert_close(
482489
scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T))
483490

484491
@pytest.mark.gpu
@@ -513,4 +520,4 @@ def test_eval_step(self):
513520
models = {'main': model}
514521
models['main'].eval()
515522
out = logic.eval_step(models, 0, input)
516-
torch.testing.assert_allclose(out, model(input))
523+
torch_testing_assert_close(out, model(input))

0 commit comments

Comments
 (0)