@@ -141,61 +141,43 @@ const EVP_CIPHER *Crypto::cipher(const std::string &algo)
141141std::vector<uint8_t > Crypto::concatKDF (const std::string &hashAlg, uint32_t keyDataLen,
142142 const std::vector<uint8_t > &z, const std::vector<uint8_t > &otherInfo)
143143{
144- std::vector<uint8_t > key;
145- uint32_t hashLen = SHA384_DIGEST_LENGTH;
146- if (hashAlg == SHA256_MTH) hashLen = SHA256_DIGEST_LENGTH;
147- else if (hashAlg == SHA384_MTH) hashLen = SHA384_DIGEST_LENGTH;
148- else if (hashAlg == SHA512_MTH) hashLen = SHA512_DIGEST_LENGTH;
149- else return key;
150-
151- SHA256_CTX sha256;
152- SHA512_CTX sha512;
153- std::vector<uint8_t > hash (hashLen, 0 );
154- uint8_t intToFourBytes[4 ];
144+ std::vector<uint8_t > key;
145+ const EVP_MD *md {};
146+ if (hashAlg == SHA256_MTH) md = EVP_sha256 ();
147+ else if (hashAlg == SHA384_MTH) md = EVP_sha384 ();
148+ else if (hashAlg == SHA512_MTH) md = EVP_sha512 ();
149+ else {
150+ LOG_WARN (" Usnupported hash algo {}" , hashAlg);
151+ return key;
152+ }
155153
154+ uint32_t hashLen = EVP_MD_get_size (md);
156155 uint32_t reps = keyDataLen / hashLen;
157156 if (keyDataLen % hashLen > 0 )
158157 reps++;
159158
160- for (uint32_t i = 1 ; i <= reps; i++)
161- {
162- intToFourBytes[0 ] = uint8_t (i >> 24 );
163- intToFourBytes[1 ] = uint8_t (i >> 16 );
164- intToFourBytes[2 ] = uint8_t (i >> 8 );
165- intToFourBytes[3 ] = uint8_t (i >> 0 );
166- switch (hashLen)
167- {
168- case SHA256_DIGEST_LENGTH:
169- if (SSL_FAILED (SHA256_Init (&sha256), " SHA256_Init" ) ||
170- SSL_FAILED (SHA256_Update (&sha256, intToFourBytes, 4 ), " SHA256_Update" ) ||
171- SSL_FAILED (SHA256_Update (&sha256, z.data (), z.size ()), " SHA256_Update" ) ||
172- SSL_FAILED (SHA256_Update (&sha256, otherInfo.data (), otherInfo.size ()), " SHA256_Update" ) ||
173- SSL_FAILED (SHA256_Final (hash.data (), &sha256), " SHA256_Final" ))
174- return {};
175- break ;
176- case SHA384_DIGEST_LENGTH:
177- if (SSL_FAILED (SHA384_Init (&sha512), " SHA384_Init" ) ||
178- SSL_FAILED (SHA384_Update (&sha512, intToFourBytes, 4 ), " SHA384_Update" ) ||
179- SSL_FAILED (SHA384_Update (&sha512, z.data (), z.size ()), " SHA384_Update" ) ||
180- SSL_FAILED (SHA384_Update (&sha512, otherInfo.data (), otherInfo.size ()), " SHA384_Update" ) ||
181- SSL_FAILED (SHA384_Final (hash.data (), &sha512), " SHA384_Final" ))
182- return {};
183- break ;
184- case SHA512_DIGEST_LENGTH:
185- if (SSL_FAILED (SHA512_Init (&sha512), " SHA512_Init" ) ||
186- SSL_FAILED (SHA512_Update (&sha512, intToFourBytes, 4 ), " SHA512_Update" ) ||
187- SSL_FAILED (SHA512_Update (&sha512, otherInfo.data (), otherInfo.size ()), " SHA512_Update" ) ||
188- SSL_FAILED (SHA512_Final (hash.data (), &sha512), " SHA512_Update" ))
189- return {};
190- break ;
191- default :
192- LOG_WARN (" Usnupported hash length {}" , hashLen);
193- return key;
194- }
195- key.insert (key.cend (), hash.cbegin (), hash.cend ());
196- }
197- key.resize (size_t (keyDataLen));
198- return key;
159+ auto ctx = make_unique_ptr<EVP_MD_CTX_free>(EVP_MD_CTX_new ());
160+ if (!ctx)
161+ {
162+ LOG_SSL_ERROR (" EVP_MD_CTX_new" );
163+ return key;
164+ }
165+
166+ std::vector<uint8_t > hash (hashLen, 0 );
167+ for (uint32_t i = 1 ; i <= reps; i++)
168+ {
169+ uint8_t intToFourBytes[4 ] { uint8_t (i >> 24 ), uint8_t (i >> 16 ), uint8_t (i >> 8 ), uint8_t (i >> 0 ) };
170+ unsigned int size = hashLen;
171+ if (SSL_FAILED (EVP_DigestInit (ctx.get (), md), " EVP_DigestInit" ) ||
172+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), intToFourBytes, 4 ), " EVP_DigestUpdate" ) ||
173+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), z.data (), z.size ()), " EVP_DigestUpdate" ) ||
174+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), otherInfo.data (), otherInfo.size ()), " EVP_DigestUpdate" ) ||
175+ SSL_FAILED (EVP_DigestFinal (ctx.get (), hash.data (), &size), " EVP_DigestFinal" ))
176+ return {};
177+ key.insert (key.cend (), hash.cbegin (), hash.cend ());
178+ }
179+ key.resize (size_t (keyDataLen));
180+ return key;
199181}
200182
201183std::vector<uint8_t > Crypto::concatKDF (const std::string &hashAlg, uint32_t keyDataLen, const std::vector<uint8_t > &z,
@@ -608,14 +590,109 @@ EncryptionConsumer::close()
608590 if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_GCM_GET_TAG, int (tag.size ()), tag.data ()), " EVP_CIPHER_CTX_ctrl" ))
609591 return CRYPTO_ERROR;
610592 LOG_DBG (" tag: {}" , toHex (tag));
611- return dst.write (tag.data (), tag.size ());
593+ if (dst.write (tag.data (), tag.size ()) != tag.size ())
594+ return IO_ERROR;
612595 }
613- if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
596+ else if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
614597 {
615598 if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_AEAD_GET_TAG, int (tag.size ()), tag.data ()), " EVP_CIPHER_CTX_ctrl" ))
616599 return CRYPTO_ERROR;
617600 LOG_DBG (" tag: {}" , toHex (tag));
618- return dst.write (tag.data (), tag.size ());
601+ if (dst.write (tag.data (), tag.size ()) != tag.size ())
602+ return IO_ERROR;
603+ }
604+ return OK;
605+ }
606+
607+ DecryptionSource::DecryptionSource (DataSource &src, const std::string &method, const std::vector<unsigned char > &key)
608+ : DecryptionSource(src, Crypto::cipher(method), key)
609+ {}
610+
611+ DecryptionSource::DecryptionSource (DataSource &src, const EVP_CIPHER *cipher, const std::vector<unsigned char > &key)
612+ : ctx{EVP_CIPHER_CTX_new (), EVP_CIPHER_CTX_free}
613+ , src(src)
614+ {
615+ EVP_CIPHER_CTX_set_flags (ctx.get (), EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
616+ int ivLen = EVP_CIPHER_iv_length (cipher);
617+ std::vector<unsigned char > iv (ivLen, 0 );
618+ if (auto rv = src.read (iv.data (), ivLen); size_t (rv) != iv.size ())
619+ error = rv < 0 ? rv : IO_ERROR;
620+ else if (SSL_FAILED (EVP_CipherInit_ex (ctx.get (), cipher, nullptr , key.data (), iv.data (), 0 ), " EVP_CipherInit_ex" ))
621+ error = CRYPTO_ERROR;
622+ else if (auto rv = src.read (tag.data (), tag.size ()); size_t (rv) != tag.size ())
623+ error = rv < 0 ? rv : IO_ERROR;
624+ }
625+
626+ result_t DecryptionSource::readAAD (const std::vector<uint8_t > &data)
627+ {
628+ if (error != OK)
629+ return error;
630+ int len = 0 ;
631+ if (SSL_FAILED (EVP_CipherUpdate (ctx.get (), nullptr , &len, data.data (), int (data.size ())), " EVP_CipherUpdate" ))
632+ return CRYPTO_ERROR;
633+ return OK;
634+ }
635+
636+ result_t DecryptionSource::read (unsigned char *dst, size_t size)
637+ {
638+ if (error != OK)
639+ return error;
640+ if (!dst || size == 0 )
641+ return OK;
642+ if (size < tag.size ())
643+ return INPUT_STREAM_ERROR;
644+
645+ auto r = src.read (dst + tag.size (), size - tag.size ());
646+ if (r <= 0 ) {
647+ return r;
648+ }
649+ auto nread = static_cast <size_t >(r);
650+
651+ std::copy (tag.begin (), tag.end (), dst);
652+
653+ if (nread < size - tag.size ()) {
654+ std::copy_n (std::next (dst, nread), tag.size (), tag.begin ());
655+ size = nread;
656+ } else if (auto r = src.read (tag.data (), tag.size ()); r < 0 ) {
657+ return r;
658+ } else if (auto tagSize = static_cast <size_t >(r); tagSize < tag.size ()) {
659+ std::move_backward (tag.begin (), std::next (tag.begin (), tagSize), tag.end ());
660+ size_t more = tag.size () - tagSize;
661+ std::copy_n (std::next (dst, size - more), more, tag.data ());
662+ size -= more;
663+ }
664+
665+ if (int out = 0 ;
666+ SSL_FAILED (EVP_CipherUpdate (ctx.get (), dst, &out, dst, size), " EVP_CipherUpdate" ) ||
667+ size != out) {
668+ return error = CRYPTO_ERROR;
669+ }
670+ return size;
671+ }
672+
673+ result_t DecryptionSource::close ()
674+ {
675+ if (error != OK)
676+ return error;
677+
678+ if (EVP_CIPHER_CTX_mode (ctx.get ()) == EVP_CIPH_GCM_MODE) {
679+ LOG_DBG (" tag: {}" , toHex (tag));
680+ if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_GCM_SET_TAG, int (tag.size ()), (void *)tag.data ()), " EVP_CIPHER_CTX_ctrl" )) {
681+ return error = CRYPTO_ERROR;
682+ }
683+ }
684+ else if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
685+ {
686+ LOG_DBG (" tag: {}" , toHex (tag));
687+ if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_AEAD_SET_TAG, int (tag.size ()), (void *)tag.data ()), " EVP_CIPHER_CTX_ctrl" )) {
688+ return error = CRYPTO_ERROR;
689+ }
690+ }
691+
692+ int len = 0 ;
693+ std::vector<uint8_t > buffer (EVP_CIPHER_CTX_block_size (ctx.get ()), 0 );
694+ if (SSL_FAILED (EVP_CipherFinal_ex (ctx.get (), buffer.data (), &len), " EVP_CipherFinal_ex" )) {
695+ return error = CRYPTO_ERROR;
619696 }
620697 return OK;
621698}
0 commit comments