Skip to content

Commit 312bf4c

Browse files
committed
Add test cases
1 parent 894857e commit 312bf4c

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,33 @@ def test_sparse(self, compile):
267267

268268
torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)
269269

270+
271+
#TODO: Remove this test once the deprecated API has been removed
272+
def test_sparse_deprecated(self):
273+
import sys
274+
import warnings
275+
276+
# We need to clear the cache to force re-importing and trigger the warning again.
277+
modules_to_clear = [
278+
'torchao.dtypes.uintx.block_sparse_layout',
279+
'torchao.dtypes',
280+
]
281+
for mod in modules_to_clear:
282+
if mod in sys.modules:
283+
del sys.modules[mod]
284+
285+
with warnings.catch_warnings(record=True) as w:
286+
warnings.simplefilter("always") # Ensure all warnings are captured
287+
from torchao.dtypes import BlockSparseLayout
288+
self.assertTrue(
289+
any(
290+
issubclass(warning.category, DeprecationWarning)
291+
and "BlockSparseLayout" in str(warning.message)
292+
for warning in w
293+
),
294+
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}"
295+
)
296+
270297

271298
common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
272299
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)

0 commit comments

Comments
 (0)