Skip to content

Conversation

@SpiritSeeker
Copy link

This PR adds support for elementwise functions (elementwise unary operations), allowing for easy extension to all hls math functions from HLS Math Library. The current PR adds support for ReLU, elementwise Exp, and elementwise Erf functions.

Status of tests:
✔️ ReLU - FLOAT32, FLOAT16, INT, FIXED (cppsim and rtlsim)
✔️ Exp and Erf - cppsim + FLOAT32
✔️ Exp - rtlsim: FLOAT32 and FLOAT16

✖️ Exp and Erf - cppsim + FLOAT16 (numpy and simulated results differ significantly)
✖️ Erf - rtlsim: FLOAT32 and FLOAT16 (RTL watchdog timeout error)

Copy link
Collaborator

@preusser preusser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @SpiritSeeker!
Please, review comments.

@property
def cpp_op(self):
odt_hls_name = self.out_dtype.get_hls_datatype_str()
return "({0} > 0 ? (%s){0} : (%s)0)" % (odt_hls_name, odt_hls_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reversed comparison {0} < 0 is easier for most datatypes.

inp_bw = self.inp_dtype.bitwidth()
# The output would be unsigned with same bit-width as input
# if input was unsigned, else one bit less
out_bw = inp_bw - 1 if self.inp_dtype.signed() else inp_bw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Consider issuing a warning when constructing a ElementwiseReLU node with an unsigned input type.
  • You can only safely strip a bit from the output datatype if the input datatype is narrow, i.e. within [-2^(n-1) + 1 : 2^(n-1) - 1].

odt_hls_name = self.out_dtype.get_hls_datatype_str()
# Explicitly use the overloads, using hls::exp results in minor errors
if self.out_dtype.get_canonical_name() == "FLOAT32":
return "(hls::expf((%s){0}))" % (odt_hls_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return "hls::exp(%s({0}))" % (odt_hls_name) should be the only return statement. Rely on function overload selection by the argument type for specialization.

odt_hls_name = self.out_dtype.get_hls_datatype_str()
# Explicitly use the overloads, using hls::erf results in minor errors
if self.out_dtype.get_canonical_name() == "FLOAT32":
return "(hls::erff((%s){0}))" % (odt_hls_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return "hls::erf(%s({0}))" % (odt_hls_name) should be the only return statement. Rely on function overload selection by the argument type for specialization.

# Generates C++ code for declaring all streams involved in C++ simulation
# for testing
def strm_decl(self):
# Allways add the output stream to the declarations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not concise?:

self.code_gen_dict["$STREAMDECLARATIONS$"] = [
            # Note: Assumes stream type aliases to be set in defines
            "OutStream out0_V;",
            "InpStream in0_V;"
        ]

#pragma HLS BIND_STORAGE variable=out type=RAM_S2P impl=LUTRAM
""",
# Perfect loop nest over all folded output dimensions
*[for_loop(dim, size) + " {" for dim, size in enumerate(out_shape)],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be equivalent to a single flat loop using the product over all dimensions as its bound.


# Add HLS interface directives specifying how to create RTL ports for
# the top-level function arguments
self.code_gen_dict["$PRAGMAS$"] += [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fuse into a single compact append for both lines of code.

def get_verilog_top_module_intf_names(self):
# Start collecting interface names in a dictionary starting with clock
# and reset
intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pick up all the other associations as part of the initialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants