@@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
11131113 return {{cosh (inputs[0 ], stream ())}, axes};
11141114}
11151115
1116- std::vector<array> CustomVJP ::vjp (
1116+ std::vector<array> CustomTransforms ::vjp (
11171117 const std::vector<array>& primals,
11181118 const std::vector<array>& cotangents,
11191119 const std::vector<int >& argnums,
11201120 const std::vector<array>& outputs) {
1121- std::vector<array> inputs (primals.begin (), primals.end () - outputs.size ());
1121+ // Extract the inputs to the VJP function
1122+ std::vector<array> inputs (primals.begin (), primals.end () - num_outputs_);
1123+
1124+ // Compute all the vjps
11221125 auto all_vjps = vjp_fun_ (inputs, cotangents, outputs);
11231126 for (const auto & cot : cotangents) {
11241127 all_vjps.emplace_back (cot);
11251128 }
11261129
1130+ // Select the vjps requested
11271131 std::vector<array> vjps;
11281132 vjps.reserve (argnums.size ());
11291133 for (auto arg : argnums) {
@@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
11331137 return vjps;
11341138}
11351139
1140+ std::vector<array> CustomTransforms::jvp (
1141+ const std::vector<array>& primals,
1142+ const std::vector<array>& tangents,
1143+ const std::vector<int >& argnums) {
1144+ // Extract the inputs to the JVP function
1145+ std::vector<array> inputs (primals.begin (), primals.end () - num_outputs_);
1146+
1147+ // Compute the jvps
1148+ return jvp_fun_ (inputs, tangents, argnums);
1149+ }
1150+
1151+ std::pair<std::vector<array>, std::vector<int >> CustomTransforms::vmap (
1152+ const std::vector<array>& inputs_,
1153+ const std::vector<int >& axes_) {
1154+ // Extract the inputs to the vmap function
1155+ std::vector<array> inputs (inputs_.begin (), inputs_.end () - num_outputs_);
1156+ std::vector<int > axes (axes_.begin (), axes_.end () - num_outputs_);
1157+ return vmap_fun_ (inputs, axes);
1158+ }
1159+
11361160std::vector<array> Depends::vjp (
11371161 const std::vector<array>& primals,
11381162 const std::vector<array>& cotangents,
0 commit comments