diff --git a/sm4/sm4.go b/sm4/sm4.go index 0e301deb..edce6873 100644 --- a/sm4/sm4.go +++ b/sm4/sm4.go @@ -19,7 +19,6 @@ modified by Jack, 2017 Oct package sm4 import ( - "bytes" "crypto/cipher" "errors" "strconv" @@ -268,8 +267,10 @@ func xor(in, iv []byte) (out []byte) { func pkcs7Padding(src []byte) []byte { padding := BlockSize - len(src)%BlockSize - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - return append(src, padtext...) + for i := 0; i < padding; i++ { + src = append(src, byte(padding)) + } + return src } func pkcs7UnPadding(src []byte) ([]byte, error) { diff --git a/sm4/sm4_test.go b/sm4/sm4_test.go index 6115ba14..d8b1f8a0 100644 --- a/sm4/sm4_test.go +++ b/sm4/sm4_test.go @@ -16,6 +16,7 @@ limitations under the License. package sm4 import ( + "bytes" "fmt" "reflect" "testing" @@ -151,3 +152,21 @@ func testCompare(key1, key2 []byte) bool { } return true } + +func TestPkcs7Padding(t *testing.T) { + src := []byte("0123456789abcdef") + src = pkcs7Padding(src) + + want := []byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 97, 98, 99, 100, 101, 102, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16} + if !bytes.Equal(src, want) { + t.Errorf("want %v, got %v", want, src) + } +} + +func BenchmarkPkcs7Padding(b *testing.B) { + for i := 0; i < b.N; i++ { + src := []byte("0123456789abcdef") + src = pkcs7Padding(src) + _ = src + } +}