Skip to content

Commit efa508a

Browse files
author
Francisco Santos
committed
revert TS data processor integration
1 parent 611cb67 commit efa508a

File tree

1 file changed

+6
-12
lines changed
  • src/ydata_synthetic/synthesizers/timeseries/tscwgan

1 file changed

+6
-12
lines changed

src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from tqdm import trange
77
from numpy import array, vstack, hstack
88
from numpy.random import normal
9-
from typing import List
109

1110
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
1211
from tensorflow import data as tfdata
@@ -24,9 +23,9 @@ class TSCWGAN(BaseModel):
2423

2524
def __init__(self, model_parameters, gradient_penalty_weight=10):
2625
"""Create a base TSCWGAN."""
27-
super().__init__(model_parameters)
2826
self.gradient_penalty_weight = gradient_penalty_weight
2927
self.cond_dim = model_parameters.condition
28+
super().__init__(model_parameters)
3029

3130
def define_gan(self):
3231
self.generator = Generator(self.batch_size). \
@@ -45,18 +44,14 @@ def define_gan(self):
4544
score = concat([cond, gen], axis=1)
4645
score = self.critic(score)
4746

48-
def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str],
49-
preprocess: bool = True):
50-
super().train(data, num_cols, cat_cols, preprocess)
51-
52-
processed_data = self.processor.transform(data)
53-
real_batches = self.get_batch_data(processed_data)
47+
def train(self, data, train_arguments: TrainParameters):
48+
real_batches = self.get_batch_data(data)
5449
noise_batches = self.get_batch_noise()
5550

5651
for epoch in trange(train_arguments.epochs):
5752
for i in range(train_arguments.critic_iter):
5853
real_batch = next(real_batches)
59-
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate noise tensor to real data shape
54+
noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor
6055

6156
c_loss = self.update_critic(real_batch, noise_batch)
6257

@@ -149,10 +144,9 @@ def get_batch_data(self, data, n_windows= None):
149144

150145
def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
151146
"""For a given condition, produce n_samples of length seq_len.
152-
The samples are returned in the original data format (any preprocessing transformation is inverted).
153147
154148
Args:
155-
condition (numpy.array): Condition for the generated samples, must have the same length .
149+
condition (numpy.array): Condition for the generated samples, must have the same length.
156150
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
157151
seq_len (int): Length of the generated samples.
158152
@@ -175,7 +169,7 @@ def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
175169
data_.append(records)
176170
data_ = hstack(data_)[:, :seq_len]
177171
data.append(data_)
178-
return self.processor.inverse_transform(array(vstack(data)))
172+
return array(vstack(data))
179173

180174

181175
class Generator(Model):

0 commit comments

Comments
 (0)