|
1 | 1 | # SPDX-License-Identifier: BSD-3-Clause
|
2 | 2 | # Copyright (c) 2024 ScicatProject contributors (https://github.com/ScicatProject)
|
3 | 3 | import logging
|
| 4 | +from collections.abc import Generator |
4 | 5 |
|
5 | 6 | from confluent_kafka import Consumer
|
| 7 | +from streaming_data_types import deserialise_wrdn |
| 8 | +from streaming_data_types.finished_writing_wrdn import ( |
| 9 | + FILE_IDENTIFIER as WRDN_FILE_IDENTIFIER, |
| 10 | +) |
| 11 | +from streaming_data_types.finished_writing_wrdn import WritingFinished |
6 | 12 |
|
7 | 13 | from scicat_configuration import kafkaOptions
|
8 | 14 |
|
@@ -66,3 +72,58 @@ def validate_consumer(consumer: Consumer, logger: logging.Logger) -> bool:
|
66 | 72 | else:
|
67 | 73 | logger.info("Kafka consumer successfully instantiated")
|
68 | 74 | return True
|
| 75 | + |
| 76 | + |
| 77 | +def _validate_data_type(message_content: bytes, logger: logging.Logger) -> bool: |
| 78 | + logger.info("Data type: %s", (data_type := message_content[4:8])) |
| 79 | + if data_type == WRDN_FILE_IDENTIFIER: |
| 80 | + logger.info("WRDN message received.") |
| 81 | + return True |
| 82 | + else: |
| 83 | + logger.error("Unexpected data type: %s", data_type) |
| 84 | + return False |
| 85 | + |
| 86 | + |
| 87 | +def _filter_error_encountered( |
| 88 | + wrdn_content: WritingFinished, logger: logging.Logger |
| 89 | +) -> WritingFinished | None: |
| 90 | + """Filter out messages with the ``error_encountered`` flag set to True.""" |
| 91 | + if wrdn_content.error_encountered: |
| 92 | + logger.error( |
| 93 | + "``error_encountered`` flag True. " |
| 94 | + "Unable to deserialize message. Skipping the message." |
| 95 | + ) |
| 96 | + return wrdn_content |
| 97 | + else: |
| 98 | + return None |
| 99 | + |
| 100 | + |
| 101 | +def _deserialise_wrdn( |
| 102 | + message_content: bytes, logger: logging.Logger |
| 103 | +) -> WritingFinished | None: |
| 104 | + if _validate_data_type(message_content, logger): |
| 105 | + logger.info("Deserialising WRDN message") |
| 106 | + wrdn_content: WritingFinished = deserialise_wrdn(message_content) |
| 107 | + logger.info("Deserialised WRDN message: %.5000s", wrdn_content) |
| 108 | + return _filter_error_encountered(wrdn_content, logger) |
| 109 | + |
| 110 | + |
| 111 | +def wrdn_messages( |
| 112 | + consumer: Consumer, logger: logging.Logger |
| 113 | +) -> Generator[WritingFinished | None, None, None]: |
| 114 | + """Wait for a WRDN message and yield it. |
| 115 | +
|
| 116 | + Yield ``None`` if no message is received or an error is encountered. |
| 117 | + """ |
| 118 | + while True: |
| 119 | + # The decision to proceed or stop will be done by the caller. |
| 120 | + message = consumer.poll(timeout=1.0) |
| 121 | + if message is None: |
| 122 | + logger.info("Received no messages") |
| 123 | + yield None |
| 124 | + elif message.error(): |
| 125 | + logger.error("Consumer error: %s", message.error()) |
| 126 | + yield None |
| 127 | + else: |
| 128 | + logger.info("Received message.") |
| 129 | + yield _deserialise_wrdn(message.value(), logger) |
0 commit comments