Skip to content

Commit 95bf220

Browse files
author
Masaru Kimura
committed
Fix torch.jit.ScriptModule.zero_grad.
TorchSharp 0.105.0 doesn't have torch.jit.ScriptModule.zero_grad and falls back into torch.nn.Module.zero_grad incorrectly, then terminates silently. Most probably, because JITModule is not compatible to NNModule in LibTorchSharp. And as reported in pytorch/pytorch#27144, libtorch also doesn't have torch::jit::Module::zero_grad. As a workaround, manually loop over the parameters and zero them out like optimizer does.
1 parent b87317e commit 95bf220

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

src/Native/LibTorchSharp/THSJIT.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ int THSJIT_Module_is_training(JITModule module)
6868
return (*module)->is_training();
6969
}
7070

71+
void THSJIT_Module_zero_grad(const JITModule module, bool set_to_none)
72+
{
73+
// According to https://github.com/pytorch/pytorch/issues/27144,
74+
// torch::jit::Module has no zero_grad().
75+
// As a workaround, manually loop over the parameters and zero them out like optimizer does;
76+
// https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/api/src/optim/optimizer.cpp#L123
77+
for (auto& p : (*module)->parameters()) {
78+
if (p.mutable_grad().defined()) {
79+
p.mutable_grad().detach_();
80+
if (set_to_none)
81+
p.mutable_grad().reset();
82+
else
83+
p.mutable_grad().zero_();
84+
}
85+
}
86+
}
87+
7188
void THSJIT_Module_train(JITModule module, bool on)
7289
{
7390
(*module)->train(on);

src/Native/LibTorchSharp/THSJIT.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name,
4444
EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t idx);
4545

4646
EXPORT_API(int) THSJIT_Module_is_training(JITModule module);
47+
EXPORT_API(void) THSJIT_Module_zero_grad(const JITModule module, bool set_to_none);
4748
EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on);
4849
EXPORT_API(void) THSJIT_Module_eval(JITModule module);
4950

src/TorchSharp/JIT/ScriptModule.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ public override bool training {
143143
}
144144
}
145145

146+
public override void zero_grad(bool set_to_none = true)
147+
{
148+
THSJIT_Module_zero_grad(handle, set_to_none);
149+
CheckForErrors();
150+
151+
foreach (var (_, p) in named_parameters()) {
152+
using var grad = p.grad;
153+
if (grad is not null) {
154+
if (set_to_none) {
155+
p.grad = null;
156+
} else {
157+
grad.zero_();
158+
}
159+
}
160+
}
161+
}
162+
146163
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking)
147164
{
148165
if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };

src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ internal static partial class NativeMethods
5757
[return: MarshalAs(UnmanagedType.U1)]
5858
internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module);
5959

60+
[DllImport("LibTorchSharp")]
61+
internal static extern void THSJIT_Module_zero_grad(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool set_to_none);
62+
6063
[DllImport("LibTorchSharp")]
6164
internal static extern void THSJIT_Module_to_device(torch.nn.Module.HType module, long deviceType, long deviceIndex);
6265

0 commit comments

Comments
 (0)