-
Notifications
You must be signed in to change notification settings - Fork 274
Tests to warmup_cosine_decay_schedule edge cases #1413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
a603faf to
d1fc68e
Compare
|
Can you explain the justification for adding these tests? We have some warmup cosine tests here already: optax/optax/schedules/_schedule_test.py Line 484 in e5649e7
Do you want to extend them? Are there any cases you're interested in testing in particular? |
|
These tests were mainly to cover the edge case where decay_steps == warmup_steps, plus checking dtype handling and JIT/vmap consistency. I can fold them into _schedule_test.py if that’s the better place. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some non-exhaustive, comments, could you compress these tests into fewer separate cases and fold them into existing schedule tests please?
| for t in range(warmup_steps, decay_steps + 1): | ||
| self.assertLessEqual(float(vs[t]), float(vb[t]) + 1e-12) | ||
|
|
||
| def test_accepts_int_steps_and_returns_float_dtype(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this test is necessary
| sched_jit = jax.jit(sched) | ||
|
|
||
| steps = jnp.arange(decay_steps + 1) | ||
| eager = jax.vmap(sched)(steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the vmap use here
|
|
||
| class WarmupCosineScheduleTest(parameterized.TestCase): | ||
|
|
||
| def test_warmup_increases_and_reaches_peak(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you fold this into the next test?
|
Folded tests into _schedule_test.py and compressed into two cases as suggested and passed all tests locally. |
|
Are you proposing to remove existing tests? Could you comment the test by surrounding code blocks with subTest? |
|
Wrapped cases with subTest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You seem to be removing existing test functions. Please revert this.
You did wrap your new tests in subtests, that's perfect, thank you!
c5d65b7 to
e706c87
Compare
|
Restored all original tests unchanged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now you have duplicated the existing tests, can you make sure the diff is just the new tests, please?
For example, these golden values are tested the same way in a previous test:
np.testing.assert_allclose(
output,
np.array([
0.20000004768371582,
0.4020000100135803,
1.2100000381469727,
-1.947500228881836,
-3.000000238418579,
]),
rtol=1e-6, atol=1e-8,
)
e706c87 to
e416db9
Compare
|
Removed duplicated tests. |
e416db9 to
fd84e38
Compare
Added unit tests for warmup_cosine_decay_schedule covering edge cases, dtype handling, and JIT consistency to improve reliability of this common schedule. All tests pass locally: 6 passed.