Skip to content

Commit 4df9d70

Browse files
Ignore regex fallback for quantized policies in DTypePolicyMap (#21608)
* Ignore regex match for quantized policies in DTypePolicyMap * only match complete path components when falling back to regex match * updates docstring * minor conditional modification * updates comment * replace re.search with re.match * switch to re.fullmatch for more explicit behavior spec * Added detailed example to docstring
1 parent a009455 commit 4df9d70

File tree

2 files changed

+151
-56
lines changed

2 files changed

+151
-56
lines changed

keras/src/dtype_policies/dtype_policy_map.py

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,54 @@ def get_config(self):
3535
However, it is also possible to set a regex as the key. See the docstring of
3636
`get` for more details.
3737
38-
See below for a usage example. You can define the naming schema
39-
of the `DTypePolicy`, and then retrieve the corresponding `DTypePolicy`
40-
instance.
41-
42-
```python
43-
dtype_policy_map = DTypePolicyMap()
44-
dtype_policy_map["layer/dense_0"] = DTypePolicy("bfloat16")
45-
dtype_policy_map["layer/dense_1"] = QuantizedDTypePolicy("int8", "bfloat16")
46-
47-
policy_0 = dtype_policy_map["layer/dense_0"]
48-
policy_1 = dtype_policy_map["layer/dense_1"]
49-
policy_2 = dtype_policy_map["layer/dense_2"] # No hit
50-
assert policy_0 == DTypePolicy("bfloat16")
51-
assert policy_1 == QuantizedDTypePolicy("int8", "bfloat16")
52-
assert policy_2 == keras.config.dtype_policy()
53-
```
54-
5538
Args:
5639
default_policy: An optional `DTypePolicy` instance specifying the
5740
default dtype policy. If not specified, the value will default to
5841
`keras.config.dtype_policy()`.
5942
policy_map: An optional dict that maps string to `DTypePolicy`
6043
instances. Defaults to `None`
44+
45+
Example:
46+
47+
```python
48+
>>> from keras.src import dtype_policies
49+
>>> bfloat16 = dtype_policies.DTypePolicy("bfloat16")
50+
>>> float16 = dtype_policies.DTypePolicy("float16")
51+
>>> float32 = dtype_policies.DTypePolicy("float32")
52+
>>> policy_map = DTypePolicyMap(default_policy=float32)
53+
54+
# Set policies using an exact path and a regex pattern.
55+
# Note: "decoder" will only match the exact path, not its children.
56+
>>> policy_map["encoder/layer_0/dense"] = bfloat16
57+
>>> policy_map["encoder/.*"] = float16
58+
>>> policy_map["decoder"] = bfloat16
59+
60+
# 1. An exact match is found and returned directly.
61+
>>> policy_map["encoder/layer_0/dense"].name
62+
'bfloat16'
63+
64+
# 2. A regex match is found for a child layer.
65+
# It matches the "encoder/.*" pattern.
66+
>>> policy_map["encoder/attention/query"].name
67+
'float16'
68+
69+
# 3. No implicit prefix matching occurs.
70+
# "decoder/attention" does not match the key "decoder".
71+
# The default policy is returned.
72+
>>> policy_map["decoder/attention"].name
73+
'float32'
74+
75+
# 4. A ValueError is raised if a path matches multiple patterns.
76+
>>> policy_map["encoder/attention/.*"] = bfloat16
77+
# "encoder/attention/query" now matches two patterns:
78+
# - "encoder/.*"
79+
# - "encoder/attention/.*"
80+
>>> try:
81+
... policy_map["encoder/attention/query"]
82+
... except ValueError as e:
83+
... print(e)
84+
Path 'encoder/attention/query' matches multiple dtype policy ..
85+
```
6186
"""
6287

6388
def __init__(self, default_policy=None, policy_map=None):
@@ -100,24 +125,79 @@ def quantization_mode(self):
100125
def __getitem__(self, key):
101126
"""Retrieves the corresponding `DTypePolicy` by the string key.
102127
103-
When there isn't an exact match, all the existing keys in the map
104-
will be treated as a regex and map against the input key again. When
105-
there are multiple matches for the regex, an `ValueError` will be
106-
raised. Returns `self.default_policy` if there isn't any match found.
128+
This method first attempts an exact key match. If no exact match is
129+
found, it treats all keys in the map as regular expression patterns
130+
and uses `re.fullmatch` to find a policy.
131+
132+
For example, to apply a policy to all sublayers of an `encoder` block,
133+
the key should be explicitly set to `"encoder/.*"`. A key of
134+
`"encoder"` will only match the layer with that exact path.
107135
108136
Args:
109-
key: String key to query a `DTypePolicy`.
137+
key: str. The key to query for a `DTypePolicy`.
110138
111139
Returns:
112-
Corresponding `DTypePolicy` based on the query.
140+
The corresponding `DTypePolicy`. If no match is found, this method
141+
returns `self.default_policy`.
142+
143+
Raises:
144+
ValueError: If the `key` matches more than one regex pattern in the
145+
map.
146+
147+
Example:
148+
149+
```python
150+
>>> from keras.src import dtype_policies
151+
>>> bfloat16 = dtype_policies.DTypePolicy("bfloat16")
152+
>>> float16 = dtype_policies.DTypePolicy("float16")
153+
>>> float32 = dtype_policies.DTypePolicy("float32")
154+
>>> policy_map = DTypePolicyMap(default_policy=float32)
155+
156+
# Set policies using an exact path and a regex pattern.
157+
# Note: "decoder" will only match the exact path, not its children.
158+
>>> policy_map["encoder/layer_0/dense"] = bfloat16
159+
>>> policy_map["encoder/.*"] = float16
160+
>>> policy_map["decoder"] = bfloat16
161+
162+
# 1. An exact match is found and returned directly.
163+
>>> policy_map["encoder/layer_0/dense"].name
164+
'bfloat16'
165+
166+
# 2. A regex match is found for a child layer.
167+
# It matches the "encoder/.*" pattern.
168+
>>> policy_map["encoder/attention/query"].name
169+
'float16'
170+
171+
# 3. No implicit prefix matching occurs.
172+
# "decoder/attention" does not match the key "decoder".
173+
# The default policy is returned.
174+
>>> policy_map["decoder/attention"].name
175+
'float32'
176+
177+
# 4. A ValueError is raised if a path matches multiple patterns.
178+
>>> policy_map["encoder/attention/.*"] = bfloat16
179+
# "encoder/attention/query" now matches two patterns:
180+
# - "encoder/.*"
181+
# - "encoder/attention/.*"
182+
>>> try:
183+
... policy_map["encoder/attention/query"]
184+
... except ValueError as e:
185+
... print(e)
186+
Path 'encoder/attention/query' matches multiple dtype policy ..
187+
```
113188
"""
189+
# 1. Check for an exact match.
114190
if key in self._policy_map:
115191
return self._policy_map[key]
116192

117-
matching_keys = []
118-
for k in self._policy_map:
119-
if re.search(k, key):
120-
matching_keys.append(k)
193+
# 2. Fallback to a full regex match.
194+
matching_keys = [
195+
pattern
196+
for pattern in self._policy_map
197+
if re.fullmatch(pattern, key)
198+
]
199+
200+
# 3. Handle cases based on the number of matches found.
121201
if len(matching_keys) > 1:
122202
raise ValueError(
123203
f"Path '{key}' matches multiple dtype policy "
@@ -127,6 +207,8 @@ def __getitem__(self, key):
127207
)
128208
elif len(matching_keys) == 1:
129209
return self._policy_map[matching_keys[0]]
210+
211+
# 4. If there were no matches, return the default.
130212
return self.default_policy
131213

132214
def __setitem__(self, key, policy):

keras/src/dtype_policies/dtype_policy_map_test.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -124,50 +124,63 @@ def test_add(self):
124124
dtype_policy_map["layer/dense_3"] = 123
125125

126126
def test_get(self):
127-
dtype_policy_map = DTypePolicyMap()
128-
dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy(
129-
"bfloat16"
130-
)
131-
dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy(
127+
# 1. Setup
128+
bfloat16_policy = dtype_policies.DTypePolicy("bfloat16")
129+
int8_policy = dtype_policies.QuantizedDTypePolicy(
132130
"int8", "mixed_bfloat16"
133131
)
134-
dtype_policy_map["layer/dense_2"] = (
135-
dtype_policies.QuantizedFloat8DTypePolicy("float8", "mixed_float16")
136-
)
132+
float32_policy = dtype_policies.DTypePolicy("float32")
133+
float16_policy = dtype_policies.DTypePolicy("float16")
137134

135+
policy_map = DTypePolicyMap()
136+
# Policy for an exact layer path
137+
policy_map["model/encoder/layer_0/dense"] = bfloat16_policy
138+
# Policy for a layer that is also a prefix of another layer's name
139+
policy_map["model/encoder/attention/query"] = int8_policy
140+
# Regex policies for entire scopes MUST include wildcards
141+
policy_map["model/decoder/.*"] = float32_policy
142+
policy_map["model/decoder/attention/.*"] = float16_policy
143+
144+
# 2. Test exact match
138145
self.assertEqual(
139-
dtype_policy_map["layer/dense_0"],
140-
dtype_policies.DTypePolicy("bfloat16"),
141-
)
142-
self.assertEqual(
143-
dtype_policy_map["layer/dense_1"],
144-
dtype_policies.QuantizedDTypePolicy("int8", "mixed_bfloat16"),
146+
policy_map["model/encoder/layer_0/dense"], bfloat16_policy
145147
)
146148
self.assertEqual(
147-
dtype_policy_map["layer/dense_2"],
148-
dtype_policies.QuantizedFloat8DTypePolicy(
149-
"float8", "mixed_float16"
150-
),
149+
policy_map["model/encoder/attention/query"], int8_policy
151150
)
152151

153-
self.assertNotEqual(
154-
dtype_policy_map["layer/dense_2"],
155-
dtype_policies.QuantizedFloat8DTypePolicy("float8", "bfloat16"),
152+
# 3. Test successful regex fallback (explicit wildcard)
153+
# "model/decoder/.*" should match its children.
154+
self.assertEqual(policy_map["model/decoder/layer_0"], float32_policy)
155+
156+
# 4. Test that partial matches are ignored
157+
# The exact key "model/encoder/attention/query" should not match
158+
# "model/encoder/attention/query_norm" without a wildcard.
159+
self.assertEqual(
160+
policy_map["model/encoder/attention/query_norm"],
161+
policy_map.default_policy,
156162
)
163+
# A plain key "model/decoder" will not match "model/decoder/layer_0"
164+
policy_map["model/decoder"] = bfloat16_policy # Add exact key
165+
self.assertEqual(policy_map["model/decoder/layer_0"], float32_policy)
166+
# Still matches the more general regex
167+
self.assertEqual(policy_map["model/decoder"], bfloat16_policy)
157168

158-
# No hit
169+
# 5. Test no match
159170
self.assertEqual(
160-
dtype_policy_map["layer/batch_normalization"],
161-
dtype_policy_map.default_policy,
171+
policy_map["model/embedding"], policy_map.default_policy
162172
)
163173

164-
# It will cause a ValueError in the case of one-to-many.
165-
dtype_policy_map["dense"] = dtype_policies.DTypePolicy("float32")
166-
dtype_policy_map["dense_1"] = dtype_policies.DTypePolicy("float32")
174+
# 6. Test multiple regex matches causing a ValueError
175+
# "model/decoder/attention/output" matches two regex patterns:
176+
# - "model/decoder/.*"
177+
# - "model/decoder/attention/.*"
167178
with self.assertRaisesRegex(
168-
ValueError, "Path 'dense_10' matches multiple dtype policy"
179+
ValueError,
180+
"Path 'model/decoder/attention/output' matches multiple "
181+
"dtype policy",
169182
):
170-
dtype_policy_map["dense_10"]
183+
_ = policy_map["model/decoder/attention/output"]
171184

172185
def test_delete(self):
173186
dtype_policy_map = DTypePolicyMap()

0 commit comments

Comments
 (0)