-
Notifications
You must be signed in to change notification settings - Fork 574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add library specific handling of tensors in logits processors #1498
Add library specific handling of tensors in logits processors #1498
Conversation
6304811
to
6217d08
Compare
I was originally skeptical of the tensor handler the approach but it's growing on me. I would use the name
🎉 |
6217d08
to
8762761
Compare
I changed the name from |
8762761
to
733eeed
Compare
71dc231
to
6b250b4
Compare
outlines/models/llamacpp.py
Outdated
|
||
def __init__(self, model: "Llama"): | ||
def __init__(self, model: "Llama", tensor_library_name: Optional[str] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't tensor_library_name
for llama-cpp-python
necessarily numpy
?
outlines/models/mlxlm.py
Outdated
|
||
def __init__( | ||
self, | ||
model: "nn.Module", | ||
tokenizer: "PreTrainedTokenizer", | ||
tensor_library_name: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it necessarily mlx
?
outlines/models/transformers.py
Outdated
@@ -163,11 +163,13 @@ def format_output_type(self, output_type): | |||
|
|||
class Transformers(Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may be able to do this automatically. It looks like all models implemented implemented with JAX inherit from FlaxPretrainedModel
outlines/models/vllm.py
Outdated
@@ -139,8 +144,8 @@ def generate_stream(self, model_input, output_type, **inference_kwargs): | |||
) | |||
|
|||
|
|||
def from_vllm(model: "LLM") -> VLLM: | |||
return VLLM(model) | |||
def from_vllm(model: "LLM", tensor_library_name: Optional[str] = None) -> VLLM: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it necessarily torch
? I don't think vLLM supports any other backend.
@@ -40,7 +39,6 @@ dependencies = [ | |||
"typing_extensions", | |||
"iso3166", | |||
"airportsdata", | |||
"torch", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥳
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. The main comment I have is that I don't think that any library other than transformers
supports more than one backend, and even for transformers
there may be a way to determine the backend automatically.
6b250b4
to
407eddf
Compare
Yes, you're right. I completely removed the possibility of specifying the tensor library to use when initiating a model as there's no case in which it's necessary. As a result I could simplify the logic the |
407eddf
to
2aaa146
Compare
outlines/models/transformers.py
Outdated
@@ -185,11 +185,18 @@ def __init__( | |||
the `transformers` API for tokenizers. | |||
|
|||
""" | |||
from transformers import FlaxPreTrainedModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this requires flax
to be installed :/ Let's see if we can do this another way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just add flax
to the test dependencies? We already have jax
there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will still fail whenever users try to initialise a transformers
model. We could use try/except
for the checks (also add TFPretrainedModel
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a TensorAdapter
for tensorflow
2aaa146
to
1d00c9a
Compare
3318a42
to
0951993
Compare
60f96bb
to
ae46cff
Compare
Do we have tests that initialize a |
ae46cff
to
ed9b572
Compare
Yes. The issue was that there are lines in the |
This PR addresses issue #1445
The PR creates classes of type
TensorHandler
for various tensor libraries. Such classes are used by the logit processors to avoid having to manipulate theinput_its/logits
tensors itself. This solution allows us not to require the user to download all tensor libraries supported by outlines and to treat the tensor in their native type without having to turn them into torch tensors. We also modify the local models to possess an attributetensor_library_name
to indicate to the logits processor whatTensorHandler
implementation to use.Thanks to those change, we remove the mandatory dependencies on torch and numpy.