Skip to content
4 changes: 4 additions & 0 deletions config/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@ type CryptoConfiguration struct {
AggregatorPrivateKey string
AggregatorAccountEVM string
AggregatorAccountStarknet string
MessageHashMaxSize int
}

// CryptoConfig sets the crypto configuration
func CryptoConfig() *CryptoConfiguration {

viper.SetDefault("MESSAGE_HASH_MAX_SIZE", 500)

return &CryptoConfiguration{
HDWalletMnemonic: viper.GetString("HD_WALLET_MNEMONIC"),
AggregatorPublicKey: viper.GetString("AGGREGATOR_PUBLIC_KEY"),
AggregatorPrivateKey: viper.GetString("AGGREGATOR_PRIVATE_KEY"),
AggregatorAccountEVM: viper.GetString("AGGREGATOR_ACCOUNT_EVM"),
AggregatorAccountStarknet: viper.GetString("AGGREGATOR_ACCOUNT_STARKNET"),
MessageHashMaxSize: viper.GetInt("MESSAGE_HASH_MAX_SIZE"),
}
}

Expand Down
165 changes: 149 additions & 16 deletions utils/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"encoding/pem"
Expand All @@ -26,9 +27,11 @@ import (
"golang.org/x/crypto/bcrypt"
)

var authConf = config.AuthConfig()
var cryptoConf = config.CryptoConfig()
var serverConf = config.ServerConfig()
var (
authConf = config.AuthConfig()
cryptoConf = config.CryptoConfig()
serverConf = config.ServerConfig()
)

// CheckPasswordHash is a function to compare provided password with the hashed password
func CheckPasswordHash(password, hash string) bool {
Expand Down Expand Up @@ -63,7 +66,6 @@ func EncryptPlain(plaintext []byte) ([]byte, error) {

// DecryptPlain decrypts ciphertext using AES encryption algorithm with Galois Counter Mode
func DecryptPlain(ciphertext []byte) ([]byte, error) {

block, err := aes.NewCipher([]byte(authConf.Secret))
if err != nil {
return nil, err
Expand All @@ -89,7 +91,6 @@ func DecryptPlain(ciphertext []byte) ([]byte, error) {

// EncryptJSON encrypts JSON serializable data using AES encryption algorithm with Galois Counter Mode
func EncryptJSON(data interface{}) ([]byte, error) {

// Encode data to JSON
plaintext, err := json.Marshal(data)
if err != nil {
Expand All @@ -107,7 +108,6 @@ func EncryptJSON(data interface{}) ([]byte, error) {

// DecryptJSON decrypts JSON serializable data using AES encryption algorithm with Galois Counter Mode
func DecryptJSON(ciphertext []byte) (interface{}, error) {

// Decrypt as normal
plaintext, err := DecryptPlain(ciphertext)
if err != nil {
Expand All @@ -121,7 +121,6 @@ func DecryptJSON(ciphertext []byte) (interface{}, error) {
}

return data, nil

}

// PublicKeyEncryptPlain encrypts plaintext using RSA 2048 encryption algorithm
Expand Down Expand Up @@ -155,7 +154,6 @@ func PublicKeyEncryptPlain(plaintext []byte, publicKeyPEM string) ([]byte, error

// PublicKeyEncryptJSON encrypts JSON serializable data using RSA 2048 encryption algorithm
func PublicKeyEncryptJSON(data interface{}, publicKeyPEM string) ([]byte, error) {

// Encode data to JSON
plaintext, err := json.Marshal(data)
if err != nil {
Expand Down Expand Up @@ -189,7 +187,6 @@ func PublicKeyDecryptPlain(ciphertext []byte, privateKeyPEM string) ([]byte, err

// PublicKeyDecryptJSON decrypts JSON serializable data using RSA 2048 encryption algorithm
func PublicKeyDecryptJSON(ciphertext []byte, privateKeyPEM string) (interface{}, error) {

// Decrypt as normal
plaintext, err := PublicKeyDecryptPlain(ciphertext, privateKeyPEM)
if err != nil {
Expand Down Expand Up @@ -252,6 +249,112 @@ func GenerateTronAccountFromIndex(accountIndex int) (wallet *tronWallet.TronWall
return wallet, nil
}

// encryptHybridJSON encrypts JSON data using AES-256-GCM + RSA-2048 with size limit
func encryptHybridJSON(data interface{}, publicKeyPEM string, maxSize int) ([]byte, error) {
// Marshal to JSON
plaintext, err := json.Marshal(data)
if err != nil {
return nil, err
}

// Enforce size limit
if len(plaintext) > maxSize {
return nil, fmt.Errorf("payload too large: %d bytes (max %d)", len(plaintext), maxSize)
}

// Generate random AES-256 key
aesKey := make([]byte, 32)
if _, err := rand.Read(aesKey); err != nil {
return nil, fmt.Errorf("failed to generate AES key: %w", err)
}

// Encrypt plaintext with AES-GCM
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, err
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonce := make([]byte, gcm.NonceSize()) // gcm.NonceSize() always returns 12 bytes
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}

aesCiphertext := gcm.Seal(nonce, nonce, plaintext, nil)

// Encrypt AES key with RSA
encryptedKey, err := PublicKeyEncryptPlain(aesKey, publicKeyPEM)
if err != nil {
return nil, err
}

// Combine: [key_length(4)][encrypted_key][aes_ciphertext]
result := make([]byte, 4+len(encryptedKey)+len(aesCiphertext))
binary.BigEndian.PutUint32(result[0:4], uint32(len(encryptedKey)))
copy(result[4:], encryptedKey)
copy(result[4+len(encryptedKey):], aesCiphertext)

return result, nil
}

// decryptHybridJSON decrypts hybrid-encrypted JSON data
func decryptHybridJSON(encrypted []byte, privateKeyPEM string) (interface{}, error) {
if len(encrypted) < 4 {
return nil, fmt.Errorf("invalid encrypted data")
}

// Extract encrypted key
keyLen := binary.BigEndian.Uint32(encrypted[0:4])
if len(encrypted) < int(4+keyLen) {
return nil, fmt.Errorf("invalid encrypted data length")
}

encryptedKey := encrypted[4 : 4+keyLen]
aesCiphertext := encrypted[4+keyLen:]

// Decrypt AES key with RSA
aesKey, err := PublicKeyDecryptPlain(encryptedKey, privateKeyPEM)
if err != nil {
return nil, err
}

// Decrypt data with AES-GCM
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, err
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonceSize := gcm.NonceSize()
if len(aesCiphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}

nonce := aesCiphertext[:nonceSize]
ciphertext := aesCiphertext[nonceSize:]

plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}

// Unmarshal JSON
var data interface{}
if err := json.Unmarshal(plaintext, &data); err != nil {
return nil, err
}

return data, nil
}

// EncryptOrderRecipient encrypts the recipient details using the aggregator's public key
func EncryptOrderRecipient(order *ent.PaymentOrder) (string, error) {
// Generate a cryptographically secure random nonce
Expand All @@ -275,24 +378,55 @@ func EncryptOrderRecipient(order *ent.PaymentOrder) (string, error) {
base64.StdEncoding.EncodeToString(nonce), order.AccountIdentifier, order.AccountName, order.Institution, providerID, order.Memo, order.Metadata,
}

// Encrypt with the public key of the aggregator
messageCipher, err := PublicKeyEncryptJSON(message, cryptoConf.AggregatorPublicKey)
// Encrypt with the public key of the aggregator and enforce max size
messageCipher, err := encryptHybridJSON(message, cryptoConf.AggregatorPublicKey, cryptoConf.MessageHashMaxSize)
if err != nil {
return "", fmt.Errorf("failed to encrypt message: %w", err)
}

return base64.StdEncoding.EncodeToString(messageCipher), nil
}

// isHybridEncrypted checks if data is in hybrid format
func isHybridEncrypted(data []byte) bool {
if len(data) < 4 {
return false
}

keyLen := binary.BigEndian.Uint32(data[0:4])

if keyLen != 256 {
return false
}

if len(data) < int(4+keyLen+28) {
return false
}

return true
}

// decryptOrderRecipientWithFallback attempts to decrypt using hybrid format first, then falls back to legacy RSA format
func decryptOrderRecipientWithFallback(encrypted []byte, privateKeyPEM string) (interface{}, error) {
// Detect format
if isHybridEncrypted(encrypted) {
return decryptHybridJSON(encrypted, privateKeyPEM)
}

// Fallback to old RSA decryption
return PublicKeyDecryptJSON(encrypted, privateKeyPEM)
}

// GetOrderRecipientFromMessageHash decrypts the message hash and returns the order recipient
// Supports both hybrid encryption (new format) and legacy pure RSA encryption (old format) for backward compatibility
func GetOrderRecipientFromMessageHash(messageHash string) (*types.PaymentOrderRecipient, error) {
messageCipher, err := base64.StdEncoding.DecodeString(messageHash)
if err != nil {
return nil, fmt.Errorf("failed to decode message hash: %w", err)
}

// Decrypt with the private key of the aggregator
message, err := PublicKeyDecryptJSON(messageCipher, config.CryptoConfig().AggregatorPrivateKey)
// Decrypt with fallback support for both formats
message, err := decryptOrderRecipientWithFallback(messageCipher, config.CryptoConfig().AggregatorPrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt message hash: %w", err)
}
Expand All @@ -313,18 +447,17 @@ func GetOrderRecipientFromMessageHash(messageHash string) (*types.PaymentOrderRe
func NormalizeStarknetAddress(address string) string {
// Remove 0x prefix if present
addr := strings.TrimPrefix(address, "0x")

// Starknet addresses should be 64 hex characters (excluding 0x)
// Pad with leading zeros if shorter
if len(addr) < 64 {
addr = strings.Repeat("0", 64-len(addr)) + addr
}

// Add 0x prefix back
return "0x" + addr
}


// generateSecureSeed generates a cryptographically secure random seed
func GenerateSecureSeed() (string, error) {
b := make([]byte, 32)
Expand Down
Loading
Loading