diff --git a/facto/inputgen/argtuple/gen.py b/facto/inputgen/argtuple/gen.py index 811bb03..ac180cf 100644 --- a/facto/inputgen/argtuple/gen.py +++ b/facto/inputgen/argtuple/gen.py @@ -83,3 +83,51 @@ def gen( engine = MetaArgTupleEngine(self._modified_spec, out=out) for meta_tuple in engine.gen(valid=valid): yield self.gen_tuple(meta_tuple, out=out) + + def gen_errors( + self, op, *, valid: bool = True, out: bool = False + ) -> Generator[ + Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any + ]: + """ + Generate input tuples and yield only those that don't behave as expected. + + This function takes the same signature as gen() but with an additional + op parameter. It filters inputs based on whether they behave as expected: + - When valid=True: yields inputs that should be valid but DO error + - When valid=False: yields inputs that should be invalid but DON'T error + + Args: + op: The operation/function to test the inputs against + valid: Whether to generate valid or invalid inputs (same as gen()) + out: Whether to include output arguments (same as gen()) + + Yields: + Tuples of (posargs, inkwargs, outargs) that don't behave as expected + """ + for posargs, inkwargs, outargs in self.gen(valid=valid, out=out): + try: + # Try to execute the operation with the generated inputs + if out: + # If there are output arguments, include them in the call + op(*posargs, **inkwargs, **outargs) + else: + # Otherwise, just call with positional and keyword arguments + op(*posargs, **inkwargs) + + # If execution succeeds: + if valid: + # When valid=True, we expect success, so this is NOT a bug + continue + else: + # When valid=False, we expect failure, so success IS a bug + yield posargs, inkwargs, outargs + + except Exception: + # If execution fails: + if valid: + # When valid=True, we expect success, so failure IS a bug + yield posargs, inkwargs, outargs + else: + # When valid=False, we expect failure, so this is NOT a bug + continue diff --git a/facto/utils.py b/facto/utils.py new file mode 100644 index 0000000..e3bf146 --- /dev/null +++ b/facto/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def get_op_overload(op_name: str): + """ + Get the torch operation overload from an operation name. + + Args: + op_name: Operation name in the format "op_base.overload" (e.g., "add.Tensor") + + Returns: + The torch operation overload (e.g., torch.ops.aten.add.Tensor) + + Raises: + AttributeError: If the operation is not found + ValueError: If the operation name format is invalid + """ + if "." not in op_name: + raise ValueError( + f"Operation name '{op_name}' must contain a '.' to separate base and overload" + ) + + parts = op_name.split(".") + if len(parts) != 2: + raise ValueError( + f"Operation name '{op_name}' must be in format 'op_base.overload'" + ) + + op_base, overload = parts + + # Get the operation from torch.ops.aten + if not hasattr(torch.ops.aten, op_base): + raise AttributeError(f"Operation base '{op_base}' not found in torch.ops.aten") + + op_obj = getattr(torch.ops.aten, op_base) + + if not hasattr(op_obj, overload): + raise AttributeError( + f"Overload '{overload}' not found for operation '{op_base}'" + ) + + return getattr(op_obj, overload) diff --git a/test/specdb/test_specdb.py b/test/specdb/test_specdb.py new file mode 100644 index 0000000..e697f5e --- /dev/null +++ b/test/specdb/test_specdb.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from facto.inputgen.argtuple.gen import ArgumentTupleGenerator +from facto.specdb.db import SpecDictDB +from facto.utils import get_op_overload + + +class TestSpecDBOperations(unittest.TestCase): + """Test class for validating all specs in SpecDB using gen_errors.""" + + def test_all_ops(self): + """ + Test all ops in SpecDB. + + This test iterates through all operations in SpecDB and calls + ArgumentTupleGenerator.gen_errors with valid=True, out=False + for each operation. Each operation is tested as a subtest. + """ + # Get all operation names from SpecDB + op_names = list(SpecDictDB.keys()) + + skip_ops = [ + "_native_batch_norm_legit_no_training.default", + "addmm.default", + "arange.default", + "arange.start_step", + "constant_pad_nd.default", + "reflection_pad1d.default", + "reflection_pad2d.default", + "reflection_pad3d.default", + "replication_pad1d.default", + "replication_pad2d.default", + "replication_pad3d.default", + "split_with_sizes_copy.default", + ] + + for op_name in op_names: + if op_name in skip_ops: + continue + with self.subTest(op=op_name): + try: + # Get the spec and operation + spec = SpecDictDB[op_name] + op = get_op_overload(op_name) + generator = ArgumentTupleGenerator(spec) + except Exception as e: + # If we can't resolve the operation or there's another issue, + # fail this subtest with a descriptive message + self.fail(f"Failed to test operation {op_name}: {e}") + + try: + errors = list(generator.gen_errors(op, valid=True, out=False)) + except Exception as e: + self.fail(f"Failed while testing operation {op_name}: {e}") + + if len(errors) > 0: + self.fail( + f"Found {len(errors)} errors for {op_name} with valid=True, out=False" + ) + + +if __name__ == "__main__": + unittest.main()