|  | 
| 4 | 4 | And on: https://github.com/CasperHogenboom/WGAN_financial_time-series | 
| 5 | 5 | """ | 
| 6 | 6 | from tqdm import trange | 
| 7 |  | -from numpy import array, vstack | 
|  | 7 | +from numpy import array, vstack, hstack | 
| 8 | 8 | from numpy.random import normal | 
| 9 | 9 | from typing import List | 
| 10 | 10 | 
 | 
| 11 |  | -from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, constant | 
|  | 11 | +from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile | 
| 12 | 12 | from tensorflow import data as tfdata | 
| 13 | 13 | from tensorflow.keras import Model, Sequential | 
| 14 | 14 | from tensorflow.keras.optimizers import Adam | 
|  | 
| 17 | 17 | from ydata_synthetic.synthesizers.gan import BaseModel | 
| 18 | 18 | from ydata_synthetic.synthesizers import TrainParameters | 
| 19 | 19 | from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty | 
| 20 |  | -from ydata_synthetic.synthesizers.timeseries import TimeSeriesDataProcessor | 
| 21 | 20 | 
 | 
| 22 | 21 | class TSCWGAN(BaseModel): | 
| 23 | 22 | 
 | 
| @@ -148,31 +147,35 @@ def get_batch_data(self, data, n_windows= None): | 
| 148 | 147 |                                 .shuffle(buffer_size=n_windows) | 
| 149 | 148 |                                 .batch(self.batch_size).repeat()) | 
| 150 | 149 | 
 | 
| 151 |  | -    def sample(self, cond_array: array, n_samples: int, inverse_transform: bool = True): | 
| 152 |  | -        """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array. | 
| 153 |  | -        The returned samples per condition will always be a multiple of batch_size and equal or bigger than n_samples. | 
| 154 |  | -
 | 
| 155 |  | -        Arguments: | 
| 156 |  | -            cond_array (numpy array): Array with the set of conditions for the sampling process. | 
| 157 |  | -            n_samples (int): Number of samples to be taken for each condition in cond_array. | 
| 158 |  | -            inverse_transform (bool): """ | 
| 159 |  | -        assert len(cond_array.shape) == 2, "Condition array should be two-dimensional. N_conditions x cond_dim" | 
| 160 |  | -        assert cond_array.shape[1] == self.cond_dim, \ | 
| 161 |  | -            f"The condition sequences should have a {self.cond_dim} length." | 
| 162 |  | -        steps = n_samples // self.batch_size + 1 | 
|  | 150 | +    def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24): | 
|  | 151 | +        """For a given condition, produce n_samples of length seq_len. | 
|  | 152 | +        The samples are returned in the original data format (any preprocessing transformation is inverted). | 
|  | 153 | +
 | 
|  | 154 | +        Args: | 
|  | 155 | +            condition (numpy.array): Condition for the generated samples, must have the same length . | 
|  | 156 | +            n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size). | 
|  | 157 | +            seq_len (int): Length of the generated samples. | 
|  | 158 | +
 | 
|  | 159 | +        Returns: | 
|  | 160 | +            data (numpy.array): An array of data of shape [n_samples, seq_len]""" | 
|  | 161 | +        assert len(condition.shape) == 2, "Condition array should be two-dimensional." | 
|  | 162 | +        assert condition.shape[1] == self.cond_dim, \ | 
|  | 163 | +            f"The condition sequence should have {self.cond_dim} length." | 
|  | 164 | +        batches = n_samples // self.batch_size + 1 | 
|  | 165 | +        ar_steps = seq_len // self.data_dim + 1 | 
| 163 | 166 |         data = [] | 
| 164 | 167 |         z_dist = self.get_batch_noise() | 
| 165 |  | -        for condition in cond_array: | 
|  | 168 | +        for batch in trange(batches, desc=f'Synthetic data generation'): | 
|  | 169 | +            data_ = [] | 
| 166 | 170 |             cond_seq = convert_to_tensor(condition, float32) | 
| 167 |  | -            cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) | 
| 168 |  | -            for step in trange(steps, desc=f'Synthetic data generation'): | 
| 169 |  | -                gen_input = concat([cond_seq, next(z_dist)], axis=1) | 
|  | 171 | +            gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1) | 
|  | 172 | +            for step in range(ar_steps): | 
| 170 | 173 |                 records = self.generator(gen_input, training=False) | 
| 171 |  | -                data.append(records) | 
| 172 |  | -        data = array(vstack(data)) | 
| 173 |  | -        if inverse_transform: | 
| 174 |  | -            return self.processor.inverse_transform(data) | 
| 175 |  | -        return data | 
|  | 174 | +                gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1) | 
|  | 175 | +                data_.append(records) | 
|  | 176 | +            data_ = hstack(data_)[:, :seq_len] | 
|  | 177 | +            data.append(data_) | 
|  | 178 | +        return self.processor.inverse_transform(array(vstack(data))) | 
| 176 | 179 | 
 | 
| 177 | 180 | 
 | 
| 178 | 181 | class Generator(Model): | 
|  | 
0 commit comments