Skip to content

Test all specs in SpecDB for valid inputs #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions facto/inputgen/argtuple/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,51 @@ def gen(
engine = MetaArgTupleEngine(self._modified_spec, out=out)
for meta_tuple in engine.gen(valid=valid):
yield self.gen_tuple(meta_tuple, out=out)

def gen_errors(
self, op, *, valid: bool = True, out: bool = False
) -> Generator[
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
]:
"""
Generate input tuples and yield only those that don't behave as expected.

This function takes the same signature as gen() but with an additional
op parameter. It filters inputs based on whether they behave as expected:
- When valid=True: yields inputs that should be valid but DO error
- When valid=False: yields inputs that should be invalid but DON'T error

Args:
op: The operation/function to test the inputs against
valid: Whether to generate valid or invalid inputs (same as gen())
out: Whether to include output arguments (same as gen())

Yields:
Tuples of (posargs, inkwargs, outargs) that don't behave as expected
"""
for posargs, inkwargs, outargs in self.gen(valid=valid, out=out):
try:
# Try to execute the operation with the generated inputs
if out:
# If there are output arguments, include them in the call
op(*posargs, **inkwargs, **outargs)
else:
# Otherwise, just call with positional and keyword arguments
op(*posargs, **inkwargs)

# If execution succeeds:
if valid:
# When valid=True, we expect success, so this is NOT a bug
continue
else:
# When valid=False, we expect failure, so success IS a bug
yield posargs, inkwargs, outargs

except Exception:
# If execution fails:
if valid:
# When valid=True, we expect success, so failure IS a bug
yield posargs, inkwargs, outargs
else:
# When valid=False, we expect failure, so this is NOT a bug
continue
48 changes: 48 additions & 0 deletions facto/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch


def get_op_overload(op_name: str):
"""
Get the torch operation overload from an operation name.

Args:
op_name: Operation name in the format "op_base.overload" (e.g., "add.Tensor")

Returns:
The torch operation overload (e.g., torch.ops.aten.add.Tensor)

Raises:
AttributeError: If the operation is not found
ValueError: If the operation name format is invalid
"""
if "." not in op_name:
raise ValueError(
f"Operation name '{op_name}' must contain a '.' to separate base and overload"
)

parts = op_name.split(".")
if len(parts) != 2:
raise ValueError(
f"Operation name '{op_name}' must be in format 'op_base.overload'"
)

op_base, overload = parts

# Get the operation from torch.ops.aten
if not hasattr(torch.ops.aten, op_base):
raise AttributeError(f"Operation base '{op_base}' not found in torch.ops.aten")

op_obj = getattr(torch.ops.aten, op_base)

if not hasattr(op_obj, overload):
raise AttributeError(
f"Overload '{overload}' not found for operation '{op_base}'"
)

return getattr(op_obj, overload)
70 changes: 70 additions & 0 deletions test/specdb/test_specdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.specdb.db import SpecDictDB
from facto.utils import get_op_overload


class TestSpecDBOperations(unittest.TestCase):
"""Test class for validating all specs in SpecDB using gen_errors."""

def test_all_ops(self):
"""
Test all ops in SpecDB.

This test iterates through all operations in SpecDB and calls
ArgumentTupleGenerator.gen_errors with valid=True, out=False
for each operation. Each operation is tested as a subtest.
"""
# Get all operation names from SpecDB
op_names = list(SpecDictDB.keys())

skip_ops = [
"_native_batch_norm_legit_no_training.default",
"addmm.default",
"arange.default",
"arange.start_step",
"constant_pad_nd.default",
"reflection_pad1d.default",
"reflection_pad2d.default",
"reflection_pad3d.default",
"replication_pad1d.default",
"replication_pad2d.default",
"replication_pad3d.default",
"split_with_sizes_copy.default",
]

for op_name in op_names:
if op_name in skip_ops:
continue
with self.subTest(op=op_name):
try:
# Get the spec and operation
spec = SpecDictDB[op_name]
op = get_op_overload(op_name)
generator = ArgumentTupleGenerator(spec)
except Exception as e:
# If we can't resolve the operation or there's another issue,
# fail this subtest with a descriptive message
self.fail(f"Failed to test operation {op_name}: {e}")

try:
errors = list(generator.gen_errors(op, valid=True, out=False))
except Exception as e:
self.fail(f"Failed while testing operation {op_name}: {e}")

if len(errors) > 0:
self.fail(
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
)


if __name__ == "__main__":
unittest.main()