File tree 1 file changed +18
-0
lines changed
1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -78,6 +78,14 @@ private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSu
78
78
private unsafe def copyImpl (a : DArray n α) : DArray n α :=
79
79
unsafeCast <| a.data.extract 0 n
80
80
81
+ private unsafe def toArrayImpl (a : DArray n fun _ => α) : Array α :=
82
+ unsafeCast a
83
+
84
+ @[specialize]
85
+ private unsafe def mapImpl (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
86
+ let f := fun i x => (unsafeCast (f (i:=i.cast lcProof) (unsafeCast x)) : NonScalar)
87
+ unsafeCast <| a.data.mapIdx f
88
+
81
89
private unsafe def foldlMImpl [Monad m] (a : DArray n α) (f : β → {i : Fin n} → α i → m β)
82
90
(init : β) : m β :=
83
91
if n < USize.size then
@@ -185,6 +193,16 @@ protected def push (a : DArray n α) (v : β) :
185
193
protected def pop (a : DArray (n+1 ) α) : DArray n fun i => α i.castSucc :=
186
194
mk fun i => a.get i.castSucc
187
195
196
+ /-- Cast a dependent array with constant types to an array. `O(1)` if exclusive else `O(n)`. -/
197
+ @[implemented_by toArrayImpl]
198
+ protected def toArray (a : DArray n fun _ => α) : Array α :=
199
+ .ofFn fun i => a.get i
200
+
201
+ /-- Applies `f` to each element of a dependent array, returns the array of results. -/
202
+ @[implemented_by mapImpl]
203
+ protected def map (f : {i : Fin n} → α i → β i) (a : DArray n α) : DArray n β :=
204
+ mk fun i => f (a.get i)
205
+
188
206
/--
189
207
Folds a monadic function over a `DArray` from left to right:
190
208
```
You can’t perform that action at this time.
0 commit comments