|
| 1 | +/* |
| 2 | + * Copyright 2025 WebAssembly Community Group participants |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +// As a POC, only do the backward analyis to find unused parameters, including |
| 18 | +// those that appear to be used because they are forwarded on to another call |
| 19 | +// but are then unused by that call. |
| 20 | +// |
| 21 | +// To match and exceed the power of DAE, we will need to extend this backward |
| 22 | +// analysis to find unused results as well, and also add a forward analysis that |
| 23 | +// propagates constants and types through parameters and results. |
| 24 | + |
| 25 | +#include <memory> |
| 26 | +#include <unordered_map> |
| 27 | +#include <vector> |
| 28 | + |
| 29 | +#include "analysis/lattices/bool.h" |
| 30 | +#include "ir/local-graph.h" |
| 31 | +#include "pass.h" |
| 32 | +#include "support/index.h" |
| 33 | +#include "support/utilities.h" |
| 34 | +#include "wasm-traversal.h" |
| 35 | +#include "wasm.h" |
| 36 | + |
| 37 | +namespace wasm { |
| 38 | + |
| 39 | +namespace { |
| 40 | + |
| 41 | +// Analysis lattice: top/true = used, bot/false = unused. |
| 42 | +using Used = analysis::Bool; |
| 43 | + |
| 44 | +// Function index and parameter index. |
| 45 | +using ParamLoc = std::pair<Index, Index>; |
| 46 | + |
| 47 | +// A set of (source, destination) index pairs for parameters of a caller |
| 48 | +// function being forwarded as arguments to a called function. |
| 49 | +using ForwardedParamSet = std::unordered_set<std::pair<Index, Index>>; |
| 50 | + |
| 51 | +struct FunctionInfo { |
| 52 | + // Analysis results. |
| 53 | + // TODO: Fix Bool to wrap its element in a struct so we can store it directly |
| 54 | + // in a vector without getting the bool overload. |
| 55 | + std::vector<std::tuple<Used::Element>> paramUsages; |
| 56 | + |
| 57 | + // Map callee function names to their forwarded params for direct calls. |
| 58 | + std::unordered_map<Name, ForwardedParamSet> directForwardedParams; |
| 59 | + |
| 60 | + // Map callee types to their forwarded params for indirect calls. |
| 61 | + std::unordered_map<HeapType, ForwardedParamSet> indirectForwardedParams; |
| 62 | + |
| 63 | + // For each parameter of this function, the list of parameters in direct |
| 64 | + // callers that will become used if the parameter in this function turns out |
| 65 | + // to be used. Computed by reversing the directForwardedParams graph. |
| 66 | + std::vector<std::vector<ParamLoc>> callerParams; |
| 67 | + |
| 68 | + // Whether we need to additionally propagate param usage to indirect callers |
| 69 | + // of this function's type. Atomic because it can be set when visiting other |
| 70 | + // functions in parallel. |
| 71 | + std::atomic<bool> referenced = false; |
| 72 | +}; |
| 73 | + |
| 74 | +struct GraphBuilder : public WalkerPass<ExpressionStackWalker<GraphBuilder>> { |
| 75 | + // Analysis lattice. |
| 76 | + const Used& used; |
| 77 | + |
| 78 | + // The function info graph is stored as vectors accessed by function index. |
| 79 | + // Map function names to their indices. |
| 80 | + const std::unordered_map<Name, Index>& funcIndices; |
| 81 | + |
| 82 | + // Vector of analysis info representing the analysis graph we are building. |
| 83 | + // This is populated safely in parallel because the visitor for each function |
| 84 | + // only modifies the entry for that function. |
| 85 | + std::vector<FunctionInfo>& funcInfos; |
| 86 | + |
| 87 | + // The index of the function we are currently walking. |
| 88 | + Index index = -1; |
| 89 | + |
| 90 | + // A use of a parameter local does not necessarily imply the use of the |
| 91 | + // parameter value. We use a local graph to check where parameter values may |
| 92 | + // be used. |
| 93 | + std::optional<LazyLocalGraph> localGraph; |
| 94 | + |
| 95 | + GraphBuilder(const Used& used, |
| 96 | + const std::unordered_map<Name, Index>& funcIndices, |
| 97 | + std::vector<FunctionInfo>& funcInfos) |
| 98 | + : used(used), funcIndices(funcIndices), funcInfos(funcInfos) {} |
| 99 | + |
| 100 | + bool isFunctionParallel() override { return true; } |
| 101 | + bool modifiesBinaryenIR() override { return false; } |
| 102 | + |
| 103 | + std::unique_ptr<Pass> create() override { |
| 104 | + return std::make_unique<GraphBuilder>(used, funcIndices, funcInfos); |
| 105 | + } |
| 106 | + |
| 107 | + void runOnFunction(Module* wasm, Function* func) override { |
| 108 | + assert(index == Index(-1)); |
| 109 | + index = funcIndices.at(func->name); |
| 110 | + assert(index < funcInfos.size()); |
| 111 | + if (func->imported()) { |
| 112 | + // We must assume imported functions use all their parameters. |
| 113 | + auto& usages = funcInfos[index].paramUsages; |
| 114 | + assert(usages.empty()); |
| 115 | + usages.insert(usages.end(), func->getNumParams(), used.getTop()); |
| 116 | + } else { |
| 117 | + localGraph.emplace(func); |
| 118 | + using Super = WalkerPass<ExpressionStackWalker<GraphBuilder>>; |
| 119 | + Super::runOnFunction(wasm, func); |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + void visitRefFunc(RefFunc* curr) { |
| 124 | + funcInfos[funcIndices.at(curr->func)].referenced = true; |
| 125 | + } |
| 126 | + |
| 127 | + Index getArgIndex(const ExpressionList& operands, Expression* arg) { |
| 128 | + for (Index i = 0; i < operands.size(); ++i) { |
| 129 | + if (operands[i] == arg) { |
| 130 | + return i; |
| 131 | + } |
| 132 | + } |
| 133 | + WASM_UNREACHABLE("expected arg"); |
| 134 | + } |
| 135 | + |
| 136 | + void handleDirectForwardedParam(LocalGet* curr, Call* call) { |
| 137 | + auto argIndex = getArgIndex(call->operands, curr); |
| 138 | + auto& forwarded = funcInfos[index].directForwardedParams[call->target]; |
| 139 | + forwarded.insert({curr->index, argIndex}); |
| 140 | + } |
| 141 | + |
| 142 | + void handleIndirectForwardedParam(LocalGet* curr, |
| 143 | + const ExpressionList& operands, |
| 144 | + HeapType type) { |
| 145 | + auto argIndex = getArgIndex(operands, curr); |
| 146 | + auto& forwarded = funcInfos[index].indirectForwardedParams[type]; |
| 147 | + forwarded.insert({curr->index, argIndex}); |
| 148 | + } |
| 149 | + |
| 150 | + void visitLocalGet(LocalGet* curr) { |
| 151 | + if (curr->index >= getFunction()->getNumParams()) { |
| 152 | + // Not a parameter. |
| 153 | + return; |
| 154 | + } |
| 155 | + |
| 156 | + const auto& sets = localGraph->getSets(curr); |
| 157 | + bool usesParam = std::any_of( |
| 158 | + sets.begin(), sets.end(), [](LocalSet* set) { return set == nullptr; }); |
| 159 | + |
| 160 | + if (!usesParam) { |
| 161 | + // The original parameter value does not reach here. |
| 162 | + return; |
| 163 | + } |
| 164 | + |
| 165 | + auto* parent = getParent(); |
| 166 | + if (auto* call = parent->dynCast<Call>()) { |
| 167 | + handleDirectForwardedParam(curr, call); |
| 168 | + } else if (auto* call = parent->dynCast<CallIndirect>()) { |
| 169 | + handleIndirectForwardedParam(curr, call->operands, call->heapType); |
| 170 | + } else if (auto* call = parent->dynCast<CallRef>()) { |
| 171 | + if (!call->target->type.isSignature()) { |
| 172 | + // The call will never happen, so we don't need to consider it. |
| 173 | + return; |
| 174 | + } |
| 175 | + auto heapType = call->target->type.getHeapType(); |
| 176 | + handleIndirectForwardedParam(curr, call->operands, heapType); |
| 177 | + } else { |
| 178 | + // The parameter value is used by something other than a call. We could |
| 179 | + // check whether the user is a drop, but for simplicity we assume that |
| 180 | + // Vacuum would have already removed such patterns. |
| 181 | + funcInfos[index].paramUsages[curr->index] = used.getTop(); |
| 182 | + } |
| 183 | + } |
| 184 | +}; |
| 185 | + |
| 186 | +struct DAE2 : public Pass { |
| 187 | + // Analysis lattice. |
| 188 | + Used used; |
| 189 | + |
| 190 | + // Map function name to index. |
| 191 | + std::unordered_map<Name, Index> funcIndices; |
| 192 | + |
| 193 | + // The intermediate and final analysis results by function index. |
| 194 | + std::vector<FunctionInfo> funcInfos; |
| 195 | + |
| 196 | + // For each parameter in each indirectly called type, the set of forwarded |
| 197 | + // params in the callers that need to be marked used if a param of a callee of |
| 198 | + // that type is used. |
| 199 | + std::unordered_map<HeapType, std::vector<std::vector<ParamLoc>>> |
| 200 | + indirectCallerParams; |
| 201 | + |
| 202 | + Module* wasm = nullptr; |
| 203 | + |
| 204 | + void run(Module* wasm) override { |
| 205 | + this->wasm = wasm; |
| 206 | + for (auto& func : wasm->functions) { |
| 207 | + funcIndices.insert({func->name, funcIndices.size()}); |
| 208 | + } |
| 209 | + analyzeModule(wasm); |
| 210 | + prepareAnalysis(); |
| 211 | + computeFixedPoint(); |
| 212 | + optimize(); |
| 213 | + } |
| 214 | + |
| 215 | + void analyzeModule(Module* wasm) { |
| 216 | + funcInfos.resize(wasm->functions.size()); |
| 217 | + |
| 218 | + // Analyze functions to find forwarded and used parameters as well as |
| 219 | + // function references. |
| 220 | + GraphBuilder builder(used, funcIndices, funcInfos); |
| 221 | + builder.run(getPassRunner(), wasm); |
| 222 | + |
| 223 | + // Find additional function references at the module level. |
| 224 | + builder.walkModuleCode(wasm); |
| 225 | + |
| 226 | + // Mark parameters of exported functions as used. |
| 227 | + for (auto& export_ : wasm->exports) { |
| 228 | + if (export_->kind == ExternalKind::Function) { |
| 229 | + auto name = *export_->getInternalName(); |
| 230 | + auto& usages = funcInfos[funcIndices.at(name)].paramUsages; |
| 231 | + std::fill(usages.begin(), usages.end(), used.getTop()); |
| 232 | + } |
| 233 | + } |
| 234 | + |
| 235 | + // TODO: Find function types that escape the module beyond exported |
| 236 | + // functions (or just use all public function types as a conservative |
| 237 | + // approximation) and mark parameters of referenced funtions of those types |
| 238 | + // as used. |
| 239 | + } |
| 240 | + |
| 241 | + void prepareAnalysis() { |
| 242 | + // Compute the reverse graph used by the fixed point analysis from the |
| 243 | + // forward graph we have built. |
| 244 | + for (Index i = 0; i < funcInfos.size(); ++i) { |
| 245 | + funcInfos[i].callerParams.resize(funcInfos[i].paramUsages.size()); |
| 246 | + } |
| 247 | + for (Index callerIndex = 0; callerIndex < funcInfos.size(); ++callerIndex) { |
| 248 | + for (auto& [callee, forwarded] : |
| 249 | + funcInfos[callerIndex].directForwardedParams) { |
| 250 | + auto& callerParams = funcInfos[funcIndices.at(callee)].callerParams; |
| 251 | + for (auto& [srcParam, destParam] : forwarded) { |
| 252 | + callerParams[destParam].push_back({callerIndex, srcParam}); |
| 253 | + } |
| 254 | + } |
| 255 | + for (auto& [calleeType, forwarded] : |
| 256 | + funcInfos[callerIndex].indirectForwardedParams) { |
| 257 | + auto& callerParams = indirectCallerParams[calleeType]; |
| 258 | + callerParams.resize(funcInfos[callerIndex].paramUsages.size()); |
| 259 | + for (auto& [srcParam, destParam] : forwarded) { |
| 260 | + callerParams[destParam].push_back({callerIndex, srcParam}); |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + bool join(ParamLoc loc, const Used::Element& other) { |
| 267 | + auto& elem = std::get<0>(funcInfos[loc.first].paramUsages[loc.second]); |
| 268 | + return used.join(elem, other); |
| 269 | + } |
| 270 | + |
| 271 | + void computeFixedPoint() { |
| 272 | + // List of params from which we may need to propagate usage information. |
| 273 | + // Initialized with all params we have observed to be used in the IR. |
| 274 | + std::vector<ParamLoc> work; |
| 275 | + for (Index i = 0; i < funcInfos.size(); ++i) { |
| 276 | + for (Index j = 0; j < funcInfos[i].paramUsages.size(); ++j) { |
| 277 | + work.push_back({i, j}); |
| 278 | + } |
| 279 | + } |
| 280 | + while (!work.empty()) { |
| 281 | + auto [calleeIndex, calleeParamIndex] = work.back(); |
| 282 | + work.pop_back(); |
| 283 | + |
| 284 | + const auto& elem = |
| 285 | + std::get<0>(funcInfos[calleeIndex].paramUsages[calleeParamIndex]); |
| 286 | + assert(elem && "unexpected unused param"); |
| 287 | + |
| 288 | + // Propagate back to forwarded params of direct callers. |
| 289 | + auto& callerParams = |
| 290 | + funcInfos[calleeIndex].callerParams[calleeParamIndex]; |
| 291 | + for (auto param : callerParams) { |
| 292 | + if (join(param, elem)) { |
| 293 | + work.push_back(param); |
| 294 | + } |
| 295 | + } |
| 296 | + |
| 297 | + if (!funcInfos[calleeIndex].referenced) { |
| 298 | + // Non-referenced functions can only be called directly. |
| 299 | + continue; |
| 300 | + } |
| 301 | + |
| 302 | + // Propagate usage back to forwarded params of the indirect callers of all |
| 303 | + // supertypes of this function's type. |
| 304 | + for (std::optional<HeapType> type = |
| 305 | + wasm->functions[calleeIndex]->type.getHeapType(); |
| 306 | + type; |
| 307 | + type = type->getDeclaredSuperType()) { |
| 308 | + auto it = indirectCallerParams.find(*type); |
| 309 | + if (it == indirectCallerParams.end()) { |
| 310 | + continue; |
| 311 | + } |
| 312 | + auto& callerParams = it->second[calleeParamIndex]; |
| 313 | + for (auto param : callerParams) { |
| 314 | + if (join(param, elem)) { |
| 315 | + work.push_back(param); |
| 316 | + } |
| 317 | + } |
| 318 | + } |
| 319 | + |
| 320 | + // TODO: Propagate usage to all functions of any type in the type tree of |
| 321 | + // this function's type to keep subtyping valid. |
| 322 | + } |
| 323 | + } |
| 324 | + |
| 325 | + void optimize() { |
| 326 | + struct Optimizer : public WalkerPass<PostWalker<Optimizer>> { |
| 327 | + // TODO: Visit functions in parallel, replacing unused parameters with |
| 328 | + // locals. Direct calls should look at their target to determine which |
| 329 | + // operands to remove (being sure to preserve side effects using |
| 330 | + // ChildLocalizer). Indirect calls need to look at the analysis results |
| 331 | + // for the target type (TODO: materialize this, possibly just for the root |
| 332 | + // type for each type tree) to determine what operands to remove. |
| 333 | + }; |
| 334 | + Optimizer{}.run(getPassRunner(), wasm); |
| 335 | + } |
| 336 | +}; |
| 337 | + |
| 338 | +} // anonymous namespace |
| 339 | + |
| 340 | +Pass* createDAE2Pass() { return new DAE2(); } |
| 341 | + |
| 342 | +} // namespace wasm |
0 commit comments