-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
## Initialising FishLeg ##
opt = FishLeg(
model_FishLeg,
aux_loader,
likelihood,
lr=lr,
beta=beta,
weight_decay=weight_decay,
aux_lr=aux_lr,
aux_betas=(0.9, 0.999),
aux_eps=aux_eps,
damping=damping,
update_aux_every=update_aux_every,
writer=writer,
method="antithetic",
method_kwargs={"eps": 1e-4},
precondition_aux=True,
aux_log=True
)
needs to include
device=device
so it becomes
## Initialising FishLeg ##
opt = FishLeg(
model_FishLeg,
aux_loader,
likelihood,
lr=lr,
beta=beta,
weight_decay=weight_decay,
aux_lr=aux_lr,
aux_betas=(0.9, 0.999),
aux_eps=aux_eps,
damping=damping,
update_aux_every=update_aux_every,
writer=writer,
method="antithetic",
method_kwargs={"eps": 1e-4},
precondition_aux=True,
aux_log=True,
device=device
)
because the default for FishLeg.__init__() is to use the cpu, but when using mps or cuda (as is supported in the tutorial) this does not function because the tensors end up on seperate devices resulting in this error
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
or alternatively
Expected all tensors to be on the same device, but found at least two devices, mps and cpu!
this fix has been tested on cuda and cpu, i would appreciate someone checking this on mps
Metadata
Metadata
Assignees
Labels
No labels