Skip to content

Commit 6d9236a

Browse files
committed
feat: add map and amap
1 parent bf6f528 commit 6d9236a

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

Batteries/Data/DArray/Basic.lean

+25
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ where
105105
loop (i : USize) (x : β) : m β :=
106106
if i = 0 then pure x else f (a.ugetImpl (i-1) lcProof) x >>= loop (i-1)
107107

108+
@[specialize]
109+
private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
110+
let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar)
111+
unsafeCast <| a.data.mapIdx f
112+
113+
@[specialize]
114+
private unsafe def amapImpl (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β :=
115+
unsafeCast <| a.mapImpl f
116+
108117
end unsafe_implementation
109118

110119
attribute [implemented_by mkImpl] DArray.mk
@@ -235,3 +244,19 @@ where
235244

236245
instance (α : Fin n → Type _) [Monad m] : ForIn m (DArray n α) (Sigma α) where
237246
forIn := forIn
247+
248+
/--
249+
Applies `f : {i : Fin n} → α i → β i` to each element of a `DArray n α`,
250+
returns the dependent array of results.
251+
-/
252+
@[implemented_by mapImpl]
253+
protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
254+
mk fun i => f (a.get i)
255+
256+
/--
257+
Applies `f : {i : Fin n} → α i → β` to each element of a `DArray n α`,
258+
returns the (non-dependent) array of results.
259+
-/
260+
@[implemented_by amapImpl]
261+
def amap (f : {i : Fin n} → α i → β) (a : DArray n α) : Array β :=
262+
Array.ofFn fun i => f (a.get i)

0 commit comments

Comments
 (0)