Skip to content

Commit

Permalink
Merge pull request #1091 from Xilinx/bigfix/mvu4bit
Browse files Browse the repository at this point in the history
Expose DSP variant for RTL MVU
  • Loading branch information
auphelia authored Jun 12, 2024
2 parents 3d597ac + 10704f7 commit a4bbb59
Show file tree
Hide file tree
Showing 61 changed files with 766 additions and 501 deletions.
157 changes: 109 additions & 48 deletions finn-rtllib/mvu/mvu_4sx4u.sv
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ module mvu_4sx4u #(
int unsigned SIMD,
int unsigned ACCU_WIDTH,

int unsigned VERSION = 1,
int unsigned VERSION = 1, // Version 1 (DSP48E1) *must* commit to NARROW_WEIGHTS
bit SIGNED_ACTIVATIONS = 0,
bit NARROW_WEIGHTS = 0, // Weights from [-7:7] rather than [-8:7]
bit FORCE_BEHAVIORAL = 0
)(
// Global Control
Expand All @@ -62,6 +63,55 @@ module mvu_4sx4u #(
`endif
FORCE_BEHAVIORAL;

//-----------------------------------------------------------------------
// Determine Lane Configuration
initial begin
if(!NARROW_WEIGHTS && (VERSION == 1)) begin
$error("%m: Need NARROW_WEIGHTS for DSP48E1.");
$finish;
end
end

/**
* Lane Slicing
* Assumptions:
* - Internal lane widths differ, at most, by a single bit.
* - The rightmost lane (#0) has the maximum internal width.
* - The leftmost lane (#3) extends into the wide DSP accumulation path and
* is constrained by ACCU_WIDTH rather than the next lane. It doesn't have
* an external high extension.
* - The one but leftmost lane (#2) has the minimum internal width and, hence,
* the macimum external high extension.
*/
typedef int unsigned lane_offset_v[4:0];
function lane_offset_v sliceLanes();
unique case(VERSION)
1: begin
return NARROW_WEIGHTS?
lane_offset_v'{ ACCU_WIDTH+21, 21, 14, 7, 0 } :
lane_offset_v'{ 0, 0, 0, 0, 0 }; // not supported
end
2: begin
return NARROW_WEIGHTS?
lane_offset_v'{ ACCU_WIDTH+23, 23, 16, 8, 0 } :
lane_offset_v'{ ACCU_WIDTH+22, 22, 15, 8, 0 };
end
endcase
endfunction : sliceLanes
localparam lane_offset_v OFFSETS = sliceLanes();

function int unsigned lo_width(input int unsigned i);
return OFFSETS[i+1] - OFFSETS[i];
endfunction : lo_width
function int unsigned hi_width(input int unsigned i);
return 1 + $clog2(2**(ACCU_WIDTH-lo_width(i)-1)+SIMD);
endfunction : hi_width
localparam int unsigned LO_WIDTH_MAX = OFFSETS[1] - OFFSETS[0];
localparam int unsigned HI_WIDTH_MAX = hi_width(2);

localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath

// Compute the count of decendents for all nodes in the reduction trees.
typedef int unsigned leave_load_t[2*SIMD-1];
function leave_load_t init_leave_loads();
automatic leave_load_t res;
Expand All @@ -79,16 +129,14 @@ module mvu_4sx4u #(
assign vld = L[5];

// Stages #1 - #3: DSP Lanes + cross-lane canaries duplicated with SIMD parallelism
localparam int unsigned D[4:0] = '{ ACCU_WIDTH+22, 22, 15, 8, 0 }; // Lane offsets

localparam int unsigned PIPE_COUNT = (PE+3)/4;
for(genvar c = 0; c < PIPE_COUNT; c++) begin : genPipes

localparam int unsigned PE_BEG = 4*c;
localparam int unsigned PE_END = PE < 4*(c+1)? PE : 4*(c+1);
localparam int unsigned PE_REM = 4*(c+1) - PE_END;

uwire [57:0] p3[SIMD];
uwire [47:0] p3[SIMD];
uwire signed [ 1:0] h3[SIMD][3];
for(genvar s = 0; s < SIMD; s++) begin : genSIMD

Expand All @@ -98,10 +146,10 @@ module mvu_4sx4u #(
logic [26:0] dd;
logic [ 1:0] xx[3:1];
if(1) begin : blkVectorize
uwire [3:0] ww[PE_END - PE_BEG];
uwire signed [3:0] ww[PE_END - PE_BEG];
for(genvar pe = 0; pe < PE_END - PE_BEG; pe++) begin
assign ww[pe] = w[PE_BEG + pe][s];
if(pe) begin
if(pe > 0) begin
if(BEHAVIORAL) assign xx[pe + PE_REM] = zero? 0 : ww[pe] * a[s];
`ifndef VERILATOR
else begin
Expand All @@ -123,8 +171,19 @@ module mvu_4sx4u #(
dd = '0;
aa = '0;
for(int unsigned pe = 0; pe < PE_END - PE_BEG; pe++) begin
dd[D[pe + PE_REM]+:3] = ww[pe];
aa[D[pe + PE_REM]+ 3] = ww[pe][3];
automatic int unsigned ofs = OFFSETS[pe + PE_REM];
dd[ofs+:3] = ww[pe];
assert(!NARROW_WEIGHTS || rst || !en || zero || (ww[pe] != -8)) else begin
$warning("%m: Weight of -8 violates NARROW_WEIGHTS commitment.");
end

// The sign of the weights are generally put on the subtracted A port.
// However, when coinciding with the actual sign bit position of the
// multiplier input path, it also goes onto the D input. This prevents
// sign extensions that may happen when a DSP primitive is auto-promoted
// to a newer generation.
if(ofs+3 == A_WIDTH-1) dd[ofs+3] = ww[pe][3];
else aa[ofs+3] = ww[pe][3];
end
end
end : blkVectorize
Expand All @@ -135,14 +194,15 @@ module mvu_4sx4u #(
// rst can be only applied to AD and zero only to B
// with the same effect as zeroing both.
if(BEHAVIORAL) begin : genBehav

// Stage #1: Input Refine
logic signed [17:0] B1 = 0;
always_ff @(posedge clk) begin
if(zero) B1 <= 0;
else if(en) B1 <= bb;
end

logic signed [26:0] AD1 = 0;
logic signed [A_WIDTH-1:0] AD1 = 0;
always_ff @(posedge clk) begin
if(rst) AD1 <= 0;
else if(en) AD1 <= dd - aa;
Expand Down Expand Up @@ -429,14 +489,14 @@ module mvu_4sx4u #(
X1 <= xx;
X2 <= X1;
foreach(X3[i]) begin
X3[i] <= X2[i] + (L[3]? 2'h0 : pp[D[i]+:2]);
X3[i] <= X2[i] + (L[3]? 2'h0 : pp[OFFSETS[i]+:2]);
end
end
end

// Derive actual cross-lane overflows
for(genvar i = 0; i < 3; i++) begin
assign h3[s][i] = pp[D[i+1]+:2] - X3[i+1];
assign h3[s][i] = pp[OFFSETS[i+1]+:2] - X3[i+1];
end
assign p3[s] = pp;

Expand All @@ -445,51 +505,55 @@ module mvu_4sx4u #(
// Stage #4: Cross-SIMD Reduction

// Count leaves reachable from each node
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1}; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1 }; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop

uwire signed [ACCU_WIDTH-1:0] up4;
uwire signed [$clog2(2**(ACCU_WIDTH-8)+SIMD):0] hi4[3]; // min LO_WIDTH=7
uwire [$clog2(SIMD)+7 :0] lo4[3]; // max LO_WIDTH=8
uwire signed [ HI_WIDTH_MAX-1:0] hi4[3];
uwire [$clog2(SIMD)+LO_WIDTH_MAX-1:0] lo4[3];
for(genvar i = 0; i < 4; i++) begin
localparam int unsigned LO_WIDTH = D[i+1] - D[i];
localparam int unsigned HI_WIDTH = 1 + $clog2(2**(ACCU_WIDTH-LO_WIDTH-1)+SIMD);

// Conclusive high part accumulation
if(i >= PE_REM && i < 3) begin : genHi
// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
assign tree[n] = s;
end
if(i < 3) begin : genHi
if(i < PE_REM) assign hi4[i] = '0;
else begin
localparam int unsigned HI_WIDTH = hi_width(i);

// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
assign tree[n] = s;
end

// High Sideband Accumulation
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
// High Sideband Accumulation
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
end
Hi4 <= h;
end
Hi4 <= h;
end
assign hi4[i] = Hi4;

end
assign hi4[i] = Hi4;
end : genHi
else if (i < 3) begin : genHiZero
assign hi4[i] = '0;
end : genHiZero

// Conclusive low part accumulation (all unsigned arithmetic)
if(i >= PE_REM) begin : blkLo
if(i < PE_REM) assign lo4[i] = '0;
else begin : genLo
localparam int unsigned LO_WIDTH = lo_width(i);

// Adder Tree across all SIMD low contributions
localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1));
uwire [2*SIMD-2:0][ROOT_WIDTH-1:0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][D[i]+:LO_WIDTH];
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][OFFSETS[i]+:LO_WIDTH];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1));
Expand All @@ -505,10 +569,7 @@ module mvu_4sx4u #(

if(i == 3) assign up4 = Lo4;
else assign lo4[i] = Lo4;
end : blkLo
else begin : blkLoZero
assign lo4[i] = '0;
end : blkLoZero
end : genLo

end

Expand All @@ -518,9 +579,9 @@ module mvu_4sx4u #(
if(rst) Res5 <= '{ default: 0 };
else if(en) begin
Res5[3] <= up4 - hi4[2];
Res5[2] <= $signed({ hi4[2], {(D[3] - D[2]){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1];
Res5[1] <= $signed({ hi4[1], {(D[2] - D[1]){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0];
Res5[0] <= $signed({ hi4[0], {(D[1] - D[0]){1'b0}} }) + $signed({ 1'b0, lo4[0] });
Res5[2] <= $signed({ hi4[2], {(lo_width(2)){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1];
Res5[1] <= $signed({ hi4[1], {(lo_width(1)){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0];
Res5[0] <= $signed({ hi4[0], {(lo_width(0)){1'b0}} }) + $signed({ 1'b0, lo4[0] });
end
end

Expand Down
19 changes: 17 additions & 2 deletions finn-rtllib/mvu/mvu_vvu_axi.sv
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ module mvu_vvu_axi #(
int unsigned ACTIVATION_WIDTH,
int unsigned WEIGHT_WIDTH,
int unsigned ACCU_WIDTH,
bit NARROW_WEIGHTS = 0,
bit SIGNED_ACTIVATIONS = 0,

bit PUMPED_COMPUTE = 0,
Expand Down Expand Up @@ -306,8 +307,22 @@ module mvu_vvu_axi #(
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_4sx4u":
mvu_4sx4u #(.PE(PE), .SIMD(DSP_SIMD), .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core (
"mvu_4sx4u_dsp48e1":
mvu_4sx4u #(
.PE(PE), .SIMD(DSP_SIMD),
.ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.VERSION(1), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) core (
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_4sx4u_dsp48e2":
mvu_4sx4u #(
.PE(PE), .SIMD(DSP_SIMD),
.ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.VERSION(2), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) core (
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
Expand Down
3 changes: 2 additions & 1 deletion finn-rtllib/mvu/mvu_vvu_axi_wrapper.v
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(
parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$,
parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$,
parameter ACCU_WIDTH = $ACCU_WIDTH$,
parameter NARROW_WEIGHTS = $NARROW_WEIGHTS$,
parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$,
parameter SEGMENTLEN = $SEGMENTLEN$,
parameter FORCE_BEHAVIORAL = $FORCE_BEHAVIORAL$,
Expand Down Expand Up @@ -77,7 +78,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(

mvu_vvu_axi #(
.IS_MVU(IS_MVU), .COMPUTE_CORE(COMPUTE_CORE), .PUMPED_COMPUTE(PUMPED_COMPUTE), .MW(MW), .MH(MH), .PE(PE), .SIMD(SIMD),
.ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH),
.ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .SEGMENTLEN(SEGMENTLEN), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) inst (
.ap_clk(ap_clk),
Expand Down
Loading

0 comments on commit a4bbb59

Please sign in to comment.