Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tensorflow_addons/optimizers/tests/lookahead_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# ==============================================================================
"""Tests for Lookahead optimizer."""

import os

import numpy as np
import pytest
import tensorflow as tf
import tempfile

from tensorflow_addons.optimizers import Lookahead
from tensorflow_addons.utils import test_utils
Expand Down Expand Up @@ -186,3 +189,68 @@ def test_serialization():
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


def _init_model(optimizer, init_w):
model = tf.keras.models.Sequential()
dense = tf.keras.layers.Dense(input_shape=(3,), units=1)
model.add(dense)
model.compile(Lookahead(optimizer), loss="mse")
dense.set_weights([init_w, np.zeros(1,)])
return model


def assert_same_optimizer_states(optimizer, new_optimizer):
# Remove the iteration variable
weights = []
for weight in optimizer.weights:
if "iter" not in weight.name:
weights.append(weight)
new_weights = []
for weight in new_optimizer.weights:
if "iter" not in weight.name:
new_weights.append(weight)

assert len(weights) == len(new_weights)

weights = sorted(weights, key=lambda w: w.name)
new_weights = sorted(new_weights, key=lambda w: w.name)

for weight, new_weight in zip(weights, new_weights):
assert np.allclose(weight.numpy(), new_weight.numpy(), atol=1e-4)

# Assert recursively
if hasattr(optimizer, "_optimizer"):
assert_same_optimizer_states(optimizer._optimizer, new_optimizer._optimizer)


@pytest.mark.parametrize("optimizer", ["sgd", "adam"])
@pytest.mark.parametrize("weights_only", [False, True])
def test_save_load(optimizer, weights_only):
x = np.random.standard_normal((10000, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4

init_w = np.random.standard_normal((3, 1))

model = _init_model(optimizer, init_w)
model.fit(x, y, epochs=2, shuffle=False)

with tempfile.TemporaryDirectory() as ckpt_dir:
new_model = _init_model(optimizer, init_w)
new_model.fit(x, y, epochs=1, shuffle=False)

ckpt_path = os.path.join(ckpt_dir, "model.ckpt")
if weights_only:
new_model.save_weights(ckpt_path)
new_model = _init_model(optimizer, init_w)
new_model.load_weights(ckpt_path)
else:
new_model.save(ckpt_path)
new_model = tf.keras.models.load_model(
ckpt_path, custom_objects={"Lookahead": Lookahead,}
)

new_model.fit(x, y, epochs=1, shuffle=False)

assert_same_optimizer_states(model.optimizer, new_model.optimizer)