Skip to content

Commit 5460a72

Browse files
committed
feat(GAN): add Conditional GAN (CGAN)
1 parent 8cc464c commit 5460a72

File tree

5 files changed

+182
-88
lines changed

5 files changed

+182
-88
lines changed

cgan_samples_epoch_1.png

340 KB
Loading

examples/models-usages/generation/gan-image-generation/gan-mnist-dense.ipynb

Lines changed: 88 additions & 42 deletions
Large diffs are not rendered by default.
1.86 MB
Loading

neuralnetlib/models.py

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,7 @@ class GAN(BaseModel):
23142314
def __init__(
23152315
self,
23162316
latent_dim: int = 100,
2317+
n_classes: int | None = None,
23172318
gradient_clip_threshold: float = 0.1,
23182319
enable_padding: bool = False,
23192320
padding_size: int = 32,
@@ -2327,6 +2328,7 @@ def __init__(
23272328
super().__init__(gradient_clip_threshold, enable_padding, padding_size, random_state)
23282329

23292330
self.latent_dim = latent_dim
2331+
self.n_classes = n_classes
23302332
self.generator = None
23312333
self.discriminator = None
23322334
self.generator_optimizer = None
@@ -2445,10 +2447,26 @@ def backward_pass(self, error: np.ndarray):
24452447

24462448
self.generator.backward_pass(error)
24472449

2448-
def _generate_latent_points(self, n_samples: int) -> np.ndarray:
2450+
def _generate_latent_points(self, n_samples: int, labels: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
24492451
rng = np.random.default_rng(self.random_state)
24502452
latent_points = rng.normal(0, 1, (n_samples, self.latent_dim))
2451-
return latent_points
2453+
2454+
if self.n_classes is not None:
2455+
if labels is None:
2456+
labels = rng.integers(0, self.n_classes, n_samples)
2457+
elif labels.ndim == 2 and labels.shape[1] == self.n_classes:
2458+
return np.concatenate([latent_points, labels], axis=1), labels
2459+
2460+
one_hot_labels = np.zeros((n_samples, self.n_classes))
2461+
if labels.ndim == 1:
2462+
one_hot_labels[np.arange(n_samples), labels] = 1
2463+
else:
2464+
one_hot_labels = labels
2465+
2466+
latent_points = np.concatenate([latent_points, one_hot_labels], axis=1)
2467+
return latent_points, one_hot_labels
2468+
2469+
return latent_points, None
24522470

24532471
def _apply_spectral_norm(self, model: 'Sequential'):
24542472
if not self.use_spectral_norm:
@@ -2495,7 +2513,8 @@ def _ensure_initialized(self, input_data: np.ndarray):
24952513
def train_on_batch(
24962514
self,
24972515
real_samples: np.ndarray,
2498-
batch_size: int,
2516+
labels: np.ndarray | None = None,
2517+
batch_size: int = 32,
24992518
n_critic: int = 1
25002519
) -> tuple[float, float]:
25012520
rng = np.random.default_rng(self.random_state)
@@ -2504,60 +2523,71 @@ def train_on_batch(
25042523
for _ in range(n_critic):
25052524
idx = rng.choice(len(real_samples), batch_size, replace=False)
25062525
real_batch = real_samples[idx]
2526+
batch_labels = labels[idx] if labels is not None else None
25072527

2508-
noise = rng.standard_normal(size=(batch_size, self.latent_dim))
2528+
noise, gen_labels = self._generate_latent_points(batch_size, batch_labels)
25092529
fake_batch = self.generator.forward_pass(noise, training=False)
25102530

25112531
combined_batch = np.concatenate([real_batch, fake_batch])
25122532
combined_labels = np.zeros((2 * batch_size, 1))
25132533
combined_labels[:batch_size] = 1.0
25142534

2535+
if self.n_classes is not None:
2536+
if batch_labels is not None:
2537+
combined_cond = np.concatenate([batch_labels, gen_labels])
2538+
combined_batch = np.concatenate([combined_batch, combined_cond], axis=1)
2539+
25152540
self.discriminator.y_true = combined_labels
2516-
predictions = self.discriminator.forward_pass(
2517-
combined_batch, training=True)
2541+
predictions = self.discriminator.forward_pass(combined_batch, training=True)
25182542
d_loss = self.discriminator_loss(combined_labels, predictions)
2519-
d_grad = self.discriminator_loss.derivative(
2520-
combined_labels, predictions)
2543+
d_grad = self.discriminator_loss.derivative(combined_labels, predictions)
25212544
self.discriminator.backward_pass(d_grad)
25222545

25232546
d_loss_total += d_loss
25242547

25252548
d_loss_avg = d_loss_total / n_critic
25262549

2527-
noise = rng.standard_normal(size=(batch_size, self.latent_dim))
2550+
noise, gen_labels = self._generate_latent_points(batch_size)
25282551
fake_samples = self.generator.forward_pass(noise, training=True)
25292552

2553+
if self.n_classes is not None:
2554+
fake_samples_with_cond = np.concatenate([fake_samples, gen_labels], axis=1)
2555+
else:
2556+
fake_samples_with_cond = fake_samples
2557+
25302558
target_labels = np.ones((batch_size, 1))
25312559

2532-
disc_predictions = self.discriminator.forward_pass(
2533-
fake_samples, training=False)
2560+
disc_predictions = self.discriminator.forward_pass(fake_samples_with_cond, training=False)
25342561
self.discriminator.y_true = target_labels
25352562
g_loss = self.generator_loss(target_labels, disc_predictions)
2536-
g_grad = self.generator_loss.derivative(
2537-
target_labels, disc_predictions)
2563+
g_grad = self.generator_loss.derivative(target_labels, disc_predictions)
25382564

25392565
d_grad = self.discriminator.backward_pass(g_grad, compute_only=True)
2566+
if self.n_classes is not None:
2567+
d_grad = d_grad[:, :-self.n_classes]
25402568
self.generator.backward_pass(d_grad, gan=True)
25412569

25422570
return d_loss_avg, g_loss
25432571

25442572
def fit(
2545-
self,
2546-
x_train: np.ndarray,
2547-
epochs: int = 100,
2548-
batch_size: int | None = None,
2549-
n_critic: int = 5,
2550-
verbose: bool = True,
2551-
metrics: list | None = None,
2552-
random_state: int | None = None,
2553-
validation_data: tuple | None = None,
2554-
validation_split: float | None = None,
2555-
callbacks: list = [],
2556-
plot_generated: bool = False,
2557-
plot_interval: int = 1,
2558-
fixed_noise: np.ndarray | None = None,
2559-
n_gen_samples: int | None = None
2560-
) -> dict:
2573+
self,
2574+
x_train: np.ndarray,
2575+
y_train: np.ndarray | None = None,
2576+
epochs: int = 100,
2577+
batch_size: int | None = None,
2578+
n_critic: int = 5,
2579+
verbose: bool = True,
2580+
metrics: list | None = None,
2581+
random_state: int | None = None,
2582+
validation_data: tuple | None = None,
2583+
validation_split: float | None = None,
2584+
callbacks: list = [],
2585+
plot_generated: bool = False,
2586+
plot_interval: int = 1,
2587+
fixed_noise: np.ndarray | None = None,
2588+
fixed_labels: np.ndarray | None = None,
2589+
n_gen_samples: int | None = None
2590+
) -> dict:
25612591

25622592
history = History({
25632593
'discriminator_loss': [],
@@ -2574,6 +2604,7 @@ def fit(
25742604
validation_data = (x_val, None)
25752605

25762606
x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train
2607+
y_train = np.array(y_train) if not isinstance(y_train, np.ndarray) else y_train
25772608

25782609
if metrics is not None:
25792610
metrics = [Metric(m) for m in metrics]
@@ -2586,7 +2617,7 @@ def fit(
25862617

25872618
if plot_generated and fixed_noise is None:
25882619
rng = np.random.default_rng(self.random_state)
2589-
fixed_noise = rng.standard_normal(size=(64, self.latent_dim))
2620+
fixed_noise = rng.normal(0, 1, (80, self.latent_dim))
25902621

25912622
callbacks = callbacks if callbacks is not None else []
25922623

@@ -2613,8 +2644,8 @@ def fit(
26132644
callback.on_epoch_begin(epoch, epoch_logs)
26142645

26152646
start_time = time.time()
2616-
x_train_shuffled = shuffle(
2617-
x_train,
2647+
x_train_shuffled, y_train_shuffled = shuffle(
2648+
x_train, y_train,
26182649
random_state=random_state if random_state is not None else self.random_state
26192650
)
26202651
d_error = 0
@@ -2631,6 +2662,7 @@ def fit(
26312662
for j in range(0, x_train.shape[0], batch_size):
26322663
batch_index = j // batch_size
26332664
x_batch = x_train_shuffled[j:j + batch_size]
2665+
y_batch = y_train_shuffled[j:j + batch_size] if y_train is not None else None
26342666

26352667
batch_logs = {
26362668
'batch': batch_index,
@@ -2642,13 +2674,13 @@ def fit(
26422674
callback.on_batch_begin(batch_index, batch_logs)
26432675

26442676
d_loss, g_loss = self.train_on_batch(
2645-
x_batch, min(batch_size, len(x_batch)), n_critic)
2677+
x_batch, y_batch, min(batch_size, len(x_batch)), n_critic)
26462678
d_error += d_loss
26472679
g_error += g_loss
26482680

26492681
batch_metrics = {}
26502682
if metrics is not None:
2651-
noise = self._generate_latent_points(len(x_batch))
2683+
noise = self._generate_latent_points(len(x_batch), y_batch)
26522684
generated_samples = self.forward_pass(noise, training=False)
26532685
for metric in metrics:
26542686
metric_value = metric(generated_samples, x_batch)
@@ -2688,11 +2720,11 @@ def fit(
26882720
metric_values[k] /= num_batches
26892721

26902722
else:
2691-
d_error, g_error = self.train_on_batch(x_train, len(x_train), n_critic)
2723+
d_error, g_error = self.train_on_batch(x_train, y_train, len(x_train), n_critic)
26922724

26932725
if metrics is not None:
26942726
noise = self._generate_latent_points(
2695-
len(x_train) if n_gen_samples is None else n_gen_samples)
2727+
len(x_train) if n_gen_samples is None else n_gen_samples, y_train)
26962728
generated_samples = self.forward_pass(noise, training=False)
26972729

26982730
for metric in metrics:
@@ -2765,21 +2797,37 @@ def fit(
27652797
return history
27662798

27672799
def _plot_samples(self, noise: np.ndarray, epoch: int):
2768-
generated = self.forward_pass(noise, training=False)
2800+
if self.n_classes is not None:
2801+
samples_per_class = noise.shape[0] // self.n_classes
2802+
labels = np.repeat(np.arange(self.n_classes), samples_per_class)
2803+
one_hot_labels = np.zeros((len(labels), self.n_classes))
2804+
one_hot_labels[np.arange(len(labels)), labels] = 1
2805+
latent_points = np.concatenate([noise, one_hot_labels], axis=1)
2806+
else:
2807+
latent_points = noise
27692808

2809+
generated = self.generator.forward_pass(latent_points, training=False)
2810+
27702811
height, width = self.image_dimensions
2771-
sample = generated[0].reshape(height, width)
2772-
2773-
plt.figure(figsize=(4, 4))
2774-
plt.imshow(sample, cmap='gray_r', interpolation='nearest')
2812+
n_rows = samples_per_class if self.n_classes else 8
2813+
n_cols = self.n_classes if self.n_classes else 8
2814+
figure = np.zeros((height * n_rows, width * n_cols))
2815+
2816+
for i in range(n_rows):
2817+
for j in range(n_cols):
2818+
sample_idx = i * n_cols + j
2819+
sample = generated[sample_idx].reshape(height, width)
2820+
figure[i * height:(i + 1) * height, j * width:(j + 1) * width] = sample
2821+
2822+
plt.figure(figsize=(10, 8))
2823+
plt.imshow(figure, cmap='gray_r', interpolation='nearest')
27752824
plt.axis('off')
2776-
plt.tight_layout()
2777-
2778-
plt.savefig(f'video{str(epoch).zfill(2)}.png')
2825+
plt.tight_layout(pad=0)
2826+
plt.savefig(f'video{str(epoch).zfill(2)}.png', bbox_inches='tight', pad_inches=0)
27792827
plt.close()
27802828

2781-
def predict(self, n_samples: int, temperature: float = 1.0) -> np.ndarray:
2782-
latent_points = self._generate_latent_points(n_samples)
2829+
def predict(self, n_samples: int, labels: np.ndarray | None = None, temperature: float = 1.0) -> np.ndarray:
2830+
latent_points, _ = self._generate_latent_points(n_samples, labels)
27832831
return self.generator.predict(latent_points, temperature)
27842832

27852833
def evaluate(

0 commit comments

Comments
 (0)