@@ -1529,9 +1529,11 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
15291529 )
15301530 self .model_args = model_args
15311531
1532- def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1533- _init_weights_tok_embeddings (self )
1534- _init_weights_layers (self , buffer_device )
1532+ def init_weights (
1533+ self , buffer_device : torch .device | None = None , seed : int | None = None
1534+ ) -> None :
1535+ _init_weights_tok_embeddings (self , seed )
1536+ _init_weights_layers (self , buffer_device , seed )
15351537 _init_weights_norm_and_output (self )
15361538
15371539 def forward (
@@ -1585,8 +1587,10 @@ def forward(self, h):
15851587 h = layer (h , self .freqs_cis )
15861588 return h
15871589
1588- def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1589- _init_weights_layers (self , buffer_device )
1590+ def init_weights (
1591+ self , buffer_device : torch .device | None = None , seed : int | None = None
1592+ ) -> None :
1593+ _init_weights_layers (self , buffer_device , seed )
15901594
15911595
15921596class DeepSeekV3Stage0 (DeepSeekV3StageI ):
@@ -1600,9 +1604,11 @@ def forward(self, tokens):
16001604 # torch.Size([1024, 1024, 2048])
16011605 return super ().forward (h )
16021606
1603- def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1604- _init_weights_tok_embeddings (self )
1605- super ().init_weights (buffer_device = buffer_device )
1607+ def init_weights (
1608+ self , buffer_device : torch .device | None = None , seed : int | None = None
1609+ ) -> None :
1610+ _init_weights_tok_embeddings (self , seed )
1611+ super ().init_weights (buffer_device , seed )
16061612
16071613
16081614class DeepSeekV3StageN (DeepSeekV3StageI ):
@@ -1618,8 +1624,10 @@ def forward(self, h):
16181624 output = self .output (h ) if self .output is not None else h
16191625 return output
16201626
1621- def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1622- super ().init_weights (buffer_device = buffer_device )
1627+ def init_weights (
1628+ self , buffer_device : torch .device | None = None , seed : int | None = None
1629+ ) -> None :
1630+ super ().init_weights (buffer_device , seed )
16231631 _init_weights_norm_and_output (self )
16241632
16251633
@@ -1628,23 +1636,30 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
16281636######################
16291637
16301638
1631- def _init_weights_tok_embeddings (self : Union [DeepSeekV3Model , DeepSeekV3Stage0 ]):
1639+ def _init_weights_tok_embeddings (
1640+ self : Union [DeepSeekV3Model , DeepSeekV3Stage0 ], seed : int | None = None
1641+ ):
1642+ if seed is not None :
1643+ torch .manual_seed (seed )
16321644 if self .tok_embeddings is not None :
16331645 nn .init .normal_ (self .tok_embeddings .weight )
16341646
16351647
16361648def _init_weights_layers (
16371649 self : Union [DeepSeekV3Model , DeepSeekV3StageI ],
16381650 buffer_device : torch .device | None ,
1651+ seed : int | None = None ,
16391652):
16401653 if buffer_device is None :
16411654 buffer_device = self .freqs_cis .device # type: ignore[assignment]
16421655 with torch .device (buffer_device ): # type: ignore[arg-type]
16431656 self .freqs_cis = precompute_freqs_cis (self .model_args )
1644- for layer in self .layers .values ():
1657+ for i , layer in enumerate (self .layers .values ()):
1658+ if seed is not None :
1659+ torch .manual_seed (seed )
16451660 if layer is not None :
16461661 assert isinstance (layer , TransformerBlock )
1647- layer .init_weights (buffer_device = buffer_device ) # type: ignore[arg-type]
1662+ layer .init_weights (buffer_device ) # type: ignore[arg-type]
16481663
16491664
16501665def _init_weights_norm_and_output (self : Union [DeepSeekV3Model , DeepSeekV3StageN ]):
0 commit comments