Replies: 1 comment
-
Hey @erfanzar, todo do this in JAX the easiest thing is just to create two different loss functions and calculate the gradients you need. While it may look wasteful, you will get good performance because XLA ( |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
i want to Calculating Two Gradient for two different loss in same functions
i have a ppo function im trying to implement
RLHF
forjax/flax
models in my library EasyDel for the learn function in my trainer usually in pytorch you just calculate the loss for two model and do.step' and
.zero_grad` but when i calculate the loss and policy loss i have no idea how to do that i search for that for a long time but all of the current implementation was not good enough (they were calculating or computing some algorithms twine) so i implemented my method like thisthis is my
trainState
classand this is my forward or step function
you can also check the code here
Beta Was this translation helpful? Give feedback.
All reactions