Skip to content

Commit 2f4dc91

Browse files
nadime15charmitropmundkur
committed
Add support for Zvksed extension
Co-authored-by: Charalampos Mitrodimas <[email protected]> Co-authored-by: Prashanth Mundkur <[email protected]>
1 parent 0806ed5 commit 2f4dc91

7 files changed

+185
-2
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ Supported RISC-V ISA features
120120
- Zvbc extension for vector carryless multiplication, v1.0
121121
- Zvkb extension for vector cryptography bit-manipulation, v1.0
122122
- Zvknha and Zvknhb extensions for vector cryptography NIST Suite: Vector SHA-2 Secure Hash, v1.0
123+
- Zvksed extension for vector cryptography ShangMi Suite: SM4 Block Cipher, v1.0
123124
- Machine, Supervisor, and User modes
124125
- Smcntrpmf extension for cycle and instret privilege mode filtering, v1.0
125126
- Sscofpmf extension for Count Overflow and Mode-Based Filtering, v1.0

config/default.json

+3
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@
180180
"Zvknhb" : {
181181
"supported" : true
182182
},
183+
"Zvksed" : {
184+
"supported" : true
185+
},
183186
"Sscofpmf" : {
184187
"supported" : true
185188
},

model/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ foreach (xlen IN ITEMS 32 64)
8989
"riscv_insts_zvbb.sail"
9090
"riscv_insts_zvbc.sail"
9191
"riscv_insts_zvknhab.sail"
92+
"riscv_insts_zvksed.sail"
9293
# Zimop and Zcmop should be at the end so they can be overridden by earlier extensions
9394
"riscv_insts_zimop.sail"
9495
"riscv_insts_zcmop.sail"

model/riscv_extensions.sail

+3
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ function clause hartSupports(Ext_Zvknha) = config extensions.Zvknha.supported
185185

186186
enum clause extension = Ext_Zvknhb
187187
function clause hartSupports(Ext_Zvknhb) = config extensions.Zvknhb.supported
188+
// ShangMi Suite: SM4 Block Cipher
189+
enum clause extension = Ext_Zvksed
190+
function clause hartSupports(Ext_Zvksed) = config extensions.Zvksed.supported
188191

189192
// Count Overflow and Mode-Based Filtering
190193
enum clause extension = Ext_Sscofpmf

model/riscv_insts_vext_utils.sail

+11
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,17 @@ function write_velem_quad(vd, SEW, input, i) = {
183183
write_single_element(SEW, 4 * i + j, vd, slice(input, j * SEW, SEW));
184184
}
185185

186+
/* Extracts 4 consecutive vector elements starting from index 4*i and returns a vector */
187+
val get_velem_quad_vec : forall 'n 'm 'p, 'n > 0 & 8 <= 'm <= 64 & 'p >= 0 & 4 * 'p + 3 < 'n. (vector('n, bits('m)), int('p)) -> vector(4, bits('m))
188+
function get_velem_quad_vec(v, i) = [ v[4 * i + 3], v[4 * i + 2], v[4 * i + 1], v[4 * i] ]
189+
190+
/* Writes each of the 4 elements from the input vector to the vector register vd, starting at position 4 * i */
191+
val write_velem_quad_vec : forall 'p 'n, 8 <= 'n <= 64 & 'p >= 0. (vregidx, int('n), vector(4, bits('n)), int('p)) -> unit
192+
function write_velem_quad_vec(vd, SEW, input, i) = {
193+
foreach(j from 0 to 3)
194+
write_single_element(SEW, 4 * i + j, vd, input[j]);
195+
}
196+
186197
/* Get the starting element index from csr vtype */
187198
val get_start_element : unit -> result(nat, unit)
188199
function get_start_element() = {

model/riscv_insts_zvksed.sail

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*=======================================================================================*/
2+
/* This Sail RISC-V architecture model, comprising all files and */
3+
/* directories except where otherwise noted is subject the BSD */
4+
/* two-clause license in the LICENSE file. */
5+
/* */
6+
/* SPDX-License-Identifier: BSD-2-Clause */
7+
/*=======================================================================================*/
8+
9+
function clause currentlyEnabled(Ext_Zvksed) = hartSupports(Ext_Zvksed) & currentlyEnabled(Ext_V)
10+
11+
union clause ast = VSM4K_VI : (vregidx, bits(5), vregidx)
12+
13+
mapping clause encdec = VSM4K_VI(vs2, uimm, vd)
14+
<-> 0b1000011 @ encdec_vreg(vs2) @ uimm @ 0b010 @ encdec_vreg(vd) @ 0b1110111
15+
when currentlyEnabled(Ext_Zvksed) & get_sew() == 32 & zvk_check_encdec(128, 4)
16+
17+
mapping clause assembly = VSM4K_VI(vs2, uimm, vd)
18+
<-> "vsm4k.vi" ^ spc() ^ vreg_name(vd) ^ sep() ^ vreg_name(vs2) ^ sep() ^ hex_bits_5(uimm)
19+
20+
function clause execute (VSM4K_VI(vs2, uimm, vd)) = {
21+
let SEW = get_sew();
22+
let LMUL_pow = get_lmul_pow();
23+
let num_elem = get_num_elem(LMUL_pow, SEW);
24+
25+
assert(SEW == 32);
26+
27+
let vs2_val = read_vreg(num_elem, SEW, LMUL_pow, vs2);
28+
29+
let rnd = unsigned(uimm[2..0]);
30+
31+
let eg_len = (unsigned(vl) / 4);
32+
let eg_start = (unsigned(vstart) / 4);
33+
34+
foreach (i from eg_start to (eg_len - 1)) {
35+
assert(i * 4 + 3 < num_elem);
36+
37+
let rk_in : vector(4, bits(32)) = get_velem_quad_vec(vs2_val, i);
38+
var rk_out : vector(4, bits(32)) = vector_init(zeros());
39+
40+
var B = rk_in[1] ^ rk_in[2] ^ rk_in[3] ^ zvk_sm4_sbox(4 * rnd);
41+
var S = zvk_sm4_subword(B);
42+
rk_out[0] = zvk_round_key(rk_in[0], S);
43+
44+
B = rk_in[2] ^ rk_in[3] ^ rk_out[0] ^ zvk_sm4_sbox(4 * rnd + 1);
45+
S = zvk_sm4_subword(B);
46+
rk_out[1] = zvk_round_key(rk_in[1], S);
47+
48+
B = rk_in[3] ^ rk_out[0] ^ rk_out[1] ^ zvk_sm4_sbox(4 * rnd + 2);
49+
S = zvk_sm4_subword(B);
50+
rk_out[2] = zvk_round_key(rk_in[2], S);
51+
52+
B = rk_out[0] ^ rk_out[1] ^ rk_out[2] ^ zvk_sm4_sbox(4 * rnd + 3);
53+
S = zvk_sm4_subword(B);
54+
rk_out[3] = zvk_round_key(rk_in[3], S);
55+
56+
write_velem_quad_vec(vd, SEW, rk_out, i);
57+
};
58+
59+
set_vstart(zeros());
60+
RETIRE_SUCCESS
61+
}
62+
63+
union clause ast = ZVKSM4RTYPE : (zvkfunct6, vregidx, vregidx)
64+
65+
mapping clause encdec = ZVKSM4RTYPE(ZVK_VSM4RVV, vs2, vd)
66+
<-> 0b1010001 @ encdec_vreg(vs2) @ 0b10000 @ 0b010 @ encdec_vreg(vd) @ 0b1110111
67+
when currentlyEnabled(Ext_Zvksed) & get_sew() == 32 & zvk_check_encdec(128, 4)
68+
69+
mapping clause encdec = ZVKSM4RTYPE(ZVK_VSM4RVS, vs2, vd)
70+
<-> 0b1010011 @ encdec_vreg(vs2) @ 0b10000 @ 0b010 @ encdec_vreg(vd) @ 0b1110111
71+
when currentlyEnabled(Ext_Zvksed) & get_sew() == 32 & zvk_check_encdec(128, 4) & zvk_valid_reg_overlap(vs2, vd, get_lmul_pow())
72+
73+
mapping vsm4r_mnemonic : zvkfunct6 <-> string = {
74+
ZVK_VSM4RVV <-> "vsm4r.vv",
75+
ZVK_VSM4RVS <-> "vsm4r.vs",
76+
}
77+
78+
mapping clause assembly = ZVKSM4RTYPE(funct6, vs2, vd)
79+
<-> vsm4r_mnemonic(funct6) ^ spc() ^ vreg_name(vd) ^ sep() ^ vreg_name(vs2)
80+
81+
function clause execute (ZVKSM4RTYPE(funct6, vs2, vd)) = {
82+
let SEW = get_sew();
83+
let LMUL_pow = get_lmul_pow();
84+
let num_elem = get_num_elem(LMUL_pow, SEW);
85+
86+
assert(SEW == 32);
87+
88+
let vs2_val = read_vreg(num_elem, SEW, LMUL_pow, vs2);
89+
let vd_val = read_vreg(num_elem, SEW, LMUL_pow, vd);
90+
91+
let eg_len = (unsigned(vl) / 4);
92+
let eg_start = (unsigned(vstart) / 4);
93+
94+
foreach (i from eg_start to (eg_len - 1)) {
95+
assert(i * 4 + 3 < num_elem);
96+
97+
let rk_in : vector(4, bits(32)) = if funct6 == ZVK_VSM4RVV
98+
then get_velem_quad_vec(vs2_val, i)
99+
else get_velem_quad_vec(vs2_val, 0);
100+
101+
let x_in : vector(4, bits(32)) = get_velem_quad_vec(vd_val, i);
102+
var x_out : vector(4, bits(32)) = vector_init(zeros());
103+
104+
var B = x_in[1] ^ x_in[2] ^ x_in[3] ^ rk_in[0];
105+
var S = zvk_sm4_subword(B);
106+
x_out[0] = zvk_sm4_round(x_in[0], S);
107+
108+
B = x_in[2] ^ x_in[3] ^ x_out[0] ^ rk_in[1];
109+
S = zvk_sm4_subword(B);
110+
x_out[1] = zvk_sm4_round(x_in[1], S);
111+
112+
B = x_in[3] ^ x_out[0] ^ x_out[1] ^ rk_in[2];
113+
S = zvk_sm4_subword(B);
114+
x_out[2] = zvk_sm4_round(x_in[2], S);
115+
116+
B = x_out[0] ^ x_out[1] ^ x_out[2] ^ rk_in[3];
117+
S = zvk_sm4_subword(B);
118+
x_out[3] = zvk_sm4_round(x_in[3], S);
119+
120+
write_velem_quad_vec(vd, SEW, x_out, i);
121+
};
122+
123+
set_vstart(zeros());
124+
RETIRE_SUCCESS
125+
}

model/riscv_zvk_utils.sail

+41-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ function zvk_valid_reg_overlap(rs, rd, emul_pow) = {
1616

1717
function zvk_check_encdec(EGW: int, EGS: int) -> bool = (unsigned(vl) % EGS == 0) & (unsigned(vstart) % EGS == 0) & (2 ^ get_lmul_pow() * VLEN) >= EGW
1818

19+
enum zvkfunct6 = {ZVK_VSHA2CH, ZVK_VSHA2CL, ZVK_VSM4RVV, ZVK_VSM4RVS}
20+
1921
/*
2022
* Utility functions for Zvknh[ab]
2123
* ----------------------------------------------------------------------
2224
*/
2325

24-
enum zvkfunct6 = {ZVK_VSHA2CH, ZVK_VSHA2CL}
25-
2626
function zvknhab_check_encdec(vs2: vregidx, vs1: vregidx, vd: vregidx) -> bool = {
2727
let SEW = get_sew();
2828
let LMUL_pow = get_lmul_pow();
@@ -66,3 +66,42 @@ function zvk_ch(x, y, z) = (x & y) ^ (~(x) & z)
6666

6767
val zvk_maj : forall 'n, 'n >= 0. (bits('n), bits('n), bits('n)) -> bits('n)
6868
function zvk_maj(x, y, z) = (x & y) ^ (x & z) ^ (y & z)
69+
70+
/*
71+
* Utility functions for Zvksed
72+
* ----------------------------------------------------------------------
73+
*/
74+
75+
val zvk_round_key : (bits(32), bits(32)) -> bits(32)
76+
function zvk_round_key(X, S) = X ^ (S ^ (S <<< 13) ^ (S <<< 23))
77+
78+
val zvk_sm4_round : (bits(32), bits(32)) -> bits(32)
79+
function zvk_sm4_round(X, S) = X ^ (S ^ (S <<< 2) ^ (S <<< 10) ^ (S <<< 18) ^ (S <<< 24))
80+
81+
// SM4 Constant Key (CK)
82+
let zvksed_ck : vector(32, bits(32)) = [
83+
0x00070E15, 0x1C232A31, 0x383F464D, 0x545B6269,
84+
0x70777E85, 0x8C939AA1, 0xA8AFB6BD, 0xC4CBD2D9,
85+
0xE0E7EEF5, 0xFC030A11, 0x181F262D, 0x343B4249,
86+
0x50575E65, 0x6C737A81, 0x888F969D, 0xA4ABB2B9,
87+
0xC0C7CED5, 0xDCE3EAF1, 0xF8FF060D, 0x141B2229,
88+
0x30373E45, 0x4C535A61, 0x686F767D, 0x848B9299,
89+
0xA0A7AEB5, 0xBCC3CAD1, 0xD8DFE6ED, 0xF4FB0209,
90+
0x10171E25, 0x2C333A41, 0x484F565D, 0x646B7279
91+
]
92+
93+
val zvksed_box_lookup : (bits(5), vector(32, bits(32))) -> bits(32)
94+
function zvksed_box_lookup(x, table) = {
95+
table[31 - unsigned(x)]
96+
}
97+
98+
val zvk_sm4_sbox : (int) -> bits(32)
99+
function zvk_sm4_sbox(x) = zvksed_box_lookup(to_bits(5, x), zvksed_ck)
100+
101+
val zvk_sm4_subword : bits(32) -> bits(32)
102+
function zvk_sm4_subword(x) = {
103+
sm4_sbox(x[31..24]) @
104+
sm4_sbox(x[23..16]) @
105+
sm4_sbox(x[15.. 8]) @
106+
sm4_sbox(x[ 7.. 0])
107+
}

0 commit comments

Comments
 (0)