Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 39 additions & 80 deletions dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
onlyOn,
skipOn,
)

#`fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].
#So we use fused=False and cuda results to compare with fused torch_dipu results.

class TestFusedAdamW(TestCase):
def setUp(self):
Expand All @@ -17,73 +18,56 @@ def setUp(self):
self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8]
self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3]
self.amsgrad_list = [False, False, True, True]
self.step_list = [2, 3, 4, 5]

def run_adamw_cpu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch.optim._functional.adamw(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

param.grad = param_grad
optimizer = torch.optim.AdamW(params = [param],
lr = lr,
betas = (beta1,beta2),
eps=eps,
weight_decay=weight_decay,
amsgrad = amsgrad,
fused = False)
optimizer.step()
state_index = 0
exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"]
exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"]
return param, exp_avg, exp_avg_sq

def run_adamw_dipu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch._fused_adamw_(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq
param.grad = param_grad
optimizer = torch.optim.AdamW(params = [param],
lr = lr,
betas = (beta1,beta2),
eps=eps,
weight_decay=weight_decay,
amsgrad = amsgrad,
fused = True)
optimizer.step()
state_index = 0
exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"]
exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"]
return param, exp_avg, exp_avg_sq

def adamw_(self, dtype_):
for i in range(len(self.weight_shape_list)):
Expand All @@ -93,54 +77,37 @@ def adamw_(self, dtype_):
if dtype_ == torch.float16
else weight.cpu()
)
weight_fused_cpu = weight_cpu.clone().detach()
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
grad_cpu = (
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
)
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu()
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu()
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
max_v_cpu = (
max_v.cpu().to(torch.float32)
if dtype_ == torch.float16
else max_v.cpu()
)
grad_fused_cpu = grad_cpu.clone().detach()

lr = self.lr_list[i]
beta1 = self.beta1_list[i]
beta2 = self.beta2_list[i]
eps = self.eps_list[i]
weight_decay = self.weight_decay_list[i]
amsgrad = self.amsgrad_list[i]
step = self.step_list[i]

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(
w_new_cpu, m_new_cpu, v_new_cpu = self.run_adamw_cpu(
weight_cpu,
grad_cpu,
m_cpu,
v_cpu,
max_v_cpu,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(
w_new, m_new, v_new= self.run_adamw_dipu(
weight,
grad,
m,
v,
max_v,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
Expand All @@ -155,8 +122,11 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)

self.assertTrue(
torch.allclose(
m_new.cpu(),
(
Expand All @@ -166,7 +136,7 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)
self.assertTrue(
Expand All @@ -179,18 +149,7 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
max_v_new.cpu(),
(
max_v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else max_v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)

Expand Down
Loading