-
Notifications
You must be signed in to change notification settings - Fork 180
[WIP] Add RandAugment #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v1.1.0
Are you sure you want to change the base?
Conversation
| @njit(parallel=True, fastmath=True, inline='always') | ||
| def equalize(source, scratch, destination): | ||
| for i in prange(source.shape[-1]): | ||
| scratch[i] = np.bincount(source[..., i].flatten(), minlength=256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunate that np.bincount doesn't have an out argument...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A numba version should be pretty fast and relatively easy to implement no ? (and might even be faster since it would skip the first pass of bincount that checks the min and max values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good idea. I'll try to add that in the near future.
|
I see that you are using |
|
This is moving very fast 🚅 I was hoping to release v1.0.0 by the end of the week. Would you like it to be part of the release ? If yes could you change the target branch to |
|
Sure, that sounds good. I'll try to finish a working demo soon. |
|
This still needs some testing, but it looks promising. In a brief experiment on CIFAR-10, this RandAugment implementation added +1% test accuracy and cost about 0.1s/epoch. The first epoch takes substantially longer (assuming for the extra memory allocation), adding about 20s. |
|
I did install this and added it to a training pipeline I'm currently using. Got this error: I then disabled jit by setting the env variable NUMBA_DISABLE_JIT=1 and I get this more helpful error: Seems like this is because I didn't call it with size set to the image size. I'll try that now. However in the imagenet training example the image size is scaled over the epochs. How would this work then? |
|
Thanks for pointing this out. I'll look into automatically adapting the image size. Did it work after fixing the image size in your example? |
|
Yes it did work and added only a negligible slow down! Good work! |
|
Great, thanks! Hopefully your experiments work out. It looks like the |
|
This is awesome! When will it be merged into master / a release? |
Thanks! I'm not sure -- the plan was to merge it with release v1.0.0 (#160), but as far as I know, development on that release has slowed down for the time being. |
|
Hi @ashertrockman ! It seems I lost track of this PR a while ago - do you think it's feasible to merge into v1.1.0? |
Yeah, I think it should be fine to merge. |
| state_allocation = operation.declare_shared_memory(state) | ||
|
|
||
| if next_state.device.type != 'cuda' and isinstance(operation, | ||
| if next_state.device != 'cuda' and isinstance(operation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as of v1.0 the device will be a torch.device in which case we would want next_state.device.type?
|
Great work on this @ashertrockman! I have had a successful training run with this fork (RandAugment ) with ImageNet-1k val acc > 80 with a ViT-B/16 model. I think it would be valuable to merge this (and other similar augments like Colorjitter, Grayscale, 3Aug etc) because these are essential for any ViT runs. |
Glad to hear! |
By the way, if you're training ViTs, allow me to shamelessly promote my research: https://arxiv.org/abs/2305.09828 |
Here's a draft PR to make visible our efforts to add a fast implementation of RandAugment [1] to ffcv.
We currently have committed the following transforms:
Ideally, it would be nice to have tests to ensure that our transforms are similar to some baseline (I've currently chosen PyTorch's
torchvision.transforms.functionalas this baseline).Now that the transforms have been implemented, there's a few more things:
njittoCompiler.compilenp.bincountreplacement[1] https://arxiv.org/abs/1909.13719