Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import GenerateOutput, StandardInput
from distilabel.typing import (
GenerateOutput,
OutlinesStructuredOutputType,
StandardInput,
)

if TYPE_CHECKING:
import mlx.nn as nn
Expand All @@ -46,6 +51,8 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: the tokenizer configuration.
mlx_model_config: the MLX model configuration.
adapter_path: the path to the adapter.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
Expand All @@ -69,15 +76,42 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```

Generate structured data:

```python
from pathlib import Path
from distilabel.models.llms import MlxLLM
from pydantic import BaseModel

class User(BaseModel):
first_name: str
last_name: str

llm = MlxLLM(
path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
structured_output={"format": "json", "schema": User},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]])
```
"""

path_or_hf_repo: str
tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
adapter_path: Optional[str] = None
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)

_model: Optional["nn.Module"] = PrivateAttr(None)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
_logits_processor: Union[Callable, None] = PrivateAttr(default=None)
_mlx_generate: Optional[Callable] = PrivateAttr(None)
_make_sampler: Optional[Callable] = PrivateAttr(None)

Expand Down Expand Up @@ -193,22 +227,29 @@ def generate( # type: ignore
min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
)
structured_output = None
result = []
for input in inputs:
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
elif self.structured_output:
structured_output = self.structured_output

output: List[str] = []
for _ in range(num_generations):
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
configured_processors = list(logits_processors or [])
if structured_output:
structured_processors = self._prepare_structured_output(
structured_output
)
configured_processors.append(structured_processors)

prompt = self.prepare_input(input)
generation = self._mlx_generate( # type: ignore
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
logits_processors=logits_processors,
logits_processors=configured_processors,
max_tokens=max_tokens,
sampler=sampler,
max_kv_size=max_kv_size,
Expand Down Expand Up @@ -236,3 +277,23 @@ def generate( # type: ignore
)
)
return result

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, List[Callable]]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
structured_output: the configuration dict to prepare the structured output.

Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(structured_output, "mlx", self)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
9 changes: 7 additions & 2 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa
import mlx.nn as nn # noqa

from distilabel.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp"]
Frameworks = Literal["transformers", "llamacpp", "mlx"]


def _check_outlines_available() -> None:
Expand All @@ -59,7 +60,7 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:


def _create_outlines_model(
llm: Union["Pipeline", "Llama"],
llm: Union["Pipeline", "Llama", "nn.Module"],
framework: Frameworks,
) -> Any:
"""Create an outlines model wrapper for the given framework.
Expand All @@ -86,6 +87,10 @@ def _create_outlines_model(
from outlines import from_llamacpp

return from_llamacpp(llm)
elif framework == "mlx":
from outlines import from_mlxlm

return from_mlxlm(llm._model, llm._tokenizer)


def prepare_guided_output(
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/models/llms/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import platform
from typing import Any, Dict, Generator

import pytest
from pydantic import BaseModel

from distilabel.models.llms.mlx import MlxLLM

Expand Down Expand Up @@ -63,6 +64,47 @@ def test_generate(self, llm: MlxLLM) -> None:
assert "input_tokens" in statistics
assert "output_tokens" in statistics

def test_structured_generation_json(self, llm: MlxLLM) -> None:
class User(BaseModel):
first_name: str
last_name: str

llm.structured_output = {"format": "json", "schema": User.model_json_schema()}

responses = llm.generate(
inputs=[
[{"role": "user", "content": "Create a user profile for John Smith"}],
],
num_generations=1,
)

assert len(responses) == 1
assert "generations" in responses[0]
assert "statistics" in responses[0]
generations = responses[0]["generations"]
assert len(generations) == 1

# Clean and parse the generation
for generation in generations:
# Remove the <|im_end|> tokens and clean up the string
cleaned_json = generation.replace("<|im_end|>", "").strip()
try:
user_data = json.loads(cleaned_json)
parsed_user = User(**user_data)
assert isinstance(parsed_user, User)
assert parsed_user.first_name == "John"
assert parsed_user.last_name == "Smith"
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw generation: {cleaned_json}")
raise
except ValueError as e:
print(f"Validation error: {e}")
raise
statistics = responses[0]["statistics"]
assert "input_tokens" in statistics
assert "output_tokens" in statistics

@pytest.mark.parametrize(
"structured_output, dump",
[
Expand All @@ -74,6 +116,7 @@ def test_generate(self, llm: MlxLLM) -> None:
"structured_output": None,
"adapter_path": None,
"jobs_ids": None,
"mlx_model_config": {},
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"magpie_pre_query_template": None,
Expand Down Expand Up @@ -102,6 +145,7 @@ def test_generate(self, llm: MlxLLM) -> None:
},
"adapter_path": None,
"jobs_ids": None,
"mlx_model_config": {},
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
Expand Down