@@ -73,7 +73,7 @@ num_parts(a::AbstractPData) = length(a)
73
73
"""
74
74
get_backend(a::AbstractPData) -> AbstractBackend
75
75
76
- Get the back-end associated with `a`.
76
+ Get the back-end associated with `a`.
77
77
"""
78
78
get_backend (a:: AbstractPData ) = @abstractmethod
79
79
@@ -639,6 +639,16 @@ function lids_are_equal(a::AbstractIndexSet,b::AbstractIndexSet)
639
639
a. lid_to_gid == b. lid_to_gid
640
640
end
641
641
642
+ function find_lid_map (a:: AbstractIndexSet ,b:: AbstractIndexSet )
643
+ alid_to_blid = fill (Int32 (- 1 ),num_lids (a))
644
+ for blid in 1 : num_lids (b)
645
+ gid = b. lid_to_gid[blid]
646
+ alid = a. gid_to_lid[gid]
647
+ alid_to_blid[alid] = blid
648
+ end
649
+ alid_to_blid
650
+ end
651
+
642
652
# The given ids are assumed to be a sub-set of the lids
643
653
function touched_hids (a:: AbstractIndexSet ,gids:: AbstractVector{<:Integer} )
644
654
i = 0
@@ -828,7 +838,7 @@ function async_exchange!(
828
838
data_snd = allocate_snd_buffer (Tsnd,exchanger)
829
839
830
840
# Fill snd buffer
831
- t1 = map_parts (t0,values_snd,data_snd,exchanger. lids_snd) do t0,values_snd,data_snd,lids_snd
841
+ t1 = map_parts (t0,values_snd,data_snd,exchanger. lids_snd) do t0,values_snd,data_snd,lids_snd
832
842
@task begin
833
843
wait (schedule (t0))
834
844
for p in 1 : length (lids_snd. data)
@@ -848,7 +858,7 @@ function async_exchange!(
848
858
849
859
# Fill values_rcv from rcv buffer
850
860
# asynchronously
851
- t3 = map_parts (t2,values_rcv,data_rcv,exchanger. lids_rcv) do t2,values_rcv,data_rcv,lids_rcv
861
+ t3 = map_parts (t2,values_rcv,data_rcv,exchanger. lids_rcv) do t2,values_rcv,data_rcv,lids_rcv
852
862
@task begin
853
863
wait (schedule (t2))
854
864
for p in 1 : length (lids_rcv. data)
@@ -946,7 +956,7 @@ mutable struct PRange{A,B,C} <: AbstractUnitRange{Int}
946
956
exchanger:: Exchanger ,
947
957
gid_to_part:: Union{AbstractPData{<:AbstractArray{<:Integer}},Nothing} = nothing ,
948
958
ghost:: Bool = true )
949
-
959
+
950
960
A = typeof (partition)
951
961
B = typeof (exchanger)
952
962
C = typeof (gid_to_part)
@@ -1455,12 +1465,22 @@ function Base.similar(
1455
1465
end
1456
1466
1457
1467
function Base. copy! (a:: PVector ,b:: PVector )
1458
- map_parts (copy!,a. values,b. values)
1468
+ @check oids_are_equal (a. rows,b. rows)
1469
+ if a. rows. partition === b. rows. partition
1470
+ map_parts (copy!,a. values,b. values)
1471
+ else
1472
+ map_parts (copy!,a. owned_values,b. owned_values)
1473
+ end
1459
1474
a
1460
1475
end
1461
1476
1462
1477
function Base. copyto! (a:: PVector ,b:: PVector )
1463
- map_parts (copyto!,a. values,b. values)
1478
+ @check oids_are_equal (a. rows,b. rows)
1479
+ if a. rows. partition === b. rows. partition
1480
+ map_parts (copyto!,a. values,b. values)
1481
+ else
1482
+ map_parts (copyto!,a. owned_values,b. owned_values)
1483
+ end
1464
1484
a
1465
1485
end
1466
1486
@@ -1470,6 +1490,19 @@ function Base.copy(b::PVector)
1470
1490
a
1471
1491
end
1472
1492
1493
+ function LinearAlgebra. rmul! (a:: PVector ,v:: Number )
1494
+ map_parts (a. values) do l
1495
+ rmul! (l,v)
1496
+ end
1497
+ a
1498
+ end
1499
+
1500
+ function Base.:(== )(a:: PVector ,b:: PVector )
1501
+ length (a) == length (b) &&
1502
+ num_parts (a. values) == num_parts (b. values) &&
1503
+ reduce (& ,map_parts (== ,a. owned_values,b. owned_values),init= true )
1504
+ end
1505
+
1473
1506
struct DistributedBroadcasted{A,B,C}
1474
1507
owned_values:: A
1475
1508
ghost_values:: B
@@ -1556,6 +1589,101 @@ function LinearAlgebra.norm(a::PVector,p::Real=2)
1556
1589
reduce (+ ,contibs;init= zero (eltype (contibs)))^ (1 / p)
1557
1590
end
1558
1591
1592
+ # Distances.jl related (needed eg for non-linear solvers)
1593
+
1594
+ for M in Distances. metrics
1595
+ @eval begin
1596
+ function (dist:: $M )(a:: PVector ,b:: PVector )
1597
+ _eval_dist (dist,a,b,Distances. parameters (dist))
1598
+ end
1599
+ end
1600
+ end
1601
+
1602
+ function _eval_dist (d,a,b,:: Nothing )
1603
+ partials = map_parts (a. owned_values,b. owned_values) do a,b
1604
+ _eval_dist_local (d,a,b,nothing )
1605
+ end
1606
+ s = reduce (
1607
+ (i,j)-> Distances. eval_reduce (d,i,j),
1608
+ partials,
1609
+ init= Distances. eval_start (d, a, b))
1610
+ Distances. eval_end (d,s)
1611
+ end
1612
+
1613
+ Base. @propagate_inbounds function _eval_dist_local (d,a,b,:: Nothing )
1614
+ @boundscheck if length (a) != length (b)
1615
+ throw (DimensionMismatch (" first array has length $(length (a)) which does not match the length of the second, $(length (b)) ." ))
1616
+ end
1617
+ if length (a) == 0
1618
+ return zero (Distances. result_type (d, a, b))
1619
+ end
1620
+ @inbounds begin
1621
+ s = Distances. eval_start (d, a, b)
1622
+ if (IndexStyle (a, b) === IndexLinear () && eachindex (a) == eachindex (b)) || axes (a) == axes (b)
1623
+ @simd for I in eachindex (a, b)
1624
+ ai = a[I]
1625
+ bi = b[I]
1626
+ s = Distances. eval_reduce (d, s, Distances. eval_op (d, ai, bi))
1627
+ end
1628
+ else
1629
+ for (ai, bi) in zip (a, b)
1630
+ s = Distances. eval_reduce (d, s, Distances. eval_op (d, ai, bi))
1631
+ end
1632
+ end
1633
+ return s
1634
+ end
1635
+ end
1636
+
1637
+ function _eval_dist (d,a,b,p)
1638
+ @notimplemented
1639
+ end
1640
+
1641
+ function _eval_dist_local (d,a,b,p)
1642
+ @notimplemented
1643
+ end
1644
+
1645
+ function Base. any (f:: Function ,x:: PVector )
1646
+ partials = map_parts (x. owned_values) do o
1647
+ any (f,o)
1648
+ end
1649
+ reduce (| ,partials,init= false )
1650
+ end
1651
+
1652
+ function Base. all (f:: Function ,x:: PVector )
1653
+ partials = map_parts (x. owned_values) do o
1654
+ all (f,o)
1655
+ end
1656
+ reduce (& ,partials,init= true )
1657
+ end
1658
+
1659
+ function Base. maximum (x:: PVector )
1660
+ partials = map_parts (maximum,x. owned_values)
1661
+ reduce (max,partials,init= typemin (eltype (x)))
1662
+ end
1663
+
1664
+ function Base. maximum (f:: Function ,x:: PVector )
1665
+ partials = map_parts (x. owned_values) do o
1666
+ maximum (f,o)
1667
+ end
1668
+ reduce (max,partials,init= typemin (eltype (x)))
1669
+ end
1670
+
1671
+ function Base. minimum (x:: PVector )
1672
+ partials = map_parts (minimum,x. owned_values)
1673
+ reduce (min,partials,init= typemax (eltype (x)))
1674
+ end
1675
+
1676
+ function Base. minimum (f:: Function ,x:: PVector )
1677
+ partials = map_parts (x. owned_values) do o
1678
+ minimum (f,o)
1679
+ end
1680
+ reduce (min,partials,init= typemax (eltype (x)))
1681
+ end
1682
+
1683
+ function Base. findall (f:: Function ,x:: PVector )
1684
+ @notimplemented
1685
+ end
1686
+
1559
1687
function PVector {T} (
1560
1688
:: UndefInitializer ,
1561
1689
rows:: PRange ) where T
@@ -1678,8 +1806,46 @@ function LinearAlgebra.dot(a::PVector,b::PVector)
1678
1806
end
1679
1807
1680
1808
function local_view (a:: PVector ,rows:: PRange )
1681
- @notimplementedif a. rows != = rows
1682
- a. values
1809
+ if a. rows === rows
1810
+ a. values
1811
+ else
1812
+ map_parts (a. values,rows. partition,a. rows. partition) do values,rows,arows
1813
+ LocalView (values,(find_lid_map (rows,arows),))
1814
+ end
1815
+ end
1816
+ end
1817
+
1818
+ struct LocalView{T,N,A,B} <: AbstractArray{T,N}
1819
+ plids_to_value:: A
1820
+ d_to_lid_to_plid:: B
1821
+ local_size:: NTuple{N,Int}
1822
+ function LocalView (
1823
+ plids_to_value:: AbstractArray{T,N} ,d_to_lid_to_plid:: NTuple{N} ) where {T,N}
1824
+ A = typeof (plids_to_value)
1825
+ B = typeof (d_to_lid_to_plid)
1826
+ local_size = map (length,d_to_lid_to_plid)
1827
+ new {T,N,A,B} (plids_to_value,d_to_lid_to_plid,local_size)
1828
+ end
1829
+ end
1830
+
1831
+ Base. size (a:: LocalView ) = a. local_size
1832
+ Base. IndexStyle (:: Type{<:LocalView} ) = IndexCartesian ()
1833
+ function Base. getindex (a:: LocalView{T,N} ,lids:: Vararg{Integer,N} ) where {T,N}
1834
+ plids = map (_lid_to_plid,lids,a. d_to_lid_to_plid)
1835
+ if all (i-> i> 0 ,plids)
1836
+ a. plids_to_value[plids... ]
1837
+ else
1838
+ zero (T)
1839
+ end
1840
+ end
1841
+ function Base. setindex! (a:: LocalView{T,N} ,v,lids:: Vararg{Integer,N} ) where {T,N}
1842
+ plids = map (_lid_to_plid,lids,a. d_to_lid_to_plid)
1843
+ @check all (i-> i> 0 ,plids) " You are trying to set a value that is not stored in the local portion"
1844
+ a. plids_to_value[plids... ] = v
1845
+ end
1846
+ function _lid_to_plid (lid,lid_to_plid)
1847
+ plid = lid_to_plid[lid]
1848
+ plid
1683
1849
end
1684
1850
1685
1851
function global_view (a:: PVector ,rows:: PRange )
@@ -1772,6 +1938,13 @@ struct PSparseMatrix{T,A,B,C,D} <: AbstractMatrix{T}
1772
1938
end
1773
1939
end
1774
1940
1941
+ function LinearAlgebra. fillstored! (a:: PSparseMatrix ,v)
1942
+ map_parts (a. values) do values
1943
+ LinearAlgebra. fillstored! (values,v)
1944
+ end
1945
+ a
1946
+ end
1947
+
1775
1948
function Base. copy (a:: PSparseMatrix )
1776
1949
PSparseMatrix (
1777
1950
copy (a. values),
@@ -1916,9 +2089,16 @@ function LinearAlgebra.mul!(
1916
2089
end
1917
2090
1918
2091
function local_view (a:: PSparseMatrix ,rows:: PRange ,cols:: PRange )
1919
- @notimplementedif a. rows != = rows
1920
- @notimplementedif a. cols != = cols
1921
- a. values
2092
+ if a. rows === rows && a. cols === cols
2093
+ a. values
2094
+ else
2095
+ map_parts (
2096
+ a. values,rows. partition,cols. partition,a. rows. partition,a. cols. partition) do values,rows,cols,arows,acols
2097
+ rmap = find_lid_map (rows,arows)
2098
+ cmap = (cols === rows && acols === arows) ? rmap : find_lid_map (cols,acols)
2099
+ LocalView (values,(rmap,cmap))
2100
+ end
2101
+ end
1922
2102
end
1923
2103
1924
2104
function global_view (a:: PSparseMatrix ,rows:: PRange ,cols:: PRange )
@@ -1978,7 +2158,7 @@ function matrix_exchanger(values,row_exchanger,row_lids,col_lids)
1978
2158
gj_rcv = Table (gj_rcv_data,ptrs)
1979
2159
k_rcv, gi_rcv, gj_rcv
1980
2160
end
1981
-
2161
+
1982
2162
k_rcv, gi_rcv, gj_rcv = map_parts (setup_rcv,part,parts_rcv,row_lids,col_lids,values)
1983
2163
1984
2164
gi_snd = exchange (gi_rcv,parts_snd,parts_rcv)
@@ -2087,9 +2267,9 @@ function async_assemble!(
2087
2267
gi_rcv = Table (gi_rcv_data,ptrs)
2088
2268
gj_rcv = Table (gj_rcv_data,ptrs)
2089
2269
v_rcv = Table (v_rcv_data,ptrs)
2090
- gi_rcv, gj_rcv, v_rcv
2270
+ gi_rcv, gj_rcv, v_rcv
2091
2271
end
2092
-
2272
+
2093
2273
gi_rcv, gj_rcv, v_rcv = map_parts (setup_rcv,part,parts_rcv,rows. partition,coo_values)
2094
2274
2095
2275
gi_snd, t1 = async_exchange (gi_rcv,parts_snd,parts_rcv)
@@ -2186,9 +2366,9 @@ function async_exchange!(
2186
2366
gi_snd = Table (gi_snd_data,ptrs)
2187
2367
gj_snd = Table (gj_snd_data,ptrs)
2188
2368
v_snd = Table (v_snd_data,ptrs)
2189
- gi_snd, gj_snd, v_snd
2369
+ gi_snd, gj_snd, v_snd
2190
2370
end
2191
-
2371
+
2192
2372
gi_snd, gj_snd, v_snd = map_parts (
2193
2373
setup_snd,part,parts_snd,lids_snd,rows. partition,coo_values)
2194
2374
@@ -2365,4 +2545,3 @@ function IterativeSolvers.zerox(A::PSparseMatrix,b::PVector)
2365
2545
fill! (x, zero (T))
2366
2546
return x
2367
2547
end
2368
-
0 commit comments