22import Base. Broadcast: BroadcastStyle
33using Base. Broadcast: AbstractArrayStyle, Broadcasted, DefaultArrayStyle
44
5+ # combine_sizes moved from StaticArrays after https://github.com/JuliaArrays/StaticArrays.jl/pull/1008
6+ # see also https://github.com/JuliaArrays/HybridArrays.jl/issues/50
7+ @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
8+ sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
9+ ndims = 0
10+ for i = 1 : length (sizes)
11+ ndims = max (ndims, length (sizes[i]))
12+ end
13+ newsize = StaticArrays. StaticDimension[Dynamic () for _ = 1 : ndims]
14+ for i = 1 : length (sizes)
15+ s = sizes[i]
16+ for j = 1 : length (s)
17+ if s[j] isa Dynamic
18+ continue
19+ elseif newsize[j] isa Dynamic || newsize[j] == 1
20+ newsize[j] = s[j]
21+ elseif newsize[j] ≠ s[j] && s[j] ≠ 1
22+ throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
23+ end
24+ end
25+ end
26+ quote
27+ Base. @_inline_meta
28+ Size ($ (tuple (newsize... )))
29+ end
30+ end
31+
32+ function broadcasted_index (oldsize, newindex)
33+ index = ones (Int, length (oldsize))
34+ for i = 1 : length (oldsize)
35+ if oldsize[i] != 1
36+ index[i] = newindex[i]
37+ end
38+ end
39+ return LinearIndices (oldsize)[index... ]
40+ end
41+
42+ scalar_getindex (x) = x
43+ scalar_getindex (x:: Ref ) = x[]
44+
545# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
646# A constructor that changes the style parameter N (array dimension) is also required
747struct HybridArrayStyle{N} <: AbstractArrayStyle{N} end
@@ -22,7 +62,7 @@ BroadcastStyle(::HybridArray{M}, ::StaticArrays.StaticArrayStyle{0}) where {M} =
2262@inline function Base. copy (B:: Broadcasted{HybridArrayStyle{M}} ) where M
2363 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
2464 argsizes = StaticArrays. broadcast_sizes (as... )
25- destsize = StaticArrays . combine_sizes (argsizes)
65+ destsize = combine_sizes (argsizes)
2666 if Length (destsize) === Length {StaticArrays.Dynamic()} ()
2767 # destination dimension cannot be determined statically; fall back to generic broadcast
2868 return HybridArray {StaticArrays.size_tuple(destsize)} (copy (convert (Broadcasted{DefaultArrayStyle{M}}, B)))
3575@inline function _copyto! (dest, B:: Broadcasted{HybridArrayStyle{M}} ) where M
3676 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
3777 argsizes = StaticArrays. broadcast_sizes (as... )
38- destsize = StaticArrays . combine_sizes ((Size (dest), argsizes... ))
78+ destsize = combine_sizes ((Size (dest), argsizes... ))
3979 if Length (destsize) === Length {StaticArrays.Dynamic()} ()
4080 # destination dimension cannot be determined statically; fall back to generic broadcast!
4181 return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
68108
69109 make_expr (i) = begin
70110 if ! (a[i] <: AbstractArray )
71- return :(StaticArrays . scalar_getindex (a[$ i]))
111+ return :(scalar_getindex (a[$ i]))
72112 elseif hasdynamic (Tuple{sizes[i]. .. })
73113 return :(a[$ i][$ (current_ind... )])
74114 else
75- :(a[$ i][$ (StaticArrays . broadcasted_index (sizes[i], current_ind))])
115+ :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))])
76116 end
77117 end
78118
0 commit comments