@@ -1541,22 +1541,17 @@ def fit(self, x_train: np.ndarray,
15411541 def generate_image (self , x_train : np .ndarray , n_samples : int = 10 , seed : int | None = None , n_examples : int = 1000 ) -> np .ndarray :
15421542 _ = self .forward_pass (x_train [:n_examples ])
15431543
1544- latent_mean_stats = np . mean ( self .latent_mean , axis = 0 )
1545- latent_std_stats = np . exp ( 0.5 * np . mean ( self . latent_log_var , axis = 0 ) )
1544+ if not self .variational :
1545+ raise ValueError ( "generate_image requires variational=True" )
15461546
1547- latent_mean_repeated = np .tile ( latent_mean_stats , ( n_samples , 1 ) )
1548- latent_std_repeated = np .tile ( latent_std_stats , ( n_samples , 1 ))
1547+ mu = np .mean ( self . latent_mean , axis = 0 )
1548+ sigma = np .exp ( 0.5 * np . mean ( self . latent_log_var , axis = 0 ))
15491549
1550- rng = np .random .default_rng (
1551- seed if seed is not None else self .random_state )
1550+ rng = np .random .default_rng (seed if seed is not None else self .random_state )
15521551 noise = rng .standard_normal (size = (n_samples , self .latent_dim ))
1552+ z = mu [None , :] + noise * sigma [None , :]
15531553
1554- latent_samples = np .concatenate ([
1555- latent_mean_repeated + noise *
1556- latent_std_repeated , np .zeros ((n_samples , self .latent_dim ))
1557- ], axis = 1 )
1558-
1559- generated = latent_samples
1554+ generated = z
15601555 for layer in self .decoder_layers :
15611556 if isinstance (layer , (Dropout , LSTM , Bidirectional , GRU )):
15621557 generated = layer .forward_pass (generated , training = False )
0 commit comments