Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 54 additions & 23 deletions internal/local/activate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -116,14 +117,24 @@ func (service *ProvisioningService) StartSecureHostBasedConfiguration(certsAndKe
// create leaf certificate hash
var certHashByteArray [64]byte

leafHash := sha256.Sum256(certsAndKeys.certs[0].Raw)
copy(certHashByteArray[:], leafHash[:])

certAlgo, err := utils.CheckCertificateAlgorithmSupported(certsAndKeys.certs[0].SignatureAlgorithm)
if err != nil {
return amt.SecureHBasedResponse{}, utils.ActivationFailedCertHash
}

// Generate hash based on certificate algorithm
switch certAlgo {
case 2: // SHA256
leafHash := sha256.Sum256(certsAndKeys.certs[0].Raw)
copy(certHashByteArray[:], leafHash[:])
case 3: // SHA384
leafHash := sha512.Sum384(certsAndKeys.certs[0].Raw)
copy(certHashByteArray[:], leafHash[:])
default:
// Only SHA-256 and SHA-384 are supported for secure host-based configuration
return amt.SecureHBasedResponse{}, errors.New("unsupported certificate algorithm for activation")
}

// Call StartConfigurationHBased
params := amt.SecureHBasedParameters{
CertHash: certHashByteArray,
Expand All @@ -141,12 +152,12 @@ func (service *ProvisioningService) StartSecureHostBasedConfiguration(certsAndKe
func (service *ProvisioningService) ActivateACM(oldWay bool) error {
if oldWay {
// Extract the provisioning certificate
_, certObject, fingerPrint, err := service.GetProvisioningCertObj()
_, certObject, fingerprints, err := service.GetProvisioningCertObj()
if err != nil {
return err
}
// Check provisioning certificate is accepted by AMT
err = service.CompareCertHashes(fingerPrint)
err = service.CompareCertHashes(fingerprints)
if err != nil {
return err
}
Expand Down Expand Up @@ -272,20 +283,20 @@ func (service *ProvisioningService) CCMCommit(tlsConfig *tls.Config) error {
return nil
}

func (service *ProvisioningService) GetProvisioningCertObj() (CertsAndKeys, ProvisioningCertObj, string, error) {
func (service *ProvisioningService) GetProvisioningCertObj() (CertsAndKeys, ProvisioningCertObj, map[string]string, error) {
config := service.config.ACMSettings

certsAndKeys, err := convertPfxToObject(config.ProvisioningCert, config.ProvisioningCertPwd)
if err != nil {
return certsAndKeys, ProvisioningCertObj{}, "", err
return certsAndKeys, ProvisioningCertObj{}, nil, err
}

result, fingerprint, err := dumpPfx(certsAndKeys)
result, fingerprints, err := dumpPfx(certsAndKeys)
if err != nil {
return certsAndKeys, ProvisioningCertObj{}, "", err
return certsAndKeys, ProvisioningCertObj{}, nil, err
}

return certsAndKeys, result, fingerprint, nil
return certsAndKeys, result, fingerprints, nil
}

func convertPfxToObject(pfxb64, passphrase string) (CertsAndKeys, error) {
Expand Down Expand Up @@ -314,20 +325,20 @@ func convertPfxToObject(pfxb64, passphrase string) (CertsAndKeys, error) {
return pfxOut, nil
}

func dumpPfx(pfxobj CertsAndKeys) (ProvisioningCertObj, string, error) {
func dumpPfx(pfxobj CertsAndKeys) (ProvisioningCertObj, map[string]string, error) {
if len(pfxobj.certs) == 0 {
return ProvisioningCertObj{}, "", utils.ActivationFailedNoCertFound
return ProvisioningCertObj{}, nil, utils.ActivationFailedNoCertFound
}

if len(pfxobj.keys) == 0 {
return ProvisioningCertObj{}, "", utils.ActivationFailedNoPrivKeys
return ProvisioningCertObj{}, nil, utils.ActivationFailedNoPrivKeys
}

var provisioningCertificateObj ProvisioningCertObj

var certificateList []*CertificateObject

var fingerprint string
fingerprints := make(map[string]string)

for _, cert := range pfxobj.certs {
pemBlock := &pem.Block{
Expand All @@ -338,19 +349,23 @@ func dumpPfx(pfxobj CertsAndKeys) (ProvisioningCertObj, string, error) {
pem := utils.CleanPEM(string(pem.EncodeToMemory(pemBlock)))
certificateObject := CertificateObject{pem: pem, subject: cert.Subject.String(), issuer: cert.Issuer.String()}

// Get the fingerprint from the Root certificate
// Get the fingerprints from the Root certificate
if cert.Subject.String() == cert.Issuer.String() {
der := cert.Raw
hash := sha256.Sum256(der)
fingerprint = hex.EncodeToString(hash[:])
// SHA-384 (48 bytes) - Supported by AMT 16.1+
hashSHA384 := sha512.Sum384(der)
fingerprints["SHA384"] = hex.EncodeToString(hashSHA384[:])
// SHA-256 (32 bytes) - Supported by all AMT versions
hashSHA256 := sha256.Sum256(der)
fingerprints["SHA256"] = hex.EncodeToString(hashSHA256[:])
}

// Put all the certificateObjects into a single list
certificateList = append(certificateList, &certificateObject)
}

if fingerprint == "" {
return provisioningCertificateObj, "", utils.ActivationFailedNoRootCertFound
if len(fingerprints) == 0 {
return provisioningCertificateObj, nil, utils.ActivationFailedNoRootCertFound
}

// Add them to the certChain in order
Expand All @@ -364,18 +379,34 @@ func dumpPfx(pfxobj CertsAndKeys) (ProvisioningCertObj, string, error) {
// Add the certificate algorithm
provisioningCertificateObj.certificateAlgorithm = pfxobj.certs[0].SignatureAlgorithm

return provisioningCertificateObj, fingerprint, nil
return provisioningCertificateObj, fingerprints, nil
}

func (service *ProvisioningService) CompareCertHashes(fingerPrint string) error {
// compareCertHashes compares certificate hash with AMT stored hashes
// Uses pre-computed SHA-256 and SHA-384 fingerprints to support different AMT platforms
func (service *ProvisioningService) CompareCertHashes(fingerprints map[string]string) error {
// Get all certificate hashes from AMT
result, err := service.amtCommand.GetCertificateHashes()
if err != nil {
return utils.ActivationFailedGetCertHash
}

// Try to match stored hash with corresponding algorithm
for _, v := range result {
if v.Hash == fingerPrint {
return nil
if v.Algorithm != "" {
// Algorithm specified: match SHA256 with SHA256, SHA384 with SHA384
if computedHash, exists := fingerprints[v.Algorithm]; exists {
if v.Hash == computedHash {
return nil
}
}
} else {
// Algorithm not specified: try all computed fingerprints
for _, computedHash := range fingerprints {
if v.Hash == computedHash {
return nil
}
}
}
}

Expand Down
111 changes: 111 additions & 0 deletions internal/local/activate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,114 @@ func TestDumpPfx(t *testing.T) {
_, _, err = dumpPfx(certsAndKeys)
assert.NotNil(t, err)
}

// Test for StartSecureHostBasedConfiguration with different certificate algorithms
func TestStartSecureHostBasedConfiguration(t *testing.T) {
tests := []struct {
name string
certAlgo x509.SignatureAlgorithm
wantErr bool
}{
{
name: "SHA256 algorithm - should succeed",
certAlgo: x509.SHA256WithRSA,
wantErr: false,
},
{
name: "SHA384 algorithm - should succeed",
certAlgo: x509.SHA384WithRSA,
wantErr: false,
},
{
name: "Unknown algorithm - should fail",
certAlgo: x509.UnknownSignatureAlgorithm,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &flags.Flags{}
service := setupService(f)

cert := &x509.Certificate{
SignatureAlgorithm: tt.certAlgo,
Raw: []byte("test-cert-data"),
}

certsAndKeys := CertsAndKeys{
certs: []*x509.Certificate{cert},
keys: []interface{}{"test-key"},
}

_, err := service.StartSecureHostBasedConfiguration(certsAndKeys)

if (err != nil) != tt.wantErr {
t.Errorf("StartSecureHostBasedConfiguration() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

// Test for CompareCertHashes with multi-algorithm support (SHA256 and SHA384)
func TestCompareCertHashes(t *testing.T) {
testCerts := getTestCerts()

tests := []struct {
name string
mockHashes []amt2.CertHashEntry
wantErr bool
}{
{
name: "SHA256 algorithm match - should succeed",
mockHashes: []amt2.CertHashEntry{
{
Hash: testCerts.Root.Fingerprint,
Algorithm: "SHA256",
IsActive: true,
IsDefault: true,
},
},
wantErr: false,
},
{
name: "No matching hash - should fail",
mockHashes: []amt2.CertHashEntry{
{
Hash: "wronghash1234567890abcdef",
Algorithm: "SHA256",
IsActive: true,
IsDefault: true,
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &flags.Flags{}
f.LocalConfig.ACMSettings.ProvisioningCert = testCerts.Pfxb64
f.LocalConfig.ACMSettings.ProvisioningCertPwd = testCerts.PfxPassword
service := setupService(f)
mockCertHashes = tt.mockHashes

// Parse the PFX and get fingerprints
certsAndKeys, err := convertPfxToObject(testCerts.Pfxb64, testCerts.PfxPassword)
if err != nil {
t.Fatalf("Failed to parse PFX: %v", err)
}

_, fingerprints, err := dumpPfx(certsAndKeys)
if err != nil {
t.Fatalf("Failed to dump PFX: %v", err)
}

err = service.CompareCertHashes(fingerprints)

if (err != nil) != tt.wantErr {
t.Errorf("CompareCertHashes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
6 changes: 3 additions & 3 deletions pkg/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func InterpretHashAlgorithm(hashAlgorithm int) (hashSize int, algorithm string)
case 2: // SHA256
hashSize = 32
algorithm = "SHA256"
case 3: // SHA512
hashSize = 64
algorithm = "SHA512"
case 3: // SHA384
hashSize = 48
algorithm = "SHA384"
default:
hashSize = 0
algorithm = "UNKNOWN"
Expand Down
5 changes: 1 addition & 4 deletions pkg/utils/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestInterpretHashAlgorithm(t *testing.T) {
{"Hash0", 0, "MD5", 16},
{"Hash1", 1, "SHA1", 20},
{"Hash2", 2, "SHA256", 32},
{"Hash3", 3, "SHA512", 64},
{"Hash3", 3, "SHA384", 48},
{"Hash4", 4, "UNKNOWN", 0},
}
for _, tt := range tests {
Expand Down Expand Up @@ -176,9 +176,6 @@ func TestCheckCertificateAlgorithmSupported(t *testing.T) {
{"SHA384WithRSA", x509.SHA384WithRSA, 3, false},
{"ECDSAWithSHA384", x509.ECDSAWithSHA384, 3, false},
{"SHA384WithRSAPSS", x509.SHA384WithRSAPSS, 3, false},
{"SHA512WithRSA", x509.SHA512WithRSA, 5, false},
{"ECDSAWithSHA512", x509.ECDSAWithSHA512, 5, false},
{"SHA512WithRSAPSS", x509.SHA512WithRSAPSS, 5, false},

// Unsupported cases
{"UnknownSignatureAlgorithm", x509.UnknownSignatureAlgorithm, 99, true},
Expand Down
Loading