Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,60 @@ class AstCloner : public AstNodeVisitor {
absl::flat_hash_map<const AstNode*, AstNode*> old_to_new_;
};

absl::StatusOr<ModuleMember> MakeClonedModuleMember(
const ModuleMember& original,
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new) {
ModuleMember new_member;
XLS_RETURN_IF_ERROR(absl::visit(
Visitor{
[&](auto* node) -> absl::Status {
auto it = old_to_new.find(node);
if (it == old_to_new.end()) {
return absl::InternalError("Cloned node missing from map.");
}
AstNode* replacement = it->second;
if (replacement->kind() == AstNodeKind::kVerbatimNode) {
new_member = down_cast<VerbatimNode*>(replacement);
return absl::OkStatus();
}
using NodeT = std::remove_pointer_t<decltype(node)>;
new_member = down_cast<NodeT*>(replacement);
return absl::OkStatus();
},
},
original));
return new_member;
}

bool ShouldSkipMember(
const ModuleMember& member,
const absl::flat_hash_set<const AstNode*>& nodes_to_ignore) {
if (nodes_to_ignore.empty()) {
return false;
}
const AstNode* member_node = ToAstNode(member);
if (nodes_to_ignore.contains(member_node)) {
return true;
}
for (NameDef* name_def : ModuleMemberGetNameDefs(member)) {
if (nodes_to_ignore.contains(static_cast<const AstNode*>(name_def))) {
return true;
}
}
return false;
}

void CollectUseTreeEntries(const UseTreeEntry& entry,
absl::flat_hash_set<const AstNode*>& ignored_nodes) {
ignored_nodes.insert(&entry);
for (AstNode* child : entry.GetChildren(/*want_types=*/false)) {
if (child == nullptr || child->kind() != AstNodeKind::kUseTreeEntry) {
continue;
}
CollectUseTreeEntries(*down_cast<UseTreeEntry*>(child), ignored_nodes);
}
}

} // namespace

std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
Expand Down Expand Up @@ -1335,6 +1389,106 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
return new_module;
}

absl::StatusOr<std::unique_ptr<Module>> CloneModuleIgnoreMembers(
const Module& module, absl::Span<const AstNode* const> members_to_ignore) {
absl::flat_hash_set<const AstNode*> nodes_to_ignore;
nodes_to_ignore.reserve(members_to_ignore.size());
for (const AstNode* node : members_to_ignore) {
if (node != nullptr) {
nodes_to_ignore.insert(node);
}
}

absl::flat_hash_map<const NameDef*, std::string> ignored_name_defs;
absl::flat_hash_set<const AstNode*> ignored_nodes(nodes_to_ignore.begin(),
nodes_to_ignore.end());
for (const ModuleMember& member : module.top()) {
if (!ShouldSkipMember(member, nodes_to_ignore)) {
continue;
}
const AstNode* member_node = ToAstNode(member);
ignored_nodes.insert(member_node);
for (NameDef* def : ModuleMemberGetNameDefs(member)) {
ignored_name_defs.emplace(def, def->identifier());
ignored_nodes.insert(def);
}
if (auto* use = dynamic_cast<const Use*>(member_node)) {
CollectUseTreeEntries(use->root(), ignored_nodes);
}
}

auto new_module = std::make_unique<Module>(module.name(), module.fs_path(),
*module.file_table());
std::optional<Span> attribute_span = module.GetAttributeSpan();
for (const ModuleAttribute& attribute : module.attributes()) {
new_module->AddAttribute(attribute, attribute_span);
}

absl::flat_hash_map<const AstNode*, AstNode*> global_map;

for (const ModuleMember& member : module.top()) {
if (ShouldSkipMember(member, nodes_to_ignore)) {
continue;
}

const AstNode* original_node = ToAstNode(member);
CloneReplacer reuse_existing =
[&](const AstNode* node, Module* target,
const absl::flat_hash_map<const AstNode*, AstNode*>&)
-> absl::StatusOr<std::optional<AstNode*>> {
if (auto cached = global_map.find(node); cached != global_map.end()) {
return cached->second;
}

if (node->kind() == AstNodeKind::kNameRef) {
const auto* name_ref = down_cast<const NameRef*>(node);
if (std::holds_alternative<const NameDef*>(name_ref->name_def())) {
const NameDef* def = std::get<const NameDef*>(name_ref->name_def());
if (auto ignored_it = ignored_name_defs.find(def);
ignored_it != ignored_name_defs.end()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Module member references removed definition '%s'",
ignored_it->second));
}
if (auto def_it = global_map.find(def); def_it != global_map.end()) {
auto* new_def = down_cast<NameDef*>(def_it->second);
AstNode* new_ref =
target->Make<NameRef>(name_ref->span(), name_ref->identifier(),
new_def, name_ref->in_parens());
global_map[node] = new_ref;
return new_ref;
}
}
} else if (node->kind() == AstNodeKind::kTypeRef) {
const auto* type_ref = down_cast<const TypeRef*>(node);
const bool references_ignored =
absl::visit(Visitor{[&](auto* ref) {
return ref != nullptr && ignored_nodes.contains(ref);
}},
type_ref->type_definition());
if (references_ignored) {
return absl::InvalidArgumentError(absl::StrFormat(
"Module member references removed definition '%s'",
type_ref->ToString()));
}
}

return std::optional<AstNode*>{std::nullopt};
};

XLS_ASSIGN_OR_RETURN(auto old_to_new,
CloneAstAndGetAllPairs(original_node, new_module.get(),
std::move(reuse_existing)));
global_map.insert(old_to_new.begin(), old_to_new.end());
XLS_ASSIGN_OR_RETURN(ModuleMember cloned_member,
MakeClonedModuleMember(member, old_to_new));
XLS_RETURN_IF_ERROR(new_module->AddTop(cloned_member,
/*make_collision_error=*/nullptr));
}

return new_module;
}

CloneReplacer ChainCloneReplacers(CloneReplacer first, CloneReplacer second) {
return
[first = std::move(first), second = std::move(second)](
Expand Down
7 changes: 7 additions & 0 deletions xls/dslx/frontend/ast_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ CloneAstAndGetAllPairs(const AstNode* root,
absl::StatusOr<std::unique_ptr<Module>> CloneModule(
const Module& module, CloneReplacer replacer = &NoopCloneReplacer);

// Returns a clone of `module` that omits any top-level members whose
// identifiers appear in `member_names_to_ignore`. References to those members
// are not validated or rewritten; the caller is responsible for handling any
// resulting dangling references.
absl::StatusOr<std::unique_ptr<Module>> CloneModuleIgnoreMembers(
const Module& module, absl::Span<const AstNode* const> members_to_ignore);

// Returns a CloneReplacer that runs `first` and then runs `second` on the
// preliminary result, short-circuiting if `first` returns an error.
CloneReplacer ChainCloneReplacers(CloneReplacer first, CloneReplacer second);
Expand Down
Loading