diff --git a/src/passes/Unsubtyping.cpp b/src/passes/Unsubtyping.cpp index 729dd18d488..6f7a71fb295 100644 --- a/src/passes/Unsubtyping.cpp +++ b/src/passes/Unsubtyping.cpp @@ -14,18 +14,28 @@ * limitations under the License. */ +#define UNSUBTYPING_DEBUG 0 + +#include + +#if !UNSUBTYPING_DEBUG #include +#include +#endif #include "ir/subtype-exprs.h" -#include "ir/subtypes.h" #include "ir/type-updating.h" #include "ir/utils.h" #include "pass.h" -#include "support/unique_deferring_queue.h" +#include "support/index.h" #include "wasm-traversal.h" #include "wasm-type.h" #include "wasm.h" +#if UNSUBTYPING_DEBUG +#include "support/insert_ordered.h" +#endif + // Compute and use the minimal subtype relation required to maintain module // validity and behavior. This minimal relation will be a subset of the original // subtype relation. Start by walking the IR and collecting pairs of types that @@ -96,287 +106,494 @@ // // Starting with the initial subtype relation determined by walking the IR, // repeatedly search for new subtypings by analyzing type definitions and casts -// in lock step until we reach a fixed point. This is the minimal subtype -// relation that preserves module validity and behavior that can be found -// without a more precise analysis of types that might flow into each cast. +// until we reach a fixed point. This is the minimal subtype relation that +// preserves module validity and behavior that can be found without a more +// precise analysis of types that might flow into each cast. namespace wasm { namespace { -struct Unsubtyping - : WalkerPass< - ControlFlowWalker>> { - // The new set of supertype relations. - std::unordered_map supertypes; +#if UNSUBTYPING_DEBUG +template using Map = InsertOrderedMap; +template using Set = InsertOrderedSet; +#else +template using Map = std::unordered_map; +template using Set = std::unordered_set; +#endif +// A tree (or rather a forest) of types with the ability to query and set +// supertypes in constant time and efficiently iterate over supertypes and +// subtypes. +struct TypeTree { + struct Node { + // The type represented by this node. + HeapType type; + // The index of the parent (supertype) in the list of nodes. Set to the + // index of this node if there is no parent. + Index parent; + // The index of this node in the parent's list of children, if any, enabling + // O(1) updates. + Index indexInParent = 0; + // The indices of the children (subtypes) in the list of nodes. + std::vector children; - // Map from cast source types to their destinations. - std::unordered_map> castTypes; + Node(HeapType type, Index index) : type(type), parent(index) {} + }; + + std::vector nodes; + Map indices; + + void setSupertype(HeapType sub, HeapType super) { + auto subIndex = getNode(sub); + auto superIndex = getNode(super); + auto& childNode = nodes[subIndex]; + auto& parentNode = nodes[superIndex]; + // Remove sub from its old supertype if necessary. + if (auto oldParentIndex = childNode.parent; oldParentIndex != subIndex) { + auto& oldParentNode = nodes[oldParentIndex]; + // Move sub to the back of its parent's children and then pop it. + auto& children = oldParentNode.children; + assert(children[childNode.indexInParent] == subIndex); + auto& swappedNode = nodes[children.back()]; + assert(swappedNode.indexInParent == children.size() - 1); + // Swap the indices in the parent's child vector. + std::swap(children[childNode.indexInParent], children.back()); + // Swap the indices in the children. + std::swap(childNode.indexInParent, swappedNode.indexInParent); + children.pop_back(); + } + childNode.parent = superIndex; + childNode.indexInParent = parentNode.children.size(); + parentNode.children.push_back(subIndex); + } + + std::optional getSupertype(HeapType type) { + auto index = getNode(type); + auto parentIndex = nodes[index].parent; + if (parentIndex == index) { + return std::nullopt; + } + return nodes[parentIndex].type; + } + + struct SupertypeIterator { + using value_type = const HeapType; + using difference_type = std::ptrdiff_t; + using reference = const HeapType&; + using pointer = const HeapType*; + using iterator_category = std::input_iterator_tag; - // The set of subtypes that need to have their type definitions analyzed to - // transitively find other subtype relations they depend on. We add to it - // every time we find a new subtype relationship we need to keep. - UniqueDeferredQueue work; + TypeTree* parent; + std::optional index; + + bool operator==(const SupertypeIterator& other) { + return index == other.index; + } + bool operator!=(const SupertypeIterator& other) { + return !(*this == other); + } + const HeapType& operator*() const { return parent->nodes[*index].type; } + const HeapType* operator->() const { return &*(*this); } + SupertypeIterator& operator++() { + auto parentIndex = parent->nodes[*index].parent; + if (parentIndex == *index) { + index = std::nullopt; + } else { + index = parentIndex; + } + return *this; + } + SupertypeIterator operator++(int) { + auto it = *this; + ++(*this); + return it; + } + }; + + struct Supertypes { + TypeTree* parent; + Index index; + SupertypeIterator begin() { return {parent, index}; } + SupertypeIterator end() { return {parent, std::nullopt}; } + }; + + Supertypes supertypes(HeapType type) { return {this, getNode(type)}; } + + struct SubtypeIterator { + using value_type = const HeapType; + using difference_type = std::ptrdiff_t; + using reference = const HeapType&; + using pointer = const HeapType*; + using iterator_category = std::input_iterator_tag; + + TypeTree* parent; + + // DFS stack of (node index, child index) pairs. + std::vector> stack; + + bool operator==(const SubtypeIterator& other) { + return stack == other.stack; + } + bool operator!=(const SubtypeIterator& other) { return !(*this == other); } + const HeapType& operator*() const { + return parent->nodes[stack.back().first].type; + } + const HeapType* operator->() const { return &*(*this); } + SubtypeIterator& operator++() { + while (true) { + if (stack.empty()) { + return *this; + } + auto& [index, childIndex] = stack.back(); + if (childIndex == parent->nodes[index].children.size()) { + stack.pop_back(); + continue; + } + break; + } + auto& [index, childIndex] = stack.back(); + auto child = parent->nodes[index].children[childIndex]; + ++childIndex; + stack.push_back({child, 0u}); + return *this; + } + SubtypeIterator operator++(int) { + auto it = *this; + ++(*this); + return it; + } + }; + + struct Subtypes { + TypeTree* parent; + Index index; + SubtypeIterator begin() { return {parent, {std::make_pair(index, 0u)}}; } + SubtypeIterator end() { return {parent, {}}; } + }; + + Subtypes subtypes(HeapType type) { return {this, getNode(type)}; } + +private: + Index getNode(HeapType type) { + auto [it, inserted] = indices.insert({type, nodes.size()}); + if (inserted) { + nodes.emplace_back(type, nodes.size()); + } + return it->second; + } +}; + +struct Unsubtyping : Pass { + // (sub, super) pairs that we have discovered but not yet processed. + std::vector> work; + + // Record the type tree with supertype and subtype relations in such a way + // that we can add new supertype relationships in constant time. + TypeTree types; + + // Map from cast source types to their destinations. + Map> casts; void run(Module* wasm) override { if (!wasm->features.hasGC()) { return; } + + // Initialize the subtype relation based on what is immediately required to + // keep the code and public types valid. analyzePublicTypes(*wasm); - walkModule(wasm); - analyzeTransitiveDependencies(); - optimizeTypes(*wasm); + analyzeModule(*wasm); + + // Find further subtypings and iterate to a fixed point. + while (!work.empty()) { + auto [sub, super] = work.back(); + work.pop_back(); + process(sub, super); + } + + rewriteTypes(*wasm); + // Cast types may be refinable if their source and target types are no // longer related. TODO: Experiment with running this only after checking // whether it is necessary. ReFinalize().run(getPassRunner(), wasm); } - // Note that sub must remain a subtype of super. - void noteSubtype(HeapType sub, HeapType super) { - if (sub == super || sub.isBottom() || super.isBottom()) { + size_t noteCount = 0; + void note(HeapType sub, HeapType super) { + // Bottom types are uninteresting, but other basic heap types can be + // interesting because of their interactions with casts. + if (sub == super || sub.isBottom()) { return; } + ++noteCount; - auto [it, inserted] = supertypes.insert({sub, super}); - if (inserted) { - work.push(sub); - // TODO: Incrementally check all subtypes (inclusive) of sub against super - // and all its supertypes if we have already analyzed casts. - return; - } - // We already had a recorded supertype. The new supertype might be deeper, - // shallower, or identical to the old supertype. - auto oldSuper = it->second; - if (super == oldSuper) { - return; - } - // There are two different supertypes, but each type can only have a single - // direct subtype so the supertype chain cannot fork and one of the - // supertypes must be a supertype of the other. Recursively record that - // relationship as well. - if (HeapType::isSubType(super, oldSuper)) { - // sub <: super <: oldSuper - it->second = super; - work.push(sub); - // TODO: Incrementally check all subtypes (inclusive) of sub against super - // if we have already analyzed casts. - noteSubtype(super, oldSuper); - } else { - // sub <: oldSuper <: super - noteSubtype(oldSuper, super); - } + work.push_back({sub, super}); } - void noteSubtype(Type sub, Type super) { + void note(Type sub, Type super) { if (sub.isTuple()) { assert(super.isTuple() && sub.size() == super.size()); for (size_t i = 0, size = sub.size(); i < size; ++i) { - noteSubtype(sub[i], super[i]); + note(sub[i], super[i]); } return; } if (!sub.isRef() || !super.isRef()) { return; } - noteSubtype(sub.getHeapType(), super.getHeapType()); + note(sub.getHeapType(), super.getHeapType()); } - // Note a subtyping where one or both sides are expressions. - void noteSubtype(Expression* sub, Type super) { - noteSubtype(sub->type, super); - } - void noteSubtype(Type sub, Expression* super) { - noteSubtype(sub, super->type); - } - void noteSubtype(Expression* sub, Expression* super) { - noteSubtype(sub->type, super->type); + void analyzePublicTypes(Module& wasm) { + // We cannot change supertypes for anything public. + for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) { + if (auto super = type.getDeclaredSuperType()) { + note(type, *super); + } + } } - void noteNonFlowSubtype(Expression* sub, Type super) { - // This expression's type must be a subtype of |super|, but the value does - // not flow anywhere - this is a static constraint. As the value does not - // flow, it cannot reach anywhere else, which means we need this in order to - // validate but it does not interact with casts. Given that, if super is a - // basic type then we can simply ignore this: we only remove subtyping - // between user types, so subtyping wrt basic types is unchanged, and so - // this constraint will never be a problem. - // - // This is sort of a hack because in general to be precise we should not - // just consider basic types here - in general, we should note for each - // constraint whether it is a flow-based one or not, and only take the - // flow-based ones into account when looking at the impact of casts. - // However, in practice this is enough as the only non-trivial case of - // |noteNonFlowSubtype| is for RefEq, which uses a basic type (eqref). Other - // cases of non-flow subtyping end up trivial, e.g., the target of a - // CallRef is compared to itself (and we ignore constraints of A :> A). - // However, if we change how |noteNonFlowSubtype| is used in - // SubtypingDiscoverer then we may need to generalize this. - if (super.isRef() && super.getHeapType().isBasic()) { - return; - } + void analyzeModule(Module& wasm) { + struct Info { + // (source, target) pairs for casts. + Set> casts; - // Otherwise, we must take this into account. - noteSubtype(sub, super); - } + // Observed (sub, super) subtype constraints. + Set> subtypings; + }; - void noteCast(HeapType src, HeapType dest) { - if (src == dest || dest.isBottom()) { - return; + struct Collector + : ControlFlowWalker> { + Info& info; + Collector(Info& info) : info(info) {} + void noteSubtype(Type sub, Type super) { + if (sub.isTuple()) { + assert(super.isTuple() && sub.size() == super.size()); + for (size_t i = 0, size = sub.size(); i < size; ++i) { + noteSubtype(sub[i], super[i]); + } + return; + } + if (!sub.isRef() || !super.isRef()) { + return; + } + noteSubtype(sub.getHeapType(), super.getHeapType()); + } + void noteSubtype(HeapType sub, HeapType super) { + if (sub == super || sub.isBottom()) { + return; + } + info.subtypings.insert({sub, super}); + } + void noteSubtype(Type sub, Expression* super) { + noteSubtype(sub, super->type); + } + void noteSubtype(Expression* sub, Type super) { + noteSubtype(sub->type, super); + } + void noteSubtype(Expression* sub, Expression* super) { + noteSubtype(sub->type, super->type); + } + void noteNonFlowSubtype(Expression* sub, Type super) { + // This expression's type must be a subtype of |super|, but the value + // does not flow anywhere - this is a static constraint. As the value + // does not flow, it cannot reach anywhere else, which means we need + // this in order to validate but it does not interact with casts. Given + // that, if super is a basic type then we can simply ignore this: we + // only remove subtyping between user types, so subtyping wrt basic + // types is unchanged, and so this constraint will never be a problem. + // + // This is sort of a hack because in general to be precise we should not + // just consider basic types here - in general, we should note for each + // constraint whether it is a flow-based one or not, and only take the + // flow-based ones into account when looking at the impact of casts. + // However, in practice this is enough as the only non-trivial case of + // |noteNonFlowSubtype| is for RefEq, which uses a basic type (eqref). + // Other cases of non-flow subtyping end up trivial, e.g., the target of + // a CallRef is compared to itself (and we ignore constraints of A :> + // A). However, if we change how |noteNonFlowSubtype| is used in + // SubtypingDiscoverer then we may need to generalize this. + if (super.isRef() && super.getHeapType().isBasic()) { + return; + } + + // Otherwise, we must take this into account. + noteSubtype(sub, super); + } + void noteCast(HeapType src, HeapType dst) { + // Casts to self and casts that must fail because they have incompatible + // types are uninteresting. + if (dst == src) { + return; + } + if (HeapType::isSubType(dst, src)) { + info.casts.insert({src, dst}); + return; + } + if (HeapType::isSubType(src, dst)) { + // This is an upcast that will always succeed, but only if we ensure + // src <: dst. + info.subtypings.insert({src, dst}); + } + } + void noteCast(Expression* src, Type dst) { + if (src->type.isRef() && dst.isRef()) { + noteCast(src->type.getHeapType(), dst.getHeapType()); + } + } + void noteCast(Expression* src, Expression* dst) { + if (src->type.isRef() && dst->type.isRef()) { + noteCast(src->type.getHeapType(), dst->type.getHeapType()); + } + } + }; + + // Collect subtyping constraints and casts from functions in parallel. + ModuleUtils::ParallelFunctionAnalysis analysis( + wasm, [&](Function* func, Info& info) { + if (!func->imported()) { + Collector(info).walkFunctionInModule(func, &wasm); + } + }); + + Info collectedInfo; + for (auto& [_, info] : analysis.map) { + collectedInfo.casts.insert(info.casts.begin(), info.casts.end()); + collectedInfo.subtypings.insert(info.subtypings.begin(), + info.subtypings.end()); } - assert(HeapType::isSubType(dest, src)); - castTypes[src].insert(dest); - } - void noteCast(Type src, Type dest) { - assert(!src.isTuple() && !dest.isTuple()); - if (src == Type::unreachable) { - return; + // Collect constraints from module-level code as well. + Collector collector(collectedInfo); + collector.walkModuleCode(&wasm); + collector.setModule(&wasm); + for (auto& global : wasm.globals) { + collector.visitGlobal(global.get()); + } + for (auto& segment : wasm.elementSegments) { + collector.visitElementSegment(segment.get()); } - assert(src.isRef() && dest.isRef()); - noteCast(src.getHeapType(), dest.getHeapType()); - } - // Note a cast where one or both sides are expressions. - void noteCast(Expression* src, Type dest) { noteCast(src->type, dest); } - void noteCast(Expression* src, Expression* dest) { - noteCast(src->type, dest->type); + // Prepare the collected information for the upcoming processing loop. + for (auto& [sub, super] : collectedInfo.subtypings) { + note(sub, super); + } + for (auto [src, dst] : collectedInfo.casts) { + casts[src].push_back(dst); + } } - void analyzePublicTypes(Module& wasm) { - // We cannot change supertypes for anything public. - for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) { - if (auto super = type.getDeclaredSuperType()) { - noteSubtype(type, *super); + size_t processCount = 0; + void process(HeapType sub, HeapType super) { + ++processCount; + auto oldSuper = types.getSupertype(sub); + if (oldSuper) { + // We already had a recorded supertype. The new supertype might be + // deeper,shallower, or equal to the old supertype. We must recursively + // note the relationship between the old and new supertypes. + if (super == *oldSuper) { + // Nothing new to do here. + return; + } + if (HeapType::isSubType(*oldSuper, super)) { + // sub <: oldSuper <: super + note(*oldSuper, super); + // We already handled sub <: oldSuper, so we're done. + return; } + // sub <: super <: oldSuper + // Eagerly process super <: oldSuper first. This ensures that sub and + // super will already be in the same tree when we process them below, so + // when we process casts we will know that we only need to process up to + // oldSuper. + process(super, *oldSuper); } + + types.setSupertype(sub, super); + + // We have a new supertype. Find the implied subtypings from the type + // definitions and casts. + processDefinitions(sub, super); + processCasts(sub, super, oldSuper); } - void analyzeTransitiveDependencies() { - // While we have found new subtypings and have not reached a fixed point... - while (!work.empty()) { - // Subtype relationships that we are keeping might depend on other subtype - // relationships that we are not yet planning to keep. Transitively find - // all the relationships we need to keep all our type definitions valid. - while (!work.empty()) { - auto type = work.pop(); - auto super = supertypes.at(type); - if (super.isBasic()) { - continue; - } - switch (type.getKind()) { - case HeapTypeKind::Func: { - auto sig = type.getSignature(); - auto superSig = super.getSignature(); - noteSubtype(superSig.params, sig.params); - noteSubtype(sig.results, superSig.results); - break; - } - case HeapTypeKind::Struct: { - const auto& fields = type.getStruct().fields; - const auto& superFields = super.getStruct().fields; - for (size_t i = 0, size = superFields.size(); i < size; ++i) { - noteSubtype(fields[i].type, superFields[i].type); - } - break; - } - case HeapTypeKind::Array: { - auto elem = type.getArray().element; - noteSubtype(elem.type, super.getArray().element.type); - break; - } - case HeapTypeKind::Cont: - WASM_UNREACHABLE("TODO: cont"); - case HeapTypeKind::Basic: - WASM_UNREACHABLE("unexpected kind"); - } - if (auto desc = type.getDescriptorType()) { - if (auto superDesc = super.getDescriptorType()) { - noteSubtype(*desc, *superDesc); - } + void processDefinitions(HeapType sub, HeapType super) { + if (super.isBasic()) { + return; + } + switch (sub.getKind()) { + case HeapTypeKind::Func: { + auto sig = sub.getSignature(); + auto superSig = super.getSignature(); + note(superSig.params, sig.params); + note(sig.results, superSig.results); + break; + } + case HeapTypeKind::Struct: { + const auto& fields = sub.getStruct().fields; + const auto& superFields = super.getStruct().fields; + for (size_t i = 0, size = superFields.size(); i < size; ++i) { + note(fields[i].type, superFields[i].type); } + break; + } + case HeapTypeKind::Array: { + auto elem = sub.getArray().element; + note(elem.type, super.getArray().element.type); + break; + } + case HeapTypeKind::Cont: + WASM_UNREACHABLE("TODO: cont"); + case HeapTypeKind::Basic: + WASM_UNREACHABLE("unexpected kind"); + } + if (auto desc = sub.getDescriptorType()) { + if (auto superDesc = super.getDescriptorType()) { + note(*desc, *superDesc); } - - // Analyze all casts at once. - // TODO: This is expensive. Analyze casts incrementally after we - // initially analyze them. - analyzeCasts(); } } - void analyzeCasts() { - // For each cast (src, dest) pair, any type that remains a subtype of src - // (meaning its values can inhabit locations typed src) and that was - // originally a subtype of dest (meaning its values would have passed the - // cast) should remain a subtype of dest so that its values continue to pass - // the cast. - // - // For every type, walk up its new supertype chain to find cast sources and - // compare against their associated cast destinations. - for (auto it = supertypes.begin(); it != supertypes.end(); ++it) { - auto type = it->first; - for (auto srcIt = it; srcIt != supertypes.end(); - srcIt = supertypes.find(srcIt->second)) { - auto src = srcIt->second; - auto destsIt = castTypes.find(src); - if (destsIt == castTypes.end()) { - continue; + void + processCasts(HeapType sub, HeapType super, std::optional oldSuper) { + // We are either attaching the one tree rooted at `type` under a new + // supertype in another tree, or we are reparenting `type` below a + // descendent of `oldSuper` in the same tree. In the former case, we must + // evaluate `type` and all its subtypes against all its new supertypes and + // their cast destinations. In the latter case, `type` and all its subtypes + // must have already been evaluated against `oldSuper` and its supertypes, + // so we only need to additionally evaluate them against supertypes up to + // `oldSuper`. + for (auto type : types.subtypes(sub)) { + for (auto src : types.supertypes(super)) { + if (oldSuper && src == *oldSuper) { + break; } - for (auto dest : destsIt->second) { - if (HeapType::isSubType(type, dest)) { - noteSubtype(type, dest); + for (auto dst : casts[src]) { + if (HeapType::isSubType(type, dst)) { + note(type, dst); } } } } } - void optimizeTypes(Module& wasm) { + void rewriteTypes(Module& wasm) { struct Rewriter : GlobalTypeRewriter { Unsubtyping& parent; Rewriter(Unsubtyping& parent, Module& wasm) : GlobalTypeRewriter(wasm), parent(parent) {} std::optional getDeclaredSuperType(HeapType type) override { - if (auto it = parent.supertypes.find(type); - it != parent.supertypes.end() && !it->second.isBasic()) { - return it->second; + if (auto super = parent.types.getSupertype(type); + super && !super->isBasic()) { + return *super; } return std::nullopt; } }; Rewriter(*this, wasm).update(); } - - void doWalkModule(Module* wasm) { - // Visit the functions in parallel, filling in `supertypes` and `castTypes` - // on separate instances which will later be merged. - ModuleUtils::ParallelFunctionAnalysis analysis( - *wasm, [&](Function* func, Unsubtyping& unsubtyping) { - if (!func->imported()) { - unsubtyping.walkFunctionInModule(func, wasm); - } - }); - // Collect the results from the functions. - for (auto& [_, unsubtyping] : analysis.map) { - for (auto [sub, super] : unsubtyping.supertypes) { - noteSubtype(sub, super); - } - for (auto& [src, dests] : unsubtyping.castTypes) { - for (auto dest : dests) { - noteCast(src, dest); - } - } - } - // Collect constraints from top-level items. - for (auto& global : wasm->globals) { - visitGlobal(global.get()); - } - for (auto& seg : wasm->elementSegments) { - visitElementSegment(seg.get()); - } - // Visit the rest of the code that is not in functions. - walkModuleCode(wasm); - } }; } // anonymous namespace diff --git a/test/lit/passes/unsubtyping.wast b/test/lit/passes/unsubtyping.wast index 971fe63da3c..dc365943f82 100644 --- a/test/lit/passes/unsubtyping.wast +++ b/test/lit/passes/unsubtyping.wast @@ -1850,3 +1850,25 @@ ) ) ) + +;; Regression test for assertion failure on incorrect updating of type tree +;; state. +(module + (rec + ;; CHECK: (rec + ;; CHECK-NEXT: (type $0 (sub (struct))) + (type $0 (sub (struct))) + ;; CHECK: (type $1 (sub $0 (struct (field (ref null $0))))) + (type $1 (sub $0 (struct (field (ref null $0))))) + ;; CHECK: (type $2 (sub $1 (struct (field (ref null $3))))) + (type $2 (sub $1 (struct (field (ref null $3))))) + ;; CHECK: (type $3 (sub $0 (struct))) + (type $3 (sub $0 (struct))) + ) + ;; CHECK: (global $g (ref struct) (struct.new_default $2)) + (global $g (ref struct) (struct.new_default $2)) + ;; CHECK: (global $g2 (ref null $1) (ref.null none)) + (global $g2 (ref null $1) (ref.null none)) + ;; CHECK: (export "" (global $g2)) + (export "" (global $g2)) +)