@@ -135,49 +135,47 @@ def test_get(self):
135
135
policy_map = DTypePolicyMap ()
136
136
# Policy for an exact layer path
137
137
policy_map ["model/encoder/layer_0/dense" ] = bfloat16_policy
138
- # Policy that could be partially matched
138
+ # Policy for a layer that is also a prefix of another layer's name
139
139
policy_map ["model/encoder/attention/query" ] = int8_policy
140
- # Regex policies for entire scopes
141
- policy_map ["model/decoder" ] = float32_policy
142
- policy_map ["model/decoder/attention" ] = float16_policy
140
+ # Regex policies for entire scopes MUST include wildcards
141
+ policy_map ["model/decoder/.* " ] = float32_policy
142
+ policy_map ["model/decoder/attention/.* " ] = float16_policy
143
143
144
144
# 2. Test exact match
145
- # An exact key should always return the correct policy.
146
145
self .assertEqual (
147
146
policy_map ["model/encoder/layer_0/dense" ], bfloat16_policy
148
147
)
149
148
self .assertEqual (
150
149
policy_map ["model/encoder/attention/query" ], int8_policy
151
150
)
152
151
153
- # 3. Test successful regex fallback (component-wise match)
154
- # "model/decoder" should match "model/decoder/layer_0" because
155
- # it's a full component prefix.
152
+ # 3. Test successful regex fallback (explicit wildcard)
153
+ # "model/decoder/.*" should match its children.
156
154
self .assertEqual (policy_map ["model/decoder/layer_0" ], float32_policy )
157
155
158
- # 4. Test prevention of partial regex match
159
- # "model/encoder/attention/query" should NOT match
160
- # "model/encoder/attention/query_norm"
161
- # as it's not followed by a '/' or the end of the string.
156
+ # 4. Test that partial matches are ignored
157
+ # The exact key "model/encoder/attention/query" should not match
158
+ # "model/encoder/attention/query_norm" without a wildcard.
162
159
self .assertEqual (
163
160
policy_map ["model/encoder/attention/query_norm" ],
164
161
policy_map .default_policy ,
165
162
)
163
+ # A plain key "model/decoder" will not match "model/decoder/layer_0"
164
+ policy_map ["model/decoder" ] = bfloat16_policy # Add exact key
165
+ self .assertEqual (
166
+ policy_map ["model/decoder/layer_0" ], float32_policy
167
+ ) # Still matches the more general regex
168
+ self .assertEqual (policy_map ["model/decoder" ], bfloat16_policy )
166
169
167
170
# 5. Test no match
168
- # A key with no exact or valid regex match should return the default.
169
171
self .assertEqual (
170
172
policy_map ["model/embedding" ], policy_map .default_policy
171
173
)
172
- self .assertEqual (
173
- policy_map ["prefix/model/decoder/attention" ],
174
- policy_map .default_policy ,
175
- )
176
174
177
175
# 6. Test multiple regex matches causing a ValueError
178
- # The path "model/decoder/attention/output" matches two regex keys :
179
- # - "model/decoder"
180
- # - "model/decoder/attention"
176
+ # "model/decoder/attention/output" matches two regex patterns :
177
+ # - "model/decoder/.* "
178
+ # - "model/decoder/attention/.* "
181
179
with self .assertRaisesRegex (
182
180
ValueError ,
183
181
"Path 'model/decoder/attention/output' matches multiple "
0 commit comments