@@ -77,6 +77,91 @@ pub trait Map2 {
7777 }
7878}
7979
80+ pub trait Map2Alpha {
81+ const OP : & ' static str ;
82+ fn f < T : WithDType > (
83+ & self ,
84+ v1 : & [ T ] ,
85+ l1 : & Layout ,
86+ v2 : & [ T ] ,
87+ l2 : & Layout ,
88+ alpha : f64 ,
89+ ) -> Result < Vec < T > > ;
90+
91+ fn map ( & self , v1 : & C , l1 : & Layout , v2 : & C , l2 : & Layout , alpha : f64 ) -> Result < C > {
92+ match ( v1, v2) {
93+ ( C :: U8 ( v1) , C :: U8 ( v2) ) => Ok ( C :: U8 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
94+ ( C :: U32 ( v1) , C :: U32 ( v2) ) => Ok ( C :: U32 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
95+ ( C :: I16 ( v1) , C :: I16 ( v2) ) => Ok ( C :: I16 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
96+ ( C :: I32 ( v1) , C :: I32 ( v2) ) => Ok ( C :: I32 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
97+ ( C :: I64 ( v1) , C :: I64 ( v2) ) => Ok ( C :: I64 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
98+ ( C :: BF16 ( v1) , C :: BF16 ( v2) ) => Ok ( C :: BF16 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
99+ ( C :: F16 ( v1) , C :: F16 ( v2) ) => Ok ( C :: F16 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
100+ ( C :: F32 ( v1) , C :: F32 ( v2) ) => Ok ( C :: F32 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
101+ ( C :: F64 ( v1) , C :: F64 ( v2) ) => Ok ( C :: F64 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
102+ ( C :: F8E4M3 ( v1) , C :: F8E4M3 ( v2) ) => Ok ( C :: F8E4M3 ( self . f ( v1, l1, v2, l2, alpha) ?) ) ,
103+ _ => Err ( Error :: DTypeMismatchBinaryOp {
104+ lhs : v1. dtype ( ) ,
105+ rhs : v2. dtype ( ) ,
106+ op : Self :: OP ,
107+ }
108+ . bt ( ) ) ,
109+ }
110+ }
111+ }
112+
113+ pub trait Map3 {
114+ const OP : & ' static str ;
115+ fn f < T : WithDType > (
116+ & self ,
117+ v1 : & [ T ] ,
118+ l1 : & Layout ,
119+ v2 : & [ T ] ,
120+ l2 : & Layout ,
121+ v3 : & [ T ] ,
122+ l3 : & Layout ,
123+ ) -> Result < Vec < T > > ;
124+
125+ fn map ( & self , v1 : & C , l1 : & Layout , v2 : & C , l2 : & Layout , v3 : & C , l3 : & Layout ) -> Result < C > {
126+ match ( v1, v2, v3) {
127+ ( C :: U8 ( v1) , C :: U8 ( v2) , C :: U8 ( v3) ) => Ok ( C :: U8 ( self . f ( v1, l1, v2, l2, v3, l3) ?) ) ,
128+ ( C :: U32 ( v1) , C :: U32 ( v2) , C :: U32 ( v3) ) => {
129+ Ok ( C :: U32 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
130+ }
131+ ( C :: I16 ( v1) , C :: I16 ( v2) , C :: I16 ( v3) ) => {
132+ Ok ( C :: I16 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
133+ }
134+ ( C :: I32 ( v1) , C :: I32 ( v2) , C :: I32 ( v3) ) => {
135+ Ok ( C :: I32 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
136+ }
137+ ( C :: I64 ( v1) , C :: I64 ( v2) , C :: I64 ( v3) ) => {
138+ Ok ( C :: I64 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
139+ }
140+ ( C :: BF16 ( v1) , C :: BF16 ( v2) , C :: BF16 ( v3) ) => {
141+ Ok ( C :: BF16 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
142+ }
143+ ( C :: F16 ( v1) , C :: F16 ( v2) , C :: F16 ( v3) ) => {
144+ Ok ( C :: F16 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
145+ }
146+ ( C :: F32 ( v1) , C :: F32 ( v2) , C :: F32 ( v3) ) => {
147+ Ok ( C :: F32 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
148+ }
149+ ( C :: F64 ( v1) , C :: F64 ( v2) , C :: F64 ( v3) ) => {
150+ Ok ( C :: F64 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
151+ }
152+ ( C :: F8E4M3 ( v1) , C :: F8E4M3 ( v2) , C :: F8E4M3 ( v3) ) => {
153+ Ok ( C :: F8E4M3 ( self . f ( v1, l1, v2, l2, v3, l3) ?) )
154+ }
155+ _ => Err ( Error :: DTypeMismatchBinaryOp {
156+ lhs : v1. dtype ( ) ,
157+ rhs : v2. dtype ( ) ,
158+ op : Self :: OP ,
159+ }
160+ . bt ( ) ) ,
161+ }
162+ }
163+ }
164+
80165pub trait Map2InPlace {
81166 const OP : & ' static str ;
82167 fn f < T : WithDType > ( & self , v1 : & mut [ T ] , l1 : & Layout , v2 : & [ T ] , l2 : & Layout ) -> Result < ( ) > ;
0 commit comments