Skip to content

Commit 2c146b6

Browse files
committed
[WIP] Rewrite DAE to use a fixed point analysis
DAE can be slow because it performs several rounds of interleaved analysis and optimization. On top of this, the analysis it performs is not as precise as it could be because it never removes parameters from referenced functions and it cannot optimize unused parameters or results that are forwarded through recursive cycles. Start improving both the performance and the power of DAE by creating a new pass, called DAE2 for now. DAE2 performs a single parallel walk of the module to collect information with which it performs a fixed point analysis to find unused parameters, then does a single parallel walk of the module to optimize based on this analysis.
1 parent 6ec238b commit 2c146b6

4 files changed

Lines changed: 345 additions & 0 deletions

File tree

src/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(passes_SOURCES
2828
ConstHoisting.cpp
2929
DataFlowOpts.cpp
3030
DeadArgumentElimination.cpp
31+
DeadArgumentElimination2.cpp
3132
DeadCodeElimination.cpp
3233
DeAlign.cpp
3334
DebugLocationPropagation.cpp
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
/*
2+
* Copyright 2025 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// As a POC, only do the backward analyis to find unused parameters, including
18+
// those that appear to be used because they are forwarded on to another call
19+
// but are then unused by that call.
20+
//
21+
// To match and exceed the power of DAE, we will need to extend this backward
22+
// analysis to find unused results as well, and also add a forward analysis that
23+
// propagates constants and types through parameters and results.
24+
25+
#include <memory>
26+
#include <unordered_map>
27+
#include <vector>
28+
29+
#include "analysis/lattices/bool.h"
30+
#include "ir/local-graph.h"
31+
#include "pass.h"
32+
#include "support/index.h"
33+
#include "support/utilities.h"
34+
#include "wasm-traversal.h"
35+
#include "wasm.h"
36+
37+
namespace wasm {
38+
39+
namespace {
40+
41+
// Analysis lattice: top/true = used, bot/false = unused.
42+
using Used = analysis::Bool;
43+
44+
// Function index and parameter index.
45+
using ParamLoc = std::pair<Index, Index>;
46+
47+
// A set of (source, destination) index pairs for parameters of a caller
48+
// function being forwarded as arguments to a called function.
49+
using ForwardedParamSet = std::unordered_set<std::pair<Index, Index>>;
50+
51+
struct FunctionInfo {
52+
// Analysis results.
53+
// TODO: Fix Bool to wrap its element in a struct so we can store it directly
54+
// in a vector without getting the bool overload.
55+
std::vector<std::tuple<Used::Element>> paramUsages;
56+
57+
// Map callee function names to their forwarded params for direct calls.
58+
std::unordered_map<Name, ForwardedParamSet> directForwardedParams;
59+
60+
// Map callee types to their forwarded params for indirect calls.
61+
std::unordered_map<HeapType, ForwardedParamSet> indirectForwardedParams;
62+
63+
// For each parameter of this function, the list of parameters in direct
64+
// callers that will become used if the parameter in this function turns out
65+
// to be used. Computed by reversing the directForwardedParams graph.
66+
std::vector<std::vector<ParamLoc>> callerParams;
67+
68+
// Whether we need to additionally propagate param usage to indirect callers
69+
// of this function's type. Atomic because it can be set when visiting other
70+
// functions in parallel.
71+
std::atomic<bool> referenced = false;
72+
};
73+
74+
struct GraphBuilder : public WalkerPass<ExpressionStackWalker<GraphBuilder>> {
75+
// Analysis lattice.
76+
const Used& used;
77+
78+
// The function info graph is stored as vectors accessed by function index.
79+
// Map function names to their indices.
80+
const std::unordered_map<Name, Index>& funcIndices;
81+
82+
// Vector of analysis info representing the analysis graph we are building.
83+
// This is populated safely in parallel because the visitor for each function
84+
// only modifies the entry for that function.
85+
std::vector<FunctionInfo>& funcInfos;
86+
87+
// The index of the function we are currently walking.
88+
Index index = -1;
89+
90+
// A use of a parameter local does not necessarily imply the use of the
91+
// parameter value. We use a local graph to check where parameter values may
92+
// be used.
93+
std::optional<LazyLocalGraph> localGraph;
94+
95+
GraphBuilder(const Used& used,
96+
const std::unordered_map<Name, Index>& funcIndices,
97+
std::vector<FunctionInfo>& funcInfos)
98+
: used(used), funcIndices(funcIndices), funcInfos(funcInfos) {}
99+
100+
bool isFunctionParallel() override { return true; }
101+
bool modifiesBinaryenIR() override { return false; }
102+
103+
std::unique_ptr<Pass> create() override {
104+
return std::make_unique<GraphBuilder>(used, funcIndices, funcInfos);
105+
}
106+
107+
void runOnFunction(Module* wasm, Function* func) override {
108+
assert(index == Index(-1));
109+
index = funcIndices.at(func->name);
110+
assert(index < funcInfos.size());
111+
if (func->imported()) {
112+
// We must assume imported functions use all their parameters.
113+
auto& usages = funcInfos[index].paramUsages;
114+
assert(usages.empty());
115+
usages.insert(usages.end(), func->getNumParams(), used.getTop());
116+
} else {
117+
localGraph.emplace(func);
118+
using Super = WalkerPass<ExpressionStackWalker<GraphBuilder>>;
119+
Super::runOnFunction(wasm, func);
120+
}
121+
}
122+
123+
void visitRefFunc(RefFunc* curr) {
124+
funcInfos[funcIndices.at(curr->func)].referenced = true;
125+
}
126+
127+
Index getArgIndex(const ExpressionList& operands, Expression* arg) {
128+
for (Index i = 0; i < operands.size(); ++i) {
129+
if (operands[i] == arg) {
130+
return i;
131+
}
132+
}
133+
WASM_UNREACHABLE("expected arg");
134+
}
135+
136+
void handleDirectForwardedParam(LocalGet* curr, Call* call) {
137+
auto argIndex = getArgIndex(call->operands, curr);
138+
auto& forwarded = funcInfos[index].directForwardedParams[call->target];
139+
forwarded.insert({curr->index, argIndex});
140+
}
141+
142+
void handleIndirectForwardedParam(LocalGet* curr,
143+
const ExpressionList& operands,
144+
HeapType type) {
145+
auto argIndex = getArgIndex(operands, curr);
146+
auto& forwarded = funcInfos[index].indirectForwardedParams[type];
147+
forwarded.insert({curr->index, argIndex});
148+
}
149+
150+
void visitLocalGet(LocalGet* curr) {
151+
if (curr->index >= getFunction()->getNumParams()) {
152+
// Not a parameter.
153+
return;
154+
}
155+
156+
const auto& sets = localGraph->getSets(curr);
157+
bool usesParam = std::any_of(
158+
sets.begin(), sets.end(), [](LocalSet* set) { return set == nullptr; });
159+
160+
if (!usesParam) {
161+
// The original parameter value does not reach here.
162+
return;
163+
}
164+
165+
auto* parent = getParent();
166+
if (auto* call = parent->dynCast<Call>()) {
167+
handleDirectForwardedParam(curr, call);
168+
} else if (auto* call = parent->dynCast<CallIndirect>()) {
169+
handleIndirectForwardedParam(curr, call->operands, call->heapType);
170+
} else if (auto* call = parent->dynCast<CallRef>()) {
171+
if (!call->target->type.isSignature()) {
172+
// The call will never happen, so we don't need to consider it.
173+
return;
174+
}
175+
auto heapType = call->target->type.getHeapType();
176+
handleIndirectForwardedParam(curr, call->operands, heapType);
177+
} else {
178+
// The parameter value is used by something other than a call. We could
179+
// check whether the user is a drop, but for simplicity we assume that
180+
// Vacuum would have already removed such patterns.
181+
funcInfos[index].paramUsages[curr->index] = used.getTop();
182+
}
183+
}
184+
};
185+
186+
struct DAE2 : public Pass {
187+
// Analysis lattice.
188+
Used used;
189+
190+
// Map function name to index.
191+
std::unordered_map<Name, Index> funcIndices;
192+
193+
// The intermediate and final analysis results by function index.
194+
std::vector<FunctionInfo> funcInfos;
195+
196+
// For each parameter in each indirectly called type, the set of forwarded
197+
// params in the callers that need to be marked used if a param of a callee of
198+
// that type is used.
199+
std::unordered_map<HeapType, std::vector<std::vector<ParamLoc>>>
200+
indirectCallerParams;
201+
202+
Module* wasm = nullptr;
203+
204+
void run(Module* wasm) override {
205+
this->wasm = wasm;
206+
for (auto& func : wasm->functions) {
207+
funcIndices.insert({func->name, funcIndices.size()});
208+
}
209+
analyzeModule(wasm);
210+
prepareAnalysis();
211+
computeFixedPoint();
212+
optimize();
213+
}
214+
215+
void analyzeModule(Module* wasm) {
216+
funcInfos.resize(wasm->functions.size());
217+
218+
// Analyze functions to find forwarded and used parameters as well as
219+
// function references.
220+
GraphBuilder builder(used, funcIndices, funcInfos);
221+
builder.run(getPassRunner(), wasm);
222+
223+
// Find additional function references at the module level.
224+
builder.walkModuleCode(wasm);
225+
226+
// Mark parameters of exported functions as used.
227+
for (auto& export_ : wasm->exports) {
228+
if (export_->kind == ExternalKind::Function) {
229+
auto name = *export_->getInternalName();
230+
auto& usages = funcInfos[funcIndices.at(name)].paramUsages;
231+
std::fill(usages.begin(), usages.end(), used.getTop());
232+
}
233+
}
234+
235+
// TODO: Find function types that escape the module beyond exported
236+
// functions (or just use all public function types as a conservative
237+
// approximation) and mark parameters of referenced funtions of those types
238+
// as used.
239+
}
240+
241+
void prepareAnalysis() {
242+
// Compute the reverse graph used by the fixed point analysis from the
243+
// forward graph we have built.
244+
for (Index i = 0; i < funcInfos.size(); ++i) {
245+
funcInfos[i].callerParams.resize(funcInfos[i].paramUsages.size());
246+
}
247+
for (Index callerIndex = 0; callerIndex < funcInfos.size(); ++callerIndex) {
248+
for (auto& [callee, forwarded] :
249+
funcInfos[callerIndex].directForwardedParams) {
250+
auto& callerParams = funcInfos[funcIndices.at(callee)].callerParams;
251+
for (auto& [srcParam, destParam] : forwarded) {
252+
callerParams[destParam].push_back({callerIndex, srcParam});
253+
}
254+
}
255+
for (auto& [calleeType, forwarded] :
256+
funcInfos[callerIndex].indirectForwardedParams) {
257+
auto& callerParams = indirectCallerParams[calleeType];
258+
callerParams.resize(funcInfos[callerIndex].paramUsages.size());
259+
for (auto& [srcParam, destParam] : forwarded) {
260+
callerParams[destParam].push_back({callerIndex, srcParam});
261+
}
262+
}
263+
}
264+
}
265+
266+
bool join(ParamLoc loc, const Used::Element& other) {
267+
auto& elem = std::get<0>(funcInfos[loc.first].paramUsages[loc.second]);
268+
return used.join(elem, other);
269+
}
270+
271+
void computeFixedPoint() {
272+
// List of params from which we may need to propagate usage information.
273+
// Initialized with all params we have observed to be used in the IR.
274+
std::vector<ParamLoc> work;
275+
for (Index i = 0; i < funcInfos.size(); ++i) {
276+
for (Index j = 0; j < funcInfos[i].paramUsages.size(); ++j) {
277+
work.push_back({i, j});
278+
}
279+
}
280+
while (!work.empty()) {
281+
auto [calleeIndex, calleeParamIndex] = work.back();
282+
work.pop_back();
283+
284+
const auto& elem =
285+
std::get<0>(funcInfos[calleeIndex].paramUsages[calleeParamIndex]);
286+
assert(elem && "unexpected unused param");
287+
288+
// Propagate back to forwarded params of direct callers.
289+
auto& callerParams =
290+
funcInfos[calleeIndex].callerParams[calleeParamIndex];
291+
for (auto param : callerParams) {
292+
if (join(param, elem)) {
293+
work.push_back(param);
294+
}
295+
}
296+
297+
if (!funcInfos[calleeIndex].referenced) {
298+
// Non-referenced functions can only be called directly.
299+
continue;
300+
}
301+
302+
// Propagate usage back to forwarded params of the indirect callers of all
303+
// supertypes of this function's type.
304+
for (std::optional<HeapType> type =
305+
wasm->functions[calleeIndex]->type.getHeapType();
306+
type;
307+
type = type->getDeclaredSuperType()) {
308+
auto it = indirectCallerParams.find(*type);
309+
if (it == indirectCallerParams.end()) {
310+
continue;
311+
}
312+
auto& callerParams = it->second[calleeParamIndex];
313+
for (auto param : callerParams) {
314+
if (join(param, elem)) {
315+
work.push_back(param);
316+
}
317+
}
318+
}
319+
320+
// TODO: Propagate usage to all functions of any type in the type tree of
321+
// this function's type to keep subtyping valid.
322+
}
323+
}
324+
325+
void optimize() {
326+
struct Optimizer : public WalkerPass<PostWalker<Optimizer>> {
327+
// TODO: Visit functions in parallel, replacing unused parameters with
328+
// locals. Direct calls should look at their target to determine which
329+
// operands to remove (being sure to preserve side effects using
330+
// ChildLocalizer). Indirect calls need to look at the analysis results
331+
// for the target type (TODO: materialize this, possibly just for the root
332+
// type for each type tree) to determine what operands to remove.
333+
};
334+
Optimizer{}.run(getPassRunner(), wasm);
335+
}
336+
};
337+
338+
} // anonymous namespace
339+
340+
Pass* createDAE2Pass() { return new DAE2(); }
341+
342+
} // namespace wasm

src/passes/pass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void PassRegistry::registerPasses() {
105105
"removes arguments to calls in an lto-like manner, and "
106106
"optimizes where we removed",
107107
createDAEOptimizingPass);
108+
registerPass("dae2", "Experimental reimplementation of DAE", createDAE2Pass);
108109
registerPass("abstract-type-refining",
109110
"refine and merge abstract (never-created) types",
110111
createAbstractTypeRefiningPass);

src/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Pass* createConstantFieldPropagationPass();
3535
Pass* createConstantFieldPropagationRefTestPass();
3636
Pass* createDAEPass();
3737
Pass* createDAEOptimizingPass();
38+
Pass* createDAE2Pass();
3839
Pass* createDataFlowOptsPass();
3940
Pass* createDeadCodeEliminationPass();
4041
Pass* createDeInstrumentBranchHintsPass();

0 commit comments

Comments
 (0)