-
Notifications
You must be signed in to change notification settings - Fork 1k
fix(analyzer): honour language_model_params in BasicLangExtractRecognizer #1943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,14 +62,25 @@ def __init__( | |
|
|
||
| self.model_id = model_config.get("model_id") | ||
| self.provider = provider_config.get("name") | ||
| self.provider_kwargs = provider_config.get("kwargs", {}) | ||
| self.provider_kwargs = dict(provider_config.get("kwargs", {})) | ||
|
|
||
| # Not ideal, but update _extract_params now that self.config is fully loaded. | ||
| self._extract_params.update(provider_config.get("extract_params", {})) | ||
| self._language_model_params.update( | ||
| provider_config.get("language_model_params", {}) | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| # Surface language_model_params on the ModelConfig itself. | ||
| # langextract.extract() honours `language_model_params` only when | ||
| # `config` is NOT passed (see langextract/extraction.py elif config: | ||
| # branch). Because _get_provider_params() returns a pre-built | ||
| # ModelConfig, values like `timeout` and `num_ctx` would otherwise be | ||
| # silently dropped. Merge them into provider_kwargs so they reach | ||
| # the provider constructor (e.g. OllamaLanguageModel(timeout=...)). | ||
| # `setdefault` ensures explicit `provider.kwargs:` entries always win. | ||
| for key, value in self._language_model_params.items(): | ||
| self.provider_kwargs.setdefault(key, value) | ||
|
|
||
| if not self.provider: | ||
| raise ValueError("Configuration must contain " | ||
| "'langextract.model.provider.name'") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -574,3 +574,68 @@ def test_when_analyze_called_then_params_passed_to_langextract(self, tmp_path): | |
| assert call_kwargs["config"].provider == "ollama" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a test where |
||
| assert call_kwargs["config"].provider_kwargs["model_url"] == "http://localhost:11434" | ||
|
|
||
| # Regression: language_model_params must also surface on | ||
| # ModelConfig.provider_kwargs, because langextract.extract() | ||
| # ignores `language_model_params` when `config` is passed and | ||
| # only reads values from ModelConfig.provider_kwargs in that | ||
| # branch. Without this, `timeout` and `num_ctx` are silently | ||
| # dropped and Ollama falls back to its 120s default. | ||
| assert call_kwargs["config"].provider_kwargs["timeout"] == 180 | ||
| assert call_kwargs["config"].provider_kwargs["num_ctx"] == 8192 | ||
|
|
||
| def test_language_model_params_reach_provider_kwargs(self, tmp_path): | ||
| """Regression test: values under provider.language_model_params in the | ||
| yaml must end up on ModelConfig.provider_kwargs so they actually reach | ||
| the provider constructor (e.g. OllamaLanguageModel(timeout=...)). | ||
|
|
||
| Prior to the fix, BasicLangExtractRecognizer only copied | ||
| provider.kwargs onto ModelConfig.provider_kwargs, and | ||
| provider.language_model_params was forwarded to lx.extract() as a | ||
| separate argument — but langextract.extract() ignores that argument | ||
| when config is passed directly, so values like `timeout` and | ||
| `num_ctx` were silently dropped. | ||
| """ | ||
| import yaml | ||
|
|
||
| config = create_test_config() | ||
| config["langextract"]["model"]["provider"]["language_model_params"]["timeout"] = 600 | ||
| config["langextract"]["model"]["provider"]["language_model_params"]["num_ctx"] = 16384 | ||
|
|
||
| config_file = tmp_path / "test_config.yaml" | ||
| with open(config_file, 'w') as f: | ||
| yaml.dump(config, f) | ||
|
|
||
| with patch('presidio_analyzer.llm_utils.langextract_helper.lx', | ||
| return_value=Mock()): | ||
| from presidio_analyzer.predefined_recognizers.third_party.basic_langextract_recognizer import BasicLangExtractRecognizer | ||
| recognizer = BasicLangExtractRecognizer(config_path=str(config_file)) | ||
|
|
||
| provider_kwargs = recognizer._get_provider_params()["config"].provider_kwargs | ||
| assert provider_kwargs["timeout"] == 600 | ||
| assert provider_kwargs["num_ctx"] == 16384 | ||
|
|
||
| def test_provider_kwargs_take_precedence_over_language_model_params(self, tmp_path): | ||
| """Explicit `provider.kwargs:` entries must win over values of the | ||
| same name under `provider.language_model_params:`. This preserves | ||
| backward compatibility for configs that already place timeout in | ||
| `kwargs:` as a workaround. | ||
| """ | ||
| import yaml | ||
|
|
||
| config = create_test_config() | ||
| config["langextract"]["model"]["provider"]["kwargs"]["timeout"] = 900 | ||
| config["langextract"]["model"]["provider"]["language_model_params"]["timeout"] = 60 | ||
|
|
||
| config_file = tmp_path / "test_config.yaml" | ||
| with open(config_file, 'w') as f: | ||
| yaml.dump(config, f) | ||
|
|
||
| with patch('presidio_analyzer.llm_utils.langextract_helper.lx', | ||
| return_value=Mock()): | ||
| from presidio_analyzer.predefined_recognizers.third_party.basic_langextract_recognizer import BasicLangExtractRecognizer | ||
| recognizer = BasicLangExtractRecognizer(config_path=str(config_file)) | ||
|
|
||
| provider_kwargs = recognizer._get_provider_params()["config"].provider_kwargs | ||
| # The explicit kwargs: value wins | ||
| assert provider_kwargs["timeout"] == 900 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict(provider_config.get("kwargs", {}))will raise aTypeErrorif the YAML containskwargs: null(or an emptykwargs:key). Consider usingdict(provider_config.get("kwargs") or {})(or otherwise normalizing/validating the value) so missing/empty kwargs are treated as an empty mapping and the error message remains actionable for users.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Copilot.