From 392ff7ca7539bb0c13ce8546c8b64be771136d2b Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 6 Aug 2025 16:26:34 -0400 Subject: [PATCH 1/2] Test all specs in SpecDB with valid=True, out=False --- facto/inputgen/argtuple/gen.py | 48 +++++++++++++++++++++++ facto/utils.py | 42 ++++++++++++++++++++ test/specdb/test_specdb.py | 70 ++++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 facto/utils.py create mode 100644 test/specdb/test_specdb.py diff --git a/facto/inputgen/argtuple/gen.py b/facto/inputgen/argtuple/gen.py index 811bb03..ac180cf 100644 --- a/facto/inputgen/argtuple/gen.py +++ b/facto/inputgen/argtuple/gen.py @@ -83,3 +83,51 @@ def gen( engine = MetaArgTupleEngine(self._modified_spec, out=out) for meta_tuple in engine.gen(valid=valid): yield self.gen_tuple(meta_tuple, out=out) + + def gen_errors( + self, op, *, valid: bool = True, out: bool = False + ) -> Generator[ + Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any + ]: + """ + Generate input tuples and yield only those that don't behave as expected. + + This function takes the same signature as gen() but with an additional + op parameter. It filters inputs based on whether they behave as expected: + - When valid=True: yields inputs that should be valid but DO error + - When valid=False: yields inputs that should be invalid but DON'T error + + Args: + op: The operation/function to test the inputs against + valid: Whether to generate valid or invalid inputs (same as gen()) + out: Whether to include output arguments (same as gen()) + + Yields: + Tuples of (posargs, inkwargs, outargs) that don't behave as expected + """ + for posargs, inkwargs, outargs in self.gen(valid=valid, out=out): + try: + # Try to execute the operation with the generated inputs + if out: + # If there are output arguments, include them in the call + op(*posargs, **inkwargs, **outargs) + else: + # Otherwise, just call with positional and keyword arguments + op(*posargs, **inkwargs) + + # If execution succeeds: + if valid: + # When valid=True, we expect success, so this is NOT a bug + continue + else: + # When valid=False, we expect failure, so success IS a bug + yield posargs, inkwargs, outargs + + except Exception: + # If execution fails: + if valid: + # When valid=True, we expect success, so failure IS a bug + yield posargs, inkwargs, outargs + else: + # When valid=False, we expect failure, so this is NOT a bug + continue diff --git a/facto/utils.py b/facto/utils.py new file mode 100644 index 0000000..661bcae --- /dev/null +++ b/facto/utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def get_op_overload(op_name: str): + """ + Get the torch operation overload from an operation name. + + Args: + op_name: Operation name in the format "op_base.overload" (e.g., "add.Tensor") + + Returns: + The torch operation overload (e.g., torch.ops.aten.add.Tensor) + + Raises: + AttributeError: If the operation is not found + ValueError: If the operation name format is invalid + """ + if "." not in op_name: + raise ValueError(f"Operation name '{op_name}' must contain a '.' to separate base and overload") + + parts = op_name.split(".") + if len(parts) != 2: + raise ValueError(f"Operation name '{op_name}' must be in format 'op_base.overload'") + + op_base, overload = parts + + # Get the operation from torch.ops.aten + if not hasattr(torch.ops.aten, op_base): + raise AttributeError(f"Operation base '{op_base}' not found in torch.ops.aten") + + op_obj = getattr(torch.ops.aten, op_base) + + if not hasattr(op_obj, overload): + raise AttributeError(f"Overload '{overload}' not found for operation '{op_base}'") + + return getattr(op_obj, overload) diff --git a/test/specdb/test_specdb.py b/test/specdb/test_specdb.py new file mode 100644 index 0000000..51be3ff --- /dev/null +++ b/test/specdb/test_specdb.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from facto.inputgen.argtuple.gen import ArgumentTupleGenerator +from facto.specdb.db import SpecDictDB +from facto.utils import get_op_overload + + +class TestSpecDBOperations(unittest.TestCase): + """Test class for validating all specs in SpecDB using gen_errors.""" + + def test_all_ops(self): + """ + Test all ops in SpecDB. + + This test iterates through all operations in SpecDB and calls + ArgumentTupleGenerator.gen_errors with valid=True, out=False + for each operation. Each operation is tested as a subtest. + """ + # Get all operation names from SpecDB + op_names = list(SpecDictDB.keys()) + + skip_ops = [ + '_native_batch_norm_legit_no_training.default', + 'addmm.default', + 'arange.default', + 'arange.start_step', + 'constant_pad_nd.default', + 'reflection_pad1d.default', + 'reflection_pad2d.default', + 'reflection_pad3d.default', + 'replication_pad1d.default', + 'replication_pad2d.default', + 'replication_pad3d.default', + 'split_with_sizes_copy.default', + ] + + for op_name in op_names: + if op_name in skip_ops: + continue + with self.subTest(op=op_name): + try: + # Get the spec and operation + spec = SpecDictDB[op_name] + op = get_op_overload(op_name) + generator = ArgumentTupleGenerator(spec) + except Exception as e: + # If we can't resolve the operation or there's another issue, + # fail this subtest with a descriptive message + self.fail(f"Failed to test operation {op_name}: {e}") + + try: + errors = list(generator.gen_errors(op, valid=True, out=False)) + except Exception as e: + self.fail(f"Failed while testing operation {op_name}: {e}") + + if len(errors) > 0: + self.fail(f"Found {len(errors)} errors for {op_name} with valid=True, out=False") + + + + +if __name__ == "__main__": + unittest.main() From f6075260524ec4fc30a8ea2fda15357fdd685988 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 6 Aug 2025 17:53:08 -0400 Subject: [PATCH 2/2] fix linter --- facto/utils.py | 12 +++++++++--- test/specdb/test_specdb.py | 30 +++++++++++++++--------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/facto/utils.py b/facto/utils.py index 661bcae..e3bf146 100644 --- a/facto/utils.py +++ b/facto/utils.py @@ -22,11 +22,15 @@ def get_op_overload(op_name: str): ValueError: If the operation name format is invalid """ if "." not in op_name: - raise ValueError(f"Operation name '{op_name}' must contain a '.' to separate base and overload") + raise ValueError( + f"Operation name '{op_name}' must contain a '.' to separate base and overload" + ) parts = op_name.split(".") if len(parts) != 2: - raise ValueError(f"Operation name '{op_name}' must be in format 'op_base.overload'") + raise ValueError( + f"Operation name '{op_name}' must be in format 'op_base.overload'" + ) op_base, overload = parts @@ -37,6 +41,8 @@ def get_op_overload(op_name: str): op_obj = getattr(torch.ops.aten, op_base) if not hasattr(op_obj, overload): - raise AttributeError(f"Overload '{overload}' not found for operation '{op_base}'") + raise AttributeError( + f"Overload '{overload}' not found for operation '{op_base}'" + ) return getattr(op_obj, overload) diff --git a/test/specdb/test_specdb.py b/test/specdb/test_specdb.py index 51be3ff..e697f5e 100644 --- a/test/specdb/test_specdb.py +++ b/test/specdb/test_specdb.py @@ -27,18 +27,18 @@ def test_all_ops(self): op_names = list(SpecDictDB.keys()) skip_ops = [ - '_native_batch_norm_legit_no_training.default', - 'addmm.default', - 'arange.default', - 'arange.start_step', - 'constant_pad_nd.default', - 'reflection_pad1d.default', - 'reflection_pad2d.default', - 'reflection_pad3d.default', - 'replication_pad1d.default', - 'replication_pad2d.default', - 'replication_pad3d.default', - 'split_with_sizes_copy.default', + "_native_batch_norm_legit_no_training.default", + "addmm.default", + "arange.default", + "arange.start_step", + "constant_pad_nd.default", + "reflection_pad1d.default", + "reflection_pad2d.default", + "reflection_pad3d.default", + "replication_pad1d.default", + "replication_pad2d.default", + "replication_pad3d.default", + "split_with_sizes_copy.default", ] for op_name in op_names: @@ -61,9 +61,9 @@ def test_all_ops(self): self.fail(f"Failed while testing operation {op_name}: {e}") if len(errors) > 0: - self.fail(f"Found {len(errors)} errors for {op_name} with valid=True, out=False") - - + self.fail( + f"Found {len(errors)} errors for {op_name} with valid=True, out=False" + ) if __name__ == "__main__":