-
Notifications
You must be signed in to change notification settings - Fork 5
Adding Pytest features #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
x = x.mean(dim=-2) | ||
return x # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the weird hash here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noted
## 2. ensure the output is not nan | ||
assert not torch.isnan(res).all() | ||
## 3. ensure the output shape is correct | ||
assert res.shape == (model_cfg['batch_size'], model_cfg['context_window'], embedder.token_embedder.embedding_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these aren't really testing the modelshell, they are testing the individual components...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which is fine but should be in seperate test scripts. i.e. should have test_embedder, test_core_model, test_lm_head as files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and these are just testing the forward pass. that is fine for the core model and lm head, but the embedder interface has quite a few methods (padding, truncating, inference vs forward, etc. that should be checked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My basic feedback is this:
- the
model_shell
tests should just test the functions of the model shell interface itself. in particular it should be 3 functions:
test_loglikelihood, test_inference, and test_forward.
This should be done with a MockEmbedder, MockTransformer, and MockHead that return (random) torch tensors of the appropriate shapes expected by the model - For building the embedder tests, it should make sure that the layer matches the interface defined in embedding_models.py i.e. that all the necessary functions are implemented.
Got it, I'll rework this PR. |
…ore_model, lm_head. In addition, also tests the forward, inference and loglikelihood methods of the model_shell and byte_model shell.
Hey @DylanASHillier , I've added pytest features for the methods of the model_shell (forward, inference, loglikelihood) and byte_model_shell (forward), as well as separated the pytests for the various types of the embedding_model, core_model and lm_heads. Let me know what needs to be mended or revised. Meanwhile, I will add in pytests for the trainer modules next. |
1 similar comment
please refresh and resubmit |
my bad |
(Work in progress)
This is a pull request to add pytest requirements for areas in our repo. The model_shell class object's forward function uses 3 different forward functions from the 3 different class objects that are passed to it - 1) an embedded class, 2) a core-model class, and 3) a model head class.
These functions perform overarching tasks that are 1) embedder, 2) core model, 3) model head.
Generally...
...the embedder is a class object that generally takes in a tensor (B, T) and returns a tensor (B, T, H).
...the core model is a class object that takes in a tensor (B, T, H) and returns a tensor (B, T, H).
...the model head is a class object that takes in a tensor (B, T, H) and returns a tensor (B, T, V).
This pytest will ensure that these outputs are correctly shaped and are not nan.