Skip to content

Conversation

@edawite
Copy link

@edawite edawite commented Sep 1, 2025

Added unit tests for warmup_cosine_decay_schedule covering edge cases, dtype handling, and JIT consistency to improve reliability of this common schedule. All tests pass locally: 6 passed.

@google-cla
Copy link

google-cla bot commented Sep 1, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@edawite edawite force-pushed the tests-docs/warmup-cosine branch from a603faf to d1fc68e Compare September 1, 2025 17:33
@rdyro
Copy link
Collaborator

rdyro commented Sep 1, 2025

Can you explain the justification for adding these tests? We have some warmup cosine tests here already:

class WarmupCosineDecayTest(chex.TestCase):

Do you want to extend them? Are there any cases you're interested in testing in particular?

@edawite
Copy link
Author

edawite commented Sep 1, 2025

These tests were mainly to cover the edge case where decay_steps == warmup_steps, plus checking dtype handling and JIT/vmap consistency. I can fold them into _schedule_test.py if that’s the better place.

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some non-exhaustive, comments, could you compress these tests into fewer separate cases and fold them into existing schedule tests please?

for t in range(warmup_steps, decay_steps + 1):
self.assertLessEqual(float(vs[t]), float(vb[t]) + 1e-12)

def test_accepts_int_steps_and_returns_float_dtype(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is necessary

sched_jit = jax.jit(sched)

steps = jnp.arange(decay_steps + 1)
eager = jax.vmap(sched)(steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the vmap use here


class WarmupCosineScheduleTest(parameterized.TestCase):

def test_warmup_increases_and_reaches_peak(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you fold this into the next test?

@edawite
Copy link
Author

edawite commented Sep 7, 2025

Folded tests into _schedule_test.py and compressed into two cases as suggested and passed all tests locally.

@rdyro
Copy link
Collaborator

rdyro commented Sep 8, 2025

Are you proposing to remove existing tests?

Could you comment the test by surrounding code blocks with subTest?

@edawite
Copy link
Author

edawite commented Sep 10, 2025

Wrapped cases with subTest

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to be removing existing test functions. Please revert this.

You did wrap your new tests in subtests, that's perfect, thank you!

@edawite edawite force-pushed the tests-docs/warmup-cosine branch from c5d65b7 to e706c87 Compare September 12, 2025 01:28
@edawite
Copy link
Author

edawite commented Sep 12, 2025

Restored all original tests unchanged.

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now you have duplicated the existing tests, can you make sure the diff is just the new tests, please?

For example, these golden values are tested the same way in a previous test:

      np.testing.assert_allclose(
          output,
          np.array([
              0.20000004768371582,
              0.4020000100135803,
              1.2100000381469727,
              -1.947500228881836,
              -3.000000238418579,
          ]),
          rtol=1e-6, atol=1e-8,
      )

@edawite edawite force-pushed the tests-docs/warmup-cosine branch from e706c87 to e416db9 Compare September 17, 2025 21:03
@edawite
Copy link
Author

edawite commented Sep 17, 2025

Removed duplicated tests.

@edawite edawite force-pushed the tests-docs/warmup-cosine branch from e416db9 to fd84e38 Compare September 20, 2025 22:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants