Expand tests to cover supported dtypes#29
Conversation
src/metrax/metrics_test.py
Outdated
| ).astype(np.float32) | ||
| OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)).astype(np.float32) | ||
| OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)) | ||
| OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(np.float16) |
There was a problem hiding this comment.
should we use jnp as much as we can? WDYT?
There was a problem hiding this comment.
SG, done. I think they are effectively the same in the end (I think JAX auto-promotes any np dtypes and arrays to jnp), but good to have consistency.
|
metrics_test.py was modified heavily in order to migrate from sklearn to keras. Do you want to sync your workspace and add the new parameters? |
51e56d1 to
9d64341
Compare
|
This pull request sets up GitHub code scanning for this repository. Once the scans have completed and the checks have passed, the analysis results for this pull request branch will appear on this overview. Once you merge this pull request, the 'Security' tab will show more code scanning analysis results (for example, for the default branch). Depending on your configuration and choice of analysis tool, future pull requests will be annotated with code scanning analysis results. For more information about GitHub code scanning, check out the documentation. |
9d64341 to
4a84234
Compare
I forgot to sync my fork 🤡 - redid the PR |
Includes float32, float16, and bfloat16
Fixes #22