Skip to content

Commit d002f0a

Browse files
committed
Add backend specific TPU tests for DistributedEmbedding.
Under `keras_rs/src/layers/embedding`.
1 parent 5713afe commit d002f0a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
run: python3 -c "import jax; print('JAX devices:', jax.devices())"
9191

9292
- name: Test with pytest
93-
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
93+
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py keras_rs/src/layers/embedding/${{ matrix.backend }}
9494

9595
check_format:
9696
name: Check the code format

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def test_call(
338338

339339
# Trigger layer.build(...) to initialize tables.
340340
sample_ids, sample_weights = keras_test_utils.create_random_samples(
341-
feature_configs, ragged=ragged, seed=0
341+
feature_configs, ragged=ragged, seed=0, max_ids_per_sample=10
342342
)
343343
inputs = layer.preprocess(sample_ids, sample_weights)
344344
_ = layer(inputs)

0 commit comments

Comments
 (0)