diff --git a/src/main/java/com/ibm/crypto/plus/provider/OAEPParameters.java b/src/main/java/com/ibm/crypto/plus/provider/OAEPParameters.java index 4d9efb2e2..0e5f420cf 100644 --- a/src/main/java/com/ibm/crypto/plus/provider/OAEPParameters.java +++ b/src/main/java/com/ibm/crypto/plus/provider/OAEPParameters.java @@ -1,5 +1,5 @@ /* - * Copyright IBM Corp. 2023 + * Copyright IBM Corp. 2023, 2025 * * This code is free software; you can redistribute it and/or modify it * under the terms provided by IBM in the LICENSE file that accompanied @@ -40,6 +40,8 @@ public final class OAEPParameters extends AlgorithmParametersSpi { private static ObjectIdentifier OID_MGF1; private static ObjectIdentifier OID_PSpecified; + private static boolean disableOrderCheck = Boolean.getBoolean("openjceplus.oaep.disableOrderCheck"); + static { try { OID_MGF1 = ObjectIdentifier.of("1.2.840.113549.1.1.8"); @@ -84,45 +86,104 @@ protected void engineInit(AlgorithmParameterSpec paramSpec) } protected void engineInit(byte[] encoded) throws IOException { - DerInputStream der = new DerInputStream(encoded); - mdName = "SHA-1"; - mgfSpec = MGF1ParameterSpec.SHA1; - p = new byte[0]; - DerValue[] datum = der.getSequence(3); - for (int i = 0; i < datum.length; i++) { - DerValue data = datum[i]; - if (data.isContextSpecific((byte) 0x00)) { - // hash algid - mdName = AlgorithmId.parse(data.getData().getDerValue()).getName(); - } else if (data.isContextSpecific((byte) 0x01)) { - // mgf algid - AlgorithmId val = AlgorithmId.parse(data.getData().getDerValue()); - if (!val.getOID().equals((Object) OID_MGF1)) { + if (disableOrderCheck) { + // Deprecated. To be removed in a future release. + // + // Disable check and revert to old behaviour using + // the -Dopenjceplus.oaep.disableOrderCheck flag. + DerInputStream der = new DerInputStream(encoded); + mdName = "SHA-1"; + mgfSpec = MGF1ParameterSpec.SHA1; + p = new byte[0]; + DerValue[] datum = der.getSequence(3); + for (int i = 0; i < datum.length; i++) { + DerValue data = datum[i]; + if (data.isContextSpecific((byte) 0x00)) { + // hash algid + mdName = AlgorithmId.parse(data.getData().getDerValue()).getName(); + } else if (data.isContextSpecific((byte) 0x01)) { + // mgf algid + AlgorithmId val = AlgorithmId.parse(data.getData().getDerValue()); + if (!val.getOID().equals((Object) OID_MGF1)) { + throw new IOException("Only MGF1 mgf is supported"); + } + byte[] encodedParams = val.getEncodedParams(); + if (encodedParams == null) { + throw new IOException("Missing MGF1 parameters"); + } + AlgorithmId params = AlgorithmId.parse(new DerValue(encodedParams)); + String mgfDigestName = params.getName(); + if (mgfDigestName.equals("SHA-1")) { + mgfSpec = MGF1ParameterSpec.SHA1; + } else if (mgfDigestName.equals("SHA-224")) { + mgfSpec = MGF1ParameterSpec.SHA224; + } else if (mgfDigestName.equals("SHA-256")) { + mgfSpec = MGF1ParameterSpec.SHA256; + } else if (mgfDigestName.equals("SHA-384")) { + mgfSpec = MGF1ParameterSpec.SHA384; + } else if (mgfDigestName.equals("SHA-512")) { + mgfSpec = MGF1ParameterSpec.SHA512; + } else { + throw new IOException("Unrecognized message digest algorithm"); + } + } else if (data.isContextSpecific((byte) 0x02)) { + // pSource algid + AlgorithmId val = AlgorithmId.parse(data.getData().getDerValue()); + if (!val.getOID().equals((Object) OID_PSpecified)) { + throw new IOException("Wrong OID for pSpecified"); + } + byte[] encodedParams = val.getEncodedParams(); + if (encodedParams == null) { + throw new IOException("Missing pSpecified label"); + } + + DerInputStream dis = new DerInputStream(encodedParams); + p = dis.getOctetString(); + if (dis.available() != 0) { + throw new IOException("Extra data for pSpecified"); + } + } else { + throw new IOException("Invalid encoded OAEPParameters"); + } + } + } else { + DerInputStream der = DerValue.wrap(encoded).data(); + var sub = der.getOptionalExplicitContextSpecific(0); + if (sub.isPresent()) { + mdName = AlgorithmId.parse(sub.get()).getName(); + } else { + mdName = "SHA-1"; + } + sub = der.getOptionalExplicitContextSpecific(1); + if (sub.isPresent()) { + AlgorithmId val = AlgorithmId.parse(sub.get()); + if (!val.getOID().equals(OID_MGF1)) { throw new IOException("Only MGF1 mgf is supported"); } byte[] encodedParams = val.getEncodedParams(); if (encodedParams == null) { throw new IOException("Missing MGF1 parameters"); } - AlgorithmId params = AlgorithmId.parse(new DerValue(encodedParams)); - String mgfDigestName = params.getName(); - if (mgfDigestName.equals("SHA-1")) { - mgfSpec = MGF1ParameterSpec.SHA1; - } else if (mgfDigestName.equals("SHA-224")) { - mgfSpec = MGF1ParameterSpec.SHA224; - } else if (mgfDigestName.equals("SHA-256")) { - mgfSpec = MGF1ParameterSpec.SHA256; - } else if (mgfDigestName.equals("SHA-384")) { - mgfSpec = MGF1ParameterSpec.SHA384; - } else if (mgfDigestName.equals("SHA-512")) { - mgfSpec = MGF1ParameterSpec.SHA512; - } else { - throw new IOException("Unrecognized message digest algorithm"); - } - } else if (data.isContextSpecific((byte) 0x02)) { - // pSource algid - AlgorithmId val = AlgorithmId.parse(data.getData().getDerValue()); - if (!val.getOID().equals((Object) OID_PSpecified)) { + AlgorithmId params = AlgorithmId.parse( + new DerValue(encodedParams)); + mgfSpec = switch (params.getName()) { + case "SHA-1" -> MGF1ParameterSpec.SHA1; + case "SHA-224" -> MGF1ParameterSpec.SHA224; + case "SHA-256" -> MGF1ParameterSpec.SHA256; + case "SHA-384" -> MGF1ParameterSpec.SHA384; + case "SHA-512" -> MGF1ParameterSpec.SHA512; + case "SHA-512/224" -> MGF1ParameterSpec.SHA512_224; + case "SHA-512/256" -> MGF1ParameterSpec.SHA512_256; + default -> throw new IOException( + "Unrecognized message digest algorithm"); + }; + } else { + mgfSpec = MGF1ParameterSpec.SHA1; + } + sub = der.getOptionalExplicitContextSpecific(2); + if (sub.isPresent()) { + AlgorithmId val = AlgorithmId.parse(sub.get()); + if (!val.getOID().equals(OID_PSpecified)) { throw new IOException("Wrong OID for pSpecified"); } byte[] encodedParams = val.getEncodedParams(); @@ -130,14 +191,11 @@ protected void engineInit(byte[] encoded) throws IOException { throw new IOException("Missing pSpecified label"); } - DerInputStream dis = new DerInputStream(encodedParams); - p = dis.getOctetString(); - if (dis.available() != 0) { - throw new IOException("Extra data for pSpecified"); - } + p = DerValue.wrap(encodedParams).getOctetString(); } else { - throw new IOException("Invalid encoded OAEPParameters"); + p = new byte[0]; } + der.atEnd(); } } diff --git a/src/test/java/ibm/jceplus/junit/base/BaseTestOAEPOrderCheck.java b/src/test/java/ibm/jceplus/junit/base/BaseTestOAEPOrderCheck.java new file mode 100644 index 000000000..04b822404 --- /dev/null +++ b/src/test/java/ibm/jceplus/junit/base/BaseTestOAEPOrderCheck.java @@ -0,0 +1,50 @@ +/* + * Copyright IBM Corp. 2025 + * + * This code is free software; you can redistribute it and/or modify it + * under the terms provided by IBM in the LICENSE file that accompanied + * this code, including the "Classpath" Exception described therein. + */ + +package ibm.jceplus.junit.base; + +import java.io.IOException; +import java.security.AlgorithmParameters; +import java.security.spec.MGF1ParameterSpec; +import java.util.Arrays; +import javax.crypto.spec.OAEPParameterSpec; +import javax.crypto.spec.PSource; +import org.junit.jupiter.api.Test; + +public class BaseTestOAEPOrderCheck extends BaseTestJunit5 { + + @Test + public void testOAEPOrder() throws Exception { + // Do not use default fields + OAEPParameterSpec spec = new OAEPParameterSpec( + "SHA-384", "MGF1", MGF1ParameterSpec.SHA384, + new PSource.PSpecified(new byte[10])); + AlgorithmParameters alg = AlgorithmParameters.getInstance("OAEP", getProviderName()); + alg.init(spec); + byte[] encoded = alg.getEncoded(); + + // Extract the fields inside encoding + // [0] HashAlgorithm + byte[] a0 = Arrays.copyOfRange(encoded, 2, encoded[3] + 4); + // [1] MaskGenAlgorithm + [2] PSourceAlgorithm + byte[] a12 = Arrays.copyOfRange(encoded, 2 + a0.length, encoded.length); + + // and rearrange [1] and [2] before [0] + System.arraycopy(a12, 0, encoded, 2, a12.length); + System.arraycopy(a0, 0, encoded, 2 + a12.length, a0.length); + + AlgorithmParameters alg2 = AlgorithmParameters.getInstance("OAEP", getProviderName()); + try { + alg2.init(encoded); + throw new RuntimeException("Should fail"); + } catch (IOException ioe) { + // expected + ioe.printStackTrace(); + } + } +} diff --git a/src/test/java/ibm/jceplus/junit/openjceplus/TestAll.java b/src/test/java/ibm/jceplus/junit/openjceplus/TestAll.java index 49d885624..8a23b6767 100644 --- a/src/test/java/ibm/jceplus/junit/openjceplus/TestAll.java +++ b/src/test/java/ibm/jceplus/junit/openjceplus/TestAll.java @@ -94,6 +94,7 @@ TestIsAssignableFromOrder.class, TestMD5.class, TestMiniRSAPSS2.class, + TestOAEPOrderCheck.class, TestPBKDF2.class, TestPBKDF2Interop.class, TestPublicMethodsToMakeNonPublic.class, diff --git a/src/test/java/ibm/jceplus/junit/openjceplus/TestOAEPOrderCheck.java b/src/test/java/ibm/jceplus/junit/openjceplus/TestOAEPOrderCheck.java new file mode 100644 index 000000000..eb9bf973f --- /dev/null +++ b/src/test/java/ibm/jceplus/junit/openjceplus/TestOAEPOrderCheck.java @@ -0,0 +1,24 @@ +/* + * Copyright IBM Corp. 2025 + * + * This code is free software; you can redistribute it and/or modify it + * under the terms provided by IBM in the LICENSE file that accompanied + * this code, including the "Classpath" Exception described therein. + */ + +package ibm.jceplus.junit.openjceplus; + +import ibm.jceplus.junit.base.BaseTestOAEPOrderCheck; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; + +@TestInstance(Lifecycle.PER_CLASS) +public class TestOAEPOrderCheck extends BaseTestOAEPOrderCheck { + + @BeforeAll + public void beforeAll() { + Utils.loadProviderTestSuite(); + setProviderName(Utils.TEST_SUITE_PROVIDER_NAME); + } +} \ No newline at end of file