Skip to content

Commit fe35e2f

Browse files
dplassgitcopybara-github
authored andcommitted
[DSLX Fuzz testing] Parse the #[fuzz_test] domains as a XlsTuple
(instead of an `Expr`) and store it in the `FuzzTestFunction` AST node. This required modifying the parser to be more "reentrant" by passing in an existing module istead of creating a new one. PiperOrigin-RevId: 896603714
1 parent bb45011 commit fe35e2f

File tree

18 files changed

+272
-52
lines changed

18 files changed

+272
-52
lines changed

xls/dslx/fmt/ast_fmt.cc

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3064,32 +3064,35 @@ DocRef Formatter::Format(const ProcAlias& n) {
30643064
}
30653065

30663066
DocRef Formatter::Format(const ModuleMember& n) {
3067-
return absl::visit(Visitor{
3068-
[&](const Function* n) { return Format(*n); },
3069-
[&](const Proc* n) { return Format(*n); },
3070-
[&](const TestFunction* n) { return Format(*n); },
3071-
[&](const TestProc* n) { return Format(*n); },
3072-
[&](const QuickCheck* n) { return Format(*n); },
3073-
[&](const TypeAlias* n) {
3074-
return arena_.MakeConcat(Format(*n), arena_.semi());
3075-
},
3076-
[&](const ProcAlias* n) {
3077-
return arena_.MakeConcat(Format(*n), arena_.semi());
3078-
},
3079-
[&](const StructDef* n) { return Format(*n); },
3080-
[&](const ProcDef* n) { return Format(*n); },
3081-
[&](const Impl* n) { return Format(*n); },
3082-
[&](const Trait* n) { return Format(*n); },
3083-
[&](const ConstantDef* n) { return Format(*n); },
3084-
[&](const EnumDef* n) { return Format(*n); },
3085-
[&](const Import* n) { return Format(*n); },
3086-
[&](const Use* n) { return Format(*n); },
3087-
[&](const ConstAssert* n) {
3088-
return arena_.MakeConcat(Format(*n), arena_.semi());
3089-
},
3090-
[&](const VerbatimNode* n) { return Format(*n); },
3091-
},
3092-
n);
3067+
return absl::visit(
3068+
Visitor{
3069+
[&](const Function* n) { return Format(*n); },
3070+
[&](const Proc* n) { return Format(*n); },
3071+
[&](const TestFunction* n) { return Format(*n); },
3072+
// TODO: davidplass - Add formatting for FuzzTestFunction.
3073+
[&](const FuzzTestFunction* n) { return Format(n->fn()); },
3074+
[&](const TestProc* n) { return Format(*n); },
3075+
[&](const QuickCheck* n) { return Format(*n); },
3076+
[&](const TypeAlias* n) {
3077+
return arena_.MakeConcat(Format(*n), arena_.semi());
3078+
},
3079+
[&](const ProcAlias* n) {
3080+
return arena_.MakeConcat(Format(*n), arena_.semi());
3081+
},
3082+
[&](const StructDef* n) { return Format(*n); },
3083+
[&](const ProcDef* n) { return Format(*n); },
3084+
[&](const Impl* n) { return Format(*n); },
3085+
[&](const Trait* n) { return Format(*n); },
3086+
[&](const ConstantDef* n) { return Format(*n); },
3087+
[&](const EnumDef* n) { return Format(*n); },
3088+
[&](const Import* n) { return Format(*n); },
3089+
[&](const Use* n) { return Format(*n); },
3090+
[&](const ConstAssert* n) {
3091+
return arena_.MakeConcat(Format(*n), arena_.semi());
3092+
},
3093+
[&](const VerbatimNode* n) { return Format(*n); },
3094+
},
3095+
n);
30933096
}
30943097

30953098
// Returns whether the given members are of the given "MemberT" and "grouped" --

xls/dslx/frontend/ast.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ std::string_view AstNodeKindToString(AstNodeKind kind) {
311311
return "for";
312312
case AstNodeKind::kFunctionRef:
313313
return "function-ref";
314+
case AstNodeKind::kFuzzTestFunction:
315+
return "fuzz test function";
314316
case AstNodeKind::kStatementBlock:
315317
return "statement-block";
316318
case AstNodeKind::kTrait:
@@ -2419,6 +2421,25 @@ Function::LambdaReturnTypeParametrics() const {
24192421

24202422
TestFunction::~TestFunction() = default;
24212423

2424+
// -- class FuzzTestFunction
2425+
2426+
FuzzTestFunction::~FuzzTestFunction() = default;
2427+
2428+
std::string FuzzTestFunction::ToString() const {
2429+
if (domains_.has_value()) {
2430+
return absl::StrFormat("#[fuzz_test(domains=`%s`)]\n%s",
2431+
(*domains_)->ToString(), fn_.ToString());
2432+
}
2433+
return absl::StrFormat("#[fuzz_test]\n%s", fn_.ToString());
2434+
}
2435+
2436+
std::vector<AstNode*> FuzzTestFunction::GetChildren(bool want_types) const {
2437+
if (domains_.has_value()) {
2438+
return {&fn_, *domains_};
2439+
}
2440+
return {&fn_};
2441+
}
2442+
24222443
// -- class Lambda
24232444

24242445
Lambda::Lambda(Module* owner, Span span, Function* function, bool in_parens)

xls/dslx/frontend/ast.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
X(ConstantDef) \
9595
X(EnumDef) \
9696
X(Function) \
97+
X(FuzzTestFunction) \
9798
X(Impl) \
9899
X(Import) \
99100
X(Let) \
@@ -3843,6 +3844,8 @@ class Range : public Expr {
38433844
// #[test]
38443845
// fn test_foo() { ... }
38453846
// ```
3847+
// TODO: davidplass - Add an optional parameter Expr to allow re-using as a
3848+
// FuzzTestFunction or QuickCheck.
38463849
class TestFunction : public AstNode {
38473850
public:
38483851
static std::string_view GetDebugTypeName() { return "test function"; }
@@ -3879,6 +3882,64 @@ class TestFunction : public AstNode {
38793882
Function& fn_;
38803883
};
38813884

3885+
// Represents a fuzz test construct.
3886+
//
3887+
// These are specified with an annotation as follows:
3888+
//
3889+
// ```dslx
3890+
// #[fuzz_test]
3891+
// fn test_foo() { ... }
3892+
// ```
3893+
//
3894+
// or with a domains argument:
3895+
//
3896+
// ```dslx
3897+
// #[fuzz_test(domains=`...`)]
3898+
// fn test_foo() { ... }
3899+
// ```
3900+
// TODO: davidplass - Migrate this into TestFunction, after adding an optional
3901+
// parameter to TestFunction to capture the domain tuple.
3902+
class FuzzTestFunction : public AstNode {
3903+
public:
3904+
static std::string_view GetDebugTypeName() { return "fuzz test function"; }
3905+
3906+
FuzzTestFunction(Module* owner, Span span, Function& fn,
3907+
std::optional<XlsTuple*> domains)
3908+
: AstNode(owner),
3909+
span_(std::move(span)),
3910+
fn_(fn),
3911+
domains_(std::move(domains)) {}
3912+
3913+
~FuzzTestFunction() override;
3914+
3915+
AstNodeKind kind() const override { return AstNodeKind::kFuzzTestFunction; }
3916+
NameDef* name_def() const { return fn_.name_def(); }
3917+
3918+
absl::Status Accept(AstNodeVisitor* v) const override {
3919+
return v->HandleFuzzTestFunction(this);
3920+
}
3921+
3922+
std::vector<AstNode*> GetChildren(bool want_types) const override;
3923+
3924+
std::string_view GetNodeTypeName() const override {
3925+
return "FuzzTestFunction";
3926+
}
3927+
std::string ToString() const override;
3928+
3929+
Function& fn() const { return fn_; }
3930+
std::optional<Span> GetSpan() const override { return span(); }
3931+
const Span& span() const { return span_; }
3932+
3933+
const std::string& identifier() const { return fn_.name_def()->identifier(); }
3934+
3935+
const std::optional<XlsTuple*>& domains() const { return domains_; }
3936+
3937+
private:
3938+
const Span span_;
3939+
Function& fn_;
3940+
const std::optional<XlsTuple*> domains_;
3941+
};
3942+
38823943
enum class QuickCheckTestCasesTag {
38833944
kExhaustive,
38843945
kCounted,
@@ -3914,6 +3975,8 @@ class QuickCheckTestCases {
39143975
};
39153976

39163977
// Represents a function to be quick-check'd.
3978+
// TODO: davidplass - Migrate this into TestFunction, since QuickCheck is an
3979+
// attribute which can capture the quick check test cases.
39173980
class QuickCheck : public AstNode {
39183981
public:
39193982
static std::string_view GetDebugTypeName() { return "quickcheck"; }

xls/dslx/frontend/ast_cloner.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,22 @@ class AstCloner : public AstNodeVisitor {
10471047
return absl::OkStatus();
10481048
}
10491049

1050+
absl::Status HandleFuzzTestFunction(const FuzzTestFunction* n) override {
1051+
XLS_RETURN_IF_ERROR(VisitChildren(n));
1052+
1053+
XLS_RETURN_IF_ERROR(ReplaceOrVisit(&n->fn()));
1054+
XLS_ASSIGN_OR_RETURN(Function * new_fn, CastIfNotVerbatim<Function*>(
1055+
old_to_new_.at(&n->fn())));
1056+
std::optional<XlsTuple*> new_domains =
1057+
n->domains().has_value()
1058+
? std::make_optional(absl::down_cast<XlsTuple*>(
1059+
old_to_new_.at(n->domains().value())))
1060+
: std::nullopt;
1061+
old_to_new_[n] =
1062+
module(n)->Make<FuzzTestFunction>(n->span(), *new_fn, new_domains);
1063+
return absl::OkStatus();
1064+
}
1065+
10501066
absl::Status HandleTestProc(const TestProc* n) override {
10511067
XLS_RETURN_IF_ERROR(VisitChildren(n));
10521068

xls/dslx/frontend/ast_node.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ enum class AstNodeKind : uint8_t {
5252
kFormatMacro,
5353
kFunction,
5454
kFunctionRef,
55+
kFuzzTestFunction,
5556
kImpl,
5657
kImport,
5758
kIndex,

xls/dslx/frontend/module.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ std::string_view GetModuleMemberTypeName(const ModuleMember& module_member) {
421421
[](Function*) { return "function"; },
422422
[](Proc*) { return "proc"; },
423423
[](TestFunction*) { return "test-function"; },
424+
[](FuzzTestFunction*) { return "fuzz-test-function"; },
424425
[](TestProc*) { return "test-proc"; },
425426
[](QuickCheck*) { return "quick-check"; },
426427
[](TypeAlias*) { return "type-alias"; },
@@ -443,6 +444,7 @@ bool IsPublic(const ModuleMember& member) {
443444
return absl::visit(Visitor{
444445
[](const auto* m) { return m->is_public(); },
445446
[](const TestFunction* m) { return false; },
447+
[](const FuzzTestFunction* m) { return false; },
446448
[](const TestProc* m) { return false; },
447449
[](const QuickCheck* m) { return false; },
448450
[](const Import* m) { return false; },

xls/dslx/frontend/module.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ using ModuleMember =
4747
std::variant<Function*, Proc*, TestFunction*, TestProc*, QuickCheck*,
4848
TypeAlias*, StructDef*, ProcAlias*, ProcDef*, ConstantDef*,
4949
EnumDef*, Import*, Use*, ConstAssert*, Impl*, Trait*,
50-
VerbatimNode*>;
50+
VerbatimNode*, FuzzTestFunction*>;
5151

5252
// Returns all the NameDefs defined by the given module member.
5353
//

xls/dslx/frontend/parser.cc

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ absl::StatusOr<std::unique_ptr<Module>> Parser::ParseModule(
535535
XLS_ASSIGN_OR_RETURN(Function * fn,
536536
ParseFunction(*module_member_start_pos, is_public,
537537
*bindings, &name_to_fn));
538-
XLS_ASSIGN_OR_RETURN(ModuleMember fn_or_wrapper,
539-
ApplyFunctionAttributes(fn, pending_attributes));
538+
XLS_ASSIGN_OR_RETURN(
539+
ModuleMember fn_or_wrapper,
540+
ApplyFunctionAttributes(fn, pending_attributes, *bindings));
540541
XLS_RETURN_IF_ERROR(
541542
module_->AddTop(fn_or_wrapper, make_collision_error));
542543
break;
@@ -653,11 +654,11 @@ absl::StatusOr<std::unique_ptr<Module>> Parser::ParseModule(
653654
// post-condition.
654655
XLS_RET_CHECK(AtEof());
655656

656-
XLS_RETURN_IF_ERROR(VerifyParentage(module_.get()));
657+
XLS_RETURN_IF_ERROR(VerifyParentage(module_));
657658

658659
module_->set_span(Span(module_start_pos, GetPos()));
659660

660-
auto result = std::move(module_);
661+
auto result = std::move(owned_module_);
661662
module_ = nullptr;
662663
return result;
663664
}
@@ -1008,8 +1009,10 @@ absl::Status Parser::ApplyExternVerilogAttribute(Function* fn,
10081009
}
10091010

10101011
absl::StatusOr<ModuleMember> Parser::ApplyFunctionAttributes(
1011-
Function* fn, std::vector<Attribute*> attributes) {
1012+
Function* fn, std::vector<Attribute*> attributes, Bindings& bindings) {
10121013
bool is_test = false;
1014+
bool is_fuzz_test = false;
1015+
Attribute* fuzz_test_attribute = nullptr;
10131016
std::optional<QuickCheckTestCases> quickcheck_test_cases;
10141017
std::vector<std::string> test_attributes;
10151018

@@ -1035,7 +1038,9 @@ absl::StatusOr<ModuleMember> Parser::ApplyFunctionAttributes(
10351038
is_test = true;
10361039
break;
10371040
case AttributeKind::kFuzzTest:
1041+
is_fuzz_test = true;
10381042
XLS_RETURN_IF_ERROR(ValidateFuzzTestAttribute(*next));
1043+
fuzz_test_attribute = next;
10391044
test_attributes.push_back(next->ToString());
10401045
break;
10411046

@@ -1072,6 +1077,22 @@ absl::StatusOr<ModuleMember> Parser::ApplyFunctionAttributes(
10721077
tf->SetParentage(); // Ensure the function has its parent marked.
10731078
return tf;
10741079
}
1080+
if (is_fuzz_test) {
1081+
Span ft_span(fuzz_test_attribute->GetSpan()->start(), fn->span().limit());
1082+
std::optional<XlsTuple*> domains = std::nullopt;
1083+
for (const AttributeData::Argument& arg : fuzz_test_attribute->args()) {
1084+
if (auto* kv = std::get_if<AttributeData::StringKeyValueArgument>(&arg)) {
1085+
if (kv->first == "domains" && kv->is_backticked) {
1086+
XLS_ASSIGN_OR_RETURN(domains, ParseDomains(kv->second, bindings));
1087+
break;
1088+
}
1089+
}
1090+
}
1091+
FuzzTestFunction* ft =
1092+
module_->Make<FuzzTestFunction>(ft_span, *fn, domains);
1093+
ft->SetParentage(); // Ensure the function has its parent marked.
1094+
return ft;
1095+
}
10751096

10761097
if (quickcheck_test_cases.has_value()) {
10771098
const Span quickcheck_span(attributes[0]->GetSpan()->start(),
@@ -1083,6 +1104,22 @@ absl::StatusOr<ModuleMember> Parser::ApplyFunctionAttributes(
10831104
return fn;
10841105
}
10851106

1107+
absl::StatusOr<XlsTuple*> Parser::ParseDomains(std::string_view domains_str,
1108+
Bindings& bindings) {
1109+
std::string wrapped = absl::StrCat("(", domains_str, ")");
1110+
Scanner domain_scanner(file_table(), scanner().fileno(), wrapped);
1111+
Parser sub_parser(module_, &domain_scanner, parse_fn_stubs_);
1112+
1113+
XLS_ASSIGN_OR_RETURN(Expr * parsed, sub_parser.ParseExpression(bindings));
1114+
1115+
if (parsed->kind() == AstNodeKind::kXlsTuple) {
1116+
return dynamic_cast<XlsTuple*>(parsed);
1117+
}
1118+
// If it's not already a tuple, wrap it in one.
1119+
return module_->Make<XlsTuple>(parsed->span(), std::vector<Expr*>{parsed},
1120+
/*has_trailing_comma=*/false);
1121+
}
1122+
10861123
template <typename T>
10871124
absl::Status Parser::ApplyTypeAttributes(T* node,
10881125
std::vector<Attribute*> attributes) {
@@ -3105,12 +3142,12 @@ absl::StatusOr<Spawn*> Parser::ParseSpawn(Bindings& bindings) {
31053142
absl::StrCat(colon_ref->attr(), ".config"));
31063143

31073144
ColonRef::Subject clone_subject =
3108-
CloneSubject(module_.get(), colon_ref->subject());
3145+
CloneSubject(module_, colon_ref->subject());
31093146
next_ref =
31103147
module_->Make<ColonRef>(colon_ref->span(), clone_subject,
31113148
absl::StrCat(colon_ref->attr(), ".next"));
31123149

3113-
clone_subject = CloneSubject(module_.get(), colon_ref->subject());
3150+
clone_subject = CloneSubject(module_, colon_ref->subject());
31143151
init_ref =
31153152
module_->Make<ColonRef>(colon_ref->span(), clone_subject,
31163153
absl::StrCat(colon_ref->attr(), ".init"));
@@ -3693,11 +3730,10 @@ absl::StatusOr<ModuleMember> Parser::ParseProcLike(const Pos& start_pos,
36933730
"Impl-style procs must use commas to separate members.");
36943731
}
36953732
// Assume this is an impl-style proc and return a `ProcDef` for it.
3696-
ProcDef* proc_def =
3697-
module_->Make<ProcDef>(span, name_def, std::move(parametric_bindings),
3698-
ConvertProcMembersToStructMembers(
3699-
module_.get(), proc_like_body.members),
3700-
is_public);
3733+
ProcDef* proc_def = module_->Make<ProcDef>(
3734+
span, name_def, std::move(parametric_bindings),
3735+
ConvertProcMembersToStructMembers(module_, proc_like_body.members),
3736+
is_public);
37013737
outer_bindings.Add(name_def->identifier(), proc_def);
37023738
return proc_def;
37033739
}

0 commit comments

Comments
 (0)