Skip to content

Commit 824c7ea

Browse files
authored
Merge pull request #1 from muxi-ai/fix/event-loop-management
fix: resolve event loop management issues for async/sync interoperability
2 parents ae2d0cd + 649fccc commit 824c7ea

File tree

9 files changed

+395
-80
lines changed

9 files changed

+395
-80
lines changed

onellm/audio.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
with support for fallback models if the primary model fails.
2727
"""
2828

29-
import asyncio
3029
from typing import Any, Dict, IO, List, Optional, Union
3130

3231
from .providers.base import get_provider_with_fallbacks
3332
from .utils.fallback import FallbackConfig
33+
from .utils.async_helpers import run_async
3434

3535
class AudioTranscription:
3636
"""
@@ -113,8 +113,8 @@ def create_sync(
113113
Returns:
114114
Transcription result
115115
"""
116-
# Use asyncio.run to execute the async create method in a new event loop
117-
return asyncio.run(
116+
# Use our safe async runner to execute the async create method
117+
return run_async(
118118
cls.create(
119119
file=file,
120120
model=model,
@@ -204,8 +204,8 @@ def create_sync(
204204
Returns:
205205
Translation result with text in English
206206
"""
207-
# Use asyncio.run to execute the async create method in a new event loop
208-
return asyncio.run(
207+
# Use our safe async runner to execute the async create method
208+
return run_async(
209209
cls.create(
210210
file=file,
211211
model=model,

onellm/chat_completion.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,21 @@
2525
completions from various providers in a manner compatible with OpenAI's API.
2626
"""
2727

28-
import asyncio
2928
import logging
3029
import warnings
3130
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
3231

33-
from .providers.base import get_provider_with_fallbacks
32+
from .providers.base import get_provider_with_fallbacks, parse_model_name
3433
from .models import ChatCompletionResponse, ChatCompletionChunk
3534
from .utils.fallback import FallbackConfig
36-
from .validators import validate_model_name, validate_messages, validate_stream
35+
from .utils.async_helpers import run_async
36+
from .validators import (
37+
validate_model_name,
38+
validate_messages,
39+
validate_stream,
40+
validate_chat_params,
41+
validate_provider_model,
42+
)
3743

3844
class ChatCompletion:
3945
"""Class for creating chat completions with various providers."""
@@ -262,11 +268,20 @@ def create(
262268
validate_model_name(model)
263269
validate_messages(messages)
264270
validate_stream(stream)
271+
272+
# Validate all parameters
273+
validate_chat_params(**kwargs)
274+
275+
# Parse model and validate for provider
276+
provider_name, model_without_prefix = parse_model_name(model)
277+
validate_provider_model(model_without_prefix, provider_name)
265278

266279
# Validate fallback models if provided
267280
if fallback_models:
268281
for i, fallback_model in enumerate(fallback_models):
269282
validate_model_name(fallback_model)
283+
fb_provider, fb_model = parse_model_name(fallback_model)
284+
validate_provider_model(fb_model, fb_provider)
270285

271286
# Process fallback configuration
272287
fb_config = None
@@ -294,24 +309,13 @@ def create(
294309
provider, messages, stream, kwargs
295310
)
296311

297-
# Call the provider's method synchronously
298-
if stream:
299-
# For streaming, we need to use async properly
300-
# Create a new event loop to run the async code
301-
loop = asyncio.new_event_loop()
302-
asyncio.set_event_loop(loop)
303-
return loop.run_until_complete(
304-
provider.create_chat_completion(
305-
messages=messages, model=model_name, stream=stream, **processed_kwargs
306-
)
307-
)
308-
else:
309-
# For non-streaming, we can just run and get the result
310-
return asyncio.run(
311-
provider.create_chat_completion(
312-
messages=messages, model=model_name, stream=stream, **processed_kwargs
313-
)
312+
# Call the provider's method synchronously using our safe async runner
313+
# This handles edge cases like Jupyter notebooks, existing event loops, etc.
314+
return run_async(
315+
provider.create_chat_completion(
316+
messages=messages, model=model_name, stream=stream, **processed_kwargs
314317
)
318+
)
315319

316320
@classmethod
317321
async def acreate(
@@ -359,11 +363,20 @@ async def acreate(
359363
validate_model_name(model)
360364
validate_messages(messages)
361365
validate_stream(stream)
366+
367+
# Validate all parameters
368+
validate_chat_params(**kwargs)
369+
370+
# Parse model and validate for provider
371+
provider_name, model_without_prefix = parse_model_name(model)
372+
validate_provider_model(model_without_prefix, provider_name)
362373

363374
# Validate fallback models if provided
364375
if fallback_models:
365376
for i, fallback_model in enumerate(fallback_models):
366377
validate_model_name(fallback_model)
378+
fb_provider, fb_model = parse_model_name(fallback_model)
379+
validate_provider_model(fb_model, fb_provider)
367380

368381
# Process fallback configuration
369382
fb_config = None

onellm/completion.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,19 @@
2525
completions from various providers in a manner compatible with OpenAI's API.
2626
"""
2727

28-
import asyncio
2928
from typing import Any, AsyncGenerator, List, Optional, Union
3029

31-
from .providers.base import get_provider_with_fallbacks
30+
from .providers.base import get_provider_with_fallbacks, parse_model_name
3231
from .models import CompletionResponse
3332
from .utils.fallback import FallbackConfig
33+
from .utils.async_helpers import run_async
34+
from .validators import (
35+
validate_model_name,
36+
validate_prompt,
37+
validate_stream,
38+
validate_completion_params,
39+
validate_provider_model,
40+
)
3441

3542
class Completion:
3643
"""Class for creating text completions with various providers."""
@@ -76,9 +83,24 @@ def create(
7683
... )
7784
>>> print(response.choices[0].text)
7885
"""
79-
# Validate prompt
80-
if not prompt or not prompt.strip():
81-
raise ValueError("Prompt cannot be empty")
86+
# Validate inputs
87+
validate_model_name(model)
88+
validate_prompt(prompt)
89+
validate_stream(stream)
90+
91+
# Validate all parameters
92+
validate_completion_params(**kwargs)
93+
94+
# Parse model and validate for provider
95+
provider_name, model_without_prefix = parse_model_name(model)
96+
validate_provider_model(model_without_prefix, provider_name)
97+
98+
# Validate fallback models if provided
99+
if fallback_models:
100+
for fallback_model in fallback_models:
101+
validate_model_name(fallback_model)
102+
fb_provider, fb_model = parse_model_name(fallback_model)
103+
validate_provider_model(fb_model, fb_provider)
82104

83105
# Process fallback configuration
84106
fb_config = None
@@ -104,25 +126,13 @@ def create(
104126
fallback_config=fb_config,
105127
)
106128

107-
# Call the provider's method synchronously
108-
if stream:
109-
# For streaming, we need to use async properly
110-
# Create a new event loop to run the async code in a synchronous context
111-
loop = asyncio.new_event_loop()
112-
asyncio.set_event_loop(loop)
113-
return loop.run_until_complete(
114-
provider.create_completion(
115-
prompt=prompt, model=model_name, stream=stream, **kwargs
116-
)
117-
)
118-
else:
119-
# For non-streaming, we can just run and get the result
120-
# asyncio.run creates a new event loop, runs the coroutine, and closes the loop
121-
return asyncio.run(
122-
provider.create_completion(
123-
prompt=prompt, model=model_name, stream=stream, **kwargs
124-
)
129+
# Call the provider's method synchronously using our safe async runner
130+
# This handles edge cases like Jupyter notebooks, existing event loops, etc.
131+
return run_async(
132+
provider.create_completion(
133+
prompt=prompt, model=model_name, stream=stream, **kwargs
125134
)
135+
)
126136

127137
@classmethod
128138
async def acreate(
@@ -166,9 +176,24 @@ async def acreate(
166176
... )
167177
>>> print(response.choices[0].text)
168178
"""
169-
# Validate prompt
170-
if not prompt or not prompt.strip():
171-
raise ValueError("Prompt cannot be empty")
179+
# Validate inputs
180+
validate_model_name(model)
181+
validate_prompt(prompt)
182+
validate_stream(stream)
183+
184+
# Validate all parameters
185+
validate_completion_params(**kwargs)
186+
187+
# Parse model and validate for provider
188+
provider_name, model_without_prefix = parse_model_name(model)
189+
validate_provider_model(model_without_prefix, provider_name)
190+
191+
# Validate fallback models if provided
192+
if fallback_models:
193+
for fallback_model in fallback_models:
194+
validate_model_name(fallback_model)
195+
fb_provider, fb_model = parse_model_name(fallback_model)
196+
validate_provider_model(fb_model, fb_provider)
172197

173198
# Process fallback configuration
174199
fb_config = None

onellm/embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from various providers in a manner compatible with OpenAI's API.
2626
"""
2727

28-
import asyncio
2928
from typing import List, Optional, Union
3029

3130
from .providers.base import get_provider_with_fallbacks
3231
from .models import EmbeddingResponse
3332
from .utils.fallback import FallbackConfig
33+
from .utils.async_helpers import run_async
3434
from .errors import InvalidRequestError
3535

3636
def validate_embedding_input(input_data: Union[str, List[str]]) -> None:
@@ -109,8 +109,8 @@ def create(
109109
fallback_config=fb_config,
110110
)
111111

112-
# Call the provider's method synchronously by running the async method in an event loop
113-
return asyncio.run(
112+
# Call the provider's method synchronously using our safe async runner
113+
return run_async(
114114
provider.create_embedding(input=input, model=model_name, **kwargs)
115115
)
116116

0 commit comments

Comments
 (0)