|
22 | 22 | from opentelemetry.sdk.trace import TracerProvider |
23 | 23 | from opentelemetry.sdk.trace.export import BatchSpanProcessor |
24 | 24 | from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter |
| 25 | +import openlit |
| 26 | +from transformers import AutoTokenizer |
25 | 27 |
|
26 | 28 | set_verbose(True) |
27 | 29 |
|
|
31 | 33 | otlp_endpoint = os.environ.get("OTLP_ENDPOINT", False) |
32 | 34 |
|
33 | 35 | # Initialize OpenTelemetry |
34 | | -trace.set_tracer_provider(TracerProvider()) |
35 | | -tracer = trace.get_tracer(__name__) |
| 36 | +if not isinstance(trace.get_tracer_provider(), TracerProvider): |
| 37 | + tracer_provider = TracerProvider() |
| 38 | + trace.set_tracer_provider(tracer_provider) |
| 39 | + |
| 40 | + # Set up OTLP exporter and span processor |
| 41 | + if not otlp_endpoint: |
| 42 | + logging.warning("No OTLP endpoint provided - Telemetry data will not be collected.") |
| 43 | + else: |
| 44 | + otlp_exporter = OTLPSpanExporter() |
| 45 | + span_processor = BatchSpanProcessor(otlp_exporter) |
| 46 | + tracer_provider.add_span_processor(span_processor) |
| 47 | + |
| 48 | + openlit.init( |
| 49 | + otlp_endpoint=otlp_endpoint, |
| 50 | + application_name=os.environ.get("OTEL_SERVICE_NAME", "chatqna"), |
| 51 | + environment=os.environ.get("OTEL_SERVICE_ENV", "chatqna"), |
| 52 | + ) |
36 | 53 |
|
37 | | -if otlp_endpoint: |
38 | | - otlp_exporter = OTLPSpanExporter() |
39 | | - span_processor = BatchSpanProcessor(otlp_exporter) |
40 | | - trace.get_tracer_provider().add_span_processor(span_processor) |
| 54 | + logging.info(f"Tracing enabled: OpenTelemetry configured using OTLP endpoint at {otlp_endpoint}") |
41 | 55 |
|
42 | 56 | PG_CONNECTION_STRING = os.getenv("PG_CONNECTION_STRING") |
43 | 57 | MODEL_NAME = os.getenv("EMBEDDING_MODEL","BAAI/bge-small-en-v1.5") |
|
88 | 102 | prompt = ChatPromptTemplate.from_template(template) |
89 | 103 |
|
90 | 104 | ENDPOINT_URL = os.getenv("ENDPOINT_URL", "http://localhost:8080") |
| 105 | + |
| 106 | +# Check which LLM inference backend is being used |
| 107 | +LLM_BACKEND = None |
| 108 | +if "ovms" in ENDPOINT_URL.lower(): |
| 109 | + LLM_BACKEND = "ovms" |
| 110 | +elif "text-generation" in ENDPOINT_URL.lower(): |
| 111 | + LLM_BACKEND = "text-generation" |
| 112 | +elif "vllm" in ENDPOINT_URL.lower(): |
| 113 | + LLM_BACKEND = "vllm" |
| 114 | +else: |
| 115 | + LLM_BACKEND = "unknown" |
| 116 | + |
| 117 | +logging.info(f"Using LLM inference backend: {LLM_BACKEND}") |
91 | 118 | LLM_MODEL = os.getenv("LLM_MODEL", "Intel/neural-chat-7b-v3-3") |
92 | 119 | RERANKER_ENDPOINT = os.getenv("RERANKER_ENDPOINT", "http://localhost:9090/rerank") |
93 | 120 | callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] |
94 | | - |
95 | | - |
96 | | -model = EGAIModelServing( |
97 | | - openai_api_key="EMPTY", |
98 | | - openai_api_base="{}".format(ENDPOINT_URL), |
99 | | - model_name=LLM_MODEL, |
100 | | - top_p=0.99, |
101 | | - temperature=0.01, |
102 | | - streaming=True, |
103 | | - callbacks=callbacks, |
104 | | -) |
105 | | - |
106 | | -re_ranker = CustomReranker(reranking_endpoint=RERANKER_ENDPOINT) |
107 | | -re_ranker_lambda = RunnableLambda(re_ranker.rerank) |
108 | | - |
109 | | -# RAG Chain |
110 | | -chain = ( |
111 | | - RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) |
112 | | - | re_ranker_lambda |
113 | | - | prompt |
114 | | - | model |
115 | | - | StrOutputParser() |
116 | | -) |
117 | | - |
118 | | - |
119 | | -async def process_chunks(question_text): |
| 121 | +tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) |
| 122 | + |
| 123 | +async def process_chunks(question_text,max_tokens): |
| 124 | + if LLM_BACKEND in ["vllm", "unknown"]: |
| 125 | + seed_value = None |
| 126 | + else: |
| 127 | + seed_value = int(os.getenv("SEED", 42)) |
| 128 | + tokens = tokenizer.tokenize(str(prompt)) |
| 129 | + num_tokens = len(tokens) |
| 130 | + logging.info(f"Prompt tokens for model {LLM_MODEL}: {num_tokens}") |
| 131 | + output_tokens = max_tokens - num_tokens |
| 132 | + logging.info(f"Output tokens for model {LLM_MODEL}: {output_tokens}") |
| 133 | + model = EGAIModelServing( |
| 134 | + openai_api_key="EMPTY", |
| 135 | + openai_api_base="{}".format(ENDPOINT_URL), |
| 136 | + model_name=LLM_MODEL, |
| 137 | + top_p=0.99, |
| 138 | + temperature=0.01, |
| 139 | + streaming=True, |
| 140 | + callbacks=callbacks, |
| 141 | + seed=seed_value, |
| 142 | + max_tokens=max_tokens, |
| 143 | + stop=["\n\n"] |
| 144 | + ) |
| 145 | + |
| 146 | + re_ranker = CustomReranker(reranking_endpoint=RERANKER_ENDPOINT) |
| 147 | + re_ranker_lambda = RunnableLambda(re_ranker.rerank) |
| 148 | + |
| 149 | + # RAG Chain |
| 150 | + chain = ( |
| 151 | + RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) |
| 152 | + | re_ranker_lambda |
| 153 | + | prompt |
| 154 | + | model |
| 155 | + | StrOutputParser() |
| 156 | + ) |
| 157 | + # Run the chain with the question text |
120 | 158 | async for log in chain.astream(question_text): |
121 | 159 | yield f"data: {log}\n\n" |
0 commit comments