@@ -35,29 +35,54 @@ def get_config(self):
35
35
However, it is also possible to set a regex as the key. See the docstring of
36
36
`get` for more details.
37
37
38
- See below for a usage example. You can define the naming schema
39
- of the `DTypePolicy`, and then retrieve the corresponding `DTypePolicy`
40
- instance.
41
-
42
- ```python
43
- dtype_policy_map = DTypePolicyMap()
44
- dtype_policy_map["layer/dense_0"] = DTypePolicy("bfloat16")
45
- dtype_policy_map["layer/dense_1"] = QuantizedDTypePolicy("int8", "bfloat16")
46
-
47
- policy_0 = dtype_policy_map["layer/dense_0"]
48
- policy_1 = dtype_policy_map["layer/dense_1"]
49
- policy_2 = dtype_policy_map["layer/dense_2"] # No hit
50
- assert policy_0 == DTypePolicy("bfloat16")
51
- assert policy_1 == QuantizedDTypePolicy("int8", "bfloat16")
52
- assert policy_2 == keras.config.dtype_policy()
53
- ```
54
-
55
38
Args:
56
39
default_policy: An optional `DTypePolicy` instance specifying the
57
40
default dtype policy. If not specified, the value will default to
58
41
`keras.config.dtype_policy()`.
59
42
policy_map: An optional dict that maps string to `DTypePolicy`
60
43
instances. Defaults to `None`
44
+
45
+ Example:
46
+
47
+ ```python
48
+ >>> from keras.src import dtype_policies
49
+ >>> bfloat16 = dtype_policies.DTypePolicy("bfloat16")
50
+ >>> float16 = dtype_policies.DTypePolicy("float16")
51
+ >>> float32 = dtype_policies.DTypePolicy("float32")
52
+ >>> policy_map = DTypePolicyMap(default_policy=float32)
53
+
54
+ # Set policies using an exact path and a regex pattern.
55
+ # Note: "decoder" will only match the exact path, not its children.
56
+ >>> policy_map["encoder/layer_0/dense"] = bfloat16
57
+ >>> policy_map["encoder/.*"] = float16
58
+ >>> policy_map["decoder"] = bfloat16
59
+
60
+ # 1. An exact match is found and returned directly.
61
+ >>> policy_map["encoder/layer_0/dense"].name
62
+ 'bfloat16'
63
+
64
+ # 2. A regex match is found for a child layer.
65
+ # It matches the "encoder/.*" pattern.
66
+ >>> policy_map["encoder/attention/query"].name
67
+ 'float16'
68
+
69
+ # 3. No implicit prefix matching occurs.
70
+ # "decoder/attention" does not match the key "decoder".
71
+ # The default policy is returned.
72
+ >>> policy_map["decoder/attention"].name
73
+ 'float32'
74
+
75
+ # 4. A ValueError is raised if a path matches multiple patterns.
76
+ >>> policy_map["encoder/attention/.*"] = bfloat16
77
+ # "encoder/attention/query" now matches two patterns:
78
+ # - "encoder/.*"
79
+ # - "encoder/attention/.*"
80
+ >>> try:
81
+ ... policy_map["encoder/attention/query"]
82
+ ... except ValueError as e:
83
+ ... print(e)
84
+ Path 'encoder/attention/query' matches multiple dtype policy ..
85
+ ```
61
86
"""
62
87
63
88
def __init__ (self , default_policy = None , policy_map = None ):
@@ -100,24 +125,79 @@ def quantization_mode(self):
100
125
def __getitem__ (self , key ):
101
126
"""Retrieves the corresponding `DTypePolicy` by the string key.
102
127
103
- When there isn't an exact match, all the existing keys in the map
104
- will be treated as a regex and map against the input key again. When
105
- there are multiple matches for the regex, an `ValueError` will be
106
- raised. Returns `self.default_policy` if there isn't any match found.
128
+ This method first attempts an exact key match. If no exact match is
129
+ found, it treats all keys in the map as regular expression patterns
130
+ and uses `re.fullmatch` to find a policy.
131
+
132
+ For example, to apply a policy to all sublayers of an `encoder` block,
133
+ the key should be explicitly set to `"encoder/.*"`. A key of
134
+ `"encoder"` will only match the layer with that exact path.
107
135
108
136
Args:
109
- key: String key to query a `DTypePolicy`.
137
+ key: str. The key to query for a `DTypePolicy`.
110
138
111
139
Returns:
112
- Corresponding `DTypePolicy` based on the query.
140
+ The corresponding `DTypePolicy`. If no match is found, this method
141
+ returns `self.default_policy`.
142
+
143
+ Raises:
144
+ ValueError: If the `key` matches more than one regex pattern in the
145
+ map.
146
+
147
+ Example:
148
+
149
+ ```python
150
+ >>> from keras.src import dtype_policies
151
+ >>> bfloat16 = dtype_policies.DTypePolicy("bfloat16")
152
+ >>> float16 = dtype_policies.DTypePolicy("float16")
153
+ >>> float32 = dtype_policies.DTypePolicy("float32")
154
+ >>> policy_map = DTypePolicyMap(default_policy=float32)
155
+
156
+ # Set policies using an exact path and a regex pattern.
157
+ # Note: "decoder" will only match the exact path, not its children.
158
+ >>> policy_map["encoder/layer_0/dense"] = bfloat16
159
+ >>> policy_map["encoder/.*"] = float16
160
+ >>> policy_map["decoder"] = bfloat16
161
+
162
+ # 1. An exact match is found and returned directly.
163
+ >>> policy_map["encoder/layer_0/dense"].name
164
+ 'bfloat16'
165
+
166
+ # 2. A regex match is found for a child layer.
167
+ # It matches the "encoder/.*" pattern.
168
+ >>> policy_map["encoder/attention/query"].name
169
+ 'float16'
170
+
171
+ # 3. No implicit prefix matching occurs.
172
+ # "decoder/attention" does not match the key "decoder".
173
+ # The default policy is returned.
174
+ >>> policy_map["decoder/attention"].name
175
+ 'float32'
176
+
177
+ # 4. A ValueError is raised if a path matches multiple patterns.
178
+ >>> policy_map["encoder/attention/.*"] = bfloat16
179
+ # "encoder/attention/query" now matches two patterns:
180
+ # - "encoder/.*"
181
+ # - "encoder/attention/.*"
182
+ >>> try:
183
+ ... policy_map["encoder/attention/query"]
184
+ ... except ValueError as e:
185
+ ... print(e)
186
+ Path 'encoder/attention/query' matches multiple dtype policy ..
187
+ ```
113
188
"""
189
+ # 1. Check for an exact match.
114
190
if key in self ._policy_map :
115
191
return self ._policy_map [key ]
116
192
117
- matching_keys = []
118
- for k in self ._policy_map :
119
- if re .search (k , key ):
120
- matching_keys .append (k )
193
+ # 2. Fallback to a full regex match.
194
+ matching_keys = [
195
+ pattern
196
+ for pattern in self ._policy_map
197
+ if re .fullmatch (pattern , key )
198
+ ]
199
+
200
+ # 3. Handle cases based on the number of matches found.
121
201
if len (matching_keys ) > 1 :
122
202
raise ValueError (
123
203
f"Path '{ key } ' matches multiple dtype policy "
@@ -127,6 +207,8 @@ def __getitem__(self, key):
127
207
)
128
208
elif len (matching_keys ) == 1 :
129
209
return self ._policy_map [matching_keys [0 ]]
210
+
211
+ # 4. If there were no matches, return the default.
130
212
return self .default_policy
131
213
132
214
def __setitem__ (self , key , policy ):
0 commit comments