Skip to content

[SYCL] Add SYCL Module splitting. #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions llvm/include/llvm/Transforms/Utils/SYCLModuleSplit.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
//===-------- SYCLModuleSplit.h - module split ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Functionality to split a module into call graphs. A callgraph here is a set
// of entry points with all functions reachable from them via a call. The result

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we may need to elaborate on what "entry point" means here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One related comment. Do we want to treat only SYCL kernels as 'entry points' or do we want to consider function with SYCL_EXTERNAL attribute also as 'entry points'? I think the former option is more viable from upstreaming POV.

Thanks

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Treating SYCL_EXTERNAL as entry points comes from "extra" features, like interoperability with other languages (like linking to a kernel written in ISPC which calls a function written in SYCL), or support for shared libraries/dynamic linking.

For our initial patch, I think that we can most certainly simplify things down to only considering kernels as entry points

// of the split is new modules containing corresponding callgraph.
Copy link

@asudarsa asudarsa Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my take.

This functionality takes as input a fully linked SYCL device module with a set of SYCL device kernels and performs splitting to generate several fully-contained device modules. Each of the newly formed module contains a sub-set of the original set of SYCL device kernels along with a union of all the functions from each of their respective call graphs. Here, call graph of a SYCL kernel is the set of all functions reachable from that kernel.

//===----------------------------------------------------------------------===//

#ifndef LLVM_SYCL_MODULE_SPLIT_H
#define LLVM_SYCL_MODULE_SPLIT_H

#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Function.h"
#include "llvm/Support/Error.h"

#include <memory>
#include <optional>
#include <string>

namespace llvm {

class Function;
class Module;

enum class IRSplitMode {
IRSM_PER_TU, // one module per translation unit
IRSM_PER_KERNEL, // one module per kernel
IRSM_AUTO, // automatically select split mode
IRSM_NONE // no splitting
};

/// \returns IRSplitMode value if \p S is recognized. Otherwise, std::nullopt is
/// returned.
std::optional<IRSplitMode> convertStringToSplitMode(StringRef S);

// A vector that contains all entry point functions in a split module.
using EntryPointSet = SetVector<Function *>;

/// Describes scope covered by each entry in the module-entry points map
/// populated by the groupEntryPointsByScope function.
enum EntryPointsGroupScope {
Scope_PerKernel, // one entry per kernel
Scope_PerModule, // one entry per module
Scope_Global // single entry in the map for all kernels
};

/// Represents a named group of device code entry points - kernels and
/// SYCL_EXTERNAL functions.
struct EntryPointGroup {
// Properties an entry point (EP) group
struct Properties {
// Scope represented by EPs in a group
EntryPointsGroupScope Scope = Scope_Global;
};

std::string GroupId;
EntryPointSet Functions;
Properties Props;

EntryPointGroup(StringRef GroupId = "") : GroupId(GroupId) {}
EntryPointGroup(StringRef GroupId, EntryPointSet Functions)
: GroupId(GroupId), Functions(std::move(Functions)) {}
EntryPointGroup(StringRef GroupId, EntryPointSet Functions,
const Properties &Props)
: GroupId(GroupId), Functions(std::move(Functions)), Props(Props) {}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be more consistent about by-value vs lvalue-ref vs rvalue-ref to accept more complex arguments like EntryPointSet? Note that it is a SetVector and we are making a copy here in those constuctors.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All EntryPointSet arguments are expected to be constructed by move constructors. SetVector is a container that stores content in the heap and it's move constructors are "cheap".
In our repository, they have been passed over by rvalue references which is unconventional in C++. This observation has been inspired by tips like the following: https://abseil.io/tips/117. However, there is no direct guidance for function's arguments.

I see now that we have a const Properties & argument that is being copied right away. I think we could apply the same principle for this argument as following:

EntryPointGroup(StringRef GroupId, EntryPointSet Functions, Properties Props)
      : GroupId(GroupId), Functions(std::move(Functions)), Props(std::move(Props)) {}

};

using EntryPointGroupVec = SmallVector<EntryPointGroup, 0>;

/// Annotates an llvm::Module with information necessary to perform and track
/// result of device code (llvm::Module instances) splitting:
/// - entry points of the module determined e.g. by a module splitter, as well
/// as information about entry point origin (e.g. result of a scoped split)
/// - its properties, such as whether it has specialization constants uses
/// It also provides convenience functions for entry point set transformation
/// between llvm::Function object and string representations.
class ModuleDesc {
std::unique_ptr<Module> M;
EntryPointGroup EntryPoints;

public:
ModuleDesc(std::unique_ptr<Module> M) : M(std::move(M)) {}

ModuleDesc(std::unique_ptr<Module> M, EntryPointGroup EntryPoints)
: M(std::move(M)), EntryPoints(std::move(EntryPoints)) {}

const EntryPointSet &entries() const { return EntryPoints.Functions; }
const EntryPointGroup &getEntryPointGroup() const { return EntryPoints; }
EntryPointSet &entries() { return EntryPoints.Functions; }
Module &getModule() { return *M; }
const Module &getModule() const { return *M; }
std::unique_ptr<Module> releaseModulePtr() { return std::move(M); }

// Cleans up module IR - removes dead globals, debug info etc.
void cleanup();

std::string makeSymbolTable() const;

void dump() const;
};

/// Module split support interface.
/// It gets a module (in a form of module descriptor, to get additional info)
/// and a collection of entry points groups. Each group specifies subset entry
/// points
// from input module that should be included in a split module.
class ModuleSplitterBase {
protected:
ModuleDesc Input;
EntryPointGroupVec Groups;

protected:
EntryPointGroup nextGroup() {
assert(hasMoreSplits() && "Reached end of entry point groups list.");
EntryPointGroup Res = std::move(Groups.back());
Groups.pop_back();
return Res;
}

Module &getInputModule() { return Input.getModule(); }

std::unique_ptr<Module> releaseInputModule() {
return Input.releaseModulePtr();
}

public:
ModuleSplitterBase(ModuleDesc MD, EntryPointGroupVec GroupVec)
: Input(std::move(MD)), Groups(std::move(GroupVec)) {
assert(!Groups.empty() && "Entry points groups collection is empty!");
}

virtual ~ModuleSplitterBase() = default;

/// Gets next subsequence of entry points in an input module and provides
/// split submodule containing these entry points and their dependencies.
virtual ModuleDesc nextSplit() = 0;

/// Returns a number of remaining modules, which can be split out using this
/// splitter. The value is reduced by 1 each time nextSplit is called.
size_t remainingSplits() const { return Groups.size(); }

/// Check that there are still submodules to split.
bool hasMoreSplits() const { return remainingSplits() > 0; }
};

std::unique_ptr<ModuleSplitterBase>
getDeviceCodeSplitter(ModuleDesc MD, IRSplitMode Mode, bool IROutputOnly,
bool EmitOnlyKernelsAsEntryPoints);

/// The structure represents a split LLVM Module accompanied by additional
/// information. Split Modules are being stored at disk due to the high RAM
/// consumption during the whole splitting process.
struct SYCLSplitModule {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider adding a descriptive comment. Thanks

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that it is a simple string pair, do we really need to have this custom struct at all?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are going to add Properties into this structure in follow-up patches.

std::string ModuleFilePath;
std::string Symbols;

SYCLSplitModule() = default;
SYCLSplitModule(const SYCLSplitModule &) = default;
SYCLSplitModule &operator=(const SYCLSplitModule &) = default;
SYCLSplitModule(SYCLSplitModule &&) = default;
SYCLSplitModule &operator=(SYCLSplitModule &&) = default;

SYCLSplitModule(std::string_view File, std::string Symbols)
: ModuleFilePath(File), Symbols(std::move(Symbols)) {}
};

struct ModuleSplitterSettings {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some separator kind of comment would be welcome here, I think to signify that there is a code which is related to the splitting algorithm itself and there is a code which is related to a tooling we have (that is in turn mostly used for testing)

IRSplitMode Mode;
bool OutputAssembly = false; // Bitcode or LLVM IR.
StringRef OutputPrefix;
};

/// Parses the string table.
Expected<SmallVector<SYCLSplitModule, 0>>
parseSYCLSplitModulesFromFile(StringRef File);

/// Splits the given module \p M according to the given \p Settings.
Expected<SmallVector<SYCLSplitModule, 0>>
splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings);

} // namespace llvm

#endif // LLVM_SYCL_MODULE_SPLIT_H
95 changes: 95 additions & 0 deletions llvm/include/llvm/Transforms/Utils/SYCLUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===------------ SYCLUtils.h - SYCL utility functions --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Utility functions for SYCL.
//===----------------------------------------------------------------------===//
#ifndef LLVM_TRANSFORMS_UTILS_SYCLUTILS_H
#define LLVM_TRANSFORMS_UTILS_SYCLUTILS_H

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"

#include <functional>
#include <string>
#include <vector>

namespace llvm {

constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id";
constexpr char ATTR_SYCL_OPTLEVEL[] = "sycl-optlevel";

using CallGraphNodeAction = ::std::function<void(Function *)>;
using CallGraphFunctionFilter =
std::function<bool(const Instruction *, const Function *)>;

// Traverses call graph starting from given function up the call chain applying
// given action to each function met on the way. If \c ErrorOnNonCallUse
// parameter is true, then no functions' uses are allowed except calls.
// Otherwise, any function where use of the current one happened is added to the
// call graph as if the use was a call.
// The 'functionFilter' parameter is a callback function that can be used to
// control which functions will be added to a call graph.
//
// The callback is invoked whenever a function being traversed is used
// by some instruction which is not a call to this instruction (e.g. storing
// function pointer to memory) - the first parameter is the using instructions,
// the second - the function being traversed. The parent function of the
// instruction is added to the call graph depending on whether the callback
// returns 'true' (added) or 'false' (not added).
// Functions which are part of the visited set ('Visited' parameter) are not
// traversed.

void traverseCallgraphUp(
llvm::Function *F, CallGraphNodeAction NodeF,
SmallPtrSetImpl<Function *> &Visited, bool ErrorOnNonCallUse,
const CallGraphFunctionFilter &functionFilter =
[](const Instruction *, const Function *) { return true; });

template <class CallGraphNodeActionF>
void traverseCallgraphUp(
Function *F, CallGraphNodeActionF ActionF,
SmallPtrSetImpl<Function *> &Visited, bool ErrorOnNonCallUse,
const CallGraphFunctionFilter &functionFilter =
[](const Instruction *, const Function *) { return true; }) {
traverseCallgraphUp(F, CallGraphNodeAction(ActionF), Visited,
ErrorOnNonCallUse, functionFilter);
}

template <class CallGraphNodeActionF>
void traverseCallgraphUp(
Function *F, CallGraphNodeActionF ActionF, bool ErrorOnNonCallUse = true,
const CallGraphFunctionFilter &functionFilter =
[](const Instruction *, const Function *) { return true; }) {
SmallPtrSet<Function *, 32> Visited;
traverseCallgraphUp(F, CallGraphNodeAction(ActionF), Visited,
ErrorOnNonCallUse, functionFilter);
}

inline bool isSYCLExternalFunction(const Function *F) {
return F->hasFnAttribute(ATTR_SYCL_MODULE_ID);
}

/// Removes the global variable "llvm.used" and returns true on success.
/// "llvm.used" is a global constant array containing references to kernels
/// available in the module and callable from host code. The elements of
/// the array are ConstantExpr bitcast to i8*.
/// The variable must be removed as it is a) has done the job to the moment
/// of this function call and b) the references to the kernels callable from
/// host must not have users.
bool removeSYCLKernelsConstRefArray(Module &M);

using SYCLStringTable = std::vector<std::vector<std::string>>;

void writeSYCLStringTable(const SYCLStringTable &Table, raw_ostream &OS);

} // namespace llvm

#endif // LLVM_TRANSFORMS_UTILS_SYCLUTILS_H
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ add_llvm_component_library(LLVMTransformUtils
SizeOpts.cpp
SplitModule.cpp
StripNonLineTableDebugInfo.cpp
SYCLModuleSplit.cpp
SYCLUtils.cpp
SymbolRewriter.cpp
UnifyFunctionExitNodes.cpp
UnifyLoopExits.cpp
Expand Down
Loading