-
Notifications
You must be signed in to change notification settings - Fork 2
feat(mock-server): add fallback tokenizer support #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -40,10 +40,15 @@ async def lifespan(_: FastAPI): | |||||||||||||||||||
"""Initialize tokenizers and other startup tasks.""" | ||||||||||||||||||||
logger.info("Server configuration: %s", server_config.model_dump()) | ||||||||||||||||||||
|
||||||||||||||||||||
if server_config.tokenizer_models: | ||||||||||||||||||||
logger.info(f"Pre-loading tokenizer models: {server_config.tokenizer_models}") | ||||||||||||||||||||
tokenizer_service.load_tokenizers(server_config.tokenizer_models) | ||||||||||||||||||||
logger.info("Tokenizer models loaded successfully") | ||||||||||||||||||||
tokenizer_models = [ | ||||||||||||||||||||
*server_config.tokenizer_models, | ||||||||||||||||||||
server_config.fallback_tokenizer, | ||||||||||||||||||||
] | ||||||||||||||||||||
|
||||||||||||||||||||
logger.info(f"Pre-loading tokenizer models: {tokenizer_models}") | ||||||||||||||||||||
tokenizer_service.set_fallback_tokenizer(server_config.fallback_tokenizer) | ||||||||||||||||||||
tokenizer_service.load_tokenizers(tokenizer_models) | ||||||||||||||||||||
logger.info("Tokenizer models loaded successfully") | ||||||||||||||||||||
|
||||||||||||||||||||
yield | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -71,6 +76,7 @@ def set_server_config(config: MockServerConfig) -> None: | |||||||||||||||||||
os.environ["MOCK_SERVER_PORT"] = str(config.port) | ||||||||||||||||||||
os.environ["MOCK_SERVER_WORKERS"] = str(config.workers) | ||||||||||||||||||||
os.environ["MOCK_SERVER_ACCESS_LOGS"] = str(config.access_logs) | ||||||||||||||||||||
os.environ["MOCK_SERVER_FALLBACK_TOKENIZER"] = str(config.fallback_tokenizer) | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def extract_user_prompt(messages: list[ChatMessage]) -> str: | ||||||||||||||||||||
|
@@ -135,7 +141,10 @@ async def configure(request: ConfigureMessage): | |||||||||||||||||||
logger.info(f"Loading tokenizer models: {request.tokenizer_models}") | ||||||||||||||||||||
tokenizer_service.load_tokenizers(request.tokenizer_models) | ||||||||||||||||||||
logger.info("Tokenizer models loaded successfully") | ||||||||||||||||||||
|
||||||||||||||||||||
if request.fallback_tokenizer is not None: | ||||||||||||||||||||
tokenizer_service.load_tokenizers([request.fallback_tokenizer]) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this always supposed to be called, even if the tokenizer specified exists? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i need to add a check to see if its already been loaded. Code rabbit brought that up a few tim4es. |
||||||||||||||||||||
tokenizer_service.set_fallback_tokenizer(request.fallback_tokenizer) | ||||||||||||||||||||
logger.info(f"Fallback tokenizer set to {request.fallback_tokenizer}") | ||||||||||||||||||||
Comment on lines
+144
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Persist fallback choice back into server_config We load and set the new fallback on the service, but if request.fallback_tokenizer is not None:
tokenizer_service.load_tokenizers([request.fallback_tokenizer])
tokenizer_service.set_fallback_tokenizer(request.fallback_tokenizer)
+ server_config.fallback_tokenizer = request.fallback_tokenizer
logger.info(f"Fallback tokenizer set to {request.fallback_tokenizer}") 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||
return {"status": "configured", "config": server_config.model_dump()} | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,10 +2,17 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Tokenizer service for handling different model tokenizers.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
import contextlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import io | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers.tokenization_utils import PreTrainedTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Silence tokenizer warning on import and first use | ||||||||||||||||||||||||||||||||||||||||||||||||||||
with ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
contextlib.redirect_stdout(io.StringIO()) as _, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
contextlib.redirect_stderr(io.StringIO()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers.tokenization_utils import PreTrainedTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -15,6 +22,7 @@ class TokenizerService: | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._tokenizers: dict[str, PreTrainedTokenizer] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._fallback_tokenizer: str | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def load_tokenizers(self, model_names: list[str]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Pre-load tokenizers for one or more models. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -34,7 +42,11 @@ def load_tokenizers(self, model_names: list[str]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||
def get_tokenizer(self, model_name: str) -> PreTrainedTokenizer: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Get or create a tokenizer for the specified model.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if model_name not in self._tokenizers: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError(f"No tokenizer loaded for {model_name}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._fallback_tokenizer not in self._tokenizers: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
f"No tokenizer loaded for {model_name} or {self._fallback_tokenizer}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = self._fallback_tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
44
to
50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Restore lazy tokenizer loading before falling back. We now short-circuit to the fallback whenever the requested model key is missing, which means we never even try to lazily load the requested tokenizer anymore. In the current server flows that rely on lazy loading (run without preloading, runtime configure updates, etc.), this silently swaps responses to the fallback tokenizer or just raises when no fallback is configured—a regression from today’s behavior. Please attempt to load the requested tokenizer first and only fall back when that load really fails, while also bootstrapping the fallback if it hasn’t been loaded yet. - if model_name not in self._tokenizers:
- if self._fallback_tokenizer not in self._tokenizers:
- raise ValueError(
- f"No tokenizer loaded for {model_name} or {self._fallback_tokenizer}"
- )
- model_name = self._fallback_tokenizer
+ if model_name not in self._tokenizers:
+ try:
+ logger.info(f"Lazy-loading tokenizer for model: {model_name}")
+ self._tokenizers[model_name] = AutoTokenizer.from_pretrained(
+ model_name, trust_remote_code=True
+ )
+ except Exception as exc:
+ fallback = self._fallback_tokenizer
+ if not fallback:
+ raise ValueError(
+ f"No tokenizer loaded for {model_name}"
+ ) from exc
+ if fallback not in self._tokenizers:
+ logger.info(f"Lazy-loading fallback tokenizer: {fallback}")
+ self._tokenizers[fallback] = AutoTokenizer.from_pretrained(
+ fallback, trust_remote_code=True
+ )
+ model_name = fallback 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._tokenizers[model_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -57,6 +69,10 @@ def count_tokens(self, text: str, model_name: str) -> int: | |||||||||||||||||||||||||||||||||||||||||||||||||||
tokenizer = self.get_tokenizer(model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return len(tokenizer.encode(text, add_special_tokens=False)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def set_fallback_tokenizer(self, fallback_tokenizer: str) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Set the fallback tokenizer to use if the requested tokenizer is not found.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._fallback_tokenizer = fallback_tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Global tokenizer service instance | ||||||||||||||||||||||||||||||||||||||||||||||||||||
tokenizer_service = TokenizerService() |
Uh oh!
There was an error while loading. Please reload this page.