2222
2323
2424
25+ namespace tnf = torch::nn::functional;
26+
27+
2528#define def_bridge_simple (Name ) \
2629 extern " C" bridge_tensor_t Name (bridge_tensor_t input) { \
2730 auto t_input = bridge_to_torch (input); \
28- auto t_output = torch::Name (t_input); \
31+ auto t_output = torch::nn::functional:: Name (t_input); \
2932 return torch_to_bridge (t_output); \
3033 }
3134
32-
33-
3435// Globals
3536
3637
@@ -591,6 +592,250 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
591592// }
592593// }
593594
594- // cap.release();
595- // cv::destroyAllWindows();
596- // }
595+
596+ // Simple activation function defs
597+
598+ def_bridge_simple (gelu);
599+
600+ def_bridge_simple (logsigmoid);
601+
602+ def_bridge_simple (mish);
603+
604+ def_bridge_simple (relu);
605+
606+ def_bridge_simple (relu6);
607+
608+ def_bridge_simple (selu);
609+
610+ def_bridge_simple (silu);
611+
612+ def_bridge_simple (softsign);
613+
614+ def_bridge_simple (tanhshrink);
615+
616+
617+ // More complex activation functions with scary parameters
618+
619+ extern " C" bridge_tensor_t rrelu (
620+ bridge_tensor_t input,
621+ float lower,
622+ float upper,
623+ bool training
624+ ) {
625+ auto t_input = bridge_to_torch (input);
626+ auto t_output = tnf::rrelu (t_input,
627+ tnf::RReLUFuncOptions ()
628+ .lower (lower)
629+ .upper (upper)
630+ .training (training));
631+
632+ return torch_to_bridge (t_output);
633+ }
634+
635+
636+ extern " C" bridge_tensor_t hardshrink (
637+ bridge_tensor_t input,
638+ float lambda
639+ ) {
640+ auto t_input = bridge_to_torch (input);
641+ auto t_output = tnf::hardshrink (t_input,
642+ tnf::HardshrinkFuncOptions ()
643+ .lambda (lambda));
644+
645+ return torch_to_bridge (t_output);
646+ }
647+
648+
649+ extern " C" bridge_tensor_t hardtanh (
650+ bridge_tensor_t input,
651+ float min_val,
652+ float max_val
653+ ) {
654+ auto t_input = bridge_to_torch (input);
655+ auto t_output = tnf::hardtanh (t_input,
656+ tnf::HardtanhFuncOptions ()
657+ .min_val (min_val)
658+ .max_val (max_val));
659+
660+ return torch_to_bridge (t_output);
661+ }
662+
663+
664+ extern " C" bridge_tensor_t elu (
665+ bridge_tensor_t input,
666+ float alpha
667+ ) {
668+ auto t_input = bridge_to_torch (input);
669+ auto t_output = tnf::elu (t_input,
670+ tnf::ELUFuncOptions ()
671+ .alpha (alpha));
672+
673+ return torch_to_bridge (t_output);
674+ }
675+
676+
677+ extern " C" bridge_tensor_t softplus (
678+ bridge_tensor_t input,
679+ float beta,
680+ float threshold
681+ ) {
682+ auto t_input = bridge_to_torch (input);
683+ auto t_output = tnf::softplus (t_input,
684+ tnf::SoftplusFuncOptions ()
685+ .beta (beta)
686+ .threshold (threshold));
687+
688+ return torch_to_bridge (t_output)
689+ }
690+
691+
692+ extern " C" bridge_tensor_t threshold (
693+ bridge_tensor_t input,
694+ float threshold,
695+ float value
696+ ) {
697+ auto t_input = bridge_to_torch (input);
698+ auto t_output = tnf::threshold (t_input,
699+ tnf::ThresholdFuncOptions ()
700+ .threshold (threshold)
701+ .value (value));
702+
703+ return torch_to_bridge (t_output);
704+ }
705+
706+
707+ extern " C" bridge_tensor_t celu (
708+ bridge_tensor_t input,
709+ float alpha
710+ ) {
711+ auto t_input = bridge_to_torch (input);
712+ auto t_output = tnf::celu (t_input,
713+ tnf::CELUFuncOptions ()
714+ .alpha (alpha));
715+
716+ return t_output;
717+ }
718+
719+
720+ extern " C" bridge_tensor_t leaky_relu (
721+ bridge_tensor_t input,
722+ float negative_slope
723+ ) {
724+ auto t_input = bridge_to_torch (input);
725+ auto t_output = tnf::leaky_relu (t_input,
726+ tnf::LeakyReLUFuncOptions ()
727+ .negative_slope (negative_slope));
728+
729+ return torch_to_bridge (t_output);
730+ }
731+
732+
733+ extern " C" bridge_tensor_t softshrink (
734+ bridge_tensor_t input,
735+ float lambda
736+ ) {
737+ auto t_input = bridge_to_torch (input);
738+ auto t_output = tnf::softshrink (t_input,
739+ tnf::SoftshrinkFuncOptions (lambda));
740+
741+ return torch_to_bridge (t_output);
742+ }
743+
744+
745+ extern " C" bridge_tensor_t softmax (
746+ bridge_tensor_t input,
747+ std::int64_t dim
748+ ) {
749+ auto t_input = bridge_to_torch (input);
750+ auto t_output = tnf::softmax (t_input,
751+ tnf::SoftmaxFuncOptions (dim));
752+
753+ return torch_to_bridge (t_output);
754+ }
755+
756+
757+ extern " C" bridge_tensor_t softmin (
758+ bridge_tensor_t input,
759+ std::int64_t dim
760+ ) {
761+ auto t_input = bridge_to_torch (input);
762+ auto t_output = tnf::softmin (t_input,
763+ tnf::SoftminFuncOptions (dim));
764+
765+ return torch_to_bridge (t_output);
766+ }
767+
768+
769+ extern " C" bridge_tensor_t dropout (
770+ bridge_tensor_t input,
771+ double p,
772+ bool training
773+ ) {
774+ auto t_input = bridge_to_torch (input);
775+ auto t_output = tnf::dropout (
776+ tnf::DropoutFuncOptions ()
777+ .p (p)
778+ .training (training));
779+
780+ return torch_to_bridge (t_output);
781+ }
782+
783+
784+ extern " C" bridge_tensor_t alpha_dropout (
785+ bridge_tensor_t input,
786+ double p,
787+ bool training
788+ ) {
789+ auto t_input = bridge_to_torch (input);
790+ auto t_output = tnf::alpha_dropout (
791+ tnf::AlphaDropoutFuncOptions ()
792+ .p (p)
793+ .training (training));
794+
795+ return torch_to_bridge (t_output);
796+ }
797+
798+
799+ extern " C" bridge_tensor_t feature_alpha_dropout (
800+ bridge_tensor_t input,
801+ double p,
802+ bool training
803+ ) {
804+ auto t_input = bridge_to_torch (input);
805+ auto t_output = tnf::feature_alpha_dropout (
806+ tnf::FeatureAlphaDropoutFuncOptions ()
807+ .p (p)
808+ .training (training));
809+
810+ return torch_to_bridge (t_output);
811+ }
812+
813+
814+ extern " C" bridge_tensor_t dropout2d (
815+ bridge_tensor_t input,
816+ double p,
817+ bool training
818+ ) {
819+ auto t_input = bridge_to_torch (input);
820+ auto t_output = tnf::dropout2d (
821+ tnf::Dropout2dFuncOptions ()
822+ .p (p)
823+ .training (training));
824+
825+ return torch_to_bridge (t_output);
826+ }
827+
828+
829+ extern " C" bridge_tensor_t dropout3d (
830+ bridge_tensor_t input,
831+ double p,
832+ bool training
833+ ) {
834+ auto t_input = bridge_to_torch (input);
835+ auto t_output = tnf::dropout3d (
836+ tnf::Dropout3dFuncOptions ()
837+ .p (p)
838+ .training (training));
839+
840+ return torch_to_bridge (t_output);
841+ }
0 commit comments