Skip to content

Commit 611cb67

Browse files
author
Francisco Santos
committed
Auto regressive timeseries sampling method
1 parent 3f6cbe5 commit 611cb67

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

examples/timeseries/tscwgan_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from numpy import squeeze
1+
from numpy import reshape
22

33
from ydata_synthetic.preprocessing.timeseries import processed_stock
44
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
@@ -51,9 +51,9 @@
5151
#Sampling the data
5252
#Note that the data returned is not inverse processed.
5353
cond_index = 100 # Arbitrary sequence for conditioning
54-
cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1)
54+
cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1))
5555

56-
data_sample = synth.sample(cond_array, 1000)
56+
data_sample = synth.sample(cond_array, 1000, 100)
5757

5858
# Inverting the scaling of the synthetic samples
5959
data_sample = inverse_transform(data_sample, scaler)

src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
And on: https://github.com/CasperHogenboom/WGAN_financial_time-series
55
"""
66
from tqdm import trange
7-
from numpy import array, vstack
7+
from numpy import array, vstack, hstack
88
from numpy.random import normal
99
from typing import List
1010

11-
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, constant
11+
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
1212
from tensorflow import data as tfdata
1313
from tensorflow.keras import Model, Sequential
1414
from tensorflow.keras.optimizers import Adam
@@ -17,7 +17,6 @@
1717
from ydata_synthetic.synthesizers.gan import BaseModel
1818
from ydata_synthetic.synthesizers import TrainParameters
1919
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
20-
from ydata_synthetic.synthesizers.timeseries import TimeSeriesDataProcessor
2120

2221
class TSCWGAN(BaseModel):
2322

@@ -148,31 +147,35 @@ def get_batch_data(self, data, n_windows= None):
148147
.shuffle(buffer_size=n_windows)
149148
.batch(self.batch_size).repeat())
150149

151-
def sample(self, cond_array: array, n_samples: int, inverse_transform: bool = True):
152-
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.
153-
The returned samples per condition will always be a multiple of batch_size and equal or bigger than n_samples.
154-
155-
Arguments:
156-
cond_array (numpy array): Array with the set of conditions for the sampling process.
157-
n_samples (int): Number of samples to be taken for each condition in cond_array.
158-
inverse_transform (bool): """
159-
assert len(cond_array.shape) == 2, "Condition array should be two-dimensional. N_conditions x cond_dim"
160-
assert cond_array.shape[1] == self.cond_dim, \
161-
f"The condition sequences should have a {self.cond_dim} length."
162-
steps = n_samples // self.batch_size + 1
150+
def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
151+
"""For a given condition, produce n_samples of length seq_len.
152+
The samples are returned in the original data format (any preprocessing transformation is inverted).
153+
154+
Args:
155+
condition (numpy.array): Condition for the generated samples, must have the same length .
156+
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
157+
seq_len (int): Length of the generated samples.
158+
159+
Returns:
160+
data (numpy.array): An array of data of shape [n_samples, seq_len]"""
161+
assert len(condition.shape) == 2, "Condition array should be two-dimensional."
162+
assert condition.shape[1] == self.cond_dim, \
163+
f"The condition sequence should have {self.cond_dim} length."
164+
batches = n_samples // self.batch_size + 1
165+
ar_steps = seq_len // self.data_dim + 1
163166
data = []
164167
z_dist = self.get_batch_noise()
165-
for condition in cond_array:
168+
for batch in trange(batches, desc=f'Synthetic data generation'):
169+
data_ = []
166170
cond_seq = convert_to_tensor(condition, float32)
167-
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
168-
for step in trange(steps, desc=f'Synthetic data generation'):
169-
gen_input = concat([cond_seq, next(z_dist)], axis=1)
171+
gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1)
172+
for step in range(ar_steps):
170173
records = self.generator(gen_input, training=False)
171-
data.append(records)
172-
data = array(vstack(data))
173-
if inverse_transform:
174-
return self.processor.inverse_transform(data)
175-
return data
174+
gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1)
175+
data_.append(records)
176+
data_ = hstack(data_)[:, :seq_len]
177+
data.append(data_)
178+
return self.processor.inverse_transform(array(vstack(data)))
176179

177180

178181
class Generator(Model):

0 commit comments

Comments
 (0)