@@ -61,7 +61,22 @@ def log(t, eps = 1e-20):
6161def entropy (prob , eps = 1e-5 ):
6262 return (- prob * log (prob , eps = eps )).sum (dim = - 1 )
6363
64+ def accum_grad_ (t , grad ):
65+ if exists (t .grad ):
66+ t .grad .add_ (grad )
67+ else :
68+ t .grad = grad .clone ().detach ()
69+
6470def ema_inplace (old , new , decay , weight = None ):
71+
72+ # if old.grad is populated, add it to new and set it to None
73+
74+ if exists (old .grad ):
75+ new .add_ (old .grad )
76+ old .grad = None
77+
78+ # take care of custom weighting
79+
6580 weight = default (weight , 1. )
6681
6782 if is_tensor (weight ):
@@ -71,7 +86,7 @@ def ema_inplace(old, new, decay, weight = None):
7186 assert weight .ndim == 2 and weight .shape == old .shape [:2 ]
7287 weight = append_dims_to (weight , old .ndim )
7388
74- old .lerp_ (new , (1. - decay ) * weight )
89+ old .data . lerp_ (new , (1. - decay ) * weight )
7590
7691def pack_one (t , pattern ):
7792 packed , ps = pack ([t ], pattern )
@@ -511,7 +526,8 @@ def forward(
511526 mask = None ,
512527 freeze_codebook = False ,
513528 codebook_transform_fn : Callable | None = None ,
514- ema_update_weight : Tensor | Callable | None = None
529+ ema_update_weight : Tensor | Callable | None = None ,
530+ accum_ema_update = False
515531 ):
516532 needs_codebook_dim = x .ndim < 4
517533 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -603,12 +619,16 @@ def forward(
603619 if callable (ema_update_weight ):
604620 ema_update_weight = ema_update_weight (embed_sum , cluster_size )
605621
606- ema_inplace (self .cluster_size .data , cluster_size , self .decay , ema_update_weight )
607- ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
622+ if accum_ema_update :
623+ accum_grad_ (self .cluster_size , cluster_size )
624+ accum_grad_ (self .embed_avg , embed_sum )
625+ else :
626+ ema_inplace (self .cluster_size , cluster_size , self .decay , ema_update_weight )
627+ ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
608628
609- if not self .manual_ema_update :
610- self .update_ema ()
611- self .expire_codes_ (x )
629+ if not self .manual_ema_update :
630+ self .update_ema ()
631+ self .expire_codes_ (x )
612632
613633 if needs_codebook_dim :
614634 quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -743,7 +763,8 @@ def forward(
743763 mask = None ,
744764 freeze_codebook = False ,
745765 codebook_transform_fn : Callable | None = None ,
746- ema_update_weight : Tensor | None = None
766+ ema_update_weight : Tensor | None = None ,
767+ accum_ema_update = False
747768 ):
748769 needs_codebook_dim = x .ndim < 4
749770 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -819,12 +840,17 @@ def forward(
819840 if callable (ema_update_weight ):
820841 ema_update_weight = ema_update_weight (embed_sum , bins )
821842
822- ema_inplace (self .cluster_size .data , bins , self .decay , ema_update_weight )
823- ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
843+ if accum_ema_update :
844+ accum_grad_ (self .cluster_size , bins )
845+ accum_grad_ (self .embed_avg , embed_sum )
846+ else :
847+
848+ ema_inplace (self .cluster_size , bins , self .decay , ema_update_weight )
849+ ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
824850
825- if not self .manual_ema_update :
826- self .update_ema ()
827- self .expire_codes_ (x )
851+ if not self .manual_ema_update :
852+ self .update_ema ()
853+ self .expire_codes_ (x )
828854
829855 if needs_codebook_dim :
830856 quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -1062,7 +1088,8 @@ def forward(
10621088 freeze_codebook = None ,
10631089 return_loss_breakdown = False ,
10641090 codebook_transform_fn : Callable | None = None ,
1065- ema_update_weight : Tensor | None = None
1091+ ema_update_weight : Tensor | None = None ,
1092+ accum_ema_update = False
10661093 ):
10671094 orig_input , input_requires_grad = x , x .requires_grad
10681095
@@ -1119,7 +1146,8 @@ def forward(
11191146 mask = mask ,
11201147 freeze_codebook = freeze_codebook ,
11211148 codebook_transform_fn = codebook_transform_fn ,
1122- ema_update_weight = ema_update_weight
1149+ ema_update_weight = ema_update_weight ,
1150+ accum_ema_update = accum_ema_update
11231151 )
11241152
11251153 # quantize
0 commit comments