Skip to content

Commit 0c11017

Browse files
committed
docs: add conv example for gan
1 parent 3116a8a commit 0c11017

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import random\n",
10+
"import numpy as np\n",
11+
"import matplotlib.pyplot as plt\n",
12+
"import os\n",
13+
"\n",
14+
"from PIL import Image\n",
15+
"from keras.datasets import mnist\n",
16+
"from IPython.display import Image as IPImage\n",
17+
"\n",
18+
"from neuralnetlib.preprocessing import one_hot_encode\n",
19+
"from neuralnetlib.models import Sequential, GAN\n",
20+
"from neuralnetlib.layers import Input, Dense, Conv2D, Reshape, Flatten, UpSampling2D"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"# Load the MNIST dataset\n",
30+
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
31+
"n_classes = np.unique(y_train).shape[0]\n",
32+
"\n",
33+
"# Reshape images to include channel dimension\n",
34+
"x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)\n",
35+
"x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)\n",
36+
"\n",
37+
"# Normalize pixel values\n",
38+
"x_train = x_train.astype('float32') / 255\n",
39+
"x_test = x_test.astype('float32') / 255\n",
40+
"\n",
41+
"# Labels to categorical\n",
42+
"y_train = one_hot_encode(y_train, n_classes)\n",
43+
"y_test = one_hot_encode(y_test, n_classes)"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"i = random.randint(0, len(x_train) - 1)\n",
53+
"plt.imshow(x_train[i].reshape(28,28), cmap='gray')\n",
54+
"plt.title('Class: ' + str(np.argmax(y_train[i])))\n",
55+
"plt.show()"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"noise_dim = 32\n",
65+
"\n",
66+
"generator = Sequential()\n",
67+
"generator.add(Input(noise_dim))\n",
68+
"generator.add(Dense(7 * 7 * 128))\n",
69+
"generator.add(Reshape((7, 7, 128)))\n",
70+
"generator.add(UpSampling2D(size=(2, 2))) # 14x14\n",
71+
"generator.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))\n",
72+
"generator.add(UpSampling2D(size=(2, 2))) # 28x28\n",
73+
"generator.add(Conv2D(32, kernel_size=3, padding='same', activation='relu'))\n",
74+
"generator.add(Conv2D(1, kernel_size=3, padding='same', activation='sigmoid'))"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"discriminator = Sequential()\n",
84+
"discriminator.add(Input((28, 28, 1)))\n",
85+
"discriminator.add(Conv2D(32, kernel_size=3, strides=2, padding='same', activation='relu')) # 14x14\n",
86+
"discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding='same', activation='relu')) # 7x7\n",
87+
"discriminator.add(Flatten())\n",
88+
"discriminator.add(Dense(128, activation='relu'))\n",
89+
"discriminator.add(Dense(1, activation='sigmoid'))"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"gan = GAN(latent_dim=noise_dim)\n",
99+
"\n",
100+
"gan.compile(\n",
101+
" generator,\n",
102+
" discriminator,\n",
103+
" generator_optimizer='adam',\n",
104+
" discriminator_optimizer='adam',\n",
105+
" loss_function='bce',\n",
106+
" verbose=True\n",
107+
")"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": null,
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"history = gan.fit(x_train,\n",
117+
" epochs=40,\n",
118+
" batch_size=128,\n",
119+
" plot_generated=True,\n",
120+
" ) "
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": null,
126+
"metadata": {},
127+
"outputs": [],
128+
"source": [
129+
"image_files = [f for f in os.listdir() if f.endswith('.png') and f.startswith('video')]\n",
130+
"image_files.sort(key=lambda x: int(x.replace('video', '').replace('.png', '')))\n",
131+
"\n",
132+
"images = [Image.open(img) for img in image_files]\n",
133+
"\n",
134+
"if images:\n",
135+
" images[0].save('output.gif', save_all=True, append_images=images[1:], duration=100, loop=0)\n",
136+
"\n",
137+
"print(\"GIF 'output.gif' succesffuly created!\")"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"IPImage(filename=\"output.gif\")"
147+
]
148+
}
149+
],
150+
"metadata": {
151+
"kernelspec": {
152+
"display_name": "Python 3",
153+
"language": "python",
154+
"name": "python3"
155+
},
156+
"language_info": {
157+
"codemirror_mode": {
158+
"name": "ipython",
159+
"version": 3
160+
},
161+
"file_extension": ".py",
162+
"mimetype": "text/x-python",
163+
"name": "python",
164+
"nbconvert_exporter": "python",
165+
"pygments_lexer": "ipython3",
166+
"version": "3.10.8"
167+
}
168+
},
169+
"nbformat": 4,
170+
"nbformat_minor": 2
171+
}

0 commit comments

Comments
 (0)