diff --git a/keras/src/backend/jax/excluded_tpu_tests.txt b/keras/src/backend/jax/excluded_tpu_tests.txt index 13a7b799aca1..57b52b0b311f 100644 --- a/keras/src/backend/jax/excluded_tpu_tests.txt +++ b/keras/src/backend/jax/excluded_tpu_tests.txt @@ -1,53 +1,3 @@ -AdditiveAttentionTest::test_attention_correctness -AttentionTest::test_attention_calculate_scores_with_scale -AttentionTest::test_attention_correctness -CircleTest::test_correctness -CircleTest::test_correctness_weighted -CircleTest::test_mean_with_sample_weight_reduction -CircleTest::test_no_reduction -CircleTest::test_sum_reduction -ConvBasicTest::test_enable_lora_with_alpha -ConvCorrectnessTest::test_conv1d0 -ConvCorrectnessTest::test_conv1d1 -ConvCorrectnessTest::test_conv1d2 -ConvCorrectnessTest::test_conv1d3 -ConvCorrectnessTest::test_conv1d4 -ConvCorrectnessTest::test_conv2d0 -ConvCorrectnessTest::test_conv2d1 -ConvCorrectnessTest::test_conv2d2 -ConvCorrectnessTest::test_conv2d3 -ConvCorrectnessTest::test_conv2d4 -ConvCorrectnessTest::test_conv2d5 -ConvCorrectnessTest::test_conv3d0 -ConvCorrectnessTest::test_conv3d1 -ConvCorrectnessTest::test_conv3d2 -ConvCorrectnessTest::test_conv3d3 -ConvCorrectnessTest::test_conv3d4 -ConvLSTM1DTest::test_correctness -ConvLSTM1DTest::test_correctness -ConvLSTM2DTest::test_correctness -ConvLSTMCellTest::test_correctness -ConvLSTMTest::test_correctness -ConvTransposeCorrectnessTest::test_conv1d_transpose0 -ConvTransposeCorrectnessTest::test_conv1d_transpose1 -ConvTransposeCorrectnessTest::test_conv1d_transpose2 -ConvTransposeCorrectnessTest::test_conv2d_transpose0 -ConvTransposeCorrectnessTest::test_conv2d_transpose1 -ConvTransposeCorrectnessTest::test_conv2d_transpose2 -ConvTransposeCorrectnessTest::test_conv2d_transpose3 -ConvTransposeCorrectnessTest::test_conv3d_transpose0 -ConvTransposeCorrectnessTest::test_conv3d_transpose1 -ConvTransposeCorrectnessTest::test_conv3d_transpose2 -CTCTest::test_correctness -DenseTest::test_dense_sparse -DepthwiseConvCorrectnessTest::test_depthwise_conv1d0 -DepthwiseConvCorrectnessTest::test_depthwise_conv1d1 -DepthwiseConvCorrectnessTest::test_depthwise_conv1d2 -DepthwiseConvCorrectnessTest::test_depthwise_conv2d0 -DepthwiseConvCorrectnessTest::test_depthwise_conv2d1 -DepthwiseConvCorrectnessTest::test_depthwise_conv2d2 -EinsumDenseTest::test_enable_lora_with_alpha -EmbeddingTest::test_enable_lora_with_alpha ExportArchiveTest::test_jax_endpoint_registration_tf_function ExportArchiveTest::test_jax_multi_unknown_endpoint_registration ExportArchiveTest::test_layer_export @@ -71,10 +21,18 @@ ExportArchiveTest::test_track_multiple_layers ExportONNXTest::test_export_with_input_names ExportONNXTest::test_export_with_opset_version_18 ExportONNXTest::test_export_with_opset_version_none +ExportONNXTest::test_model_with_input_structure_array +ExportONNXTest::test_model_with_input_structure_dict +ExportONNXTest::test_model_with_input_structure_tuple +ExportONNXTest::test_model_with_multiple_inputs ExportONNXTest::test_standard_model_export_functional ExportONNXTest::test_standard_model_export_lstm ExportONNXTest::test_standard_model_export_sequential ExportONNXTest::test_standard_model_export_subclass +ExportOpenVINOTest::test_model_with_input_structure_array +ExportOpenVINOTest::test_model_with_input_structure_dict +ExportOpenVINOTest::test_model_with_input_structure_tuple +ExportOpenVINOTest::test_model_with_multiple_inputs ExportOpenVINOTest::test_standard_model_export_functional ExportOpenVINOTest::test_standard_model_export_sequential ExportOpenVINOTest::test_standard_model_export_subclass @@ -118,117 +76,10 @@ ExportSavedModelTest::test_model_with_tf_data_layer_subclass ExportSavedModelTest::test_standard_model_export_functional ExportSavedModelTest::test_standard_model_export_sequential ExportSavedModelTest::test_standard_model_export_subclass -GRUTest::test_correctness0 -GRUTest::test_correctness1 -GRUTest::test_legacy_implementation_argument -GRUTest::test_masking -GRUTest::test_pass_initial_state -GRUTest::test_pass_return_state -GRUTest::test_statefulness -ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant -ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror -ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest -ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect -ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap -LinalgOpsCorrectnessTest::test_cholesky_inverse_lower -LinalgOpsCorrectnessTest::test_cholesky_inverse_upper -LinalgOpsCorrectnessTest::test_eig -LinalgOpsCorrectnessTest::test_svd -LSTMTest::test_correctness0 -LSTMTest::test_correctness1 -LSTMTest::test_masking -LSTMTest::test_pass_initial_state -LSTMTest::test_statefulness -MathOpsCorrectnessTest::test_extract_sequences -MergingLayersTest::test_correctness_dynamic_dot_3d -MergingLayersTest::test_correctness_static_dot_3d -MuonTest::test_Newton_Schulz -NNOpsCorrectnessTest::test_conv_2d0 -NNOpsCorrectnessTest::test_conv_2d1 -NNOpsCorrectnessTest::test_conv_2d2 -NNOpsCorrectnessTest::test_conv_2d3 -NNOpsCorrectnessTest::test_conv_2d4 -NNOpsCorrectnessTest::test_conv_2d5 -NNOpsCorrectnessTest::test_conv_3d0 -NNOpsCorrectnessTest::test_conv_3d1 -NNOpsCorrectnessTest::test_conv_3d10 -NNOpsCorrectnessTest::test_conv_3d11 -NNOpsCorrectnessTest::test_conv_3d2 -NNOpsCorrectnessTest::test_conv_3d3 -NNOpsCorrectnessTest::test_conv_3d4 -NNOpsCorrectnessTest::test_conv_3d5 -NNOpsCorrectnessTest::test_conv_3d6 -NNOpsCorrectnessTest::test_conv_3d7 -NNOpsCorrectnessTest::test_conv_3d8 -NNOpsCorrectnessTest::test_conv_3d9 -NNOpsCorrectnessTest::test_ctc_loss -NNOpsCorrectnessTest::test_depthwise_conv_2d0 -NNOpsCorrectnessTest::test_depthwise_conv_2d1 -NNOpsCorrectnessTest::test_depthwise_conv_2d10 -NNOpsCorrectnessTest::test_depthwise_conv_2d11 -NNOpsCorrectnessTest::test_depthwise_conv_2d2 -NNOpsCorrectnessTest::test_depthwise_conv_2d3 -NNOpsCorrectnessTest::test_depthwise_conv_2d4 -NNOpsCorrectnessTest::test_depthwise_conv_2d5 -NNOpsCorrectnessTest::test_depthwise_conv_2d6 -NNOpsCorrectnessTest::test_depthwise_conv_2d7 -NNOpsCorrectnessTest::test_depthwise_conv_2d8 -NNOpsCorrectnessTest::test_depthwise_conv_2d9 -NNOpsCorrectnessTest::test_separable_conv_2d0 -NNOpsCorrectnessTest::test_separable_conv_2d1 -NNOpsCorrectnessTest::test_separable_conv_2d2 -NNOpsCorrectnessTest::test_separable_conv_2d3 -NNOpsCorrectnessTest::test_separable_conv_2d4 -NNOpsCorrectnessTest::test_separable_conv_2d5 -NNOpsCorrectnessTest::test_separable_conv_2d6 -NNOpsCorrectnessTest::test_separable_conv_2d7 -NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero -NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero -NumpyTwoInputOpsCorrectnessTest::test_logspace -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false -NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false -RandomGaussianBlurTest::test_random_erasing_basic -RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale -RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale -RandomZoomTest::test_random_zoom_out_correctness -RegularizersTest::test_orthogonal_regularizer -RNNTest::test_go_backwards -SeparableConvCorrectnessTest::test_separable_conv1d0 -SeparableConvCorrectnessTest::test_separable_conv1d1 -SeparableConvCorrectnessTest::test_separable_conv1d2 -SeparableConvCorrectnessTest::test_separable_conv2d0 -SeparableConvCorrectnessTest::test_separable_conv2d1 -SeparableConvCorrectnessTest::test_separable_conv2d2 -SimpleRNNTest::test_correctness -SimpleRNNTest::test_correctness -SimpleRNNTest::test_masking -SimpleRNNTest::test_masking -SimpleRNNTest::test_pass_initial_state -SimpleRNNTest::test_pass_initial_state -SimpleRNNTest::test_return_state -SimpleRNNTest::test_statefulness -SimpleRNNTest::test_statefulness -StackedRNNTest::test_correctness_single_state_stack -StackedRNNTest::test_correctness_two_states_stack -StackedRNNTest::test_statefullness_single_state_stack -StackedRNNTest::test_statefullness_two_states_stack -TestFitLRSchedulesFlow::test_fit_lr_correctness TestJaxLayer::test_flax_layer_training_independent_bound_method TestJaxLayer::test_flax_layer_training_rng_state_no_method TestJaxLayer::test_flax_layer_training_rng_unbound_method TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy TestJaxLayer::test_jax_layer_training_independent TestJaxLayer::test_jax_layer_training_state -TestJaxLayer::test_jax_layer_training_state_dtype_policy -TestSpectrogram::test_spectrogram_error -TestTrainer::test_loss_weights -TestTrainer::test_nested_inputs -TestTrainer::test_on_batch_methods_eager -TestTrainer::test_on_batch_methods_graph_fn -TestTrainer::test_on_batch_methods_jit \ No newline at end of file +TestJaxLayer::test_jax_layer_training_state_dtype_policy \ No newline at end of file diff --git a/keras/src/layers/attention/additive_attention_test.py b/keras/src/layers/attention/additive_attention_test.py index 51092c6c4918..e46f018e9ccc 100644 --- a/keras/src/layers/attention/additive_attention_test.py +++ b/keras/src/layers/attention/additive_attention_test.py @@ -50,10 +50,16 @@ def test_attention_correctness(self): return_attention_scores=True, ) self.assertAllClose( - output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3 + output, + [[[1.727, 2.727], [2.272, 3.272]]], + atol=1e-3, + tpu_atol=1e-2, ) self.assertAllClose( - scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3 + scores, + [[[0.636, 0.363], [0.363, 0.636]]], + atol=1e-3, + tpu_atol=1e-2, ) def test_attention_with_mask(self): diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index 805314010996..2dc5e44825ad 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -53,10 +53,18 @@ def test_attention_correctness(self): return_attention_scores=True, ) self.assertAllClose( - output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3 + output, + [[[2.462, 3.462], [1.538, 2.538]]], + atol=1e-3, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3 + scores, + [[[0.269, 0.731], [0.731, 0.269]]], + atol=1e-3, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) # Concat. @@ -66,10 +74,18 @@ def test_attention_correctness(self): return_attention_scores=True, ) self.assertAllClose( - output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3 + output, + [[[1.727, 2.727], [2.272, 3.272]]], + atol=1e-3, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3 + scores, + [[[0.636, 0.363], [0.363, 0.636]]], + atol=1e-3, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_attention_with_mask(self): @@ -149,7 +165,9 @@ def test_attention_calculate_scores_with_scale(self): expected_scores = np.matmul(query, key.transpose((0, 2, 1))) expected_scores *= layer.scale.numpy() actual_scores = layer._calculate_scores(query, key) - self.assertAllClose(actual_scores, expected_scores) + self.assertAllClose( + actual_scores, expected_scores, tpu_atol=1e-2, tpu_rtol=1e-2 + ) def test_attention_calculate_score_mask_no_causal_no_vmask(self): scores = np.random.random((2, 3, 4)) diff --git a/keras/src/layers/convolutional/conv_test.py b/keras/src/layers/convolutional/conv_test.py index 36a91673c9fe..a80672fa4245 100644 --- a/keras/src/layers/convolutional/conv_test.py +++ b/keras/src/layers/convolutional/conv_test.py @@ -779,7 +779,12 @@ def test_enable_lora_with_alpha(self): # Compare the effective kernel computed via the property. actual_effective_kernel = ops.convert_to_numpy(layer.kernel) - self.assertAllClose(actual_effective_kernel, expected_effective_kernel) + self.assertAllClose( + actual_effective_kernel, + expected_effective_kernel, + tpu_atol=1e-3, + tpu_rtol=1e-3, + ) @pytest.mark.requires_trainable_backend def test_lora_rank_argument(self): @@ -891,7 +896,7 @@ def test_conv1d( dilation_rate=dilation_rate, groups=groups, ) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, tpu_atol=1e-1, tpu_rtol=1e-1) @parameterized.parameters( { @@ -989,7 +994,9 @@ def test_conv2d( dilation_rate=dilation_rate, groups=groups, ) - self.assertAllClose(outputs, expected, rtol=5e-4) + self.assertAllClose( + outputs, expected, rtol=5e-4, tpu_atol=1e-1, tpu_rtol=1e-1 + ) @parameterized.parameters( { @@ -1078,7 +1085,9 @@ def test_conv3d( dilation_rate=dilation_rate, groups=groups, ) - self.assertAllClose(outputs, expected, rtol=1e-3) + self.assertAllClose( + outputs, expected, rtol=1e-3, tpu_atol=1e-1, tpu_rtol=1e-1 + ) def test_conv_constraints(self): layer = layers.Conv2D( diff --git a/keras/src/layers/convolutional/conv_transpose_test.py b/keras/src/layers/convolutional/conv_transpose_test.py index 307a9bed9d17..b9bc16f63b0c 100644 --- a/keras/src/layers/convolutional/conv_transpose_test.py +++ b/keras/src/layers/convolutional/conv_transpose_test.py @@ -616,7 +616,9 @@ def test_conv1d_transpose( data_format, dilation_rate, ) - self.assertAllClose(outputs, expected, atol=1e-5) + self.assertAllClose( + outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1 + ) @parameterized.parameters( { @@ -696,7 +698,9 @@ def test_conv2d_transpose( data_format, dilation_rate, ) - self.assertAllClose(outputs, expected, atol=1e-5) + self.assertAllClose( + outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1 + ) @parameterized.parameters( { @@ -767,7 +771,9 @@ def test_conv3d_transpose( data_format, dilation_rate, ) - self.assertAllClose(outputs, expected, atol=1e-5) + self.assertAllClose( + outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1 + ) @parameterized.product( kernel_size=list(range(1, 5)), diff --git a/keras/src/layers/convolutional/depthwise_conv_test.py b/keras/src/layers/convolutional/depthwise_conv_test.py index a81dd69035b2..c5e44e122916 100644 --- a/keras/src/layers/convolutional/depthwise_conv_test.py +++ b/keras/src/layers/convolutional/depthwise_conv_test.py @@ -400,7 +400,7 @@ def test_depthwise_conv1d( data_format=data_format, dilation_rate=dilation_rate, ) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2) @parameterized.parameters( { @@ -465,5 +465,9 @@ def test_depthwise_conv2d( data_format=data_format, dilation_rate=dilation_rate, ) - self.assertAllClose(outputs.shape, expected.shape) - self.assertAllClose(outputs, expected, atol=1e-5) + self.assertAllClose( + outputs.shape, expected.shape, tpu_atol=1e-2, tpu_rtol=1e-2 + ) + self.assertAllClose( + outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1 + ) diff --git a/keras/src/layers/convolutional/separable_conv_test.py b/keras/src/layers/convolutional/separable_conv_test.py index a3e600ca4898..9d186f56e40d 100644 --- a/keras/src/layers/convolutional/separable_conv_test.py +++ b/keras/src/layers/convolutional/separable_conv_test.py @@ -294,7 +294,14 @@ def test_separable_conv1d( ) self.assertAllClose(outputs.shape, expected.shape) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose( + outputs, + expected, + rtol=1e-5, + atol=1e-5, + tpu_atol=1e-1, + tpu_rtol=1e-1, + ) @parameterized.parameters( { @@ -381,4 +388,11 @@ def test_separable_conv2d( ) self.assertAllClose(outputs.shape, expected.shape) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose( + outputs, + expected, + rtol=1e-5, + atol=1e-5, + tpu_atol=1e-1, + tpu_rtol=1e-1, + ) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 9cfbb166a30a..802ca10a1d41 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -141,7 +141,9 @@ def test_dense_sparse(self): ), layer.bias, ) - self.assertAllClose(outputs, expected_outputs) + self.assertAllClose( + outputs, expected_outputs, tpu_atol=1e-2, tpu_rtol=1e-2 + ) # Verify the gradient is sparse if backend.backend() == "tensorflow": diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 51b7d6278f5c..92496f5f9d7a 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -404,7 +404,9 @@ def test_enable_lora_with_alpha(self): # Verify that the effective kernel property returns the expected value. actual_kernel = ops.convert_to_numpy(layer.kernel) - self.assertAllClose(actual_kernel, expected_kernel) + self.assertAllClose( + actual_kernel, expected_kernel, tpu_atol=1e-3, tpu_rtol=1e-3 + ) @pytest.mark.requires_trainable_backend def test_lora_rank_argument(self): diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index a22cab911caa..68b4ca1d9c15 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -239,7 +239,9 @@ def test_enable_lora_with_alpha(self): # Verify that the effective embeddings match expectation. actual_embeddings = ops.convert_to_numpy(layer.embeddings) - self.assertAllClose(actual_embeddings, expected_embeddings) + self.assertAllClose( + actual_embeddings, expected_embeddings, tpu_atol=1e-3, tpu_rtol=1e-3 + ) @pytest.mark.requires_trainable_backend def test_lora_rank_argument(self): diff --git a/keras/src/layers/merging/merging_test.py b/keras/src/layers/merging/merging_test.py index 977ad9c2cc1d..5f0f04cccf00 100644 --- a/keras/src/layers/merging/merging_test.py +++ b/keras/src/layers/merging/merging_test.py @@ -124,7 +124,7 @@ def test_correctness_static( res = model([x1, x2]) self.assertEqual(res.shape, expected_output_shape) - self.assertAllClose(res, x3, atol=1e-4) + self.assertAllClose(res, x3, atol=1e-4, tpu_atol=1e-2, tpu_rtol=1e-2) self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) self.assertIsNone(layer.compute_mask([x1, x2], [None, None])) if not skip_mask_test: @@ -161,7 +161,7 @@ def test_correctness_dynamic( res = model([x1, x2]) self.assertEqual(res.shape, expected_output_shape) - self.assertAllClose(res, x3, atol=1e-4) + self.assertAllClose(res, x3, atol=1e-4, tpu_atol=1e-2, tpu_rtol=1e-2) self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) if not skip_mask_test: self.assertTrue( diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py index 7b69d87d412a..9753b3a4af4f 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py @@ -76,7 +76,14 @@ def test_random_erasing_basic(self): } output = layer.transform_images(inputs, transformation) - self.assertAllClose(expected_output, output, atol=1e-4, rtol=1e-4) + self.assertAllClose( + expected_output, + output, + atol=1e-4, + rtol=1e-4, + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) def test_tf_data_compatibility(self): data_format = backend.config.image_data_format() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py index b29c5a679132..d5fa06d6d0ec 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py @@ -173,7 +173,12 @@ def test_random_perspective_bounding_boxes( ) self.assertAllClose( - output["boxes"], expected_boxes, atol=1e-3, rtol=1e-3 + output["boxes"], + expected_boxes, + atol=1e-3, + rtol=1e-3, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) @parameterized.named_parameters( diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py index 96407e960c60..11575534663e 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py @@ -63,6 +63,8 @@ def test_random_zoom_out_correctness(self): expected_output=expected_output, supports_masking=False, run_training_check=False, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_random_zoom_in_correctness(self): diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index 838924782dda..83fc3a275bf4 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -6,8 +6,8 @@ from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset from keras.src.utils.module_utils import tensorflow as tf -from keras.utils import PyDataset @keras_export("keras.layers.Normalization") diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py index f5409e2128e7..da76caa8b49a 100644 --- a/keras/src/layers/preprocessing/normalization_test.py +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -6,6 +6,7 @@ from keras.src import backend from keras.src import layers from keras.src import testing +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset class NormalizationTest(testing.TestCase): @@ -174,7 +175,7 @@ def test_normalization_with_scalar_mean_var(self): def test_adapt_pydataset_compat(self, pydataset_type): import keras - class CustomDataset(keras.utils.PyDataset): + class CustomDataset(PyDataset): def __len__(self): return 100 diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py index a363393d776e..7c306599a380 100644 --- a/keras/src/layers/preprocessing/stft_spectrogram_test.py +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -319,7 +319,10 @@ def test_spectrogram_error(self): ]: init_args = dict(zip(names, args)) - tol_kwargs = {"atol": 5e-4, "rtol": 1e-6} + if testing.uses_tpu(): + tol_kwargs = {"atol": 5e-2, "rtol": 1e-3} + else: + tol_kwargs = {"atol": 5e-4, "rtol": 1e-6} init_args["mode"] = "magnitude" y_true, y = self._calc_spectrograms(x, **init_args) diff --git a/keras/src/layers/rnn/bidirectional_test.py b/keras/src/layers/rnn/bidirectional_test.py index aed4127c95ce..e4af38d909a1 100644 --- a/keras/src/layers/rnn/bidirectional_test.py +++ b/keras/src/layers/rnn/bidirectional_test.py @@ -52,6 +52,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.Bidirectional(layer=forward_layer, merge_mode="ave") @@ -59,6 +61,8 @@ def test_correctness(self): self.assertAllClose( np.array([[0.24845785, 0.24845785], [0.6288199, 0.6288199]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.Bidirectional(layer=forward_layer, merge_mode=None) @@ -66,10 +70,14 @@ def test_correctness(self): self.assertAllClose( np.array([[0.39687276, 0.39687276], [0.7237238, 0.7237238]]), output1, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array([[0.10004295, 0.10004295], [0.53391594, 0.53391594]]), output2, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) backward_layer = layers.SimpleRNN( @@ -86,6 +94,8 @@ def test_correctness(self): self.assertAllClose( np.array([[0.08374989, 0.08374989], [0.6740834, 0.6740834]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) forward_layer = layers.GRU( @@ -113,6 +123,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_statefulness(self): @@ -135,6 +147,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer.reset_state() layer(sequence) @@ -147,6 +161,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_pass_initial_state(self): @@ -175,6 +191,8 @@ def test_pass_initial_state(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_masking(self): @@ -196,6 +214,8 @@ def test_masking(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_return_state(self): @@ -217,22 +237,32 @@ def test_return_state(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array([[0.1990008, 0.1990008], [0.52335435, 0.52335435]]), h1, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array([[0.35567185, 0.35567185], [1.0492687, 1.0492687]]), c1, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array([[0.12659755, 0.12659755], [0.44717982, 0.44717982]]), h2, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array([[0.2501858, 0.2501858], [0.941473, 0.941473]]), c2, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) @pytest.mark.requires_trainable_backend diff --git a/keras/src/layers/rnn/conv_lstm1d_test.py b/keras/src/layers/rnn/conv_lstm1d_test.py index b69cbf8b55aa..7024f9908ca1 100644 --- a/keras/src/layers/rnn/conv_lstm1d_test.py +++ b/keras/src/layers/rnn/conv_lstm1d_test.py @@ -74,4 +74,6 @@ def test_correctness(self): self.assertAllClose( expected_output, output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) diff --git a/keras/src/layers/rnn/conv_lstm2d_test.py b/keras/src/layers/rnn/conv_lstm2d_test.py index b3846b64058c..d70ee74646f5 100644 --- a/keras/src/layers/rnn/conv_lstm2d_test.py +++ b/keras/src/layers/rnn/conv_lstm2d_test.py @@ -86,4 +86,6 @@ def test_correctness(self): self.assertAllClose( expected_output, output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) diff --git a/keras/src/layers/rnn/conv_lstm3d_test.py b/keras/src/layers/rnn/conv_lstm3d_test.py index b6c23326539f..b2cde495cecd 100644 --- a/keras/src/layers/rnn/conv_lstm3d_test.py +++ b/keras/src/layers/rnn/conv_lstm3d_test.py @@ -104,4 +104,6 @@ def test_correctness(self): self.assertAllClose( expected_output, output, + tpu_atol=1e-4, + tpu_rtol=1e-4, ) diff --git a/keras/src/layers/rnn/conv_lstm_test.py b/keras/src/layers/rnn/conv_lstm_test.py index e66fed91b62c..1ed49c2ffde5 100644 --- a/keras/src/layers/rnn/conv_lstm_test.py +++ b/keras/src/layers/rnn/conv_lstm_test.py @@ -27,11 +27,11 @@ def test_correctness(self): ) output = layer(x, [s1, s2]) checksum_0 = np.sum(backend.convert_to_numpy(output[0])) - self.assertAllClose(checksum_0, 188.89502) + self.assertAllClose(checksum_0, 188.89502, tpu_atol=1e-4, tpu_rtol=1e-4) checksum_1 = np.sum(backend.convert_to_numpy(output[1][0])) - self.assertAllClose(checksum_1, 188.89502) + self.assertAllClose(checksum_1, 188.89502, tpu_atol=1e-4, tpu_rtol=1e-4) checksum_2 = np.sum(backend.convert_to_numpy(output[1][1])) - self.assertAllClose(checksum_2, 2170.444) + self.assertAllClose(checksum_2, 2170.444, tpu_atol=1e-4, tpu_rtol=1e-4) class ConvLSTMTest(testing.TestCase): @@ -54,4 +54,6 @@ def test_correctness(self): ) output = layer(x, initial_state=[s1, s2]) output = backend.convert_to_numpy(output) - self.assertAllClose(np.sum(output), 119.812454) + self.assertAllClose( + np.sum(output), 119.812454, tpu_atol=1e-3, tpu_rtol=1e-3 + ) diff --git a/keras/src/layers/rnn/gru_test.py b/keras/src/layers/rnn/gru_test.py index 7fc0d6c35b7e..71cde175c0b2 100644 --- a/keras/src/layers/rnn/gru_test.py +++ b/keras/src/layers/rnn/gru_test.py @@ -66,6 +66,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -85,6 +87,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -104,6 +108,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -123,6 +129,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -142,6 +150,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_statefulness(self): @@ -163,6 +173,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer.reset_state() layer(sequence) @@ -175,6 +187,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_pass_initial_state(self): @@ -190,6 +204,8 @@ def test_pass_initial_state(self): self.assertAllClose( np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -203,6 +219,8 @@ def test_pass_initial_state(self): self.assertAllClose( np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_pass_return_state(self): @@ -221,8 +239,15 @@ def test_pass_return_state(self): self.assertAllClose( np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, + ) + self.assertAllClose( + output, + state, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) - self.assertAllClose(output, state) # Test with go_backwards=True layer = layers.GRU( @@ -237,8 +262,15 @@ def test_pass_return_state(self): self.assertAllClose( np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, + ) + self.assertAllClose( + output, + state, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) - self.assertAllClose(output, state) def test_masking(self): sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") @@ -254,6 +286,8 @@ def test_masking(self): self.assertAllClose( np.array([[0.19393763, 0.19393763], [0.30818558, 0.30818558]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -274,6 +308,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -285,6 +321,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -306,6 +344,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -317,6 +357,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.GRU( @@ -330,6 +372,8 @@ def test_masking(self): self.assertAllClose( np.array([[0.11669192, 0.11669192], [0.28380975, 0.28380975]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_legacy_implementation_argument(self): @@ -353,4 +397,6 @@ def test_legacy_implementation_argument(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) diff --git a/keras/src/layers/rnn/lstm_test.py b/keras/src/layers/rnn/lstm_test.py index 0486c196e4fc..c538c563bcfd 100644 --- a/keras/src/layers/rnn/lstm_test.py +++ b/keras/src/layers/rnn/lstm_test.py @@ -67,6 +67,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -87,6 +89,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -107,6 +111,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -127,6 +133,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -147,6 +155,8 @@ def test_correctness(self, implementation): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_statefulness(self): @@ -168,6 +178,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer.reset_state() layer(sequence) @@ -180,6 +192,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_pass_initial_state(self): @@ -198,6 +212,8 @@ def test_pass_initial_state(self): self.assertAllClose( np.array([[0.20574439, 0.3558822], [0.64930826, 0.66276]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -211,6 +227,8 @@ def test_pass_initial_state(self): self.assertAllClose( np.array([[0.13281618, 0.2790356], [0.5839337, 0.5992567]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_masking(self): @@ -227,6 +245,8 @@ def test_masking(self): self.assertAllClose( np.array([[0.1524914, 0.1524914], [0.35969394, 0.35969394]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -247,6 +267,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -258,6 +280,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -279,6 +303,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -290,6 +316,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.LSTM( @@ -303,4 +331,6 @@ def test_masking(self): self.assertAllClose( np.array([[0.10056866, 0.10056866], [0.31006062, 0.31006062]]), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) diff --git a/keras/src/layers/rnn/rnn_test.py b/keras/src/layers/rnn/rnn_test.py index 6e6a52a5c37a..c2f91e3dc3b8 100644 --- a/keras/src/layers/rnn/rnn_test.py +++ b/keras/src/layers/rnn/rnn_test.py @@ -364,15 +364,28 @@ def test_go_backwards(self): layer = layers.RNN(OneStateRNNCell(2), go_backwards=True) layer(sequence) output = layer(sequence) - self.assertAllClose(np.array([[202.0, 202.0], [538.0, 538.0]]), output) + self.assertAllClose( + np.array([[202.0, 202.0], [538.0, 538.0]]), + output, + tpu_atol=1e-4, + tpu_rtol=1e-4, + ) layer = layers.RNN(OneStateRNNCell(2), stateful=True, return_state=True) layer(sequence) output, state = layer(sequence) self.assertAllClose( - np.array([[954.0, 954.0], [3978.0, 3978.0]]), output + np.array([[954.0, 954.0], [3978.0, 3978.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) + self.assertAllClose( + np.array([[954.0, 954.0], [3978.0, 3978.0]]), + state, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) - self.assertAllClose(np.array([[954.0, 954.0], [3978.0, 3978.0]]), state) def test_serialization(self): layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False) diff --git a/keras/src/layers/rnn/simple_rnn_test.py b/keras/src/layers/rnn/simple_rnn_test.py index 8493bdbee8a8..3f466b860d5a 100644 --- a/keras/src/layers/rnn/simple_rnn_test.py +++ b/keras/src/layers/rnn/simple_rnn_test.py @@ -54,6 +54,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( 4, @@ -71,6 +73,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( @@ -89,6 +93,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( 4, @@ -107,6 +113,8 @@ def test_correctness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_statefulness(self): @@ -128,6 +136,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer.reset_state() layer(sequence) @@ -140,6 +150,8 @@ def test_statefulness(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_pass_initial_state(self): @@ -160,6 +172,8 @@ def test_pass_initial_state(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( @@ -178,6 +192,8 @@ def test_pass_initial_state(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) def test_masking(self): @@ -199,6 +215,8 @@ def test_masking(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( @@ -219,6 +237,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -230,6 +250,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( @@ -251,6 +273,8 @@ def test_masking(self): ], ), output[0], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) self.assertAllClose( np.array( @@ -262,6 +286,8 @@ def test_masking(self): ], ), output[1], + tpu_atol=1e-3, + tpu_rtol=1e-3, ) layer = layers.SimpleRNN( @@ -280,4 +306,6 @@ def test_masking(self): ] ), output, + tpu_atol=1e-3, + tpu_rtol=1e-3, ) diff --git a/keras/src/layers/rnn/stacked_rnn_cells_test.py b/keras/src/layers/rnn/stacked_rnn_cells_test.py index 1b87b177f64b..9e0cc620b356 100644 --- a/keras/src/layers/rnn/stacked_rnn_cells_test.py +++ b/keras/src/layers/rnn/stacked_rnn_cells_test.py @@ -138,7 +138,10 @@ def test_correctness_single_state_stack(self): layer = layers.RNN([OneStateRNNCell(3), OneStateRNNCell(2)]) output = layer(sequence) self.assertAllClose( - np.array([[786.0, 786.0], [4386.0, 4386.0]]), output + np.array([[786.0, 786.0], [4386.0, 4386.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) layer = layers.RNN( @@ -153,6 +156,8 @@ def test_correctness_single_state_stack(self): ] ), output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) layer = layers.RNN( @@ -160,13 +165,19 @@ def test_correctness_single_state_stack(self): ) output, state_1, state_2 = layer(sequence) self.assertAllClose( - np.array([[786.0, 786.0], [4386.0, 4386.0]]), output + np.array([[786.0, 786.0], [4386.0, 4386.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1 ) self.assertAllClose( - np.array([[786.0, 786.0], [4386.0, 4386.0]]), state_2 + np.array([[786.0, 786.0], [4386.0, 4386.0]]), + state_2, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) layer = layers.RNN( @@ -183,12 +194,20 @@ def test_correctness_single_state_stack(self): ] ), output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1 + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), + state_1, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[786.0, 786.0], [4386.0, 4386.0]]), state_2 + np.array([[786.0, 786.0], [4386.0, 4386.0]]), + state_2, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_correctness_two_states_stack(self): @@ -196,7 +215,10 @@ def test_correctness_two_states_stack(self): layer = layers.RNN([TwoStatesRNNCell(3), TwoStatesRNNCell(2)]) output = layer(sequence) self.assertAllClose( - np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), output + np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) layer = layers.RNN( @@ -211,6 +233,8 @@ def test_correctness_two_states_stack(self): ] ), output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) layer = layers.RNN( @@ -219,19 +243,34 @@ def test_correctness_two_states_stack(self): output, state_1, state_2 = layer(sequence) self.assertAllClose( - np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), output + np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1[0] + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), + state_1[0], + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1[1] + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), + state_1[1], + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), state_2[0] + np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), + state_2[0], + tpu_atol=1e-2, + tpu_rtol=1e-2, ) self.assertAllClose( - np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), state_2[1] + np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), + state_2[1], + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_statefullness_single_state_stack(self): @@ -242,7 +281,10 @@ def test_statefullness_single_state_stack(self): layer(sequence) output = layer(sequence) self.assertAllClose( - np.array([[34092.0, 34092.0], [173196.0, 173196.0]]), output + np.array([[34092.0, 34092.0], [173196.0, 173196.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_statefullness_two_states_stack(self): @@ -253,7 +295,10 @@ def test_statefullness_two_states_stack(self): layer(sequence) output = layer(sequence) self.assertAllClose( - np.array([[136368.0, 136368.0], [692784.0, 692784.0]]), output + np.array([[136368.0, 136368.0], [692784.0, 692784.0]]), + output, + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_return_state_stacked_lstm_cell(self): diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py index fe0d557d96c9..160a44e60274 100644 --- a/keras/src/losses/losses_test.py +++ b/keras/src/losses/losses_test.py @@ -1851,7 +1851,7 @@ def test_correctness(self): logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100 y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]])) output = losses.CTC()(y_true, logits) - self.assertAllClose(output, 2.448645) + self.assertAllClose(output, 2.448645, tpu_atol=1e-3, tpu_rtol=1e-3) def test_dtype_arg(self): logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100 @@ -1964,6 +1964,7 @@ def test_dtype_arg(self): class CircleTest(testing.TestCase): def setup(self): + super().setUp() self.y_true = np.array([1, 1, 2, 2, 3]) self.y_pred = np.array( [ @@ -1995,11 +1996,11 @@ def test_correctness(self): self.setup() circle_loss = losses.Circle(gamma=80.0, margin=0.4) loss = circle_loss(self.y_true, self.y_pred) - self.assertAlmostEqual(loss, 188.3883) + self.assertAlmostEqual(loss, 188.3883, tpu_decimal=0) circle_loss = losses.Circle(gamma=256, margin=0.25) loss = circle_loss(self.y_true, self.y_pred) - self.assertAlmostEqual(loss, 652.7617) + self.assertAlmostEqual(loss, 652.7617, tpu_decimal=0) loss = losses.circle( self.y_true, @@ -2012,7 +2013,10 @@ def test_correctness(self): ) self.assertAllClose( - loss, (61.5844, 94.3465, 276.9344, 90.9873, 48.8963) + loss, + (61.5844, 94.3465, 276.9344, 90.9873, 48.8963), + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_correctness_weighted(self): @@ -2022,7 +2026,7 @@ def test_correctness_weighted(self): loss = circle_loss( self.y_true, self.y_pred, sample_weight=sample_weight ) - self.assertAlmostEqual(loss, 244.91918) + self.assertAlmostEqual(loss, 244.91918, tpu_decimal=0) def test_no_reduction(self): self.setup() @@ -2030,7 +2034,10 @@ def test_no_reduction(self): loss = circle_loss(self.ref_labels, self.ref_embeddings) self.assertAllClose( - loss, [82.9116, 36.7942, 92.4590, 52.6798, 0.0, 0.0] + loss, + [82.9116, 36.7942, 92.4590, 52.6798, 0.0, 0.0], + tpu_atol=1e-2, + tpu_rtol=1e-2, ) def test_sum_reduction(self): @@ -2038,7 +2045,7 @@ def test_sum_reduction(self): circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction="sum") loss = circle_loss(self.ref_labels, self.ref_embeddings) - self.assertAlmostEqual(loss, 264.845) + self.assertAlmostEqual(loss, 264.845, tpu_decimal=0) def test_mean_with_sample_weight_reduction(self): self.setup() @@ -2049,7 +2056,7 @@ def test_mean_with_sample_weight_reduction(self): loss = circle_loss( self.y_true, self.y_pred, sample_weight=sample_weight ) - self.assertAlmostEqual(loss, 163.27948) + self.assertAlmostEqual(loss, 163.27948, tpu_decimal=0) def test_dtype_arg(self): self.setup() diff --git a/keras/src/ops/image_test.py b/keras/src/ops/image_test.py index a54e4aeb3120..39360273db7d 100644 --- a/keras/src/ops/image_test.py +++ b/keras/src/ops/image_test.py @@ -1431,7 +1431,7 @@ def test_affine_transform(self, interpolation, fill_mode): fill_mode=fill_mode, ) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - self.assertAllClose(ref_out, out, atol=1e-2) + self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10) x = np.random.uniform(size=(2, 50, 50, 3)).astype("float32") * 255 transform = np.random.uniform(size=(2, 6)).astype("float32") @@ -1456,7 +1456,7 @@ def test_affine_transform(self, interpolation, fill_mode): axis=0, ) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - self.assertAllClose(ref_out, out, atol=1e-2) + self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10) # Test channels_first backend.set_image_data_format("channels_first") @@ -1477,7 +1477,7 @@ def test_affine_transform(self, interpolation, fill_mode): ) ref_out = np.transpose(ref_out, [2, 0, 1]) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - self.assertAllClose(ref_out, out, atol=1e-2) + self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=1, tpu_rtol=1) x = np.random.uniform(size=(2, 3, 50, 50)).astype("float32") * 255 transform = np.random.uniform(size=(2, 6)).astype("float32") @@ -1505,13 +1505,13 @@ def test_affine_transform(self, interpolation, fill_mode): ) ref_out = np.transpose(ref_out, [0, 3, 1, 2]) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - self.assertAllClose(ref_out, out, atol=1e-2) + self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10) # Test class out = kimage.AffineTransform( interpolation=interpolation, fill_mode=fill_mode )(x, transform) - self.assertAllClose(ref_out, out, atol=1e-2) + self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10) @parameterized.named_parameters( named_product( diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py index 0be61d5bb7f9..31f30f3ba5a4 100644 --- a/keras/src/ops/linalg_test.py +++ b/keras/src/ops/linalg_test.py @@ -399,7 +399,13 @@ def test_cholesky_inverse(self, upper): ) output_inverse = linalg.cholesky_inverse(factor, upper=upper) - self.assertAllClose(output_inverse, expected_inverse, atol=1e-5) + self.assertAllClose( + output_inverse, + expected_inverse, + atol=1e-5, + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) def test_det(self): x = np.random.rand(4, 3, 3) @@ -411,6 +417,8 @@ def test_det(self): linalg.det(x) def test_eig(self): + if testing.uses_tpu(): + self.skipTest("Skipping test with JAX + TPU as it's not supported") x = np.random.rand(2, 3, 3) x = x @ x.transpose((0, 2, 1)) w, v = map(ops.convert_to_numpy, linalg.eig(x)) @@ -586,11 +594,15 @@ def test_svd(self): ..., : s.shape[-1], : ] # High tolerance due to numerical instability - self.assertAllClose(x_reconstructed, x, atol=1e-3) + self.assertAllClose( + x_reconstructed, x, atol=1e-3, tpu_atol=1e-2, tpu_rtol=1e-2 + ) # Test `compute_uv=False` s_no_uv = linalg.svd(x, compute_uv=False) - self.assertAllClose(s_no_uv, s, atol=1e-5, rtol=1e-5) + self.assertAllClose( + s_no_uv, s, atol=1e-5, rtol=1e-5, tpu_atol=1e-2, tpu_rtol=1e-2 + ) @parameterized.named_parameters( ("b_rank_1", 1, None), @@ -608,7 +620,9 @@ def test_lstsq(self, b_rank, rcond): b_symb = backend.KerasTensor((5, 4)) out = linalg.lstsq(a, b, rcond=rcond) ref_out = np.linalg.lstsq(a, b, rcond=rcond)[0] - self.assertAllClose(out, ref_out, atol=1e-5) + self.assertAllClose( + out, ref_out, atol=1e-5, tpu_atol=1e-4, tpu_rtol=1e-4 + ) out_symb = linalg.lstsq(a_symb, b_symb) self.assertEqual(out_symb.shape, out.shape) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index bd5b17290f27..407d7803abe5 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -692,7 +692,7 @@ def test_extract_sequences(self): for i in range(num_sequences): expected[i] = x[pos : pos + sequence_length] pos += sequence_stride - self.assertAllClose(output, expected) + self.assertAllClose(output, expected, tpu_atol=1e-2, tpu_rtol=1e-2) # Test N-D case. x = np.random.random((4, 8)) @@ -706,7 +706,7 @@ def test_extract_sequences(self): for i in range(num_sequences): expected[:, i] = x[:, pos : pos + sequence_length] pos += sequence_stride - self.assertAllClose(output, expected) + self.assertAllClose(output, expected, tpu_atol=1e-2, tpu_rtol=1e-2) def test_fft(self): real = np.random.random((2, 4, 3)) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index f4718c495337..a09b8d068705 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1715,7 +1715,7 @@ def test_conv_2d(self, strides, padding): dilation_rate=1, groups=1, ) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2) @parameterized.product(strides=(1, 2), dilation_rate=(1, (2, 1))) def test_conv_2d_group_2(self, strides, dilation_rate): @@ -1777,7 +1777,14 @@ def test_conv_3d(self, strides, padding, data_format): dilation_rate=1, groups=1, ) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose( + outputs, + expected, + rtol=1e-5, + atol=1e-5, + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) # Test for tracing error on tensorflow backend. if backend.backend() == "tensorflow": @@ -1790,7 +1797,14 @@ def conv(x): ) outputs = conv(inputs_3d) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose( + outputs, + expected, + rtol=1e-5, + atol=1e-5, + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) @parameterized.product( strides=(1, (1, 1), (2, 2)), @@ -1829,7 +1843,7 @@ def test_depthwise_conv_2d(self, strides, padding, dilation_rate): data_format=backend.config.image_data_format(), dilation_rate=dilation_rate, ) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2) @parameterized.product( strides=(1, 2), @@ -1881,7 +1895,7 @@ def test_separable_conv_2d(self, strides, padding, dilation_rate): dilation_rate=dilation_rate, groups=1, ) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2) @parameterized.product(padding=("valid", "same")) def test_conv_transpose_1d(self, padding): @@ -2279,7 +2293,12 @@ def test_ctc_loss(self): output_length = np.array([3, 2]) result = knn.ctc_loss(labels, outputs, label_length, output_length) - self.assertAllClose(result, np.array([3.4411672, 1.91680186])) + self.assertAllClose( + result, + np.array([3.4411672, 1.91680186]), + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) def test_ctc_decode(self): inputs = np.array( diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 0bc35c36c9a4..0082c079feaa 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1259,7 +1259,7 @@ def test_argmax(self): self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3, 3)) @pytest.mark.skipif( - keras.config.backend() == "openvino", + keras.config.backend() == "openvino" or testing.uses_tpu(), reason="OpenVINO doesn't support this change", ) def test_argmax_negative_zero(self): @@ -1270,7 +1270,8 @@ def test_argmax_negative_zero(self): @pytest.mark.skipif( keras.config.backend() == "openvino" - or keras.config.backend() == "tensorflow", + or keras.config.backend() == "tensorflow" + or testing.uses_tpu(), reason=""" OpenVINO and TensorFlow don't support this change, TensorFlow behavior for this case is under @@ -2647,7 +2648,14 @@ def test_matmul_sparse(self, dtype, x_shape, y_shape, x_sparse, y_sparse): y = dense_to_sparse(y_np) atol = 0.1 if dtype == "float16" else 1e-4 - self.assertAllClose(knp.matmul(x, y), np.matmul(x_np, y_np), atol=atol) + tpu_atol = 1 if dtype == "float16" else 1e-1 + self.assertAllClose( + knp.matmul(x, y), + np.matmul(x_np, y_np), + atol=atol, + tpu_atol=tpu_atol, + tpu_rtol=tpu_atol, + ) self.assertSparse(knp.matmul(x, y), x_sparse and y_sparse) def test_power(self): @@ -3242,12 +3250,22 @@ def test_logical_or(self): self.assertAllClose(knp.LogicalOr()(True, x), np.logical_or(True, x)) def test_logspace(self): - self.assertAllClose(knp.logspace(0, 10, 5), np.logspace(0, 10, 5)) + self.assertAllClose( + knp.logspace(0, 10, 5), + np.logspace(0, 10, 5), + tpu_atol=1e-4, + tpu_rtol=1e-4, + ) self.assertAllClose( knp.logspace(0, 10, 5, endpoint=False), np.logspace(0, 10, 5, endpoint=False), ) - self.assertAllClose(knp.Logspace(num=5)(0, 10), np.logspace(0, 10, 5)) + self.assertAllClose( + knp.Logspace(num=5)(0, 10), + np.logspace(0, 10, 5), + tpu_atol=1e-4, + tpu_rtol=1e-4, + ) self.assertAllClose( knp.Logspace(num=5, endpoint=False)(0, 10), np.logspace(0, 10, 5, endpoint=False), diff --git a/keras/src/optimizers/muon_test.py b/keras/src/optimizers/muon_test.py index 9ec85d8985ce..bd2de6f54de8 100644 --- a/keras/src/optimizers/muon_test.py +++ b/keras/src/optimizers/muon_test.py @@ -21,7 +21,14 @@ def test_Newton_Schulz(self): tensor_input = ops.array([[0.2499, 0.9105], [0.2655, 0.8824]]) except_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]]) output = optimizer.zeropower_via_newtonschulz5(tensor_input, 5) - self.assertAllClose(output, except_output, rtol=1e-3, atol=1e-3) + self.assertAllClose( + output, + except_output, + rtol=1e-3, + atol=1e-3, + tpu_atol=1e-1, + tpu_rtol=1e-1, + ) def test_adamw_single_step(self): optimizer = Muon() diff --git a/keras/src/optimizers/schedules/learning_rate_schedule_test.py b/keras/src/optimizers/schedules/learning_rate_schedule_test.py index 052db9e93945..72d17a4aac49 100644 --- a/keras/src/optimizers/schedules/learning_rate_schedule_test.py +++ b/keras/src/optimizers/schedules/learning_rate_schedule_test.py @@ -40,6 +40,8 @@ def test_fit_lr_correctness(self): history.history["loss"], [230.79457092285156, 128.30319213867188, 79.33648681640625], rtol=5e-5, + tpu_atol=5e-3, + tpu_rtol=5e-3, ) diff --git a/keras/src/regularizers/regularizers_test.py b/keras/src/regularizers/regularizers_test.py index 36141f54f772..3f61a8f0268b 100644 --- a/keras/src/regularizers/regularizers_test.py +++ b/keras/src/regularizers/regularizers_test.py @@ -55,6 +55,8 @@ def test_orthogonal_regularizer(self): np.abs(np.dot(inputs, np.transpose(inputs)) * (1.0 - np.eye(4))) ) / (4.0 * (4.0 - 1.0) / 2.0), + tpu_atol=1e-4, + tpu_rtol=1e-4, ) def test_get_method(self): diff --git a/keras/src/testing/__init__.py b/keras/src/testing/__init__.py index ae554ff85857..9a221dd141c3 100644 --- a/keras/src/testing/__init__.py +++ b/keras/src/testing/__init__.py @@ -3,3 +3,4 @@ from keras.src.testing.test_case import tensorflow_uses_gpu from keras.src.testing.test_case import torch_uses_gpu from keras.src.testing.test_case import uses_gpu +from keras.src.testing.test_case import uses_tpu diff --git a/keras/src/testing/test_case.py b/keras/src/testing/test_case.py index 1b7ceddfdb78..bd7ba58df115 100644 --- a/keras/src/testing/test_case.py +++ b/keras/src/testing/test_case.py @@ -34,13 +34,49 @@ def setUp(self): clear_session(free_memory=False) if traceback_utils.is_traceback_filtering_enabled(): traceback_utils.disable_traceback_filtering() + self.on_tpu = False + if backend.backend() == "jax": + import jax + + available_devices = jax.devices() + self.on_tpu = any( + d.platform.lower() == "tpu" for d in available_devices + ) + jax.clear_caches() + elif backend.backend() == "tensorflow": + import tensorflow as tf + + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + devices = tf.config.list_logical_devices() + tpu_devices = [d for d in devices if "TPU" in d.device_type] + if len(tpu_devices) > 0: + self.on_tpu = True + except (ValueError, RuntimeError): + # No TPU found or initialization failed. + pass def get_temp_dir(self): temp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(temp_dir)) return temp_dir - def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): + def assertAllClose( + self, + x1, + x2, + atol=1e-6, + rtol=1e-6, + tpu_atol=None, + tpu_rtol=None, + msg=None, + ): + if tpu_atol is not None and self.on_tpu: + atol = tpu_atol + if tpu_rtol is not None and self.on_tpu: + rtol = tpu_rtol if not isinstance(x1, np.ndarray): x1 = backend.convert_to_numpy(x1) if not isinstance(x2, np.ndarray): @@ -57,7 +93,9 @@ def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): f"The two values are close at all elements. \n{msg}.\nValues: {x1}" ) - def assertAlmostEqual(self, x1, x2, decimal=3, msg=None): + def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None): + if tpu_decimal is not None and self.on_tpu: + decimal = tpu_decimal msg = msg or "" if not isinstance(x1, np.ndarray): x1 = backend.convert_to_numpy(x1) @@ -195,6 +233,8 @@ def run_layer_test( run_training_check=True, run_mixed_precision_check=True, assert_built_after_instantiation=False, + tpu_atol=None, + tpu_rtol=None, ): """Run basic checks on a layer. @@ -376,7 +416,9 @@ def run_build_asserts(layer): msg="Unexpected number of torch_params", ) - def run_output_asserts(layer, output, eager=False): + def run_output_asserts( + layer, output, eager=False, tpu_atol=None, tpu_rtol=None + ): if expected_output_shape is not None: def verify_shape(expected_shape, x): @@ -422,7 +464,11 @@ def verify_dtype(expected_dtype, x): tree.flatten(expected_output), tree.flatten(output) ): self.assertAllClose( - ref_v, v, msg="Unexpected output value" + ref_v, + v, + msg="Unexpected output value", + tpu_atol=tpu_atol, + tpu_rtol=tpu_rtol, ) if expected_num_losses is not None: self.assertLen(layer.losses, expected_num_losses) @@ -551,7 +597,13 @@ def build(self, *args, **kwargs): output_data = layer(**input_data, **call_kwargs) else: output_data = layer(input_data, **call_kwargs) - run_output_asserts(layer, output_data, eager=True) + run_output_asserts( + layer, + output_data, + eager=True, + tpu_atol=tpu_atol, + tpu_rtol=tpu_rtol, + ) if run_training_check: run_training_step(layer, input_data, output_data) @@ -621,6 +673,17 @@ def uses_gpu(): return False +def uses_tpu(): + # Condition used to skip tests when using the TPU + try: + devices = distribution.list_devices() + if any(d.startswith("tpu") for d in devices): + return True + except AttributeError: + return False + return False + + def uses_cpu(): devices = distribution.list_devices() if any(d.startswith("cpu") for d in devices): diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 51833cb55fcc..2d780a0af455 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1894,33 +1894,43 @@ def test_on_batch_methods(self, run_eagerly, jit_compile): logs = model.train_on_batch(x, y, return_dict=True) self.assertIsInstance(logs, dict) self.assertEqual(len(logs), 2) - self.assertAlmostEqual(logs["loss"], 15.579) + self.assertAlmostEqual(logs["loss"], 15.579, tpu_decimal=1) logs = model.test_on_batch(x, y) self.assertIsInstance(logs, list) self.assertEqual(len(logs), 2) - self.assertAlmostEqual(logs[0], 15.173) + self.assertAlmostEqual(logs[0], 15.173, tpu_decimal=1) logs = model.test_on_batch(x, y, return_dict=True) self.assertIsInstance(logs, dict) self.assertEqual(len(logs), 2) - self.assertAlmostEqual(logs["loss"], 14.97) + self.assertAlmostEqual(logs["loss"], 14.97, tpu_decimal=1) output = model.predict_on_batch(x) self.assertIsInstance(output, np.ndarray) - self.assertAllClose(output[0], np.array([3.789511, 3.789511, 3.789511])) + self.assertAllClose( + output[0], + np.array([3.789511, 3.789511, 3.789511]), + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) # With sample weights logs = model.train_on_batch(x, y, sw) - self.assertAlmostEqual(logs[0], 14.819) + self.assertAlmostEqual(logs[0], 14.819, tpu_decimal=1) logs = model.test_on_batch(x, y, sw) - self.assertAlmostEqual(logs[0], 14.595) + self.assertAlmostEqual(logs[0], 14.595, tpu_decimal=1) output = model.predict_on_batch(x) - self.assertAllClose(output[0], np.array([3.689468, 3.689468, 3.689468])) + self.assertAllClose( + output[0], + np.array([3.689468, 3.689468, 3.689468]), + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) # With class weights logs = model.train_on_batch(x, y, class_weight={1: 0.3, 0: 0.2}) - self.assertAlmostEqual(logs[0], 12.899) + self.assertAlmostEqual(logs[0], 12.899, tpu_decimal=1) @parameterized.named_parameters( [ @@ -2280,19 +2290,19 @@ def test_nested_inputs(self): history = model.fit( [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) ).history - self.assertAllClose(history["loss"], 16.0) + self.assertAllClose(history["loss"], 16.0, tpu_atol=1e-4, tpu_rtol=1e-4) train_out = model.train_on_batch( [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) ) - self.assertAllClose(train_out[0], 15.2200) + self.assertAllClose(train_out[0], 15.2200, tpu_atol=1e-1, tpu_rtol=1e-1) eval_out = model.evaluate( [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) ) - self.assertAllClose(eval_out[0], 13.0321) + self.assertAllClose(eval_out[0], 13.0321, tpu_atol=1e-2, tpu_rtol=1e-2) eval_out = model.test_on_batch( [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) ) - self.assertAllClose(eval_out[0], 13.0321) + self.assertAllClose(eval_out[0], 13.0321, tpu_atol=1e-2, tpu_rtol=1e-2) predict_out = model.predict([np.ones((3, 2)), np.ones((3, 3))]) self.assertEqual(predict_out.shape, (3, 2)) predict_out = model.predict_on_batch([np.ones((3, 2)), np.ones((3, 3))]) @@ -2645,6 +2655,7 @@ def test_loss_weights(self): history["loss"], [3.182979, 3.115617, 3.049681], atol=1e-3, + tpu_atol=1e-2, ) # Dict output case. @@ -2677,6 +2688,7 @@ def test_loss_weights(self): history["loss"], [4.778718, 4.694403, 4.611693], atol=1e-3, + tpu_atol=1e-2, ) # List output case. @@ -2702,6 +2714,7 @@ def test_loss_weights(self): history["loss"], [4.778718, 4.694403, 4.611693], atol=1e-3, + tpu_atol=1e-2, ) @pytest.mark.requires_trainable_backend diff --git a/log.log b/log.log index df06bfe8567e..a546a5cd5b3e 100644 --- a/log.log +++ b/log.log @@ -1,12 +1,62 @@ ============================= test session starts ============================== -platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 +platform linux -- Python 3.11.14, pytest-9.0.1, pluggy-1.6.0 -- /mnt/data/keras/venv/bin/python3 cachedir: .pytest_cache -rootdir: /Users/wenyiguo/keras +rootdir: /mnt/data/keras configfile: pyproject.toml plugins: cov-7.0.0 collecting ... collected 1 item -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method FAILED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method Model: "FlaxTrainingIndependentModel1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 23s 3s/step - categorical_accuracy: 0.0312 - loss: 2.3293 10/10 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - categorical_accuracy: 0.0813 - loss: 2.3713 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step +Model: "FlaxTrainingIndependentModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step +Model: "FlaxTrainingIndependentModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 1,944,680 (7.42 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + Optimizer params: 1,296,454 (4.95 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step +Saved artifact at '/mnt/data/tmp/tmpvyxw070h/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_8') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 126200239376720: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239374800: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239377872: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239374416: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239375184: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239377104: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239378064: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239378256: TensorSpec(shape=(), dtype=tf.resource, name=None) + 126200239378640: TensorSpec(shape=(), dtype=tf.resource, name=None) +FAILED =================================== FAILURES =================================== ________ TestJaxLayer.test_flax_layer_training_independent_bound_method ________ @@ -105,58 +155,66 @@ trainable_params = 648226, non_trainable_weights = 0, non_trainable_params = 0 non_trainable_params, ) -keras/src/utils/jax_layer_test.py:488: +keras/src/utils/jax_layer_test.py:505: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -keras/src/utils/jax_layer_test.py:231: in _test_layer - outputs1 = layer1(inputs1) - ^^^^^^^^^^^^^^^ -keras/src/utils/traceback_utils.py:113: in error_handler - return fn(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^ -keras/src/layers/layer.py:866: in __call__ - self._maybe_build(call_spec) -keras/src/layers/layer.py:1477: in _maybe_build - self.build(**shapes_dict) -keras/src/layers/layer.py:231: in build_wrapper - original_build_method(*args, **kwargs) -keras/src/utils/jax_layer.py:510: in build - self._initialize_weights(input_shape) -keras/src/utils/jax_layer.py:497: in _initialize_weights - init_result = self.init_fn(*init_args) - ^^^^^^^^^^^^^^^^^^^^^^^^ +keras/src/utils/jax_layer_test.py:348: in _test_layer + output4 = model4.serve(x_test) + ^^^^^^^^^^^^^^^^^^^^ +venv/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153: in error_handler + raise e.with_traceback(filtered_tb) from None _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -rng = {'dropout': None, 'params': None} -inputs = Array([[[[1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.]... [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.]]]], dtype=float32) +op_name = '__inference_restored_function_body_799', num_outputs = 1 +inputs = [>, ...] +attrs = ('executor_type', '', 'config_proto', b'\n\x07\n\x03CPU\x10\x01\n\x07\n\x03GPU\x10\x00\n\x0e\n\nTPU_SYSTEM\x10\x012\x02J\x008\x01\x82\x01\x00\x92\x01\x02J\x00') +ctx = +name = None - def init_without_training(rng, inputs): - return self._variables_to_params_and_state( -> self.module.init( - rng, - inputs, - method=self.method, - ) - ) -E ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. -E -------------------- -E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. + def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): + """Execute a TensorFlow operation. + + Args: + op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to + execute. + num_outputs: The number of outputs of the operation to fetch. (Explicitly + provided instead of being inferred for performance reasons). + inputs: A list of inputs to the operation. Each entry should be a Tensor, or + a value which can be passed to the Tensor constructor to create one. + attrs: A tuple with alternating string attr names and attr values for this + operation. + ctx: The value of context.context(). + name: Customized name for the operation. + + Returns: + List of output Tensor objects. The list is empty if there are no outputs + + Raises: + An exception on error. + """ + device_name = ctx.device_name + # pylint: disable=protected-access + try: + ctx.ensure_initialized() +> tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, + inputs, attrs, num_outputs) +E tensorflow.python.framework.errors_impl.NotFoundError: Graph execution error: +E +E Detected at node XlaCallModule defined at (most recent call last): +E +E could not find registered transfer manager for platform Host -- check target linkage +E [[{{node XlaCallModule}}]] [Op:__inference_restored_function_body_799] -keras/src/utils/jax_layer.py:755: ValueError +venv/lib/python3.11/site-packages/tensorflow/python/eager/execute.py:53: NotFoundError +------------------------------ Captured log call ------------------------------- +WARNING absl:function_deserialization.py:672 Importing a function (__inference_internal_grad_fn_298) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. =========================== short test summary info ============================ -FAILED keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method - ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. --------------------- -For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. -============================== 1 failed in 1.72s =============================== +FAILED keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method - tensorflow.python.framework.errors_impl.NotFoundError: Graph execution error: + +Detected at node XlaCallModule defined at (most recent call last): + +could not find registered transfer manager for platform Host -- check target linkage + [[{{node XlaCallModule}}]] [Op:__inference_restored_function_body_799] +============================== 1 failed in 28.43s ==============================