Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.

Commit 209e29f

Browse files
Merge pull request #50 from tensorflow/image-sample
Image sample
2 parents 95fa6e5 + c50ab39 commit 209e29f

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

lucid/optvis/param/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from lucid.optvis.param.lowres import lowres_tensor
1818
from lucid.optvis.param.color import to_valid_rgb
1919
from lucid.optvis.param.spatial import naive, fft_image, laplacian_pyramid
20+
from lucid.optvis.param.random import image_sample

lucid/optvis/param/random.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2018 The Lucid Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import tensorflow as tf
17+
import numpy as np
18+
19+
from lucid.optvis.param.color import to_valid_rgb
20+
21+
22+
def image_sample(shape, decorrelate=True, sd=None, decay_power=1):
23+
raw_spatial = rand_fft_image(shape, sd=sd, decay_power=decay_power)
24+
return to_valid_rgb(raw_spatial, decorrelate=decorrelate)
25+
26+
def rand_fft_image(shape, sd=None, decay_power=1):
27+
b, h, w, ch = shape
28+
sd = 0.01 if sd is None else sd
29+
30+
imgs = []
31+
for _ in range(b):
32+
freqs = _rfft2d_freqs(h, w)
33+
fh, fw = freqs.shape
34+
spectrum_var = sd*tf.random_normal([2, ch, fh, fw], dtype="float32")
35+
spectrum = tf.complex(spectrum_var[0], spectrum_var[1])
36+
spertum_scale = 1.0 / np.maximum(freqs, 1.0/max(h, w))**decay_power
37+
# Scale the spectrum by the square-root of the number of pixels
38+
# to get a unitary transformation. This allows to use similar
39+
# leanring rates to pixel-wise optimisation.
40+
spertum_scale *= np.sqrt(w*h)
41+
scaled_spectrum = spectrum * spertum_scale
42+
img = tf.spectral.irfft2d(scaled_spectrum)
43+
# in case of odd input dimension we cut off the additional pixel
44+
# we get from irfft2d length computation
45+
img = img[:ch, :h, :w]
46+
img = tf.transpose(img, [1, 2, 0])
47+
imgs.append(img)
48+
return tf.stack(imgs)/4.
49+
50+
def _rfft2d_freqs(h, w):
51+
"""Compute 2d spectrum frequences."""
52+
fy = np.fft.fftfreq(h)[:, None]
53+
# when we have an odd input dimension we need to keep one additional
54+
# frequency and later cut off 1 pixel
55+
if w % 2 == 1:
56+
fx = np.fft.fftfreq(w)[:w//2+2]
57+
else:
58+
fx = np.fft.fftfreq(w)[:w//2+1]
59+
return np.sqrt(fx*fx + fy*fy)

0 commit comments

Comments
 (0)