Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/contrib/xlscc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//clang:ast",
],
)
Expand Down
14 changes: 11 additions & 3 deletions xls/contrib/xlscc/continuations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,18 @@ OptimizationContext::GetSourcesSetTreeNodeInfoForFunction(

absl::StatusOr<bool> OptimizationContext::CheckNodeSourcesInSet(
xls::FunctionBase* in_function, xls::Node* node,
absl::flat_hash_set<const xls::Param*> sources_set) {
absl::flat_hash_set<const xls::Param*> sources_set,
bool allow_empty_sources_result) {
// Save lazy node analysis for each function for efficiency
XLS_ASSIGN_OR_RETURN(SourcesSetNodeInfo * info,
GetSourcesSetNodeInfoForFunction(in_function));

ParamSet param_sources = info->GetSingleInfoForNode(node);

// No param sources will return true
if (param_sources.empty()) {
return allow_empty_sources_result;
}

bool all_in_set = true;
for (const xls::Param* param : param_sources) {
CHECK_EQ(param->function_base(), in_function);
Expand Down Expand Up @@ -2090,10 +2094,14 @@ absl::Status Translator::MarkDirectIns(GeneratedFunction& func,
continue;
}

// Don't count things that don't actually use any direct-ins as direct-in,
// for example literals. Allow literal substitution pass to prevent these
// from being stored in state elements.
XLS_ASSIGN_OR_RETURN(
continuation_out.direct_in,
context.CheckNodeSourcesInSet(
slice.function, continuation_out.output_node, direct_in_sources));
slice.function, continuation_out.output_node, direct_in_sources,
/*allow_empty_sources_result=*/false));
}

first_slice = false;
Expand Down
142 changes: 120 additions & 22 deletions xls/contrib/xlscc/generate_fsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <limits>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand All @@ -32,6 +33,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "clang/include/clang/AST/Decl.h"
#include "xls/common/math_util.h"
#include "xls/common/status/status_macros.h"
Expand Down Expand Up @@ -539,6 +541,9 @@ absl::Status NewFSMGenerator::LayoutValuesToSaveForNewFSMStates(
if (key.value->direct_in) {
continue;
}
if (key.value->literal.has_value()) {
continue;
}
state.values_to_save.insert(key.value);
}
}
Expand Down Expand Up @@ -632,6 +637,8 @@ NewFSMGenerator::GenerateNewFSMInvocation(
XLS_ASSIGN_OR_RETURN(layout,
LayoutNewFSM(func, state_element_for_static, body_loc));

absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue> generated_conditions;

const int64_t num_slice_index_bits =
xls::CeilOfLog2(1 + xls_func->slices.size());

Expand Down Expand Up @@ -704,7 +711,7 @@ NewFSMGenerator::GenerateNewFSMInvocation(
XLS_ASSIGN_OR_RETURN(
phi_elements_by_param_node_id,
GeneratePhiConditions(layout, state_element_by_jump_slice_index, pb,
body_loc));
body_loc, generated_conditions));

// The value from the current activation's perspective,
// either outputted from invoke or state element.
Expand Down Expand Up @@ -760,6 +767,30 @@ NewFSMGenerator::GenerateNewFSMInvocation(
TrackedBValue after_activation_transition =
pb.Literal(xls::UBits(0, 1), body_loc);

// Sort by Node ID and StateElement name for determinism.
struct StateElementAndNodeLessThan {
bool operator()(const std::tuple<xls::StateElement*, xls::Node*>& a,
const std::tuple<xls::StateElement*, xls::Node*>& b) const {
const auto& [a_elem, a_node] = a;
const auto& [b_elem, b_node] = b;
if (a_elem->name() != b_elem->name()) {
return a_elem->name() < b_elem->name();
}
return a_node->id() < b_node->id();
}
};

struct NodeIdLessThan {
bool operator()(const xls::Node* a, const xls::Node* b) const {
return a->id() < b->id();
}
};

absl::btree_map<std::tuple<xls::StateElement*, xls::Node*>,
absl::btree_set<xls::Node*, NodeIdLessThan>,
StateElementAndNodeLessThan>
next_value_conditions_by_state_element_and_value;

for (int64_t slice_index = 0; slice_index < func.slices.size();
++slice_index) {
const bool is_last_slice = (slice_index == func.slices.size() - 1);
Expand Down Expand Up @@ -1023,44 +1054,65 @@ NewFSMGenerator::GenerateNewFSMInvocation(
if (state.slice_index != slice_index) {
continue;
}

absl::btree_set<int64_t> jumped_from_slice_indices_this_state;
for (const JumpInfo& jump_info : state.jumped_from_slice_indices) {
jumped_from_slice_indices_this_state.insert(jump_info.from_slice);
}

XLS_ASSIGN_OR_RETURN(
TrackedBValue state_active_condition,
GeneratePhiCondition(from_jump_slice_indices,
jumped_from_slice_indices_this_state,
state_element_by_jump_slice_index, pb,
state.slice_index, body_loc));
GeneratePhiCondition(
from_jump_slice_indices, jumped_from_slice_indices_this_state,
state_element_by_jump_slice_index, pb, state.slice_index,
body_loc, generated_conditions));

TrackedBValue next_value_condition =
pb.And(state_active_condition, jump_condition, body_loc,
/*name=*/GetIRStateName(state));

for (const ContinuationValue* continuation_out : state.values_to_save) {
// Generate next values for state elements
NextStateValue next_value = {
.priority = 0,
.value = value_by_continuation_value.at(continuation_out),
.condition = next_value_condition,
};

xls::StateElement* state_elem =
state_element_by_continuation_value.at(continuation_out)
.node()
->As<xls::StateRead>()
->state_element();

std::tuple<xls::StateElement*, xls::Node*> key = {
state_elem,
value_by_continuation_value.at(continuation_out).node()};

// Generate next values
extra_next_state_values.insert({state_elem, next_value});
next_value_conditions_by_state_element_and_value[key].insert(
next_value_condition.node());
}
}
}
}

for (auto& [key, or_nodes] :
next_value_conditions_by_state_element_and_value) {
xls::StateElement* state_elem = std::get<0>(key);
xls::Node* next_value_node = std::get<1>(key);
std::vector<NATIVE_BVAL> or_bvals;
for (xls::Node* or_node : or_nodes) {
or_bvals.push_back(NATIVE_BVAL(or_node, &pb));
}

TrackedBValue or_bval =
pb.Or(absl::MakeSpan(or_bvals), body_loc,
/*name=*/
absl::StrFormat("%s_v_%s_or_bval", state_elem->name(),
next_value_node->GetName()));

NextStateValue next_value = {
.priority = 0,
.value = TrackedBValue(next_value_node, &pb),
.condition = or_bval,
};
extra_next_state_values.insert({state_elem, next_value});
}

// Set next slice index
const TrackedBValue finished_iteration =
pb.Not(after_activation_transition, body_loc,
Expand Down Expand Up @@ -1095,8 +1147,16 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
const absl::btree_set<int64_t>& jumped_from_slice_indices_this_state,
const absl::flat_hash_map<int64_t, TrackedBValue>&
state_element_by_jump_slice_index,
xls::ProcBuilder& pb, int64_t slice_index,
const xls::SourceInfo& body_loc) {
xls::ProcBuilder& pb, int64_t slice_index, const xls::SourceInfo& body_loc,
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
phi_condition_cache) {
PhiConditionCacheKey key = {from_jump_slice_indices,
jumped_from_slice_indices_this_state};

if (phi_condition_cache.contains(key)) {
return phi_condition_cache.at(key);
}

TrackedBValue condition = pb.Literal(xls::UBits(1, 1), body_loc);

// Include all jump slices in each condition
Expand All @@ -1107,6 +1167,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
jumped_from_slice_indices_this_state.contains(from_jump_slice_index)
? 1
: 0;

TrackedBValue condition_part =
pb.Eq(jump_state_element,
pb.Literal(xls::UBits(active_value, 1), body_loc,
Expand All @@ -1122,6 +1183,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
absl::StrJoin(jumped_from_slice_indices_this_state, "_")));
}

phi_condition_cache[key] = condition;
return condition;
}

Expand All @@ -1131,7 +1193,9 @@ NewFSMGenerator::GeneratePhiConditions(
const NewFSMLayout& layout,
const absl::flat_hash_map<int64_t, TrackedBValue>&
state_element_by_jump_slice_index,
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc) {
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc,
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
phi_condition_cache) {
absl::flat_hash_map<int64_t, std::vector<PhiElement>>
phi_elements_by_param_node_id;

Expand Down Expand Up @@ -1168,10 +1232,10 @@ NewFSMGenerator::GeneratePhiConditions(

XLS_ASSIGN_OR_RETURN(
TrackedBValue condition,
GeneratePhiCondition(from_jump_slice_indices,
jumped_from_slice_indices_this_state,
state_element_by_jump_slice_index, pb,
state->slice_index, body_loc));
GeneratePhiCondition(
from_jump_slice_indices, jumped_from_slice_indices_this_state,
state_element_by_jump_slice_index, pb, state->slice_index,
body_loc, phi_condition_cache));

PhiElement& phi_element = phi_elements.emplace_back();
phi_element.value = state->current_inputs_by_input_param.at(param);
Expand Down Expand Up @@ -1201,13 +1265,47 @@ NewFSMGenerator::GenerateInputValueInContext(
std::vector<TrackedBValue> phi_conditions;
std::vector<TrackedBValue> phi_values;

// Sort by Node ID for determinism.
struct NodeIdLessThan {
bool operator()(const xls::Node* a, const xls::Node* b) const {
return a->id() < b->id();
}
};
struct BValueIdLessThan {
bool operator()(const TrackedBValue& a, const TrackedBValue& b) const {
return a.node()->id() < b.node()->id();
}
};
absl::btree_map<xls::Node*, absl::btree_set<TrackedBValue, BValueIdLessThan>,
NodeIdLessThan>
conditions_by_value_node;

phi_conditions.reserve(phi_elements.size());
phi_values.reserve(phi_elements.size());

for (const PhiElement& phi_element : phi_elements) {
phi_conditions.push_back(phi_element.condition);
XLSCC_CHECK(value_by_continuation_value.contains(phi_element.value),
phi_element.value->output_node->loc());
phi_values.push_back(value_by_continuation_value.at(phi_element.value));

xls::Node* value_node =
value_by_continuation_value.at(phi_element.value).node();
conditions_by_value_node[value_node].insert(phi_element.condition);
}

for (auto& [value_node, or_nodes] : conditions_by_value_node) {
std::vector<NATIVE_BVAL> or_bvals;
or_bvals.reserve(or_nodes.size());
for (const TrackedBValue& or_node : or_nodes) {
or_bvals.push_back(or_node);
}

TrackedBValue or_bval =
pb.Or(absl::MakeSpan(or_bvals), body_loc,
/*name=*/
absl::StrFormat("%s_v_%s_or_bval", param->name(),
value_node->GetName()));
phi_conditions.push_back(or_bval);
phi_values.push_back(TrackedBValue(value_node, &pb));
}

std::reverse(phi_conditions.begin(), phi_conditions.end());
Expand Down
18 changes: 13 additions & 5 deletions xls/contrib/xlscc/generate_fsm.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,27 @@ class NewFSMGenerator : public GeneratorBase {
const ContinuationValue* value;
};

typedef std::tuple<absl::btree_set<int64_t>, absl::btree_set<int64_t>>
PhiConditionCacheKey;

absl::StatusOr<absl::flat_hash_map<int64_t, std::vector<PhiElement>>>
GeneratePhiConditions(const NewFSMLayout& layout,
const absl::flat_hash_map<int64_t, TrackedBValue>&
state_element_by_jump_slice_index,
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc);
GeneratePhiConditions(
const NewFSMLayout& layout,
const absl::flat_hash_map<int64_t, TrackedBValue>&
state_element_by_jump_slice_index,
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc,
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
phi_condition_cache);

absl::StatusOr<TrackedBValue> GeneratePhiCondition(
const absl::btree_set<int64_t>& from_jump_slice_indices,
const absl::btree_set<int64_t>& jumped_from_slice_indices_this_state,
const absl::flat_hash_map<int64_t, TrackedBValue>&
state_element_by_jump_slice_index,
xls::ProcBuilder& pb, int64_t slice_index,
const xls::SourceInfo& body_loc);
const xls::SourceInfo& body_loc,
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
phi_condition_cache);

absl::StatusOr<std::optional<TrackedBValue>> GenerateInputValueInContext(
const xls::Param* param,
Expand Down
3 changes: 2 additions & 1 deletion xls/contrib/xlscc/translator_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,8 @@ class OptimizationContext {

absl::StatusOr<bool> CheckNodeSourcesInSet(
xls::FunctionBase* in_function, xls::Node* node,
absl::flat_hash_set<const xls::Param*> sources_set);
absl::flat_hash_set<const xls::Param*> sources_set,
bool allow_empty_sources_result = true);

private:
absl::flat_hash_map<xls::FunctionBase*, std::unique_ptr<SourcesSetNodeInfo>>
Expand Down
Loading