Skip to content

Commit 1279e5e

Browse files
ricardodcpereiraricardodcpereira
authored andcommitted
fix: improvements to the doppelganger model
1 parent ce75895 commit 1279e5e

File tree

10 files changed

+33956
-132
lines changed

10 files changed

+33956
-132
lines changed

data/fcc_mba.csv

Lines changed: 33601 additions & 0 deletions
Large diffs are not rendered by default.

docs/examples/doppelganger_example.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ DoppelGANger is a model that uses a Generative Adversarial Network (GAN) framewo
1010

1111
- 📑 **Paper:** [Using GANs for Sharing Networked Time Series Data: Challenges, Initial Promise, and Open Questions](https://dl.acm.org/doi/pdf/10.1145/3419394.3423643)
1212

13-
Here’s an example of how to synthetize time-series data with DoppelGANger using the [Yahoo Stock Price](https://www.kaggle.com/datasets/arashnic/time-series-forecasting-with-yahoo-stock-price) dataset:
13+
Here’s an example of how to synthetize time-series data with DoppelGANger using the [Measuring Broadband America](https://www.fcc.gov/reports-research/reports/measuring-broadband-america/raw-data-measuring-broadband-america-seventh) dataset:
1414

1515

1616
```python
17-
--8<-- "examples/timeseries/stock_doppelganger.py"
17+
--8<-- "examples/timeseries/mba_doppelganger.py"
1818
```
1919

2020

docs/getting-started/quickstart.md

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,30 @@ The following example showcases how to synthesize the [Yahoo Stock Price](https:
3535
```python
3636
# Import the necessary modules
3737
import pandas as pd
38-
from ydata_synthetic.synthesizers import ModelParameters
39-
from ydata_synthetic.synthesizers.timeseries import TimeGAN
40-
from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading
38+
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
39+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
4140

42-
# Load and preprocess data
43-
stock_data_df = pd.read_csv("stock_data.csv")
44-
processed_data = real_data_loading(stock_data_df.values, seq_len=24)
45-
46-
# Define model and training parameters
47-
gan_args = ModelParameters(batch_size=128, lr=5e-4, noise_dim=128, layers_dim=128)
48-
synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=24, n_seq=6, gamma=1)
41+
# Define model parameters
42+
gan_args = ModelParameters(batch_size=128,
43+
lr=5e-4,
44+
noise_dim=32,
45+
layers_dim=128,
46+
latent_dim=24,
47+
gamma=1)
4948

50-
# Train the generator model
51-
synth.train(data=processed_data, train_steps=50000)
49+
train_args = TrainParameters(epochs=50000,
50+
sequence_length=24,
51+
number_sequences=6)
52+
53+
# Read the data
54+
stock_data = pd.read_csv("stock_data.csv")
55+
56+
# Training the TimeGAN synthesizer
57+
synth = TimeSeriesSynthesizer(modelname='timegan', model_parameters=gan_args)
58+
synth.fit(stock_data, train_args, num_cols=list(stock_data.columns))
5259

53-
# Generate new synthetic data
54-
synth_data = synth.sample(len(stock_data_df))
60+
# Generating new synthetic samples
61+
synth_data = synth.sample(n_samples=500)
5562
```
5663

5764
## Running the Streamlit App
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
DoppelGANger architecture example file
3+
"""
4+
5+
# Importing necessary libraries
6+
import pandas as pd
7+
from os import path
8+
import matplotlib.pyplot as plt
9+
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
10+
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
11+
12+
# Read the data
13+
mba_data = pd.read_csv("../../data/fcc_mba.csv")
14+
numerical_cols = ["traffic_byte_counter", "ping_loss_rate"]
15+
categorical_cols = [col for col in mba_data.columns if col not in numerical_cols]
16+
17+
# Define model parameters
18+
model_args = ModelParameters(batch_size=100,
19+
lr=0.001,
20+
betas=(0.2, 0.9),
21+
latent_dim=20,
22+
gp_lambda=2,
23+
pac=1)
24+
25+
train_args = TrainParameters(epochs=400, sequence_length=56,
26+
sample_length=8, rounds=1,
27+
measurement_cols=["traffic_byte_counter", "ping_loss_rate"])
28+
29+
# Training the DoppelGANger synthesizer
30+
if path.exists('doppelganger_mba'):
31+
model_dop_gan = TimeSeriesSynthesizer.load('doppelganger_mba')
32+
else:
33+
model_dop_gan = TimeSeriesSynthesizer(modelname='doppelganger', model_parameters=model_args)
34+
model_dop_gan.fit(mba_data, train_args, num_cols=numerical_cols, cat_cols=categorical_cols)
35+
model_dop_gan.save('doppelganger_mba')
36+
37+
# Generate synthetic data
38+
synth_data = model_dop_gan.sample(n_samples=600)
39+
synth_df = pd.concat(synth_data, axis=0)
40+
41+
# Create a plot for each measurement column
42+
plt.figure(figsize=(10, 6))
43+
44+
plt.subplot(2, 1, 1)
45+
plt.plot(mba_data['traffic_byte_counter'].reset_index(drop=True), label='Real Traffic')
46+
plt.plot(synth_df['traffic_byte_counter'].reset_index(drop=True), label='Synthetic Traffic', alpha=0.7)
47+
plt.xlabel('Index')
48+
plt.ylabel('Value')
49+
plt.title('Traffic Comparison')
50+
plt.legend()
51+
plt.grid(True)
52+
53+
plt.subplot(2, 1, 2)
54+
plt.plot(mba_data['ping_loss_rate'].reset_index(drop=True), label='Real Ping')
55+
plt.plot(synth_df['ping_loss_rate'].reset_index(drop=True), label='Synthetic Ping', alpha=0.7)
56+
plt.xlabel('Index')
57+
plt.ylabel('Value')
58+
plt.title('Ping Comparison')
59+
plt.legend()
60+
plt.grid(True)
61+
62+
plt.tight_layout()
63+
plt.show()

examples/timeseries/stock_doppelganger.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)