diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h new file mode 100644 index 00000000..ac7ebf91 --- /dev/null +++ b/include/pulsar/EncryptionContext.h @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "CompressionType.h" +#include "defines.h" + +namespace pulsar { + +namespace proto { +class MessageMetadata; +} + +struct PULSAR_PUBLIC EncryptionKey { + std::string key; + std::string value; + std::unordered_map metadata; + + EncryptionKey(const std::string& key, const std::string& value, + const decltype(EncryptionKey::metadata)& metadata) + : key(key), value(value), metadata(metadata) {} +}; + +/** + * It contains encryption and compression information in it using which application can decrypt consumed + * message with encrypted-payload. + */ +class PULSAR_PUBLIC EncryptionContext { + public: + using KeysType = std::vector; + + /** + * @return the map of encryption keys used for the message + */ + const KeysType& keys() const noexcept { return keys_; } + + /** + * @return the encryption parameter used for the message + */ + const std::string& param() const noexcept { return param_; } + + /** + * @return the encryption algorithm used for the message + */ + const std::string& algorithm() const noexcept { return algorithm_; } + + /** + * @return the compression type used for the message + */ + CompressionType compressionType() const noexcept { return compressionType_; } + + /** + * @return the uncompressed message size if the message is compressed, 0 otherwise + */ + uint32_t uncompressedMessageSize() const noexcept { return uncompressedMessageSize_; } + + /** + * @return the batch size if the message is part of a batch, -1 otherwise + */ + int32_t batchSize() const noexcept { return batchSize_; } + + /** + * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be + * returned even if the decryption failed. This method is provided to let users know whether the + * decryption failed. + * + * @return whether the decryption failed + */ + bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } + + /** + * This constructor is public to allow in-place construction via std::optional + * (e.g., `std::optional(std::in_place, metadata, false)`), + * but should not be used directly in application code. + */ + EncryptionContext(const proto::MessageMetadata&, bool); + + private: + KeysType keys_; + std::string param_; + std::string algorithm_; + CompressionType compressionType_{CompressionNone}; + uint32_t uncompressedMessageSize_{0}; + int32_t batchSize_{-1}; + bool isDecryptionFailed_{false}; + + void setDecryptionFailed(bool failed) noexcept { isDecryptionFailed_ = failed; } + + friend class ConsumerImpl; +}; + +} // namespace pulsar diff --git a/include/pulsar/Message.h b/include/pulsar/Message.h index ea4c4ab4..f52879e8 100644 --- a/include/pulsar/Message.h +++ b/include/pulsar/Message.h @@ -19,10 +19,12 @@ #ifndef MESSAGE_HPP_ #define MESSAGE_HPP_ +#include #include #include #include +#include #include #include "KeyValue.h" @@ -202,6 +204,12 @@ class PULSAR_PUBLIC Message { */ const std::string& getProducerName() const noexcept; + /** + * @return the optional encryption context that is present when the message is encrypted, the pointer is + * valid as the Message instance is alive + */ + std::optional getEncryptionContext() const; + bool operator==(const Message& msg) const; protected: diff --git a/lib/Commands.cc b/lib/Commands.cc index 3c687c0a..30f5bf1a 100644 --- a/lib/Commands.cc +++ b/lib/Commands.cc @@ -930,6 +930,7 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32 batchedMessage.impl_->metadata, payload, metadata, batchedMessage.impl_->topicName_); singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_; + singleMessage.impl_->encryptionContext_ = batchedMessage.impl_->encryptionContext_; return singleMessage; } diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 4781e966..430b8512 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -19,6 +19,7 @@ #include "ConsumerImpl.h" #include +#include #include #include @@ -549,24 +550,27 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: proto::MessageMetadata& metadata, SharedBuffer& payload) { LOG_DEBUG(getName() << "Received Message -- Size: " << payload.readableBytes()); - if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) { - // Message was discarded or not consumed due to decryption failure - return; - } - if (!isChecksumValid) { // Message discarded for checksum error discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch); return; } - auto redeliveryCount = msg.redelivery_count(); - const bool isMessageUndecryptable = - metadata.encryption_keys_size() > 0 && !config_.getCryptoKeyReader().get() && - config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME; + auto encryptionContext = metadata.encryption_keys_size() > 0 + ? optional(std::in_place, metadata, false) + : std::nullopt; + const auto decryptionResult = decryptMessageIfNeeded(cnx, msg, encryptionContext, payload); + if (decryptionResult == DecryptionResult::FAILED) { + // Message was discarded or not consumed due to decryption failure + return; + } else if (decryptionResult == DecryptionResult::CONSUME_ENCRYPTED && encryptionContext.has_value()) { + // Message is encrypted, but we let the application consume it as-is + encryptionContext->setDecryptionFailed(true); + } + auto redeliveryCount = msg.redelivery_count(); const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1; - if (!isMessageUndecryptable && !isChunkedMessage) { + if (decryptionResult == DecryptionResult::SUCCESS && !isChunkedMessage) { if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) { // Message was discarded on decompression error return; @@ -590,6 +594,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: m.impl_->cnx_ = cnx.get(); m.impl_->setTopicName(getTopicPtr()); m.impl_->setRedeliveryCount(msg.redelivery_count()); + m.impl_->encryptionContext_ = std::move(encryptionContext); if (metadata.has_schema_version()) { m.impl_->setSchemaVersion(metadata.schema_version()); @@ -610,7 +615,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: return; } - if (metadata.has_num_messages_in_batch()) { + if (metadata.has_num_messages_in_batch() && decryptionResult == DecryptionResult::SUCCESS) { BitSet::Data words(msg.ack_set_size()); for (int i = 0; i < words.size(); i++) { words[i] = msg.ack_set(i); @@ -812,17 +817,18 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection return batchSize - skippedMessages; } -bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload) { - if (!metadata.encryption_keys_size()) { - return true; +auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const optional& context, SharedBuffer& payload) + -> DecryptionResult { + if (!context.has_value()) { + return DecryptionResult::SUCCESS; } // If KeyReader is not configured throw exception based on config param if (!config_.isEncryptionEnabled()) { if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message."); - return true; + return DecryptionResult::CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config " "is set to discard"); @@ -833,20 +839,20 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return DecryptionResult::FAILED; } SharedBuffer decryptedPayload; - if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) { + if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) { payload = decryptedPayload; - return true; + return DecryptionResult::SUCCESS; } if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { // Note, batch message will fail to consume even if config is set to consume LOG_WARN( getName() << "Decryption failed. Consuming encrypted message since config is set to consume."); - return true; + return DecryptionResult::CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard"); discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); @@ -855,7 +861,7 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return DecryptionResult::FAILED; } bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx, diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index c1df0804..63eb51d6 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -195,8 +195,15 @@ class ConsumerImpl : public ConsumerImplBase { bool isPriorEntryIndex(int64_t idx); void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&); - bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload); + enum class DecryptionResult : uint8_t + { + SUCCESS, + CONSUME_ENCRYPTED, + FAILED + }; + DecryptionResult decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const optional& context, + SharedBuffer& payload); // TODO - Convert these functions to lambda when we move to C++11 Result receiveHelper(Message& msg); diff --git a/lib/EncryptionContext.cc b/lib/EncryptionContext.cc new file mode 100644 index 00000000..5376f062 --- /dev/null +++ b/lib/EncryptionContext.cc @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "PulsarApi.pb.h" + +namespace pulsar { + +static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) { + EncryptionContext::KeysType keys; + for (auto&& key : msgMetadata.encryption_keys()) { + decltype(EncryptionKey::metadata) metadata; + for (int i = 0; i < key.metadata_size(); i++) { + const auto& entry = key.metadata(i); + metadata[entry.key()] = entry.value(); + } + keys.emplace_back(key.key(), key.value(), std::move(metadata)); + } + return keys; +} + +EncryptionContext::EncryptionContext(const proto::MessageMetadata& msgMetadata, bool isDecryptionFailed) + + : keys_(encryptedKeysFromMetadata(msgMetadata)), + param_(msgMetadata.encryption_param()), + algorithm_(msgMetadata.encryption_algo()), + compressionType_(static_cast(msgMetadata.compression())), + uncompressedMessageSize_(msgMetadata.uncompressed_size()), + batchSize_(msgMetadata.has_num_messages_in_batch() ? msgMetadata.num_messages_in_batch() : -1), + isDecryptionFailed_(isDecryptionFailed) {} + +} // namespace pulsar diff --git a/lib/Message.cc b/lib/Message.cc index 1e26b521..9505565b 100644 --- a/lib/Message.cc +++ b/lib/Message.cc @@ -220,6 +220,13 @@ const std::string& Message::getProducerName() const noexcept { return impl_->metadata.producer_name(); } +std::optional Message::getEncryptionContext() const { + if (!impl_ || !impl_->encryptionContext_.has_value()) { + return std::nullopt; + } + return &impl_->encryptionContext_.value(); +} + bool Message::operator==(const Message& msg) const { return getMessageId() == msg.getMessageId(); } KeyValue Message::getKeyValueData() const { return KeyValue(impl_->keyValuePtr); } diff --git a/lib/MessageCrypto.cc b/lib/MessageCrypto.cc index b06ff652..daa492ea 100644 --- a/lib/MessageCrypto.cc +++ b/lib/MessageCrypto.cc @@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set& encKeys, const CryptoKe return true; } -bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader) { - const auto& keyName = encKeys.key(); - const auto& encryptedDataKey = encKeys.value(); - const auto& encKeyMeta = encKeys.metadata(); +bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader) { + const auto& keyName = encKeys.key; + const auto& encryptedDataKey = encKeys.value; + const auto& encKeyMeta = encKeys.metadata; StringMap keyMeta; for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) { - keyMeta[iter->key()] = iter->value(); + keyMeta[iter->first] = iter->second; } // Read the private key info using callback @@ -451,11 +451,10 @@ bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const C return true; } -bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, +bool MessageCrypto::decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { // unpack iv and encrypted data - msgMetadata.encryption_param().copy(reinterpret_cast(iv_.get()), - msgMetadata.encryption_param().size()); + context.param().copy(reinterpret_cast(iv_.get()), context.param().size()); EVP_CIPHER_CTX* cipherCtx = NULL; decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + EVP_MAX_BLOCK_LENGTH + tagLen_); @@ -518,15 +517,14 @@ bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::M return true; } -bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { SharedBuffer decryptedData; bool dataDecrypted = false; - for (auto iter = msgMetadata.encryption_keys().begin(); iter != msgMetadata.encryption_keys().end(); - iter++) { - const std::string& keyName = iter->key(); - const std::string& encDataKey = iter->value(); + for (auto&& kv : context.keys()) { + const std::string& keyName = kv.key; + const std::string& encDataKey = kv.value; unsigned char keyDigest[EVP_MAX_MD_SIZE]; unsigned int digestLen = 0; getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, digestLen); @@ -539,7 +537,7 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada // retruns a different key, decryption fails. At this point, we would // call decryptDataKey to refresh the cache and come here again to decrypt. auto dataKeyEntry = dataKeyCacheIter->second; - if (decryptData(dataKeyEntry.first, msgMetadata, payload, decryptedPayload)) { + if (decryptData(dataKeyEntry.first, context, payload, decryptedPayload)) { dataDecrypted = true; break; } @@ -552,17 +550,16 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada return dataDecrypted; } -bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload) { // Attempt to decrypt using the existing key - if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) { + if (getKeyAndDecryptData(context, payload, decryptedPayload)) { return true; } // Either first time, or decryption failed. Attempt to regenerate data key bool isDataKeyDecrypted = false; - for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) { - const proto::EncryptionKeys& encKeys = msgMetadata.encryption_keys(index); + for (auto&& encKeys : context.keys()) { if (decryptDataKey(encKeys, *keyReader)) { isDataKeyDecrypted = true; break; @@ -574,7 +571,7 @@ bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuf return false; } - return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload); + return getKeyAndDecryptData(context, payload, decryptedPayload); } } /* namespace pulsar */ diff --git a/lib/MessageCrypto.h b/lib/MessageCrypto.h index cd07bf55..4052066d 100644 --- a/lib/MessageCrypto.h +++ b/lib/MessageCrypto.h @@ -26,10 +26,10 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -90,15 +90,15 @@ class MessageCrypto { /* * Decrypt the payload using the data key. Keys used to encrypt data key can be retrieved from msgMetadata * - * @param msgMetadata Message Metadata + * @param context the encryption context * @param payload Message which needs to be decrypted * @param keyReader KeyReader implementation to retrieve key value * @param decryptedPayload Contains decrypted payload if success * * @return true if success */ - bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, - const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload); + bool decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, + SharedBuffer& decryptedPayload); private: typedef std::unique_lock Lock; @@ -137,10 +137,10 @@ class MessageCrypto { Result addPublicKeyCipher(const std::string& keyName, const CryptoKeyReaderPtr& keyReader); - bool decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader); - bool decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, + bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader); + bool decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decPayload); - bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, + bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload); std::string stringToHex(const std::string& inputStr, size_t len); std::string stringToHex(const char* inputStr, size_t len); diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h index 6467b359..a234ca45 100644 --- a/lib/MessageImpl.h +++ b/lib/MessageImpl.h @@ -22,9 +22,12 @@ #include #include +#include + #include "KeyValueImpl.h" #include "PulsarApi.pb.h" #include "SharedBuffer.h" +#include "pulsar/EncryptionContext.h" using namespace pulsar; namespace pulsar { @@ -48,6 +51,7 @@ class MessageImpl { bool hasSchemaVersion_; const std::string* schemaVersion_; std::weak_ptr consumerPtr_; + std::optional encryptionContext_; const std::string& getPartitionKey() const; bool hasPartitionKey() const; diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index c9a8faa9..9a02df0c 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -1465,6 +1465,10 @@ TEST(BasicEndToEndTest, testRSAEncryption) { expected << msgContent << msgNum; ASSERT_EQ(expected.str(), msgReceived.getDataAsString()); ASSERT_EQ(ResultOk, consumer.acknowledge(msgReceived)); + auto context = msgReceived.getEncryptionContext(); + ASSERT_TRUE(context.has_value()); + ASSERT_EQ(context.value()->keys().size(), 1); + ASSERT_EQ(context.value()->keys()[0].key, "client-rsa.pem"); } ASSERT_EQ(ResultOk, consumer.unsubscribe()); diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc new file mode 100644 index 00000000..ff5cb98e --- /dev/null +++ b/tests/EncryptionTest.cc @@ -0,0 +1,145 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include +#include + +#include "lib/CompressionCodec.h" +#include "lib/MessageCrypto.h" +#include "lib/SharedBuffer.h" + +static std::string lookupUrl = "pulsar://localhost:6650"; + +using namespace pulsar; + +static CryptoKeyReaderPtr getDefaultCryptoKeyReader() { + return std::make_shared(TEST_CONF_DIR "/public-key.client-rsa.pem", + TEST_CONF_DIR "/private-key.client-rsa.pem"); +} + +static std::vector decryptValue(const char* data, size_t length, + std::optional context) { + if (!context.has_value()) { + return {std::string(data, length)}; + } + if (!context.value()->isDecryptionFailed()) { + return {std::string(data, length)}; + } + + MessageCrypto crypto{"test", false}; + SharedBuffer decryptedPayload; + auto originalPayload = SharedBuffer::copy(data, length); + if (!crypto.decrypt(*context.value(), originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) { + throw std::runtime_error("Decryption failed"); + } + + SharedBuffer uncompressedPayload; + if (!CompressionCodecProvider::getCodec(context.value()->compressionType()) + .decode(decryptedPayload, context.value()->uncompressedMessageSize(), uncompressedPayload)) { + throw std::runtime_error("Decompression failed"); + } + + std::vector values; + if (auto batchSize = context.value()->batchSize(); batchSize > 0) { + MessageBatch batch; + for (auto&& msg : batch.parseFrom(uncompressedPayload, batchSize).messages()) { + values.emplace_back(msg.getDataAsString()); + } + } else { + // non-batched message + values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes()); + } + return values; +} + +static void testDecryption(Client& client, const std::string& topic, bool withDecryption, + int numMessageReceived) { + ProducerConfiguration producerConf; + producerConf.setCompressionType(CompressionLZ4); + producerConf.addEncryptionKey("client-rsa.pem"); + producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer)); + + std::vector sentValues; + auto send = [&producer, &sentValues](const std::string& value) { + Message msg = MessageBuilder().setContent(value).build(); + producer.sendAsync(msg, nullptr); + sentValues.emplace_back(value); + }; + + for (int i = 0; i < 5; i++) { + send("msg-" + std::to_string(i)); + } + producer.flush(); + send("last-msg"); + producer.flush(); + + ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); + send("unencrypted-msg"); + producer.flush(); + producer.close(); + + ConsumerConfiguration consumerConf; + consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); + if (withDecryption) { + consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + } else { + consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); + } + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer)); + + std::vector values; + for (int i = 0; i < numMessageReceived; i++) { + Message msg; + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + if (i < numMessageReceived - 1) { + ASSERT_TRUE(msg.getEncryptionContext().has_value()); + } else { + ASSERT_FALSE(msg.getEncryptionContext().has_value()); + } + for (auto&& value : decryptValue(static_cast(msg.getData()), msg.getLength(), + msg.getEncryptionContext())) { + values.emplace_back(value); + } + } + ASSERT_EQ(values, sentValues); + consumer.close(); +} + +TEST(EncryptionTests, testDecryptionSuccess) { + Client client{lookupUrl}; + std::string topic = "test-decryption-success-" + std::to_string(time(nullptr)); + testDecryption(client, topic, true, 7); + client.close(); +} + +TEST(EncryptionTests, testDecryptionFailure) { + Client client{lookupUrl}; + std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr)); + // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once + testDecryption(client, topic, false, 3); + client.close(); +} diff --git a/win-examples/CMakeLists.txt b/win-examples/CMakeLists.txt index 3998c43a..c8d74b60 100644 --- a/win-examples/CMakeLists.txt +++ b/win-examples/CMakeLists.txt @@ -20,6 +20,7 @@ cmake_minimum_required(VERSION 3.4) project(pulsar-cpp-win-examples) +set(CMAKE_CXX_STANDARD 17) find_path(PULSAR_INCLUDES NAMES "pulsar/Client.h") if (PULSAR_INCLUDES) message(STATUS "PULSAR_INCLUDES: " ${PULSAR_INCLUDES})