Skip to content

Commit 2b28007

Browse files
committed
rtl and precommit update
1 parent 050b3c3 commit 2b28007

5 files changed

Lines changed: 186 additions & 105 deletions

File tree

finn-rtllib/where/hdl/where.sv

Lines changed: 162 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,33 @@
44
*
55
* SPDX-License-Identifier: BSD-3-Clause
66
*
7+
* @author Oliver Cassidy <oliver.cassidy@amd.com>
8+
*
79
* @brief ONNX Where stream operator with multidirectional broadcasting.
810
*
911
* @description
10-
* The three input tensors are consumed once per frame into local word
11-
* memories. The output tensor is then emitted in row-major folded order.
12-
* This frame-buffered schedule supports full ONNX multidirectional
13-
* broadcasting, including reuse across non-contiguous output positions.
12+
* This module implements the ONNX expression:
13+
*
14+
* OUT = COND ? X : Y
15+
*
16+
* after applying ONNX multidirectional broadcasting across COND, X and Y.
17+
* Each input stream carries one complete tensor frame folded by its own
18+
* innermost dimension. All three frames are first buffered, then output
19+
* words are read in row-major folded order and selected lane by lane.
20+
*
21+
* COND stream ---> C frame buffer ---\
22+
* X stream ------> X frame buffer ----+--> registered read --> select --> OUT stream
23+
* Y stream ------> Y frame buffer ---/
24+
*
25+
* The frame-buffered schedule is required for broadcast reuse across
26+
* non-contiguous output positions. The read data and selected output are
27+
* registered so the memory output does not feed the AXI/stream output
28+
* combinatorially.
1429
***************************************************************************/
1530

1631
`default_nettype none
1732

18-
module where_broadcast #(
33+
module where #(
1934
int unsigned DATA_WIDTH = 32,
2035
int unsigned PE = 1,
2136
int unsigned NDIMS = 2,
@@ -27,6 +42,7 @@ module where_broadcast #(
2742
parameter int unsigned COND_SHAPE[COND_NDIMS] = '{ default: 1 },
2843
parameter int unsigned X_SHAPE[X_NDIMS] = '{ default: 1 },
2944
parameter int unsigned Y_SHAPE[Y_NDIMS] = '{ default: 1 },
45+
parameter RAM_STYLE = "auto",
3046

3147
localparam int unsigned OUTER_DIMS = (NDIMS > 1)? NDIMS-1 : 1,
3248
localparam int unsigned COND_PE = (COND_SHAPE[COND_NDIMS-1] == 1)? 1 : PE,
@@ -58,7 +74,7 @@ module where_broadcast #(
5874
input wire logic ordy
5975
);
6076

61-
typedef int unsigned outer_idx_t[OUTER_DIMS];
77+
typedef logic [31:0] outer_idx_t[OUTER_DIMS];
6278
typedef logic [COND_PE-1:0] cond_word_t;
6379
typedef logic [X_PE-1:0][DATA_WIDTH-1:0] x_word_t;
6480
typedef logic [Y_PE-1:0][DATA_WIDTH-1:0] y_word_t;
@@ -179,6 +195,15 @@ module where_broadcast #(
179195
localparam int unsigned COND_WORDS = cond_word_count();
180196
localparam int unsigned X_WORDS = x_word_count();
181197
localparam int unsigned Y_WORDS = y_word_count();
198+
localparam int unsigned COND_ADDR_WIDTH = (COND_WORDS > 1)? $clog2(COND_WORDS) : 1;
199+
localparam int unsigned X_ADDR_WIDTH = (X_WORDS > 1)? $clog2(X_WORDS) : 1;
200+
localparam int unsigned Y_ADDR_WIDTH = (Y_WORDS > 1)? $clog2(Y_WORDS) : 1;
201+
localparam int unsigned OUT_FOLD_WIDTH = (OUT_FOLDS > 1)? $clog2(OUT_FOLDS) : 1;
202+
203+
typedef logic [COND_ADDR_WIDTH-1:0] cond_addr_t;
204+
typedef logic [X_ADDR_WIDTH-1:0] x_addr_t;
205+
typedef logic [Y_ADDR_WIDTH-1:0] y_addr_t;
206+
typedef logic [OUT_FOLD_WIDTH-1:0] out_fold_t;
182207

183208
initial begin
184209
automatic int unsigned max_dim;
@@ -257,37 +282,74 @@ module where_broadcast #(
257282
end
258283
end
259284

260-
//------------------------------------------------------------------------
285+
//=======================================================================
261286
// Frame Input Buffers
287+
(* RAM_STYLE = RAM_STYLE *)
262288
cond_word_t Cmem[COND_WORDS];
289+
(* RAM_STYLE = RAM_STYLE *)
263290
x_word_t Xmem[X_WORDS];
291+
(* RAM_STYLE = RAM_STYLE *)
264292
y_word_t Ymem[Y_WORDS];
265293

266-
int unsigned CWr = 0;
267-
int unsigned XWr = 0;
268-
int unsigned YWr = 0;
294+
cond_addr_t CWr = 0;
295+
x_addr_t XWr = 0;
296+
y_addr_t YWr = 0;
269297
logic CLoaded = 0;
270298
logic XLoaded = 0;
271299
logic YLoaded = 0;
272-
logic Emit = 0;
300+
logic Reading = 0;
301+
logic ReadValid = 0;
302+
logic OValid = 0;
273303

274-
assign crdy = !Emit && !CLoaded;
275-
assign xrdy = !Emit && !XLoaded;
276-
assign yrdy = !Emit && !YLoaded;
304+
uwire frame_busy = Reading || ReadValid || OValid;
305+
assign crdy = !frame_busy && !CLoaded;
306+
assign xrdy = !frame_busy && !XLoaded;
307+
assign yrdy = !frame_busy && !YLoaded;
277308

278309
uwire c_fire = cvld && crdy;
279310
uwire x_fire = xvld && xrdy;
280311
uwire y_fire = yvld && yrdy;
281-
uwire emit_fire = Emit && ordy;
312+
uwire output_fire = OValid && ordy;
282313

283314
uwire c_loaded_now = CLoaded || (c_fire && CWr == COND_WORDS-1);
284315
uwire x_loaded_now = XLoaded || (x_fire && XWr == X_WORDS-1);
285316
uwire y_loaded_now = YLoaded || (y_fire && YWr == Y_WORDS-1);
317+
uwire start_reading = !frame_busy && c_loaded_now && x_loaded_now && y_loaded_now;
318+
319+
uwire frame_done = output_fire && !Reading && !ReadValid;
320+
321+
always_ff @(posedge clk) begin
322+
if(rst || frame_done) begin
323+
CWr <= 0;
324+
XWr <= 0;
325+
YWr <= 0;
326+
CLoaded <= 0;
327+
XLoaded <= 0;
328+
YLoaded <= 0;
329+
end
330+
else begin
331+
if(c_fire) begin
332+
Cmem[CWr] <= cdat;
333+
CLoaded <= (CWr == COND_WORDS-1);
334+
if(CWr != COND_WORDS-1) CWr <= CWr + 1;
335+
end
336+
if(x_fire) begin
337+
Xmem[XWr] <= xdat;
338+
XLoaded <= (XWr == X_WORDS-1);
339+
if(XWr != X_WORDS-1) XWr <= XWr + 1;
340+
end
341+
if(y_fire) begin
342+
Ymem[YWr] <= ydat;
343+
YLoaded <= (YWr == Y_WORDS-1);
344+
if(YWr != Y_WORDS-1) YWr <= YWr + 1;
345+
end
346+
end
347+
end
286348

287-
//------------------------------------------------------------------------
349+
//=======================================================================
288350
// Output Indexing
289351
outer_idx_t OutIdx = '{ default: 0 };
290-
int unsigned OutFold = 0;
352+
out_fold_t OutFold = 0;
291353

292354
uwire out_last_fold = (OutFold == OUT_FOLDS-1);
293355
logic out_last_outer;
@@ -297,93 +359,108 @@ module where_broadcast #(
297359
out_last_outer &= (OutIdx[i] == OUT_SHAPE[i]-1);
298360
end
299361
uwire out_last = out_last_fold && out_last_outer;
300-
uwire frame_done = emit_fire && out_last;
362+
363+
uwire output_ready = !OValid || ordy;
364+
uwire read_ready = !ReadValid || output_ready;
365+
uwire read_issue = Reading && read_ready;
301366

302367
always_ff @(posedge clk) begin
303-
if(rst) begin
304-
CWr <= 0;
305-
XWr <= 0;
306-
YWr <= 0;
307-
CLoaded <= 0;
308-
XLoaded <= 0;
309-
YLoaded <= 0;
310-
Emit <= 0;
368+
if(rst || frame_done) begin
369+
Reading <= 0;
370+
end
371+
else begin
372+
if(start_reading)
373+
Reading <= 1;
374+
else if(read_issue && out_last)
375+
Reading <= 0;
376+
end
377+
end
378+
379+
always_ff @(posedge clk) begin
380+
if(rst || frame_done || start_reading) begin
311381
OutIdx <= '{ default: 0 };
312382
OutFold <= 0;
313383
end
314-
else begin
315-
if(frame_done) begin
316-
CWr <= 0;
317-
XWr <= 0;
318-
YWr <= 0;
319-
CLoaded <= 0;
320-
XLoaded <= 0;
321-
YLoaded <= 0;
322-
Emit <= 0;
323-
OutIdx <= '{ default: 0 };
384+
else if(read_issue && !out_last) begin
385+
if(out_last_fold) begin
386+
automatic bit carry = 1;
324387
OutFold <= 0;
325-
end
326-
else begin
327-
if(c_fire) begin
328-
Cmem[CWr] <= cdat;
329-
CLoaded <= (CWr == COND_WORDS-1);
330-
if(CWr != COND_WORDS-1) CWr <= CWr + 1;
331-
end
332-
if(x_fire) begin
333-
Xmem[XWr] <= xdat;
334-
XLoaded <= (XWr == X_WORDS-1);
335-
if(XWr != X_WORDS-1) XWr <= XWr + 1;
336-
end
337-
if(y_fire) begin
338-
Ymem[YWr] <= ydat;
339-
YLoaded <= (YWr == Y_WORDS-1);
340-
if(YWr != Y_WORDS-1) YWr <= YWr + 1;
341-
end
342-
if(!Emit && c_loaded_now && x_loaded_now && y_loaded_now)
343-
Emit <= 1;
344-
else if(emit_fire) begin
345-
if(out_last_fold) begin
346-
automatic bit carry = 1;
347-
OutFold <= 0;
348-
for(int i = int'(NDIMS)-2; i >= 0; i--) begin
349-
if(carry) begin
350-
if(OutIdx[i] == OUT_SHAPE[i]-1) begin
351-
OutIdx[i] <= 0;
352-
end
353-
else begin
354-
OutIdx[i] <= OutIdx[i] + 1;
355-
carry = 0;
356-
end
357-
end
388+
for(int i = int'(NDIMS)-2; i >= 0; i--) begin
389+
if(carry) begin
390+
if(OutIdx[i] == OUT_SHAPE[i]-1) begin
391+
OutIdx[i] <= 0;
392+
end
393+
else begin
394+
OutIdx[i] <= OutIdx[i] + 1;
395+
carry = 0;
358396
end
359397
end
360-
else
361-
OutFold <= OutFold + 1;
362398
end
363399
end
400+
else
401+
OutFold <= OutFold + 1;
364402
end
365403
end
366404

367-
//------------------------------------------------------------------------
368-
// Broadcast Selection
369-
uwire logic [31:0] c_addr = cond_word_addr(OutIdx, OutFold);
370-
uwire logic [31:0] x_addr = x_word_addr(OutIdx, OutFold);
371-
uwire logic [31:0] y_addr = y_word_addr(OutIdx, OutFold);
372-
uwire cond_word_t c_word = Cmem[c_addr];
373-
uwire x_word_t x_word = Xmem[x_addr];
374-
uwire y_word_t y_word = Ymem[y_addr];
405+
//=======================================================================
406+
// Registered Broadcast Reads
407+
uwire cond_addr_t c_addr = cond_addr_t'(cond_word_addr(OutIdx, OutFold));
408+
uwire x_addr_t x_addr = x_addr_t'(x_word_addr(OutIdx, OutFold));
409+
uwire y_addr_t y_addr = y_addr_t'(y_word_addr(OutIdx, OutFold));
410+
411+
cond_word_t CWord = 'x;
412+
x_word_t XWord = 'x;
413+
y_word_t YWord = 'x;
375414

415+
always_ff @(posedge clk) begin
416+
if(rst || frame_done) begin
417+
ReadValid <= 0;
418+
CWord <= 'x;
419+
XWord <= 'x;
420+
YWord <= 'x;
421+
end
422+
else if(read_ready) begin
423+
ReadValid <= read_issue;
424+
if(read_issue) begin
425+
CWord <= Cmem[c_addr];
426+
XWord <= Xmem[x_addr];
427+
YWord <= Ymem[y_addr];
428+
end
429+
else begin
430+
CWord <= 'x;
431+
XWord <= 'x;
432+
YWord <= 'x;
433+
end
434+
end
435+
end
436+
437+
//=======================================================================
438+
// Broadcast Selection
376439
out_word_t selected;
377440
for(genvar lane = 0; lane < PE; lane++) begin : genSelect
378-
uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? c_word[0] : c_word[lane];
379-
uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? x_word[0] : x_word[lane];
380-
uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? y_word[0] : y_word[lane];
441+
uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? CWord[0] : CWord[lane];
442+
uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? XWord[0] : XWord[lane];
443+
uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? YWord[0] : YWord[lane];
381444
assign selected[lane] = c? x : y;
382445
end : genSelect
383446

384-
assign odat = selected;
385-
assign ovld = Emit;
447+
out_word_t ODat = 'x;
448+
449+
always_ff @(posedge clk) begin
450+
if(rst || frame_done) begin
451+
OValid <= 0;
452+
ODat <= 'x;
453+
end
454+
else if(output_ready) begin
455+
OValid <= ReadValid;
456+
if(ReadValid) ODat <= selected;
457+
else ODat <= 'x;
458+
end
459+
end
460+
461+
assign odat = ODat;
462+
assign ovld = OValid;
386463

387-
endmodule : where_broadcast
464+
endmodule : where
388465

389466
`default_nettype wire

finn-rtllib/where/hdl/where_core_template.sv

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ module $TOP_MODULE_NAME$_core #(
6969
end
7070
endgenerate
7171

72-
where_broadcast #(
72+
where #(
7373
.DATA_WIDTH($DATA_WIDTH$),
7474
.PE($PE$),
7575
.NDIMS($NDIMS$),
@@ -79,7 +79,8 @@ module $TOP_MODULE_NAME$_core #(
7979
.OUT_SHAPE($OUT_SHAPE$),
8080
.COND_SHAPE($COND_SHAPE$),
8181
.X_SHAPE($X_SHAPE$),
82-
.Y_SHAPE($Y_SHAPE$)
82+
.Y_SHAPE($Y_SHAPE$),
83+
.RAM_STYLE($RAM_STYLE$)
8384
) impl (
8485
.clk(ap_clk),
8586
.rst(!ap_rst_n),

src/finn/custom_op/fpgadataflow/rtl/where_rtl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def generate_hdl(self, model, fpgapart, clk):
9797
"X_WIDTH": x_width,
9898
"Y_WIDTH": y_width,
9999
"OUT_WIDTH": out_width,
100+
"RAM_STYLE": '"{}"'.format(self.get_nodeattr("ram_style")),
100101
}
101102

102103
for key, value in code_gen_dict.items():

src/finn/custom_op/fpgadataflow/where.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def get_nodeattr_types(self):
5454
"conditionDataType": ("s", False, "BINARY"),
5555
"inputDataType": ("s", True, ""),
5656
"outputDataType": ("s", False, ""),
57+
"ram_style": (
58+
"s",
59+
False,
60+
"auto",
61+
{"auto", "block", "distributed", "ultra"},
62+
),
5763
"inFIFODepths": ("ints", False, [2, 2, 2]),
5864
"outFIFODepths": ("ints", False, [2]),
5965
}
@@ -76,7 +82,10 @@ def _input_shape(self, ind):
7682
rank = self.get_nodeattr(rank_name)
7783
shape = tuple(self.get_nodeattr(attr_name))
7884
if rank >= 0:
79-
assert len(shape) == rank, "%s length must match %s" % (attr_name, rank_name)
85+
assert len(shape) == rank, "%s length must match %s" % (
86+
attr_name,
87+
rank_name,
88+
)
8089
return shape
8190
if len(shape) != 0:
8291
return shape
@@ -204,7 +213,9 @@ def get_number_output_values(self):
204213
return int(np.prod(self.get_folded_output_shape()[:-1]))
205214

206215
def get_exp_cycles(self):
207-
return self.get_number_output_values()
216+
input_cycles = max(int(np.prod(self.get_folded_input_shape(ind)[:-1])) for ind in range(3))
217+
output_cycles = self.get_number_output_values()
218+
return input_cycles + output_cycles + 4
208219

209220
def execute_node(self, context, graph):
210221
node = self.onnx_node

0 commit comments

Comments
 (0)