diff --git a/pulsar/__init__.py b/pulsar/__init__.py index 543cd0d..2375d16 100644 --- a/pulsar/__init__.py +++ b/pulsar/__init__.py @@ -166,6 +166,122 @@ def wrap(cls, msg_id: _pulsar.MessageId): self._msg_id = msg_id return self + +class EncryptionKey: + """ + The key used for encryption. + """ + + def __init__(self, key: _pulsar.EncryptionKey): + """ + Create EncryptionKey instance. + + Parameters + ---------- + key: _pulsar.EncryptionKey + The underlying EncryptionKey instance from the C extension. + """ + self._key = key + + @property + def key(self) -> str: + """ + Returns the key, which is usually the key file's name. + """ + return self._key.key + + @property + def value(self) -> bytes: + """ + Returns the value, which is usually the key bytes used for encryption. + """ + return self._key.value() + + @property + def metadata(self) -> dict: + """ + Returns the metadata associated with the key. + """ + return self._key.metadata + + def __str__(self) -> str: + return f"EncryptionKey(key={self.key}, value_len={len(self.value)}, metadata={self.metadata})" + + def __repr__(self) -> str: + return self.__str__() + + +class EncryptionContext: + """ + It contains encryption and compression information in it using which application can decrypt + consumed message with encrypted-payload. + """ + + def __init__(self, context: _pulsar.EncryptionContext): + """ + Create EncryptionContext instance. + + Parameters + ---------- + context: _pulsar.EncryptionContext + The underlying EncryptionContext instance from the C extension. + """ + self._context = context + + def keys(self) -> List[EncryptionKey]: + """ + Returns all EncryptionKey instances when performing encryption. + """ + keys = self._context.keys() + return [EncryptionKey(key) for key in keys] + + def param(self) -> bytes: + """ + Returns the encryption param bytes. + """ + return self._context.param() + + def algorithm(self) -> str: + """ + Returns the encryption algorithm. + """ + return self._context.algorithm() + + def compression_type(self) -> CompressionType: + """ + Returns the compression type of the message. + """ + return self._context.compression_type() + + def uncompressed_message_size(self) -> int: + """ + Returns the uncompressed message size or 0 if the compression type is NONE. + """ + return self._context.uncompressed_message_size() + + def batch_size(self) -> int: + """ + Returns the number of messages in the batch or -1 if the message is not batched. + """ + return self._context.batch_size() + + def is_decryption_failed(self) -> bool: + """ + Returns whether decryption has failed for this message. + """ + return self._context.is_decryption_failed() + + def __str__(self) -> str: + return f"EncryptionContext(algorithm={self.algorithm()}, " \ + f"compression_type={self.compression_type().name}, " \ + f"uncompressed_message_size={self.uncompressed_message_size()}, " \ + f"is_decryption_failed={self.is_decryption_failed()}, " \ + f"keys=[{', '.join(str(key) for key in self.keys())}])" + + def __repr__(self) -> str: + return self.__str__() + + class Message: """ Message objects are returned by a consumer, either by calling `receive` or @@ -250,6 +366,15 @@ def producer_name(self) -> str: """ return self._message.producer_name() + def encryption_context(self) -> EncryptionContext | None: + """ + Get the encryption context for this message or None if it's not encrypted. + + It should be noted that the result should not be accessed after the current Message instance is deleted. + """ + context = self._message.encryption_context() + return None if context is None else EncryptionContext(context) + @staticmethod def _wrap(_message): self = Message() diff --git a/src/message.cc b/src/message.cc index e18861a..f3247e6 100644 --- a/src/message.cc +++ b/src/message.cc @@ -86,6 +86,20 @@ void export_message(py::module_& m) { }) .def_static("deserialize", &MessageId::deserialize); + class_(m, "EncryptionKey") + .def_readonly("key", &EncryptionKey::key) + .def("value", [](const EncryptionKey& key) { return bytes(key.value); }) + .def_readonly("metadata", &EncryptionKey::metadata); + + class_(m, "EncryptionContext") + .def("keys", &EncryptionContext::keys) + .def("param", [](const EncryptionContext& context) { return bytes(context.param()); }) + .def("algorithm", &EncryptionContext::algorithm, return_value_policy::copy) + .def("compression_type", &EncryptionContext::compressionType) + .def("uncompressed_message_size", &EncryptionContext::uncompressedMessageSize) + .def("batch_size", &EncryptionContext::batchSize) + .def("is_decryption_failed", &EncryptionContext::isDecryptionFailed); + class_(m, "Message") .def(init<>()) .def("properties", &Message::getProperties) @@ -106,7 +120,8 @@ void export_message(py::module_& m) { .def("redelivery_count", &Message::getRedeliveryCount) .def("int_schema_version", &Message::getLongSchemaVersion) .def("schema_version", &Message::getSchemaVersion, return_value_policy::copy) - .def("producer_name", &Message::getProducerName, return_value_policy::copy); + .def("producer_name", &Message::getProducerName, return_value_policy::copy) + .def("encryption_context", &Message::getEncryptionContext, return_value_policy::reference); MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const std::string& payload, uint32_t batchSize) = &MessageBatch::parseFrom; diff --git a/tests/pulsar_test.py b/tests/pulsar_test.py index 3603d84..b7f38ed 100755 --- a/tests/pulsar_test.py +++ b/tests/pulsar_test.py @@ -167,6 +167,7 @@ def test_producer_send(self): consumer.acknowledge(msg) print("receive from {}".format(msg.message_id())) self.assertEqual(msg_id, msg.message_id()) + self.assertIsNone(msg.encryption_context()) client.close() def test_producer_access_mode_exclusive(self): @@ -489,15 +490,36 @@ def test_encryption_failure(self): client = Client(self.serviceUrl) topic = "my-python-test-end-to-end-encryption-failure-" + str(time.time()) producer = client.create_producer( - topic=topic, encryption_key="client-rsa.pem", crypto_key_reader=crypto_key_reader + topic=topic, encryption_key="client-rsa.pem", crypto_key_reader=crypto_key_reader, + compression_type=CompressionType.LZ4 ) producer.send(b"msg-0") + def verify_encryption_context(context: pulsar.EncryptionContext | None, failed: bool, batch_size: int): + if context is None: + self.fail("Encryption context is None") + keys = context.keys() + self.assertEqual(len(keys), 1) + key = keys[0] + self.assertEqual(key.key, "client-rsa.pem") + self.assertGreater(len(key.value), 0) + self.assertEqual(key.metadata, {}) + self.assertGreater(len(context.param()), 0) + self.assertEqual(context.algorithm(), "") + self.assertEqual(context.compression_type(), CompressionType.LZ4) + if batch_size == -1: + self.assertEqual(context.uncompressed_message_size(), len(b"msg-0")) + else: + self.assertGreater(context.uncompressed_message_size(), len(b"msg-0")) + self.assertEqual(context.batch_size(), batch_size) + self.assertEqual(context.is_decryption_failed(), failed) + def verify_next_message(value: bytes): consumer = client.subscribe(topic, subscription, crypto_key_reader=crypto_key_reader) msg = consumer.receive(3000) self.assertEqual(msg.data(), value) + verify_encryption_context(msg.encryption_context(), False, -1) consumer.acknowledge(msg) consumer.close() @@ -520,22 +542,40 @@ def verify_next_message(value: bytes): producer.send(b"msg-2") verify_next_message(b"msg-2") # msg-1 is skipped since the crypto failure action is DISCARD + producer.close() + + # send batched messages + producer = client.create_producer( + topic=topic, + encryption_key="client-rsa.pem", + crypto_key_reader=crypto_key_reader, + compression_type=CompressionType.LZ4, + batching_enabled=True, + ) + producer.send_async(b"msg-3", None) + producer.send_async(b"msg-4", None) + producer.flush() + + def verify_undecrypted_message(msg: pulsar.Message, i: int): + self.assertNotEqual(msg.data(), f"msg-{i}".encode()) + self.assertGreater(len(msg.data()), 5, f"msg.data() is {msg.data()}") + verify_encryption_context(msg.encryption_context(), True, 2 if i >= 3 else -1) # Encrypted messages will be consumed since the crypto failure action is CONSUME + # Only 4 messages can be received because msg-3 and msg-4 are sent in batch and they are delivered + # as a single message when decryption fails. consumer = client.subscribe(topic, 'another-sub', initial_position=InitialPosition.Earliest, crypto_failure_action=pulsar.ConsumerCryptoFailureAction.CONSUME) - for i in range(3): + for i in range(4): msg = consumer.receive(3000) - self.assertNotEqual(msg.data(), f"msg-{i}".encode()) - self.assertTrue(len(msg.data()) > 5, f"msg.data() is {msg.data()}") + verify_undecrypted_message(msg, i) reader = client.create_reader(topic, MessageId.earliest, crypto_failure_action=pulsar.ConsumerCryptoFailureAction.CONSUME) - for i in range(3): + for i in range(4): msg = reader.read_next(3000) - self.assertNotEqual(msg.data(), f"msg-{i}".encode()) - self.assertTrue(len(msg.data()) > 5, f"msg.data() is {msg.data()}") + verify_undecrypted_message(msg, i) client.close()