diff --git a/src/main/java/com/uid2/shared/secure/AzureCCCoreAttestationService.java b/src/main/java/com/uid2/shared/secure/AzureCCCoreAttestationService.java index 417f5e32..8a19e94d 100644 --- a/src/main/java/com/uid2/shared/secure/AzureCCCoreAttestationService.java +++ b/src/main/java/com/uid2/shared/secure/AzureCCCoreAttestationService.java @@ -25,14 +25,18 @@ public class AzureCCCoreAttestationService implements ICoreAttestationService { private final IPolicyValidator policyValidator; - private final String azureCcProtocol; + private final Protocol azureCcProtocol; - public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl, String azureCcProtocol) { + public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl, Protocol azureCcProtocol) { this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl), azureCcProtocol); } + public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl) { + this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl), Protocol.AZURE_CC_ACI); + } + // used in UT - protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator, String azureCcProtocol) { + protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator, Protocol azureCcProtocol) { this.tokenSignatureValidator = tokenSignatureValidator; this.policyValidator = policyValidator; this.azureCcProtocol = azureCcProtocol; diff --git a/src/main/java/com/uid2/shared/secure/GcpOidcCoreAttestationService.java b/src/main/java/com/uid2/shared/secure/GcpOidcCoreAttestationService.java index be6afd0f..debc1607 100644 --- a/src/main/java/com/uid2/shared/secure/GcpOidcCoreAttestationService.java +++ b/src/main/java/com/uid2/shared/secure/GcpOidcCoreAttestationService.java @@ -39,10 +39,10 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler { @@ -76,8 +73,8 @@ public void testHappyPath(String azureProtocol) throws AttestationException { } @ParameterizedTest - @MethodSource("argumentProvider") - public void testSignatureCheckFailed_ClientError(String azureProtocol) throws AttestationException { + @EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"}) + public void testSignatureCheckFailed_ClientError(Protocol azureProtocol) throws AttestationException { var errorStr = "token signature validation failed"; when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD)); var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol); @@ -90,8 +87,8 @@ public void testSignatureCheckFailed_ClientError(String azureProtocol) throws At } @ParameterizedTest - @MethodSource("argumentProvider") - public void testSignatureCheckFailed_ServerError(String azureProtocol) throws AttestationException { + @EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"}) + public void testSignatureCheckFailed_ServerError(Protocol azureProtocol) throws AttestationException { when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error")); var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol); provider.registerEnclave(ENCLAVE_ID); @@ -102,8 +99,8 @@ public void testSignatureCheckFailed_ServerError(String azureProtocol) throws At } @ParameterizedTest - @MethodSource("argumentProvider") - public void testPolicyCheckSuccess_ClientError(String azureProtocol) throws AttestationException { + @EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"}) + public void testPolicyCheckSuccess_ClientError(Protocol azureProtocol) throws AttestationException { var errorStr = "policy validation failed"; when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD)); var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol); @@ -116,8 +113,8 @@ public void testPolicyCheckSuccess_ClientError(String azureProtocol) throws Atte } @ParameterizedTest - @MethodSource("argumentProvider") - public void testPolicyCheckFailed_ServerError(String azureProtocol) throws AttestationException { + @EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"}) + public void testPolicyCheckFailed_ServerError(Protocol azureProtocol) throws AttestationException { when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error")); var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol); provider.registerEnclave(ENCLAVE_ID); @@ -128,8 +125,8 @@ public void testPolicyCheckFailed_ServerError(String azureProtocol) throws Attes } @ParameterizedTest - @MethodSource("argumentProvider") - public void testEnclaveNotRegistered(String azureProtocol) throws AttestationException { + @EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"}) + public void testEnclaveNotRegistered(Protocol azureProtocol) throws AttestationException { var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol); attest(provider, ar -> { assertTrue(ar.succeeded()); @@ -144,11 +141,4 @@ private static void attest(ICoreAttestationService provider, Handler argumentProvider() { - return Stream.of( - Arguments.of(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL), - Arguments.of(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL) - ); - } } diff --git a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java index e932d1f5..b14dba46 100644 --- a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java +++ b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java @@ -1,6 +1,7 @@ package com.uid2.shared.secure.azurecc; import com.uid2.shared.secure.AttestationException; +import com.uid2.shared.secure.Protocol; import com.uid2.shared.secure.TestClock; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.params.ParameterizedTest; @@ -16,7 +17,7 @@ public class MaaTokenSignatureValidatorTest { @ParameterizedTest @MethodSource("argumentProvider") - public void testPayload(String payloadPath, String protocol) throws Exception { + public void testPayload(String payloadPath, Protocol protocol) throws Exception { // expire at 1695313895 var payload = loadFromJson(payloadPath); var clock = new TestClock(); @@ -41,13 +42,13 @@ public void testE2E() throws AttestationException { var maaToken = ""; var maaServerUrl = "https://sharedeus.eus.attest.azure.net"; var validator = new MaaTokenSignatureValidator(maaServerUrl); - var token = validator.validate(maaToken, MaaTokenPayload.AZURE_CC_ACI_PROTOCOL); + var token = validator.validate(maaToken, Protocol.AZURE_CC_ACI); } static Stream argumentProvider() { return Stream.of( - Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aci.json", MaaTokenPayload.AZURE_CC_ACI_PROTOCOL), - Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aks.json", MaaTokenPayload.AZURE_CC_AKS_PROTOCOL) + Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aci.json", Protocol.AZURE_CC_ACI), + Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aks.json", Protocol.AZURE_CC_AKS) ); } } diff --git a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java index 82714b1e..303996a8 100644 --- a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java +++ b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java @@ -4,6 +4,7 @@ import com.google.gson.JsonObject; import com.uid2.shared.Const; import com.uid2.shared.secure.AttestationException; +import com.uid2.shared.secure.Protocol; import java.security.KeyPairGenerator; import java.security.PublicKey; @@ -14,7 +15,7 @@ public class MaaTokenUtils { public static final String MAA_BASE_URL = "https://sharedeus.eus.attest.azure.net"; - public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock, String protocol) throws Exception{ + public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock, Protocol protocol) throws Exception{ var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass); gen.initialize(2048, new SecureRandom()); var keyPair = gen.generateKeyPair(); diff --git a/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java b/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java index 9e4cde11..e5b4f977 100644 --- a/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java +++ b/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java @@ -3,6 +3,7 @@ import com.uid2.shared.secure.AttestationClientException; import com.uid2.shared.secure.AttestationException; import com.uid2.shared.secure.AttestationFailure; +import com.uid2.shared.secure.Protocol; import org.junit.jupiter.api.Test; import java.nio.ByteBuffer; @@ -97,7 +98,7 @@ private MaaTokenPayload generateBasicPayload() { .vmDebuggable(false) .runtimeData(generateBasicRuntimeData()) .ccePolicyDigest(CCE_POLICY_DIGEST) - .azureProtocol(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL) + .azureProtocol(Protocol.AZURE_CC_ACI) .build(); } @@ -145,7 +146,7 @@ public void testValidationSuccess_AksWithAzureSignedKataccUvm() throws Attestati var aksPayload = generateBasicPayload() .toBuilder() .complianceStatus("azure-signed-katacc-uvm") - .azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL) + .azureProtocol(Protocol.AZURE_CC_AKS) .build(); var enclaveId = validator.validate(aksPayload, PUBLIC_KEY); assertEquals(CCE_POLICY_DIGEST, enclaveId); @@ -157,22 +158,11 @@ public void testValidationFailure_AksWithOtherUvm() { var aksPayload = generateBasicPayload() .toBuilder() .complianceStatus("fake-compliance") - .azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL) + .azureProtocol(Protocol.AZURE_CC_AKS) .build(); Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY)); assertEquals("Not run in Azure Compliance Utility VM", t.getMessage()); assertEquals(AttestationFailure.BAD_FORMAT, ((AttestationClientException)t).getAttestationFailure()); } - @Test - public void testValidationFailure_InvalidProtocol() { - var validator = new PolicyValidator(ATTESTATION_URL); - var aksPayload = generateBasicPayload() - .toBuilder() - .azureProtocol("fake-protocol") - .build(); - Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY)); - assertEquals("Azure protocol: fake-protocol not supported", t.getMessage()); - assertEquals(AttestationFailure.INVALID_PROTOCOL, ((AttestationClientException)t).getAttestationFailure()); - } }