Skip to content

Commit 1cc463f

Browse files
Merge branch 'keras-team:master' into openvino-nn-functions
2 parents 0ede12d + 529e162 commit 1cc463f

28 files changed

+933
-58
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,4 @@ jobs:
148148
pip uninstall -y keras keras-nightly
149149
pip install -e "." --progress-bar off --upgrade
150150
- name: Run pre-commit
151-
run: pre-commit run --all-files --hook-stage manual
151+
run: pre-commit run --all-files --hook-stage manual

.github/workflows/tpu_tests.yml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Keras Tests
2+
3+
# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
4+
# Currently only basic flow tests run with NNX enabled
5+
6+
on:
7+
push:
8+
branches: [ master ]
9+
pull_request_review:
10+
types: [submitted]
11+
release:
12+
types: [created]
13+
14+
permissions:
15+
contents: read
16+
17+
jobs:
18+
19+
test-in-container:
20+
name: Run tests on TPU
21+
runs-on: linux-x86-ct6e-44-1tpu
22+
# Only run on approved PRs, pushes to master, or releases
23+
if: |
24+
github.event_name == 'push' ||
25+
github.event_name == 'release' ||
26+
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved')
27+
28+
strategy:
29+
fail-fast: false
30+
matrix:
31+
backend: [jax]
32+
33+
container:
34+
image: python:3.10-slim
35+
options: --privileged --network host
36+
37+
steps:
38+
- name: Checkout Repository
39+
uses: actions/checkout@v4
40+
41+
- name: Install Dependencies
42+
run: |
43+
pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt \
44+
45+
- name: Set Keras Backend
46+
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV
47+
48+
- name: Run Verification and Tests
49+
run: |
50+
echo "Successfully running inside the public python container!"
51+
echo "Verifying JAX installation..."
52+
python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices : {jax.devices()}')"
53+
54+
pytest keras --ignore keras/src/applications \
55+
--cov=keras \
56+
--cov-config=pyproject.toml

conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ def pytest_collection_modifyitems(config, items):
3131
line.strip() for line in openvino_skipped_tests if line.strip()
3232
]
3333

34+
tpu_skipped_tests = []
35+
if backend() == "jax":
36+
try:
37+
with open(
38+
"keras/src/backend/jax/excluded_tpu_tests.txt", "r"
39+
) as file:
40+
tpu_skipped_tests = file.readlines()
41+
# it is necessary to check if stripped line is not empty
42+
# and exclude such lines
43+
tpu_skipped_tests = [
44+
line.strip() for line in tpu_skipped_tests if line.strip()
45+
]
46+
except FileNotFoundError:
47+
pass # File doesn't exist, no tests to skip
48+
3449
requires_trainable_backend = pytest.mark.skipif(
3550
backend() in ["numpy", "openvino"],
3651
reason="Trainer not implemented for NumPy and OpenVINO backend.",
@@ -49,6 +64,14 @@ def pytest_collection_modifyitems(config, items):
4964
"Not supported operation by openvino backend",
5065
)
5166
)
67+
# also, skip concrete tests for TPU when using JAX backend
68+
for skipped_test in tpu_skipped_tests:
69+
if skipped_test in item.nodeid:
70+
item.add_marker(
71+
pytest.mark.skip(
72+
reason="Known TPU test failure",
73+
)
74+
)
5275

5376

5477
def skip_if_backend(given_backend, reason):

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
from keras.src.ops.numpy import dot as dot
180180
from keras.src.ops.numpy import einsum as einsum
181181
from keras.src.ops.numpy import empty as empty
182+
from keras.src.ops.numpy import empty_like as empty_like
182183
from keras.src.ops.numpy import equal as equal
183184
from keras.src.ops.numpy import exp as exp
184185
from keras.src.ops.numpy import exp2 as exp2

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from keras.src.ops.numpy import dot as dot
6666
from keras.src.ops.numpy import einsum as einsum
6767
from keras.src.ops.numpy import empty as empty
68+
from keras.src.ops.numpy import empty_like as empty_like
6869
from keras.src.ops.numpy import equal as equal
6970
from keras.src.ops.numpy import exp as exp
7071
from keras.src.ops.numpy import exp2 as exp2

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
from keras.src.ops.numpy import dot as dot
180180
from keras.src.ops.numpy import einsum as einsum
181181
from keras.src.ops.numpy import empty as empty
182+
from keras.src.ops.numpy import empty_like as empty_like
182183
from keras.src.ops.numpy import equal as equal
183184
from keras.src.ops.numpy import exp as exp
184185
from keras.src.ops.numpy import exp2 as exp2

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from keras.src.ops.numpy import dot as dot
6666
from keras.src.ops.numpy import einsum as einsum
6767
from keras.src.ops.numpy import empty as empty
68+
from keras.src.ops.numpy import empty_like as empty_like
6869
from keras.src.ops.numpy import equal as equal
6970
from keras.src.ops.numpy import exp as exp
7071
from keras.src.ops.numpy import exp2 as exp2

keras/src/backend/common/name_scope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs):
5858
name_scope_stack = global_state.get_global_attribute(
5959
"name_scope_stack"
6060
)
61-
name_scope_stack.pop()
61+
if name_scope_stack:
62+
name_scope_stack.pop()
6263

6364

6465
def current_path():

keras/src/backend/common/name_scope_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import threading
2+
13
from keras.src import testing
4+
from keras.src.backend.common import global_state
25
from keras.src.backend.common.name_scope import current_path
36
from keras.src.backend.common.name_scope import name_scope
47

@@ -46,3 +49,85 @@ def test_override_parent(self):
4649
current_path(), "absolute/path/middle/inner"
4750
)
4851
self.assertEqual(current_path(), "outer")
52+
53+
def test_exit_with_none_stack(self):
54+
"""Test that __exit__ handles None name_scope_stack gracefully."""
55+
# Create a name_scope instance
56+
scope = name_scope("test")
57+
# Enter the scope normally
58+
scope.__enter__()
59+
60+
# Simulate the scenario where global state is cleared
61+
# (e.g., in a different thread)
62+
global_state.set_global_attribute("name_scope_stack", None)
63+
64+
# Exit should not raise an AttributeError
65+
scope.__exit__()
66+
67+
# Clean up: reset the stack
68+
global_state.set_global_attribute("name_scope_stack", [])
69+
70+
def test_exit_with_empty_stack(self):
71+
"""Test that __exit__ handles empty name_scope_stack gracefully."""
72+
# Create a name_scope instance
73+
scope = name_scope("test")
74+
# Enter the scope normally
75+
scope.__enter__()
76+
77+
# Simulate the scenario where the stack is cleared
78+
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
79+
name_scope_stack.clear()
80+
81+
# Exit should not raise an IndexError
82+
scope.__exit__()
83+
84+
# Verify stack is still empty
85+
name_scope_stack = global_state.get_global_attribute(
86+
"name_scope_stack", default=[]
87+
)
88+
self.assertEqual(len(name_scope_stack), 0)
89+
90+
def test_multithreaded_name_scope(self):
91+
"""Test name_scope in multithreaded environment."""
92+
results = []
93+
94+
def thread_function(thread_id):
95+
# Each thread should have its own name_scope_stack
96+
with name_scope(f"thread_{thread_id}"):
97+
path = current_path()
98+
results.append(path)
99+
# Verify we get the expected path
100+
self.assertEqual(path, f"thread_{thread_id}")
101+
102+
# Create and start multiple threads
103+
threads = []
104+
for i in range(5):
105+
thread = threading.Thread(target=thread_function, args=(i,))
106+
threads.append(thread)
107+
thread.start()
108+
109+
# Wait for all threads to complete
110+
for thread in threads:
111+
thread.join()
112+
113+
# Verify all threads executed successfully
114+
self.assertEqual(len(results), 5)
115+
116+
def test_exit_without_pop_on_exit(self):
117+
"""Test that __exit__ respects _pop_on_exit flag."""
118+
# Create a name_scope but don't enter it
119+
scope = name_scope("test")
120+
# _pop_on_exit should be False
121+
self.assertFalse(scope._pop_on_exit)
122+
123+
# Set up a stack manually
124+
global_state.set_global_attribute("name_scope_stack", [scope])
125+
126+
scope.__exit__()
127+
128+
# Verify the stack still contains the scope
129+
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
130+
self.assertEqual(len(name_scope_stack), 1)
131+
132+
# Clean up
133+
global_state.set_global_attribute("name_scope_stack", [])

keras/src/backend/common/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def value(self):
276276
return self._maybe_autocast(self._value)
277277

278278
def assign(self, value):
279-
value = self._convert_to_tensor(value, dtype=self.dtype)
279+
value = self._convert_to_tensor(value, dtype=self._dtype)
280280
if not shape_equal(value.shape, self.shape):
281281
raise ValueError(
282282
"The shape of the target variable and "

0 commit comments

Comments
 (0)