diff --git a/src/passes/GlobalEffects.cpp b/src/passes/GlobalEffects.cpp index ef0977d12fa..39fe4874bb7 100644 --- a/src/passes/GlobalEffects.cpp +++ b/src/passes/GlobalEffects.cpp @@ -22,168 +22,172 @@ #include "ir/effects.h" #include "ir/module-utils.h" #include "pass.h" +#include "support/hash.h" #include "support/unique_deferring_queue.h" #include "wasm.h" namespace wasm { -struct GenerateGlobalEffects : public Pass { - void run(Module* module) override { - // First, we do a scan of each function to see what effects they have, - // including which functions they call directly (so that we can compute - // transitive effects later). - - struct FuncInfo { - // Effects in this function. - std::optional effects; - - // Directly-called functions from this function. - std::unordered_set calledFunctions; - }; - - ModuleUtils::ParallelFunctionAnalysis analysis( - *module, [&](Function* func, FuncInfo& funcInfo) { - if (func->imported()) { - // Imports can do anything, so we need to assume the worst anyhow, - // which is the same as not specifying any effects for them in the - // map (which we do by not setting funcInfo.effects). - return; - } - - // Gather the effects. - funcInfo.effects.emplace(getPassOptions(), *module, func); - - if (funcInfo.effects->calls) { - // There are calls in this function, which we will analyze in detail. - // Clear the |calls| field first, and we'll handle calls of all sorts - // below. - funcInfo.effects->calls = false; - - // Clear throws as well, as we are "forgetting" calls right now, and - // want to forget their throwing effect as well. If we see something - // else that throws, below, then we'll note that there. - funcInfo.effects->throws_ = false; - - struct CallScanner - : public PostWalker> { - Module& wasm; - PassOptions& options; - FuncInfo& funcInfo; - - CallScanner(Module& wasm, PassOptions& options, FuncInfo& funcInfo) - : wasm(wasm), options(options), funcInfo(funcInfo) {} - - void visitExpression(Expression* curr) { - ShallowEffectAnalyzer effects(options, wasm, curr); - if (auto* call = curr->dynCast()) { - // Note the direct call. - funcInfo.calledFunctions.insert(call->target); - } else if (effects.calls) { - // This is an indirect call of some sort, so we must assume the - // worst. To do so, clear the effects, which indicates nothing - // is known (so anything is possible). - // TODO: We could group effects by function type etc. - funcInfo.effects.reset(); - } else { - // No call here, but update throwing if we see it. (Only do so, - // however, if we have effects; if we cleared it - see before - - // then we assume the worst anyhow, and have nothing to update.) - if (effects.throws_ && funcInfo.effects) { - funcInfo.effects->throws_ = true; - } +namespace { + +struct FuncInfo { + // Effects in this function. + std::optional effects; + + // Directly-called functions from this function. + std::unordered_set calledFunctions; +}; + +std::map analyzeFuncs(Module& module, + const PassOptions& passOptions) { + ModuleUtils::ParallelFunctionAnalysis analysis( + module, [&](Function* func, FuncInfo& funcInfo) { + if (func->imported()) { + // Imports can do anything, so we need to assume the worst anyhow, + // which is the same as not specifying any effects for them in the + // map (which we do by not setting funcInfo.effects). + return; + } + + // Gather the effects. + funcInfo.effects.emplace(passOptions, module, func); + + if (funcInfo.effects->calls) { + // There are calls in this function, which we will analyze in detail. + // Clear the |calls| field first, and we'll handle calls of all sorts + // below. + funcInfo.effects->calls = false; + + // Clear throws as well, as we are "forgetting" calls right now, and + // want to forget their throwing effect as well. If we see something + // else that throws, below, then we'll note that there. + funcInfo.effects->throws_ = false; + + struct CallScanner + : public PostWalker> { + Module& wasm; + const PassOptions& options; + FuncInfo& funcInfo; + + CallScanner(Module& wasm, + const PassOptions& options, + FuncInfo& funcInfo) + : wasm(wasm), options(options), funcInfo(funcInfo) {} + + void visitExpression(Expression* curr) { + ShallowEffectAnalyzer effects(options, wasm, curr); + if (auto* call = curr->dynCast()) { + // Note the direct call. + funcInfo.calledFunctions.insert(call->target); + } else if (effects.calls) { + // This is an indirect call of some sort, so we must assume the + // worst. To do so, clear the effects, which indicates nothing + // is known (so anything is possible). + // TODO: We could group effects by function type etc. + funcInfo.effects.reset(); + } else { + // No call here, but update throwing if we see it. (Only do so, + // however, if we have effects; if we cleared it - see before - + // then we assume the worst anyhow, and have nothing to update.) + if (effects.throws_ && funcInfo.effects) { + funcInfo.effects->throws_ = true; } } - }; - CallScanner scanner(*module, getPassOptions(), funcInfo); - scanner.walkFunction(func); - } - }); - - // Compute the transitive closure of effects. To do so, first construct for - // each function a list of the functions that it is called by (so we need to - // propagate its effects to them), and then we'll construct the closure of - // that. - // - // callers[foo] = [func that calls foo, another func that calls foo, ..] - // - std::unordered_map> callers; - - // Our work queue contains info about a new call pair: a call from a caller - // to a called function, that is information we then apply and propagate. - using CallPair = std::pair; // { caller, called } - UniqueDeferredQueue work; - for (auto& [func, info] : analysis.map) { - for (auto& called : info.calledFunctions) { - work.push({func->name, called}); + } + }; + CallScanner scanner(module, passOptions, funcInfo); + scanner.walkFunction(func); } + }); + + return std::move(analysis.map); +} + +// Propagate effects from callees to callers transitively +// e.g. if A -> B -> C (A calls B which calls C) +// Then B inherits effects from C and A inherits effects from both B and C. +void propagateEffects( + const Module& module, + const std::unordered_map>& in, + std::map& funcInfos) { + + std::unordered_set> processed; + std::deque> work; + + for (const auto& [callee, callers] : in) { + for (const auto& caller : callers) { + work.emplace_back(callee, caller); + processed.emplace(callee, caller); + } + } + + auto propagate = [&](Name callee, Name caller) { + auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects; + const auto& calleeEffects = + funcInfos.at(module.getFunction(callee)).effects; + if (!callerEffects) { + return; + } + + if (!calleeEffects) { + callerEffects.reset(); + return; } - // Compute the transitive closure of the call graph, that is, fill out - // |callers| so that it contains the list of all callers - even through a - // chain - of each function. - while (!work.empty()) { - auto [caller, called] = work.pop(); - - // We must not already have an entry for this call (that would imply we - // are doing wasted work). - assert(!callers[called].contains(caller)); - - // Apply the new call information. - callers[called].insert(caller); - - // We just learned that |caller| calls |called|. It also calls - // transitively, which we need to propagate to all places unaware of that - // information yet. - // - // caller => called => called by called - // - auto& calledInfo = analysis.map[module->getFunction(called)]; - for (auto calledByCalled : calledInfo.calledFunctions) { - if (!callers[calledByCalled].contains(caller)) { - work.push({caller, calledByCalled}); - } + callerEffects->mergeIn(*calleeEffects); + }; + + while (!work.empty()) { + auto [callee, caller] = work.back(); + work.pop_back(); + + if (callee == caller) { + auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects; + if (callerEffects) { + callerEffects->trap = true; } } - // Now that we have transitively propagated all static calls, apply that - // information. First, apply infinite recursion: if a function can call - // itself then it might recurse infinitely, which we consider an effect (a - // trap). - for (auto& [func, info] : analysis.map) { - if (callers[func->name].contains(func->name)) { - if (info.effects) { - info.effects->trap = true; - } + // Even if nothing changed, we still need to keep traversing the callers + // to look for a potential cycle which adds a trap affect on the above + // lines. + propagate(callee, caller); + + const auto& callerCallers = in.find(caller); + if (callerCallers == in.end()) { + continue; + } + + for (const Name& callerCaller : callerCallers->second) { + if (processed.contains({callee, callerCaller})) { + continue; } + + processed.emplace(callee, callerCaller); + work.emplace_back(callee, callerCaller); } + } +} + +struct GenerateGlobalEffects : public Pass { + void run(Module* module) override { + std::map funcInfos = + analyzeFuncs(*module, getPassOptions()); - // Next, apply function effects to their callers. - for (auto& [func, info] : analysis.map) { - auto& funcEffects = info.effects; - - for (auto& caller : callers[func->name]) { - auto& callerEffects = analysis.map[module->getFunction(caller)].effects; - if (!callerEffects) { - // Nothing is known for the caller, which is already the worst case. - continue; - } - - if (!funcEffects) { - // Nothing is known for the called function, which means nothing is - // known for the caller either. - callerEffects.reset(); - continue; - } - - // Add func's effects to the caller. - callerEffects->mergeIn(*funcEffects); + // callee : caller + std::unordered_map> callers; + for (const auto& [func, info] : funcInfos) { + for (const auto& callee : info.calledFunctions) { + callers[callee].insert(func->name); } } + propagateEffects(*module, callers, funcInfos); + // Generate the final data, starting from a blank slate where nothing is // known. - for (auto& [func, info] : analysis.map) { + for (auto& [func, info] : funcInfos) { func->effects.reset(); if (!info.effects) { continue; @@ -202,6 +206,8 @@ struct DiscardGlobalEffects : public Pass { } }; +} // namespace + Pass* createGenerateGlobalEffectsPass() { return new GenerateGlobalEffects(); } Pass* createDiscardGlobalEffectsPass() { return new DiscardGlobalEffects(); }