Skip to content

Commit f6472fb

Browse files
switch to re.fullmatch for more explicit behavior spec
1 parent d9748bf commit f6472fb

File tree

2 files changed

+30
-38
lines changed

2 files changed

+30
-38
lines changed

keras/src/dtype_policies/dtype_policy_map.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,25 +98,22 @@ def quantization_mode(self):
9898
return self.default_policy.quantization_mode
9999

100100
def __getitem__(self, key):
101-
"""Retrieves a `DTypePolicy` by its key, with a regex fallback.
101+
"""Retrieves a `DTypePolicy` by treating map keys as regex patterns.
102102
103-
This method first attempts an exact key match.
104-
If no exact match is found, it treats the keys in the map as regular
105-
expression patterns. A pattern is considered a match only if it aligns
106-
with the start of the input `key` and covers a complete path component
107-
(i.e., it is followed by a `/` or the end of the string).
103+
This method first attempts an exact key match for performance. If no
104+
exact match is found, it treats all keys in the map as regular
105+
expression patterns and uses `re.fullmatch` to find a policy.
108106
109-
For example, a map key "encoder" will match the path
110-
"encoder/attention", but it will not match "encoder_v2". This
111-
component-wise matching prevents incorrect policy assignments to
112-
similarly named layers.
107+
For example, to apply a policy to all sublayers of an `encoder` block,
108+
the key should be explicitly set to `"encoder/.*"`. A key of
109+
`"encoder"` will only match the layer with that exact path.
113110
114111
Args:
115-
key: The string key to query for a `DTypePolicy`.
112+
key: str. The key to query for a `DTypePolicy`.
116113
117114
Returns:
118-
The corresponding `DTypePolicy`. If no valid match is found
119-
(either exact or regex), this method returns `self.default_policy`.
115+
The corresponding `DTypePolicy`. If no match is found, this method
116+
returns `self.default_policy`.
120117
121118
Raises:
122119
ValueError: If the `key` matches more than one regex pattern in the
@@ -126,14 +123,11 @@ def __getitem__(self, key):
126123
if key in self._policy_map:
127124
return self._policy_map[key]
128125

129-
# 2. If no exact match is found, fallback to a regex match.
130-
# Check for a match that covers a full path component.
131-
# The pattern must match from the start of the `key` and be
132-
# followed by either a '/' or the end of the string.
126+
# 2. Fallback to a full regex match.
133127
matching_keys = [
134128
pattern
135129
for pattern in self._policy_map
136-
if re.match(f"{pattern}(/|$)", key)
130+
if re.fullmatch(pattern, key)
137131
]
138132

139133
# 3. Handle cases based on the number of matches found.

keras/src/dtype_policies/dtype_policy_map_test.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,49 +135,47 @@ def test_get(self):
135135
policy_map = DTypePolicyMap()
136136
# Policy for an exact layer path
137137
policy_map["model/encoder/layer_0/dense"] = bfloat16_policy
138-
# Policy that could be partially matched
138+
# Policy for a layer that is also a prefix of another layer's name
139139
policy_map["model/encoder/attention/query"] = int8_policy
140-
# Regex policies for entire scopes
141-
policy_map["model/decoder"] = float32_policy
142-
policy_map["model/decoder/attention"] = float16_policy
140+
# Regex policies for entire scopes MUST include wildcards
141+
policy_map["model/decoder/.*"] = float32_policy
142+
policy_map["model/decoder/attention/.*"] = float16_policy
143143

144144
# 2. Test exact match
145-
# An exact key should always return the correct policy.
146145
self.assertEqual(
147146
policy_map["model/encoder/layer_0/dense"], bfloat16_policy
148147
)
149148
self.assertEqual(
150149
policy_map["model/encoder/attention/query"], int8_policy
151150
)
152151

153-
# 3. Test successful regex fallback (component-wise match)
154-
# "model/decoder" should match "model/decoder/layer_0" because
155-
# it's a full component prefix.
152+
# 3. Test successful regex fallback (explicit wildcard)
153+
# "model/decoder/.*" should match its children.
156154
self.assertEqual(policy_map["model/decoder/layer_0"], float32_policy)
157155

158-
# 4. Test prevention of partial regex match
159-
# "model/encoder/attention/query" should NOT match
160-
# "model/encoder/attention/query_norm"
161-
# as it's not followed by a '/' or the end of the string.
156+
# 4. Test that partial matches are ignored
157+
# The exact key "model/encoder/attention/query" should not match
158+
# "model/encoder/attention/query_norm" without a wildcard.
162159
self.assertEqual(
163160
policy_map["model/encoder/attention/query_norm"],
164161
policy_map.default_policy,
165162
)
163+
# A plain key "model/decoder" will not match "model/decoder/layer_0"
164+
policy_map["model/decoder"] = bfloat16_policy # Add exact key
165+
self.assertEqual(
166+
policy_map["model/decoder/layer_0"], float32_policy
167+
) # Still matches the more general regex
168+
self.assertEqual(policy_map["model/decoder"], bfloat16_policy)
166169

167170
# 5. Test no match
168-
# A key with no exact or valid regex match should return the default.
169171
self.assertEqual(
170172
policy_map["model/embedding"], policy_map.default_policy
171173
)
172-
self.assertEqual(
173-
policy_map["prefix/model/decoder/attention"],
174-
policy_map.default_policy,
175-
)
176174

177175
# 6. Test multiple regex matches causing a ValueError
178-
# The path "model/decoder/attention/output" matches two regex keys:
179-
# - "model/decoder"
180-
# - "model/decoder/attention"
176+
# "model/decoder/attention/output" matches two regex patterns:
177+
# - "model/decoder/.*"
178+
# - "model/decoder/attention/.*"
181179
with self.assertRaisesRegex(
182180
ValueError,
183181
"Path 'model/decoder/attention/output' matches multiple "

0 commit comments

Comments
 (0)