Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 9 additions & 158 deletions keras/src/backend/jax/excluded_tpu_tests.txt
Original file line number Diff line number Diff line change
@@ -1,53 +1,3 @@
AdditiveAttentionTest::test_attention_correctness
AttentionTest::test_attention_calculate_scores_with_scale
AttentionTest::test_attention_correctness
CircleTest::test_correctness
CircleTest::test_correctness_weighted
CircleTest::test_mean_with_sample_weight_reduction
CircleTest::test_no_reduction
CircleTest::test_sum_reduction
ConvBasicTest::test_enable_lora_with_alpha
ConvCorrectnessTest::test_conv1d0
ConvCorrectnessTest::test_conv1d1
ConvCorrectnessTest::test_conv1d2
ConvCorrectnessTest::test_conv1d3
ConvCorrectnessTest::test_conv1d4
ConvCorrectnessTest::test_conv2d0
ConvCorrectnessTest::test_conv2d1
ConvCorrectnessTest::test_conv2d2
ConvCorrectnessTest::test_conv2d3
ConvCorrectnessTest::test_conv2d4
ConvCorrectnessTest::test_conv2d5
ConvCorrectnessTest::test_conv3d0
ConvCorrectnessTest::test_conv3d1
ConvCorrectnessTest::test_conv3d2
ConvCorrectnessTest::test_conv3d3
ConvCorrectnessTest::test_conv3d4
ConvLSTM1DTest::test_correctness
ConvLSTM1DTest::test_correctness
ConvLSTM2DTest::test_correctness
ConvLSTMCellTest::test_correctness
ConvLSTMTest::test_correctness
ConvTransposeCorrectnessTest::test_conv1d_transpose0
ConvTransposeCorrectnessTest::test_conv1d_transpose1
ConvTransposeCorrectnessTest::test_conv1d_transpose2
ConvTransposeCorrectnessTest::test_conv2d_transpose0
ConvTransposeCorrectnessTest::test_conv2d_transpose1
ConvTransposeCorrectnessTest::test_conv2d_transpose2
ConvTransposeCorrectnessTest::test_conv2d_transpose3
ConvTransposeCorrectnessTest::test_conv3d_transpose0
ConvTransposeCorrectnessTest::test_conv3d_transpose1
ConvTransposeCorrectnessTest::test_conv3d_transpose2
CTCTest::test_correctness
DenseTest::test_dense_sparse
DepthwiseConvCorrectnessTest::test_depthwise_conv1d0
DepthwiseConvCorrectnessTest::test_depthwise_conv1d1
DepthwiseConvCorrectnessTest::test_depthwise_conv1d2
DepthwiseConvCorrectnessTest::test_depthwise_conv2d0
DepthwiseConvCorrectnessTest::test_depthwise_conv2d1
DepthwiseConvCorrectnessTest::test_depthwise_conv2d2
EinsumDenseTest::test_enable_lora_with_alpha
EmbeddingTest::test_enable_lora_with_alpha
ExportArchiveTest::test_jax_endpoint_registration_tf_function
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
ExportArchiveTest::test_layer_export
Expand All @@ -71,10 +21,18 @@ ExportArchiveTest::test_track_multiple_layers
ExportONNXTest::test_export_with_input_names
ExportONNXTest::test_export_with_opset_version_18
ExportONNXTest::test_export_with_opset_version_none
ExportONNXTest::test_model_with_input_structure_array
ExportONNXTest::test_model_with_input_structure_dict
ExportONNXTest::test_model_with_input_structure_tuple
ExportONNXTest::test_model_with_multiple_inputs
ExportONNXTest::test_standard_model_export_functional
ExportONNXTest::test_standard_model_export_lstm
ExportONNXTest::test_standard_model_export_sequential
ExportONNXTest::test_standard_model_export_subclass
ExportOpenVINOTest::test_model_with_input_structure_array
ExportOpenVINOTest::test_model_with_input_structure_dict
ExportOpenVINOTest::test_model_with_input_structure_tuple
ExportOpenVINOTest::test_model_with_multiple_inputs
ExportOpenVINOTest::test_standard_model_export_functional
ExportOpenVINOTest::test_standard_model_export_sequential
ExportOpenVINOTest::test_standard_model_export_subclass
Expand Down Expand Up @@ -118,117 +76,10 @@ ExportSavedModelTest::test_model_with_tf_data_layer_subclass
ExportSavedModelTest::test_standard_model_export_functional
ExportSavedModelTest::test_standard_model_export_sequential
ExportSavedModelTest::test_standard_model_export_subclass
GRUTest::test_correctness0
GRUTest::test_correctness1
GRUTest::test_legacy_implementation_argument
GRUTest::test_masking
GRUTest::test_pass_initial_state
GRUTest::test_pass_return_state
GRUTest::test_statefulness
ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant
ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror
ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest
ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect
ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap
LinalgOpsCorrectnessTest::test_cholesky_inverse_lower
LinalgOpsCorrectnessTest::test_cholesky_inverse_upper
LinalgOpsCorrectnessTest::test_eig
LinalgOpsCorrectnessTest::test_svd
LSTMTest::test_correctness0
LSTMTest::test_correctness1
LSTMTest::test_masking
LSTMTest::test_pass_initial_state
LSTMTest::test_statefulness
MathOpsCorrectnessTest::test_extract_sequences
MergingLayersTest::test_correctness_dynamic_dot_3d
MergingLayersTest::test_correctness_static_dot_3d
MuonTest::test_Newton_Schulz
NNOpsCorrectnessTest::test_conv_2d0
NNOpsCorrectnessTest::test_conv_2d1
NNOpsCorrectnessTest::test_conv_2d2
NNOpsCorrectnessTest::test_conv_2d3
NNOpsCorrectnessTest::test_conv_2d4
NNOpsCorrectnessTest::test_conv_2d5
NNOpsCorrectnessTest::test_conv_3d0
NNOpsCorrectnessTest::test_conv_3d1
NNOpsCorrectnessTest::test_conv_3d10
NNOpsCorrectnessTest::test_conv_3d11
NNOpsCorrectnessTest::test_conv_3d2
NNOpsCorrectnessTest::test_conv_3d3
NNOpsCorrectnessTest::test_conv_3d4
NNOpsCorrectnessTest::test_conv_3d5
NNOpsCorrectnessTest::test_conv_3d6
NNOpsCorrectnessTest::test_conv_3d7
NNOpsCorrectnessTest::test_conv_3d8
NNOpsCorrectnessTest::test_conv_3d9
NNOpsCorrectnessTest::test_ctc_loss
NNOpsCorrectnessTest::test_depthwise_conv_2d0
NNOpsCorrectnessTest::test_depthwise_conv_2d1
NNOpsCorrectnessTest::test_depthwise_conv_2d10
NNOpsCorrectnessTest::test_depthwise_conv_2d11
NNOpsCorrectnessTest::test_depthwise_conv_2d2
NNOpsCorrectnessTest::test_depthwise_conv_2d3
NNOpsCorrectnessTest::test_depthwise_conv_2d4
NNOpsCorrectnessTest::test_depthwise_conv_2d5
NNOpsCorrectnessTest::test_depthwise_conv_2d6
NNOpsCorrectnessTest::test_depthwise_conv_2d7
NNOpsCorrectnessTest::test_depthwise_conv_2d8
NNOpsCorrectnessTest::test_depthwise_conv_2d9
NNOpsCorrectnessTest::test_separable_conv_2d0
NNOpsCorrectnessTest::test_separable_conv_2d1
NNOpsCorrectnessTest::test_separable_conv_2d2
NNOpsCorrectnessTest::test_separable_conv_2d3
NNOpsCorrectnessTest::test_separable_conv_2d4
NNOpsCorrectnessTest::test_separable_conv_2d5
NNOpsCorrectnessTest::test_separable_conv_2d6
NNOpsCorrectnessTest::test_separable_conv_2d7
NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero
NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero
NumpyTwoInputOpsCorrectnessTest::test_logspace
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false
RandomGaussianBlurTest::test_random_erasing_basic
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale
RandomZoomTest::test_random_zoom_out_correctness
RegularizersTest::test_orthogonal_regularizer
RNNTest::test_go_backwards
SeparableConvCorrectnessTest::test_separable_conv1d0
SeparableConvCorrectnessTest::test_separable_conv1d1
SeparableConvCorrectnessTest::test_separable_conv1d2
SeparableConvCorrectnessTest::test_separable_conv2d0
SeparableConvCorrectnessTest::test_separable_conv2d1
SeparableConvCorrectnessTest::test_separable_conv2d2
SimpleRNNTest::test_correctness
SimpleRNNTest::test_correctness
SimpleRNNTest::test_masking
SimpleRNNTest::test_masking
SimpleRNNTest::test_pass_initial_state
SimpleRNNTest::test_pass_initial_state
SimpleRNNTest::test_return_state
SimpleRNNTest::test_statefulness
SimpleRNNTest::test_statefulness
StackedRNNTest::test_correctness_single_state_stack
StackedRNNTest::test_correctness_two_states_stack
StackedRNNTest::test_statefullness_single_state_stack
StackedRNNTest::test_statefullness_two_states_stack
TestFitLRSchedulesFlow::test_fit_lr_correctness
TestJaxLayer::test_flax_layer_training_independent_bound_method
TestJaxLayer::test_flax_layer_training_rng_state_no_method
TestJaxLayer::test_flax_layer_training_rng_unbound_method
TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy
TestJaxLayer::test_jax_layer_training_independent
TestJaxLayer::test_jax_layer_training_state
TestJaxLayer::test_jax_layer_training_state_dtype_policy
TestSpectrogram::test_spectrogram_error
TestTrainer::test_loss_weights
TestTrainer::test_nested_inputs
TestTrainer::test_on_batch_methods_eager
TestTrainer::test_on_batch_methods_graph_fn
TestTrainer::test_on_batch_methods_jit
TestJaxLayer::test_jax_layer_training_state_dtype_policy
10 changes: 8 additions & 2 deletions keras/src/layers/attention/additive_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ def test_attention_correctness(self):
return_attention_scores=True,
)
self.assertAllClose(
output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3
output,
[[[1.727, 2.727], [2.272, 3.272]]],
atol=1e-3,
tpu_atol=1e-2,
)
self.assertAllClose(
scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3
scores,
[[[0.636, 0.363], [0.363, 0.636]]],
atol=1e-3,
tpu_atol=1e-2,
)

def test_attention_with_mask(self):
Expand Down
28 changes: 23 additions & 5 deletions keras/src/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,18 @@ def test_attention_correctness(self):
return_attention_scores=True,
)
self.assertAllClose(
output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3
output,
[[[2.462, 3.462], [1.538, 2.538]]],
atol=1e-3,
tpu_atol=1e-2,
tpu_rtol=1e-2,
)
self.assertAllClose(
scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3
scores,
[[[0.269, 0.731], [0.731, 0.269]]],
atol=1e-3,
tpu_atol=1e-2,
tpu_rtol=1e-2,
)

# Concat.
Expand All @@ -66,10 +74,18 @@ def test_attention_correctness(self):
return_attention_scores=True,
)
self.assertAllClose(
output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3
output,
[[[1.727, 2.727], [2.272, 3.272]]],
atol=1e-3,
tpu_atol=1e-2,
tpu_rtol=1e-2,
)
self.assertAllClose(
scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3
scores,
[[[0.636, 0.363], [0.363, 0.636]]],
atol=1e-3,
tpu_atol=1e-2,
tpu_rtol=1e-2,
)

def test_attention_with_mask(self):
Expand Down Expand Up @@ -149,7 +165,9 @@ def test_attention_calculate_scores_with_scale(self):
expected_scores = np.matmul(query, key.transpose((0, 2, 1)))
expected_scores *= layer.scale.numpy()
actual_scores = layer._calculate_scores(query, key)
self.assertAllClose(actual_scores, expected_scores)
self.assertAllClose(
actual_scores, expected_scores, tpu_atol=1e-2, tpu_rtol=1e-2
)

def test_attention_calculate_score_mask_no_causal_no_vmask(self):
scores = np.random.random((2, 3, 4))
Expand Down
17 changes: 13 additions & 4 deletions keras/src/layers/convolutional/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,12 @@ def test_enable_lora_with_alpha(self):

# Compare the effective kernel computed via the property.
actual_effective_kernel = ops.convert_to_numpy(layer.kernel)
self.assertAllClose(actual_effective_kernel, expected_effective_kernel)
self.assertAllClose(
actual_effective_kernel,
expected_effective_kernel,
tpu_atol=1e-3,
tpu_rtol=1e-3,
)

@pytest.mark.requires_trainable_backend
def test_lora_rank_argument(self):
Expand Down Expand Up @@ -891,7 +896,7 @@ def test_conv1d(
dilation_rate=dilation_rate,
groups=groups,
)
self.assertAllClose(outputs, expected)
self.assertAllClose(outputs, expected, tpu_atol=1e-1, tpu_rtol=1e-1)

@parameterized.parameters(
{
Expand Down Expand Up @@ -989,7 +994,9 @@ def test_conv2d(
dilation_rate=dilation_rate,
groups=groups,
)
self.assertAllClose(outputs, expected, rtol=5e-4)
self.assertAllClose(
outputs, expected, rtol=5e-4, tpu_atol=1e-1, tpu_rtol=1e-1
)

@parameterized.parameters(
{
Expand Down Expand Up @@ -1078,7 +1085,9 @@ def test_conv3d(
dilation_rate=dilation_rate,
groups=groups,
)
self.assertAllClose(outputs, expected, rtol=1e-3)
self.assertAllClose(
outputs, expected, rtol=1e-3, tpu_atol=1e-1, tpu_rtol=1e-1
)

def test_conv_constraints(self):
layer = layers.Conv2D(
Expand Down
12 changes: 9 additions & 3 deletions keras/src/layers/convolutional/conv_transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,9 @@ def test_conv1d_transpose(
data_format,
dilation_rate,
)
self.assertAllClose(outputs, expected, atol=1e-5)
self.assertAllClose(
outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1
)

@parameterized.parameters(
{
Expand Down Expand Up @@ -696,7 +698,9 @@ def test_conv2d_transpose(
data_format,
dilation_rate,
)
self.assertAllClose(outputs, expected, atol=1e-5)
self.assertAllClose(
outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1
)

@parameterized.parameters(
{
Expand Down Expand Up @@ -767,7 +771,9 @@ def test_conv3d_transpose(
data_format,
dilation_rate,
)
self.assertAllClose(outputs, expected, atol=1e-5)
self.assertAllClose(
outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1
)

@parameterized.product(
kernel_size=list(range(1, 5)),
Expand Down
10 changes: 7 additions & 3 deletions keras/src/layers/convolutional/depthwise_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def test_depthwise_conv1d(
data_format=data_format,
dilation_rate=dilation_rate,
)
self.assertAllClose(outputs, expected)
self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2)

@parameterized.parameters(
{
Expand Down Expand Up @@ -465,5 +465,9 @@ def test_depthwise_conv2d(
data_format=data_format,
dilation_rate=dilation_rate,
)
self.assertAllClose(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected, atol=1e-5)
self.assertAllClose(
outputs.shape, expected.shape, tpu_atol=1e-2, tpu_rtol=1e-2
)
self.assertAllClose(
outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1
)
18 changes: 16 additions & 2 deletions keras/src/layers/convolutional/separable_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,14 @@ def test_separable_conv1d(
)

self.assertAllClose(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(
outputs,
expected,
rtol=1e-5,
atol=1e-5,
tpu_atol=1e-1,
tpu_rtol=1e-1,
)

@parameterized.parameters(
{
Expand Down Expand Up @@ -381,4 +388,11 @@ def test_separable_conv2d(
)

self.assertAllClose(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(
outputs,
expected,
rtol=1e-5,
atol=1e-5,
tpu_atol=1e-1,
tpu_rtol=1e-1,
)
Loading
Loading