@@ -2314,6 +2314,7 @@ class GAN(BaseModel):
23142314 def __init__ (
23152315 self ,
23162316 latent_dim : int = 100 ,
2317+ n_classes : int | None = None ,
23172318 gradient_clip_threshold : float = 0.1 ,
23182319 enable_padding : bool = False ,
23192320 padding_size : int = 32 ,
@@ -2327,6 +2328,7 @@ def __init__(
23272328 super ().__init__ (gradient_clip_threshold , enable_padding , padding_size , random_state )
23282329
23292330 self .latent_dim = latent_dim
2331+ self .n_classes = n_classes
23302332 self .generator = None
23312333 self .discriminator = None
23322334 self .generator_optimizer = None
@@ -2445,10 +2447,26 @@ def backward_pass(self, error: np.ndarray):
24452447
24462448 self .generator .backward_pass (error )
24472449
2448- def _generate_latent_points (self , n_samples : int ) -> np .ndarray :
2450+ def _generate_latent_points (self , n_samples : int , labels : np . ndarray | None = None ) -> tuple [ np .ndarray , np . ndarray ] :
24492451 rng = np .random .default_rng (self .random_state )
24502452 latent_points = rng .normal (0 , 1 , (n_samples , self .latent_dim ))
2451- return latent_points
2453+
2454+ if self .n_classes is not None :
2455+ if labels is None :
2456+ labels = rng .integers (0 , self .n_classes , n_samples )
2457+ elif labels .ndim == 2 and labels .shape [1 ] == self .n_classes :
2458+ return np .concatenate ([latent_points , labels ], axis = 1 ), labels
2459+
2460+ one_hot_labels = np .zeros ((n_samples , self .n_classes ))
2461+ if labels .ndim == 1 :
2462+ one_hot_labels [np .arange (n_samples ), labels ] = 1
2463+ else :
2464+ one_hot_labels = labels
2465+
2466+ latent_points = np .concatenate ([latent_points , one_hot_labels ], axis = 1 )
2467+ return latent_points , one_hot_labels
2468+
2469+ return latent_points , None
24522470
24532471 def _apply_spectral_norm (self , model : 'Sequential' ):
24542472 if not self .use_spectral_norm :
@@ -2495,7 +2513,8 @@ def _ensure_initialized(self, input_data: np.ndarray):
24952513 def train_on_batch (
24962514 self ,
24972515 real_samples : np .ndarray ,
2498- batch_size : int ,
2516+ labels : np .ndarray | None = None ,
2517+ batch_size : int = 32 ,
24992518 n_critic : int = 1
25002519 ) -> tuple [float , float ]:
25012520 rng = np .random .default_rng (self .random_state )
@@ -2504,60 +2523,71 @@ def train_on_batch(
25042523 for _ in range (n_critic ):
25052524 idx = rng .choice (len (real_samples ), batch_size , replace = False )
25062525 real_batch = real_samples [idx ]
2526+ batch_labels = labels [idx ] if labels is not None else None
25072527
2508- noise = rng . standard_normal ( size = ( batch_size , self . latent_dim ) )
2528+ noise , gen_labels = self . _generate_latent_points ( batch_size , batch_labels )
25092529 fake_batch = self .generator .forward_pass (noise , training = False )
25102530
25112531 combined_batch = np .concatenate ([real_batch , fake_batch ])
25122532 combined_labels = np .zeros ((2 * batch_size , 1 ))
25132533 combined_labels [:batch_size ] = 1.0
25142534
2535+ if self .n_classes is not None :
2536+ if batch_labels is not None :
2537+ combined_cond = np .concatenate ([batch_labels , gen_labels ])
2538+ combined_batch = np .concatenate ([combined_batch , combined_cond ], axis = 1 )
2539+
25152540 self .discriminator .y_true = combined_labels
2516- predictions = self .discriminator .forward_pass (
2517- combined_batch , training = True )
2541+ predictions = self .discriminator .forward_pass (combined_batch , training = True )
25182542 d_loss = self .discriminator_loss (combined_labels , predictions )
2519- d_grad = self .discriminator_loss .derivative (
2520- combined_labels , predictions )
2543+ d_grad = self .discriminator_loss .derivative (combined_labels , predictions )
25212544 self .discriminator .backward_pass (d_grad )
25222545
25232546 d_loss_total += d_loss
25242547
25252548 d_loss_avg = d_loss_total / n_critic
25262549
2527- noise = rng . standard_normal ( size = ( batch_size , self .latent_dim ) )
2550+ noise , gen_labels = self ._generate_latent_points ( batch_size )
25282551 fake_samples = self .generator .forward_pass (noise , training = True )
25292552
2553+ if self .n_classes is not None :
2554+ fake_samples_with_cond = np .concatenate ([fake_samples , gen_labels ], axis = 1 )
2555+ else :
2556+ fake_samples_with_cond = fake_samples
2557+
25302558 target_labels = np .ones ((batch_size , 1 ))
25312559
2532- disc_predictions = self .discriminator .forward_pass (
2533- fake_samples , training = False )
2560+ disc_predictions = self .discriminator .forward_pass (fake_samples_with_cond , training = False )
25342561 self .discriminator .y_true = target_labels
25352562 g_loss = self .generator_loss (target_labels , disc_predictions )
2536- g_grad = self .generator_loss .derivative (
2537- target_labels , disc_predictions )
2563+ g_grad = self .generator_loss .derivative (target_labels , disc_predictions )
25382564
25392565 d_grad = self .discriminator .backward_pass (g_grad , compute_only = True )
2566+ if self .n_classes is not None :
2567+ d_grad = d_grad [:, :- self .n_classes ]
25402568 self .generator .backward_pass (d_grad , gan = True )
25412569
25422570 return d_loss_avg , g_loss
25432571
25442572 def fit (
2545- self ,
2546- x_train : np .ndarray ,
2547- epochs : int = 100 ,
2548- batch_size : int | None = None ,
2549- n_critic : int = 5 ,
2550- verbose : bool = True ,
2551- metrics : list | None = None ,
2552- random_state : int | None = None ,
2553- validation_data : tuple | None = None ,
2554- validation_split : float | None = None ,
2555- callbacks : list = [],
2556- plot_generated : bool = False ,
2557- plot_interval : int = 1 ,
2558- fixed_noise : np .ndarray | None = None ,
2559- n_gen_samples : int | None = None
2560- ) -> dict :
2573+ self ,
2574+ x_train : np .ndarray ,
2575+ y_train : np .ndarray | None = None ,
2576+ epochs : int = 100 ,
2577+ batch_size : int | None = None ,
2578+ n_critic : int = 5 ,
2579+ verbose : bool = True ,
2580+ metrics : list | None = None ,
2581+ random_state : int | None = None ,
2582+ validation_data : tuple | None = None ,
2583+ validation_split : float | None = None ,
2584+ callbacks : list = [],
2585+ plot_generated : bool = False ,
2586+ plot_interval : int = 1 ,
2587+ fixed_noise : np .ndarray | None = None ,
2588+ fixed_labels : np .ndarray | None = None ,
2589+ n_gen_samples : int | None = None
2590+ ) -> dict :
25612591
25622592 history = History ({
25632593 'discriminator_loss' : [],
@@ -2574,6 +2604,7 @@ def fit(
25742604 validation_data = (x_val , None )
25752605
25762606 x_train = np .array (x_train ) if not isinstance (x_train , np .ndarray ) else x_train
2607+ y_train = np .array (y_train ) if not isinstance (y_train , np .ndarray ) else y_train
25772608
25782609 if metrics is not None :
25792610 metrics = [Metric (m ) for m in metrics ]
@@ -2586,7 +2617,7 @@ def fit(
25862617
25872618 if plot_generated and fixed_noise is None :
25882619 rng = np .random .default_rng (self .random_state )
2589- fixed_noise = rng .standard_normal ( size = ( 64 , self .latent_dim ))
2620+ fixed_noise = rng .normal ( 0 , 1 , ( 80 , self .latent_dim ))
25902621
25912622 callbacks = callbacks if callbacks is not None else []
25922623
@@ -2613,8 +2644,8 @@ def fit(
26132644 callback .on_epoch_begin (epoch , epoch_logs )
26142645
26152646 start_time = time .time ()
2616- x_train_shuffled = shuffle (
2617- x_train ,
2647+ x_train_shuffled , y_train_shuffled = shuffle (
2648+ x_train , y_train ,
26182649 random_state = random_state if random_state is not None else self .random_state
26192650 )
26202651 d_error = 0
@@ -2631,6 +2662,7 @@ def fit(
26312662 for j in range (0 , x_train .shape [0 ], batch_size ):
26322663 batch_index = j // batch_size
26332664 x_batch = x_train_shuffled [j :j + batch_size ]
2665+ y_batch = y_train_shuffled [j :j + batch_size ] if y_train is not None else None
26342666
26352667 batch_logs = {
26362668 'batch' : batch_index ,
@@ -2642,13 +2674,13 @@ def fit(
26422674 callback .on_batch_begin (batch_index , batch_logs )
26432675
26442676 d_loss , g_loss = self .train_on_batch (
2645- x_batch , min (batch_size , len (x_batch )), n_critic )
2677+ x_batch , y_batch , min (batch_size , len (x_batch )), n_critic )
26462678 d_error += d_loss
26472679 g_error += g_loss
26482680
26492681 batch_metrics = {}
26502682 if metrics is not None :
2651- noise = self ._generate_latent_points (len (x_batch ))
2683+ noise = self ._generate_latent_points (len (x_batch ), y_batch )
26522684 generated_samples = self .forward_pass (noise , training = False )
26532685 for metric in metrics :
26542686 metric_value = metric (generated_samples , x_batch )
@@ -2688,11 +2720,11 @@ def fit(
26882720 metric_values [k ] /= num_batches
26892721
26902722 else :
2691- d_error , g_error = self .train_on_batch (x_train , len (x_train ), n_critic )
2723+ d_error , g_error = self .train_on_batch (x_train , y_train , len (x_train ), n_critic )
26922724
26932725 if metrics is not None :
26942726 noise = self ._generate_latent_points (
2695- len (x_train ) if n_gen_samples is None else n_gen_samples )
2727+ len (x_train ) if n_gen_samples is None else n_gen_samples , y_train )
26962728 generated_samples = self .forward_pass (noise , training = False )
26972729
26982730 for metric in metrics :
@@ -2765,21 +2797,37 @@ def fit(
27652797 return history
27662798
27672799 def _plot_samples (self , noise : np .ndarray , epoch : int ):
2768- generated = self .forward_pass (noise , training = False )
2800+ if self .n_classes is not None :
2801+ samples_per_class = noise .shape [0 ] // self .n_classes
2802+ labels = np .repeat (np .arange (self .n_classes ), samples_per_class )
2803+ one_hot_labels = np .zeros ((len (labels ), self .n_classes ))
2804+ one_hot_labels [np .arange (len (labels )), labels ] = 1
2805+ latent_points = np .concatenate ([noise , one_hot_labels ], axis = 1 )
2806+ else :
2807+ latent_points = noise
27692808
2809+ generated = self .generator .forward_pass (latent_points , training = False )
2810+
27702811 height , width = self .image_dimensions
2771- sample = generated [0 ].reshape (height , width )
2772-
2773- plt .figure (figsize = (4 , 4 ))
2774- plt .imshow (sample , cmap = 'gray_r' , interpolation = 'nearest' )
2812+ n_rows = samples_per_class if self .n_classes else 8
2813+ n_cols = self .n_classes if self .n_classes else 8
2814+ figure = np .zeros ((height * n_rows , width * n_cols ))
2815+
2816+ for i in range (n_rows ):
2817+ for j in range (n_cols ):
2818+ sample_idx = i * n_cols + j
2819+ sample = generated [sample_idx ].reshape (height , width )
2820+ figure [i * height :(i + 1 ) * height , j * width :(j + 1 ) * width ] = sample
2821+
2822+ plt .figure (figsize = (10 , 8 ))
2823+ plt .imshow (figure , cmap = 'gray_r' , interpolation = 'nearest' )
27752824 plt .axis ('off' )
2776- plt .tight_layout ()
2777-
2778- plt .savefig (f'video{ str (epoch ).zfill (2 )} .png' )
2825+ plt .tight_layout (pad = 0 )
2826+ plt .savefig (f'video{ str (epoch ).zfill (2 )} .png' , bbox_inches = 'tight' , pad_inches = 0 )
27792827 plt .close ()
27802828
2781- def predict (self , n_samples : int , temperature : float = 1.0 ) -> np .ndarray :
2782- latent_points = self ._generate_latent_points (n_samples )
2829+ def predict (self , n_samples : int , labels : np . ndarray | None = None , temperature : float = 1.0 ) -> np .ndarray :
2830+ latent_points , _ = self ._generate_latent_points (n_samples , labels )
27832831 return self .generator .predict (latent_points , temperature )
27842832
27852833 def evaluate (
0 commit comments