|
22 | 22 | # specific language governing permissions and limitations
|
23 | 23 | # under the License.
|
24 | 24 |
|
25 |
| -import contextvars |
26 | 25 | import logging
|
27 | 26 | import time
|
28 | 27 |
|
29 | 28 | import certifi
|
30 | 29 | import urllib3
|
31 | 30 | from urllib3.util.ssl_ import is_ipaddress
|
| 31 | +from osbenchmark.kafka_client import KafkaMessageProducer |
32 | 32 |
|
33 | 33 | from osbenchmark import exceptions, doc_link
|
| 34 | +from osbenchmark.context import RequestContextHolder |
34 | 35 | from osbenchmark.utils import console, convert
|
35 | 36 |
|
36 | 37 |
|
37 |
| -class RequestContextManager: |
38 |
| - """ |
39 |
| - Ensures that request context span the defined scope and allow nesting of request contexts with proper propagation. |
40 |
| - This means that we can span a top-level request context, open sub-request contexts that can be used to measure |
41 |
| - individual timings and still measure the proper total time on the top-level request context. |
42 |
| - """ |
43 |
| - def __init__(self, request_context_holder): |
44 |
| - self.ctx_holder = request_context_holder |
45 |
| - self.ctx = None |
46 |
| - self.token = None |
47 |
| - |
48 |
| - async def __aenter__(self): |
49 |
| - self.ctx, self.token = self.ctx_holder.init_request_context() |
50 |
| - return self |
51 |
| - |
52 |
| - @property |
53 |
| - def request_start(self): |
54 |
| - return self.ctx["request_start"] |
55 |
| - |
56 |
| - @property |
57 |
| - def request_end(self): |
58 |
| - return max((value for value in self.ctx["request_end_list"] if value < self.client_request_end)) |
59 |
| - |
60 |
| - @property |
61 |
| - def client_request_start(self): |
62 |
| - return self.ctx["client_request_start"] |
63 |
| - |
64 |
| - @property |
65 |
| - def client_request_end(self): |
66 |
| - return self.ctx["client_request_end"] |
67 |
| - |
68 |
| - async def __aexit__(self, exc_type, exc_val, exc_tb): |
69 |
| - # propagate earliest request start and most recent request end to parent |
70 |
| - client_request_start = self.client_request_start |
71 |
| - client_request_end = self.client_request_end |
72 |
| - request_start = self.request_start |
73 |
| - request_end = self.request_end |
74 |
| - self.ctx_holder.restore_context(self.token) |
75 |
| - # don't attempt to restore these values on the top-level context as they don't exist |
76 |
| - if self.token.old_value != contextvars.Token.MISSING: |
77 |
| - self.ctx_holder.update_request_start(request_start) |
78 |
| - self.ctx_holder.update_request_end(request_end) |
79 |
| - self.ctx_holder.update_client_request_start(client_request_start) |
80 |
| - self.ctx_holder.update_client_request_end(client_request_end) |
81 |
| - self.token = None |
82 |
| - return False |
83 |
| - |
84 |
| - |
85 |
| -class RequestContextHolder: |
86 |
| - """ |
87 |
| - Holds request context variables. This class is only meant to be used together with RequestContextManager. |
88 |
| - """ |
89 |
| - request_context = contextvars.ContextVar("benchmark_request_context") |
90 |
| - |
91 |
| - def new_request_context(self): |
92 |
| - return RequestContextManager(self) |
93 |
| - |
94 |
| - @classmethod |
95 |
| - def init_request_context(cls): |
96 |
| - ctx = {} |
97 |
| - token = cls.request_context.set(ctx) |
98 |
| - return ctx, token |
99 |
| - |
100 |
| - @classmethod |
101 |
| - def restore_context(cls, token): |
102 |
| - cls.request_context.reset(token) |
103 |
| - |
104 |
| - @classmethod |
105 |
| - def update_request_start(cls, new_request_start): |
106 |
| - meta = cls.request_context.get() |
107 |
| - # this can happen if multiple requests are sent on the wire for one logical request (e.g. scrolls) |
108 |
| - if "request_start" not in meta and "client_request_start" in meta: |
109 |
| - meta["request_start"] = new_request_start |
110 |
| - |
111 |
| - @classmethod |
112 |
| - def update_request_end(cls, new_request_end): |
113 |
| - meta = cls.request_context.get() |
114 |
| - if "request_end_list" not in meta: |
115 |
| - meta["request_end_list"] = [] |
116 |
| - meta["request_end_list"].append(new_request_end) |
117 |
| - |
118 |
| - @classmethod |
119 |
| - def update_client_request_start(cls, new_client_request_start): |
120 |
| - meta = cls.request_context.get() |
121 |
| - if "client_request_start" not in meta: |
122 |
| - meta["client_request_start"] = new_client_request_start |
123 |
| - |
124 |
| - @classmethod |
125 |
| - def update_client_request_end(cls, new_client_request_end): |
126 |
| - meta = cls.request_context.get() |
127 |
| - meta["client_request_end"] = new_client_request_end |
128 |
| - |
129 |
| - @classmethod |
130 |
| - def on_client_request_start(cls): |
131 |
| - cls.update_client_request_start(time.perf_counter()) |
132 |
| - |
133 |
| - @classmethod |
134 |
| - def on_client_request_end(cls): |
135 |
| - cls.update_client_request_end(time.perf_counter()) |
136 |
| - |
137 |
| - @classmethod |
138 |
| - def on_request_start(cls): |
139 |
| - cls.update_request_start(time.perf_counter()) |
140 |
| - |
141 |
| - @classmethod |
142 |
| - def on_request_end(cls): |
143 |
| - cls.update_request_end(time.perf_counter()) |
144 |
| - |
145 |
| - @classmethod |
146 |
| - def return_raw_response(cls): |
147 |
| - ctx = cls.request_context.get() |
148 |
| - ctx["raw_response"] = True |
149 |
| - |
150 |
| - |
151 | 38 | class OsClientFactory:
|
152 | 39 | """
|
153 | 40 | Abstracts how the OpenSearch client is created. Intended for testing.
|
@@ -430,3 +317,19 @@ def wait_for_rest_layer(opensearch, max_attempts=40):
|
430 | 317 | logger.warning("Got unexpected status code [%s] on attempt [%s].", e.status_code, attempt)
|
431 | 318 | raise e
|
432 | 319 | return False
|
| 320 | + |
| 321 | + |
| 322 | +class MessageProducerFactory: |
| 323 | + @staticmethod |
| 324 | + async def create(params): |
| 325 | + """ |
| 326 | + Creates and returns a message producer based on the ingestion source. |
| 327 | + Currently supports Kafka. Ingestion source should be a dict like: |
| 328 | + {'type': 'kafka', 'param': {'topic': 'test', 'bootstrap-servers': 'localhost:34803'}} |
| 329 | + """ |
| 330 | + ingestion_source = params.get("ingestion-source", {}) |
| 331 | + producer_type = ingestion_source.get("type", "kafka").lower() |
| 332 | + if producer_type == "kafka": |
| 333 | + return await KafkaMessageProducer.create(params) |
| 334 | + else: |
| 335 | + raise ValueError(f"Unsupported ingestion source type: {producer_type}") |
0 commit comments