66from tqdm import trange
77from numpy import array , vstack , hstack
88from numpy .random import normal
9- from typing import List
109
1110from tensorflow import concat , float32 , convert_to_tensor , reshape , GradientTape , reduce_mean , tile
1211from tensorflow import data as tfdata
@@ -24,9 +23,9 @@ class TSCWGAN(BaseModel):
2423
2524 def __init__ (self , model_parameters , gradient_penalty_weight = 10 ):
2625 """Create a base TSCWGAN."""
27- super ().__init__ (model_parameters )
2826 self .gradient_penalty_weight = gradient_penalty_weight
2927 self .cond_dim = model_parameters .condition
28+ super ().__init__ (model_parameters )
3029
3130 def define_gan (self ):
3231 self .generator = Generator (self .batch_size ). \
@@ -45,18 +44,14 @@ def define_gan(self):
4544 score = concat ([cond , gen ], axis = 1 )
4645 score = self .critic (score )
4746
48- def train (self , data , train_arguments : TrainParameters , num_cols : List [str ], cat_cols : List [str ],
49- preprocess : bool = True ):
50- super ().train (data , num_cols , cat_cols , preprocess )
51-
52- processed_data = self .processor .transform (data )
53- real_batches = self .get_batch_data (processed_data )
47+ def train (self , data , train_arguments : TrainParameters ):
48+ real_batches = self .get_batch_data (data )
5449 noise_batches = self .get_batch_noise ()
5550
5651 for epoch in trange (train_arguments .epochs ):
5752 for i in range (train_arguments .critic_iter ):
5853 real_batch = next (real_batches )
59- noise_batch = next (noise_batches )[:len (real_batch )] # Truncate noise tensor to real data shape
54+ noise_batch = next (noise_batches )[:len (real_batch )] # Truncate the noise tensor in the shape of the real data tensor
6055
6156 c_loss = self .update_critic (real_batch , noise_batch )
6257
@@ -149,10 +144,9 @@ def get_batch_data(self, data, n_windows= None):
149144
150145 def sample (self , condition : array , n_samples : int = 100 , seq_len : int = 24 ):
151146 """For a given condition, produce n_samples of length seq_len.
152- The samples are returned in the original data format (any preprocessing transformation is inverted).
153147
154148 Args:
155- condition (numpy.array): Condition for the generated samples, must have the same length .
149+ condition (numpy.array): Condition for the generated samples, must have the same length.
156150 n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
157151 seq_len (int): Length of the generated samples.
158152
@@ -175,7 +169,7 @@ def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
175169 data_ .append (records )
176170 data_ = hstack (data_ )[:, :seq_len ]
177171 data .append (data_ )
178- return self . processor . inverse_transform ( array (vstack (data ) ))
172+ return array (vstack (data ))
179173
180174
181175class Generator (Model ):
0 commit comments