Skip to content

Commit fe90379

Browse files
Merge pull request #819 from EIFY/torch-init-fix
fix pytorch_default_init()
2 parents 90959e1 + 579a485 commit fe90379

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

algorithmic_efficiency/init_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None:
1313
# Perform lecun_normal initialization.
1414
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
1515
std = math.sqrt(1. / fan_in) / .87962566103423978
16-
nn.init.trunc_normal_(module.weight, std=std)
16+
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
1717
if module.bias is not None:
1818
nn.init.constant_(module.bias, 0.)

0 commit comments

Comments
 (0)