Skip to content

Add support for Zvksh extension #862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@
"Zvknhb": {
"supported": true
},
"Zvksh": {
"supported": true
},
"Sscofpmf": {
"supported": true
},
Expand Down
1 change: 1 addition & 0 deletions model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ foreach (xlen IN ITEMS 32 64)
"riscv_insts_zvbb.sail"
"riscv_insts_zvbc.sail"
"riscv_insts_zvknhab.sail"
"riscv_insts_zvksh.sail"
# Zimop and Zcmop should be at the end so they can be overridden by earlier extensions
"riscv_insts_zimop.sail"
"riscv_insts_zcmop.sail"
Expand Down
9 changes: 9 additions & 0 deletions model/arithmetic.sail
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,12 @@ $[property]
function cmulr_equivalence(a : bits(16), b : bits(16)) -> bool = {
carryless_mul_reversed(a, b) == carryless_mulr(a, b)
}

val rev8 : forall 'm, 'm >= 0 & mod('m, 8) == 0. (bits('m)) -> bits('m)
function rev8(input) = {
var output : bits('m) = zeros();
foreach (i from 0 to ('m - 8) by 8) {
output[(i + 7)..i] = input[('m - i - 1) .. ('m - i - 8)];
};
output
}
4 changes: 4 additions & 0 deletions model/riscv_extensions.sail
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ function clause hartSupports(Ext_Zvknha) = config extensions.Zvknha.supported

enum clause extension = Ext_Zvknhb
function clause hartSupports(Ext_Zvknhb) = config extensions.Zvknhb.supported
// ShangMi Suite: SM3 Secure Hash
enum clause extension = Ext_Zvksh
function clause hartSupports(Ext_Zvksh) = config extensions.Zvksh.supported


// Count Overflow and Mode-Based Filtering
enum clause extension = Ext_Sscofpmf
Expand Down
22 changes: 21 additions & 1 deletion model/riscv_insts_vext_utils.sail
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ function get_scalar(rs1, SEW) = {
}
}

/* Extracts 4 consecutive vector elements starting from index 4*i */
/* Extracts 4 consecutive vector elements starting from index 4*i and returns a bitvector */
val get_velem_quad : forall 'n 'm 'p, 'n > 0 & 'm > 0 & 'p >= 0 & 4 * 'p + 3 < 'n. (vector('n, bits('m)), int('p)) -> bits(4 * 'm)
function get_velem_quad(v, i) = v[4 * i + 3] @ v[4 * i + 2] @ v[4 * i + 1] @ v[4 * i]

Expand All @@ -191,6 +191,17 @@ function write_velem_quad(vd, SEW, input, i) = {
write_single_element(SEW, 4 * i + j, vd, slice(input, j * SEW, SEW));
}

/* Extracts 8 consecutive vector elements starting from index 8*i and returns a vector */
val get_velem_oct_vec : forall 'n 'm 'p, 'n > 0 & 8 <= 'm <= 64 & 'p >= 0 & 8 * 'p + 7 < 'n. (vector('n, bits('m)), int('p)) -> vector(8, bits('m))
function get_velem_oct_vec(v, i) = [ v[8 * i + 7], v[8 * i + 6], v[8 * i + 5], v[8 * i + 4], v[8 * i + 3], v[8 * i + 2], v[8 * i + 1], v[8 * i] ]

/* Writes each of the 8 elements from the input vector to the vector register vd, starting at position 8 * i */
val write_velem_oct_vec : forall 'p 'n, 8 <= 'n <= 64 & 'p >= 0. (vregidx, int('n), vector(8, bits('n)), int('p)) -> unit
function write_velem_oct_vec(vd, SEW, input, i) = {
foreach(j from 0 to 7)
write_single_element(SEW, 8 * i + j, vd, input[j]);
}

/* Get the starting element index from csr vtype */
val get_start_element : unit -> result(nat, unit)
function get_start_element() = {
Expand Down Expand Up @@ -478,3 +489,12 @@ function count_leadingzeros (sig, len) = {
};
len - idx - 1
}

val vrev8 : forall 'n 'm, 'n >= 0 & 'm >= 0. (vector('n, bits('m * 8))) -> vector('n, bits('m * 8))
function vrev8(input) = {
var output : vector('n, bits('m * 8)) = input;
foreach (i from 0 to ('n - 1)) {
output[i] = rev8(input[i]);
};
output
}
96 changes: 96 additions & 0 deletions model/riscv_insts_zvksh.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*=======================================================================================*/
/* This Sail RISC-V architecture model, comprising all files and */
/* directories except where otherwise noted is subject the BSD */
/* two-clause license in the LICENSE file. */
/* */
/* SPDX-License-Identifier: BSD-2-Clause */
/*=======================================================================================*/

function clause currentlyEnabled(Ext_Zvksh) = hartSupports(Ext_Zvksh) & currentlyEnabled(Ext_V)

union clause ast = VSM3ME_VV : (vregidx, vregidx, vregidx)

mapping clause encdec = VSM3ME_VV(vs2, vs1, vd)
<-> 0b1000001 @ encdec_vreg(vs2) @ encdec_vreg(vs1) @ 0b010 @ encdec_vreg(vd) @ 0b1110111
when currentlyEnabled(Ext_Zvksh) & get_sew() == 32 & zvk_check_encdec(256, 8) & zvk_valid_reg_overlap(vs2, vd, get_lmul_pow())

function clause execute (VSM3ME_VV(vs2, vs1, vd)) = {
let 'SEW = get_sew();
let LMUL_pow = get_lmul_pow();
let num_elem = get_num_elem(LMUL_pow, SEW);

assert(SEW == 32);

let vs2_val = read_vreg(num_elem, SEW, LMUL_pow, vs2);
let vs1_val = read_vreg(num_elem, SEW, LMUL_pow, vs1);

var w : vector(24, bits(32)) = vector_init(zeros());

let eg_len = (unsigned(vl) / 8);
let eg_start = (unsigned(vstart) / 8);

foreach (i from eg_start to (eg_len - 1)) {
assert(i * 8 + 7 < num_elem);

foreach (j from 0 to 7){
w[j] = rev8(vs1_val[i * 8 + j]);
w[j + 8] = rev8(vs2_val[i * 8 + j]);
};

foreach (j from 16 to 23)
w[j] = zvk_sh_w(w[j - 16], w[j - 9], w[j - 3], w[j - 13], w[j - 6]);

write_velem_oct_vec(vd, SEW,
vrev8([w[23], w[22], w[21], w[20], w[19], w[18], w[17], w[16]]), i);
};

set_vstart(zeros());
RETIRE_SUCCESS
}

mapping clause assembly = VSM3ME_VV(vs2, vs1, vd)
<-> "vsm3me.vv" ^ spc() ^ vreg_name(vd) ^ sep() ^ vreg_name(vs2) ^ sep() ^ vreg_name(vs1)

union clause ast = VSM3C_VI : (vregidx, bits(5), vregidx)

mapping clause encdec = VSM3C_VI(vs2, uimm, vd)
<-> 0b1010111 @ encdec_vreg(vs2) @ uimm @ 0b010 @ encdec_vreg(vd) @ 0b1110111
when currentlyEnabled(Ext_Zvksh) & get_sew() == 32 & zvk_check_encdec(256, 8) & zvk_valid_reg_overlap(vs2, vd, get_lmul_pow())

function clause execute (VSM3C_VI(vs2, uimm, vd)) = {
let SEW = get_sew();
let LMUL_pow = get_lmul_pow();
let num_elem = get_num_elem(LMUL_pow, SEW);

assert(SEW == 32);

let vs2_val = read_vreg(num_elem, SEW, LMUL_pow, vs2);
let vd_val = read_vreg(num_elem, SEW, LMUL_pow, vd);

let rnds = unsigned(uimm);

let eg_len = (unsigned(vl) / 8);
let eg_start = (unsigned(vstart) / 8);

foreach (i from eg_start to (eg_len - 1)) {
assert(i * 8 + 7 < num_elem);

let A_H : vector(8, bits(32)) = vrev8(get_velem_oct_vec(vd_val, i));
let w : vector(8, bits(32)) = vrev8(get_velem_oct_vec(vs2_val, i));

let x_0 = w[0] ^ w[4];
let x_1 = w[1] ^ w[5];

let A1_H1 = zvk_sm3_round( A_H, w[0], x_0, 2 * rnds);
let A2_H2 = zvk_sm3_round(A1_H1, w[1], x_1, 2 * rnds + 1);

write_velem_oct_vec(vd, SEW,
vrev8([A1_H1[6], A2_H2[6], A1_H1[4], A2_H2[4], A1_H1[2], A2_H2[2], A1_H1[0], A2_H2[0]]), i);
};

set_vstart(zeros());
RETIRE_SUCCESS
}

mapping clause assembly = VSM3C_VI(vs2, uimm, vd)
<-> "vsm3c.vi" ^ spc() ^ vreg_name(vd) ^ sep() ^ vreg_name(vs2) ^ sep() ^ hex_bits_5(uimm)
51 changes: 51 additions & 0 deletions model/riscv_zvk_utils.sail
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,54 @@ function zvk_ch(x, y, z) = (x & y) ^ (~(x) & z)

val zvk_maj : forall 'n, 'n >= 0. (bits('n), bits('n), bits('n)) -> bits('n)
function zvk_maj(x, y, z) = (x & y) ^ (x & z) ^ (y & z)

/*
* Utility functions for Zvksh
* ----------------------------------------------------------------------
*/

val zvk_p0 : bits(32) -> bits(32)
function zvk_p0(X) = X ^ (X <<< 9) ^ (X <<< 17)

val zvk_p1 : bits(32) -> bits(32)
function zvk_p1(X) = X ^ (X <<< 15) ^ (X <<< 23)

val zvk_sh_w : (bits(32), bits(32), bits(32), bits(32), bits(32)) -> bits(32)
function zvk_sh_w(A, B, C, D, E) = zvk_p1(A ^ B ^ (C <<< 15)) ^ (D <<< 7) ^ E

val zvk_ff1 : (bits(32), bits(32), bits(32)) -> bits(32)
function zvk_ff1(X, Y, Z) = X ^ Y ^ Z

val zvk_ff2 : (bits(32), bits(32), bits(32)) -> bits(32)
function zvk_ff2(X, Y, Z) = (X & Y) | (X & Z) | (Y & Z)

val zvk_ff_j : (bits(32), bits(32), bits(32), int) -> bits(32)
function zvk_ff_j(X, Y, Z, J) = if J <= 15 then zvk_ff1(X, Y, Z) else zvk_ff2(X, Y, Z)

val zvk_gg1 : (bits(32), bits(32), bits(32)) -> bits(32)
function zvk_gg1(X, Y, Z) = X ^ Y ^ Z

val zvk_gg2 : (bits(32), bits(32), bits(32)) -> bits(32)
function zvk_gg2(X, Y, Z) = (X & Y) | (~(X) & Z)

val zvk_gg_j : (bits(32), bits(32), bits(32), int) -> bits(32)
function zvk_gg_j(X, Y, Z, J) = if J <= 15 then zvk_gg1(X, Y, Z) else zvk_gg2(X, Y, Z)

val zvk_t_j : int -> bits(32)
function zvk_t_j(J) = if J <= 15 then 0x79CC4519 else 0x7A879D8A

function zvk_sm3_round(A_H : vector(8, bits(32)), w : bits(32), x : bits(32), j : int) -> vector(8, bits(32)) = {
let t_j = zvk_t_j(j) <<< (j % 32);
let ss1 = ((A_H[0] <<< 12) + A_H[4] + t_j) <<< 7;
let ss2 = ss1 ^ (A_H[0] <<< 12);

let tt1 = zvk_ff_j(A_H[0], A_H[1], A_H[2], j) + A_H[3] + ss2 + x;
let tt2 = zvk_gg_j(A_H[4], A_H[5], A_H[6], j) + A_H[7] + ss1 + w;

let A1 = tt1;
let C1 = A_H[1] <<< 9;
let E1 = zvk_p0(tt2);
let G1 = A_H[5] <<< 19;

[ A_H[6], G1, A_H[4], E1, A_H[2], C1, A_H[0], A1 ]
}
Loading