Skip to content

Commit 914aa78

Browse files
committed
fix: generate_image vae
1 parent 3d2ca56 commit 914aa78

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

neuralnetlib/models.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,22 +1541,17 @@ def fit(self, x_train: np.ndarray,
15411541
def generate_image(self, x_train: np.ndarray, n_samples: int = 10, seed: int | None = None, n_examples: int = 1000) -> np.ndarray:
15421542
_ = self.forward_pass(x_train[:n_examples])
15431543

1544-
latent_mean_stats = np.mean(self.latent_mean, axis=0)
1545-
latent_std_stats = np.exp(0.5 * np.mean(self.latent_log_var, axis=0))
1544+
if not self.variational:
1545+
raise ValueError("generate_image requires variational=True")
15461546

1547-
latent_mean_repeated = np.tile(latent_mean_stats, (n_samples, 1))
1548-
latent_std_repeated = np.tile(latent_std_stats, (n_samples, 1))
1547+
mu = np.mean(self.latent_mean, axis=0)
1548+
sigma = np.exp(0.5 * np.mean(self.latent_log_var, axis=0))
15491549

1550-
rng = np.random.default_rng(
1551-
seed if seed is not None else self.random_state)
1550+
rng = np.random.default_rng(seed if seed is not None else self.random_state)
15521551
noise = rng.standard_normal(size=(n_samples, self.latent_dim))
1552+
z = mu[None, :] + noise * sigma[None, :]
15531553

1554-
latent_samples = np.concatenate([
1555-
latent_mean_repeated + noise *
1556-
latent_std_repeated, np.zeros((n_samples, self.latent_dim))
1557-
], axis=1)
1558-
1559-
generated = latent_samples
1554+
generated = z
15601555
for layer in self.decoder_layers:
15611556
if isinstance(layer, (Dropout, LSTM, Bidirectional, GRU)):
15621557
generated = layer.forward_pass(generated, training=False)

0 commit comments

Comments
 (0)