File tree 1 file changed +25
-0
lines changed
1 file changed +25
-0
lines changed Original file line number Diff line number Diff line change @@ -105,6 +105,15 @@ where
105
105
loop (i : USize) (x : β) : m β :=
106
106
if i = 0 then pure x else f (a.ugetImpl (i-1 ) lcProof) x >>= loop (i-1 )
107
107
108
+ @[specialize]
109
+ private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
110
+ let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar)
111
+ unsafeCast <| a.data.mapIdx f
112
+
113
+ @[specialize]
114
+ private unsafe def amapImpl (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β :=
115
+ unsafeCast <| a.mapImpl f
116
+
108
117
end unsafe_implementation
109
118
110
119
attribute [implemented_by mkImpl] DArray.mk
@@ -235,3 +244,19 @@ where
235
244
236
245
instance (α : Fin n → Type _) [Monad m] : ForIn m (DArray n α) (Sigma α) where
237
246
forIn := forIn
247
+
248
+ /--
249
+ Applies `f : {i : Fin n} → α i → β i` to each element of a `DArray n α`,
250
+ returns the dependent array of results.
251
+ -/
252
+ @[implemented_by mapImpl]
253
+ protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
254
+ mk fun i => f (a.get i)
255
+
256
+ /--
257
+ Applies `f : {i : Fin n} → α i → β` to each element of a `DArray n α`,
258
+ returns the (non-dependent) array of results.
259
+ -/
260
+ @[implemented_by amapImpl]
261
+ def amap (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β :=
262
+ Array.ofFn fun i => f (a.get i)
You can’t perform that action at this time.
0 commit comments