44 *
55 * SPDX-License-Identifier: BSD-3-Clause
66 *
7+ * @author Oliver Cassidy <oliver.cassidy@amd.com>
8+ *
79 * @brief ONNX Where stream operator with multidirectional broadcasting.
810 *
911 * @description
10- * The three input tensors are consumed once per frame into local word
11- * memories. The output tensor is then emitted in row-major folded order.
12- * This frame-buffered schedule supports full ONNX multidirectional
13- * broadcasting, including reuse across non-contiguous output positions.
12+ * This module implements the ONNX expression:
13+ *
14+ * OUT = COND ? X : Y
15+ *
16+ * after applying ONNX multidirectional broadcasting across COND, X and Y.
17+ * Each input stream carries one complete tensor frame folded by its own
18+ * innermost dimension. All three frames are first buffered, then output
19+ * words are read in row-major folded order and selected lane by lane.
20+ *
21+ * COND stream ---> C frame buffer ---\
22+ * X stream ------> X frame buffer ----+--> registered read --> select --> OUT stream
23+ * Y stream ------> Y frame buffer ---/
24+ *
25+ * The frame-buffered schedule is required for broadcast reuse across
26+ * non-contiguous output positions. The read data and selected output are
27+ * registered so the memory output does not feed the AXI/stream output
28+ * combinatorially.
1429 ***************************************************************************/
1530
1631`default_nettype none
1732
18- module where_broadcast # (
33+ module where # (
1934 int unsigned DATA_WIDTH = 32 ,
2035 int unsigned PE = 1 ,
2136 int unsigned NDIMS = 2 ,
@@ -27,6 +42,7 @@ module where_broadcast #(
2742 parameter int unsigned COND_SHAPE [COND_NDIMS ] = '{ default: 1 } ,
2843 parameter int unsigned X_SHAPE [X_NDIMS ] = '{ default: 1 } ,
2944 parameter int unsigned Y_SHAPE [Y_NDIMS ] = '{ default: 1 } ,
45+ parameter RAM_STYLE = " auto" ,
3046
3147 localparam int unsigned OUTER_DIMS = (NDIMS > 1 )? NDIMS - 1 : 1 ,
3248 localparam int unsigned COND_PE = (COND_SHAPE [COND_NDIMS - 1 ] == 1 )? 1 : PE ,
@@ -58,7 +74,7 @@ module where_broadcast #(
5874 input wire logic ordy
5975);
6076
61- typedef int unsigned outer_idx_t [OUTER_DIMS ];
77+ typedef logic [ 31 : 0 ] outer_idx_t [OUTER_DIMS ];
6278 typedef logic [COND_PE - 1 : 0 ] cond_word_t ;
6379 typedef logic [X_PE - 1 : 0 ][DATA_WIDTH - 1 : 0 ] x_word_t ;
6480 typedef logic [Y_PE - 1 : 0 ][DATA_WIDTH - 1 : 0 ] y_word_t ;
@@ -179,6 +195,15 @@ module where_broadcast #(
179195 localparam int unsigned COND_WORDS = cond_word_count ();
180196 localparam int unsigned X_WORDS = x_word_count ();
181197 localparam int unsigned Y_WORDS = y_word_count ();
198+ localparam int unsigned COND_ADDR_WIDTH = (COND_WORDS > 1 )? $clog2 (COND_WORDS ) : 1 ;
199+ localparam int unsigned X_ADDR_WIDTH = (X_WORDS > 1 )? $clog2 (X_WORDS ) : 1 ;
200+ localparam int unsigned Y_ADDR_WIDTH = (Y_WORDS > 1 )? $clog2 (Y_WORDS ) : 1 ;
201+ localparam int unsigned OUT_FOLD_WIDTH = (OUT_FOLDS > 1 )? $clog2 (OUT_FOLDS ) : 1 ;
202+
203+ typedef logic [COND_ADDR_WIDTH - 1 : 0 ] cond_addr_t ;
204+ typedef logic [X_ADDR_WIDTH - 1 : 0 ] x_addr_t ;
205+ typedef logic [Y_ADDR_WIDTH - 1 : 0 ] y_addr_t ;
206+ typedef logic [OUT_FOLD_WIDTH - 1 : 0 ] out_fold_t ;
182207
183208 initial begin
184209 automatic int unsigned max_dim;
@@ -257,37 +282,74 @@ module where_broadcast #(
257282 end
258283 end
259284
260- // ------------------------------------------------------------------------
285+ // =======================================================================
261286 // Frame Input Buffers
287+ (* RAM_STYLE = RAM_STYLE * )
262288 cond_word_t Cmem[COND_WORDS ];
289+ (* RAM_STYLE = RAM_STYLE * )
263290 x_word_t Xmem[X_WORDS ];
291+ (* RAM_STYLE = RAM_STYLE * )
264292 y_word_t Ymem[Y_WORDS ];
265293
266- int unsigned CWr = 0 ;
267- int unsigned XWr = 0 ;
268- int unsigned YWr = 0 ;
294+ cond_addr_t CWr = 0 ;
295+ x_addr_t XWr = 0 ;
296+ y_addr_t YWr = 0 ;
269297 logic CLoaded = 0 ;
270298 logic XLoaded = 0 ;
271299 logic YLoaded = 0 ;
272- logic Emit = 0 ;
300+ logic Reading = 0 ;
301+ logic ReadValid = 0 ;
302+ logic OValid = 0 ;
273303
274- assign crdy = ! Emit && ! CLoaded;
275- assign xrdy = ! Emit && ! XLoaded;
276- assign yrdy = ! Emit && ! YLoaded;
304+ uwire frame_busy = Reading || ReadValid || OValid;
305+ assign crdy = ! frame_busy && ! CLoaded;
306+ assign xrdy = ! frame_busy && ! XLoaded;
307+ assign yrdy = ! frame_busy && ! YLoaded;
277308
278309 uwire c_fire = cvld && crdy;
279310 uwire x_fire = xvld && xrdy;
280311 uwire y_fire = yvld && yrdy;
281- uwire emit_fire = Emit && ordy;
312+ uwire output_fire = OValid && ordy;
282313
283314 uwire c_loaded_now = CLoaded || (c_fire && CWr == COND_WORDS - 1 );
284315 uwire x_loaded_now = XLoaded || (x_fire && XWr == X_WORDS - 1 );
285316 uwire y_loaded_now = YLoaded || (y_fire && YWr == Y_WORDS - 1 );
317+ uwire start_reading = ! frame_busy && c_loaded_now && x_loaded_now && y_loaded_now;
318+
319+ uwire frame_done = output_fire && ! Reading && ! ReadValid;
320+
321+ always_ff @ (posedge clk) begin
322+ if (rst || frame_done) begin
323+ CWr <= 0 ;
324+ XWr <= 0 ;
325+ YWr <= 0 ;
326+ CLoaded <= 0 ;
327+ XLoaded <= 0 ;
328+ YLoaded <= 0 ;
329+ end
330+ else begin
331+ if (c_fire) begin
332+ Cmem[CWr] <= cdat;
333+ CLoaded <= (CWr == COND_WORDS - 1 );
334+ if (CWr != COND_WORDS - 1 ) CWr <= CWr + 1 ;
335+ end
336+ if (x_fire) begin
337+ Xmem[XWr] <= xdat;
338+ XLoaded <= (XWr == X_WORDS - 1 );
339+ if (XWr != X_WORDS - 1 ) XWr <= XWr + 1 ;
340+ end
341+ if (y_fire) begin
342+ Ymem[YWr] <= ydat;
343+ YLoaded <= (YWr == Y_WORDS - 1 );
344+ if (YWr != Y_WORDS - 1 ) YWr <= YWr + 1 ;
345+ end
346+ end
347+ end
286348
287- // ------------------------------------------------------------------------
349+ // =======================================================================
288350 // Output Indexing
289351 outer_idx_t OutIdx = '{ default : 0 } ;
290- int unsigned OutFold = 0 ;
352+ out_fold_t OutFold = 0 ;
291353
292354 uwire out_last_fold = (OutFold == OUT_FOLDS - 1 );
293355 logic out_last_outer;
@@ -297,93 +359,108 @@ module where_broadcast #(
297359 out_last_outer & = (OutIdx[i] == OUT_SHAPE [i]- 1 );
298360 end
299361 uwire out_last = out_last_fold && out_last_outer;
300- uwire frame_done = emit_fire && out_last;
362+
363+ uwire output_ready = ! OValid || ordy;
364+ uwire read_ready = ! ReadValid || output_ready;
365+ uwire read_issue = Reading && read_ready;
301366
302367 always_ff @ (posedge clk) begin
303- if (rst) begin
304- CWr <= 0 ;
305- XWr <= 0 ;
306- YWr <= 0 ;
307- CLoaded <= 0 ;
308- XLoaded <= 0 ;
309- YLoaded <= 0 ;
310- Emit <= 0 ;
368+ if (rst || frame_done) begin
369+ Reading <= 0 ;
370+ end
371+ else begin
372+ if (start_reading)
373+ Reading <= 1 ;
374+ else if (read_issue && out_last)
375+ Reading <= 0 ;
376+ end
377+ end
378+
379+ always_ff @ (posedge clk) begin
380+ if (rst || frame_done || start_reading) begin
311381 OutIdx <= '{ default : 0 } ;
312382 OutFold <= 0 ;
313383 end
314- else begin
315- if (frame_done) begin
316- CWr <= 0 ;
317- XWr <= 0 ;
318- YWr <= 0 ;
319- CLoaded <= 0 ;
320- XLoaded <= 0 ;
321- YLoaded <= 0 ;
322- Emit <= 0 ;
323- OutIdx <= '{ default : 0 } ;
384+ else if (read_issue && ! out_last) begin
385+ if (out_last_fold) begin
386+ automatic bit carry = 1 ;
324387 OutFold <= 0 ;
325- end
326- else begin
327- if (c_fire) begin
328- Cmem[CWr] <= cdat;
329- CLoaded <= (CWr == COND_WORDS - 1 );
330- if (CWr != COND_WORDS - 1 ) CWr <= CWr + 1 ;
331- end
332- if (x_fire) begin
333- Xmem[XWr] <= xdat;
334- XLoaded <= (XWr == X_WORDS - 1 );
335- if (XWr != X_WORDS - 1 ) XWr <= XWr + 1 ;
336- end
337- if (y_fire) begin
338- Ymem[YWr] <= ydat;
339- YLoaded <= (YWr == Y_WORDS - 1 );
340- if (YWr != Y_WORDS - 1 ) YWr <= YWr + 1 ;
341- end
342- if (! Emit && c_loaded_now && x_loaded_now && y_loaded_now)
343- Emit <= 1 ;
344- else if (emit_fire) begin
345- if (out_last_fold) begin
346- automatic bit carry = 1 ;
347- OutFold <= 0 ;
348- for (int i = int '(NDIMS )- 2 ; i >= 0 ; i-- ) begin
349- if (carry) begin
350- if (OutIdx[i] == OUT_SHAPE [i]- 1 ) begin
351- OutIdx[i] <= 0 ;
352- end
353- else begin
354- OutIdx[i] <= OutIdx[i] + 1 ;
355- carry = 0 ;
356- end
357- end
388+ for (int i = int '(NDIMS )- 2 ; i >= 0 ; i-- ) begin
389+ if (carry) begin
390+ if (OutIdx[i] == OUT_SHAPE [i]- 1 ) begin
391+ OutIdx[i] <= 0 ;
392+ end
393+ else begin
394+ OutIdx[i] <= OutIdx[i] + 1 ;
395+ carry = 0 ;
358396 end
359397 end
360- else
361- OutFold <= OutFold + 1 ;
362398 end
363399 end
400+ else
401+ OutFold <= OutFold + 1 ;
364402 end
365403 end
366404
367- // ------------------------------------------------------------------------
368- // Broadcast Selection
369- uwire logic [31 : 0 ] c_addr = cond_word_addr (OutIdx, OutFold);
370- uwire logic [31 : 0 ] x_addr = x_word_addr (OutIdx, OutFold);
371- uwire logic [31 : 0 ] y_addr = y_word_addr (OutIdx, OutFold);
372- uwire cond_word_t c_word = Cmem[c_addr];
373- uwire x_word_t x_word = Xmem[x_addr];
374- uwire y_word_t y_word = Ymem[y_addr];
405+ // =======================================================================
406+ // Registered Broadcast Reads
407+ uwire cond_addr_t c_addr = cond_addr_t ' (cond_word_addr (OutIdx, OutFold));
408+ uwire x_addr_t x_addr = x_addr_t ' (x_word_addr (OutIdx, OutFold));
409+ uwire y_addr_t y_addr = y_addr_t ' (y_word_addr (OutIdx, OutFold));
410+
411+ cond_word_t CWord = 'x ;
412+ x_word_t XWord = 'x ;
413+ y_word_t YWord = 'x ;
375414
415+ always_ff @ (posedge clk) begin
416+ if (rst || frame_done) begin
417+ ReadValid <= 0 ;
418+ CWord <= 'x ;
419+ XWord <= 'x ;
420+ YWord <= 'x ;
421+ end
422+ else if (read_ready) begin
423+ ReadValid <= read_issue;
424+ if (read_issue) begin
425+ CWord <= Cmem[c_addr];
426+ XWord <= Xmem[x_addr];
427+ YWord <= Ymem[y_addr];
428+ end
429+ else begin
430+ CWord <= 'x ;
431+ XWord <= 'x ;
432+ YWord <= 'x ;
433+ end
434+ end
435+ end
436+
437+ // =======================================================================
438+ // Broadcast Selection
376439 out_word_t selected;
377440 for (genvar lane = 0 ; lane < PE ; lane++ ) begin : genSelect
378- uwire c = (COND_SHAPE [COND_NDIMS - 1 ] == 1 )? c_word [0 ] : c_word [lane];
379- uwire [DATA_WIDTH - 1 : 0 ] x = (X_SHAPE [X_NDIMS - 1 ] == 1 )? x_word [0 ] : x_word [lane];
380- uwire [DATA_WIDTH - 1 : 0 ] y = (Y_SHAPE [Y_NDIMS - 1 ] == 1 )? y_word [0 ] : y_word [lane];
441+ uwire c = (COND_SHAPE [COND_NDIMS - 1 ] == 1 )? CWord [0 ] : CWord [lane];
442+ uwire [DATA_WIDTH - 1 : 0 ] x = (X_SHAPE [X_NDIMS - 1 ] == 1 )? XWord [0 ] : XWord [lane];
443+ uwire [DATA_WIDTH - 1 : 0 ] y = (Y_SHAPE [Y_NDIMS - 1 ] == 1 )? YWord [0 ] : YWord [lane];
381444 assign selected[lane] = c? x : y;
382445 end : genSelect
383446
384- assign odat = selected;
385- assign ovld = Emit;
447+ out_word_t ODat = 'x ;
448+
449+ always_ff @ (posedge clk) begin
450+ if (rst || frame_done) begin
451+ OValid <= 0 ;
452+ ODat <= 'x ;
453+ end
454+ else if (output_ready) begin
455+ OValid <= ReadValid;
456+ if (ReadValid) ODat <= selected;
457+ else ODat <= 'x ;
458+ end
459+ end
460+
461+ assign odat = ODat;
462+ assign ovld = OValid;
386463
387- endmodule : where_broadcast
464+ endmodule : where
388465
389466`default_nettype wire
0 commit comments