Skip to content

Commit 4f6e401

Browse files
committed
Add support for device vectors through a workaround moving them back to host to execute transfer before sending back to device.
1 parent bb11d08 commit 4f6e401

File tree

1 file changed

+137
-1
lines changed

1 file changed

+137
-1
lines changed

include/deal.II/multigrid/mg_transfer_matrix_free.h

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,10 @@ MGTransferMatrixFree<dim, Number, MemorySpace>::interpolate_to_mg(
16851685
{
16861686
assert_dof_handler(dof_handler);
16871687

1688-
const unsigned int min_level = dst.min_level();
1688+
const unsigned int min_level = transfer.min_level();
16891689
const unsigned int max_level = transfer.max_level();
16901690

1691+
AssertDimension(min_level, dst.min_level());
16911692
AssertDimension(max_level, dst.max_level());
16921693

16931694
for (unsigned int level = min_level; level <= max_level; ++level)
@@ -1745,6 +1746,141 @@ MGTransferMatrixFree<dim, Number, MemorySpace>::interpolate_to_mg(
17451746
}
17461747

17471748

1749+
/**
1750+
* Template override which allows for device vectors to be properly supported.
1751+
* Currently works by transferring device vectors back to the host and
1752+
* performing the transfer operation on the host. Eventually this should be
1753+
* replaced by all operations occurring on the device.
1754+
*/
1755+
1756+
template <int dim, typename Number>
1757+
class MGTransferMatrixFree<dim, Number, MemorySpace::Default>
1758+
: public MGTransferBase<
1759+
LinearAlgebra::distributed::Vector<Number, MemorySpace::Default>>
1760+
{
1761+
public:
1762+
using VectorType =
1763+
LinearAlgebra::distributed::Vector<Number, MemorySpace::Default>;
1764+
using VectorTypeHost =
1765+
LinearAlgebra::distributed::Vector<Number, dealii::MemorySpace::Host>;
1766+
1767+
MGTransferMatrixFree(
1768+
const MGLevelObject<MGTwoLevelTransfer<dim, VectorTypeHost>> &mg_transfers,
1769+
const std::function<void(const unsigned int, VectorTypeHost &)>
1770+
&initialize_dof_vector)
1771+
: transfer(mg_transfers, initialize_dof_vector)
1772+
{}
1773+
1774+
template <typename Number2>
1775+
void
1776+
copy_to_mg(
1777+
const DoFHandler<dim> &dof_handler,
1778+
MGLevelObject<VectorType> &dst,
1779+
const LinearAlgebra::distributed::Vector<Number2, MemorySpace::Default>
1780+
&src) const
1781+
{
1782+
MGLevelObject<VectorTypeHost> dst_host(dst.min_level(), dst.max_level());
1783+
LinearAlgebra::distributed::Vector<Number2, dealii::MemorySpace::Host>
1784+
src_host;
1785+
1786+
copy_to_host(src_host, src);
1787+
for (unsigned int l = dst.min_level(); l < dst.max_level(); ++l)
1788+
copy_to_host(dst_host[l], dst[l]);
1789+
1790+
transfer.copy_to_mg(dof_handler, dst_host, src_host);
1791+
1792+
for (unsigned int l = dst.min_level(); l <= dst.max_level(); ++l)
1793+
copy_from_host(dst[l], dst_host[l]);
1794+
}
1795+
1796+
template <typename Number2>
1797+
void
1798+
copy_from_mg(
1799+
const DoFHandler<dim> &dof_handler,
1800+
LinearAlgebra::distributed::Vector<Number2, MemorySpace::Default> &dst,
1801+
const MGLevelObject<VectorType> &src) const
1802+
{
1803+
LinearAlgebra::distributed::Vector<Number2, dealii::MemorySpace::Host>
1804+
dst_host;
1805+
MGLevelObject<VectorTypeHost> src_host(src.min_level(), src.max_level());
1806+
1807+
copy_to_host(dst_host, dst);
1808+
for (unsigned int l = src.min_level(); l <= src.max_level(); ++l)
1809+
copy_to_host(src_host[l], src[l]);
1810+
1811+
transfer.copy_from_mg(dof_handler, dst_host, src_host);
1812+
1813+
copy_from_host(dst, dst_host);
1814+
}
1815+
1816+
void
1817+
prolongate(const unsigned int to_level,
1818+
VectorType &dst,
1819+
const VectorType &src) const override
1820+
{
1821+
VectorTypeHost dst_host;
1822+
VectorTypeHost src_host;
1823+
1824+
copy_to_host(dst_host, dst);
1825+
copy_to_host(src_host, src);
1826+
1827+
transfer.prolongate(to_level, dst_host, src_host);
1828+
1829+
copy_from_host(dst, dst_host);
1830+
}
1831+
1832+
void
1833+
restrict_and_add(const unsigned int from_level,
1834+
VectorType &dst,
1835+
const VectorType &src) const override
1836+
{
1837+
VectorTypeHost dst_host;
1838+
VectorTypeHost src_host;
1839+
1840+
copy_to_host(dst_host, dst);
1841+
copy_to_host(src_host, src);
1842+
1843+
transfer.restrict_and_add(from_level, dst_host, src_host);
1844+
1845+
copy_from_host(dst, dst_host);
1846+
}
1847+
1848+
private:
1849+
const MGTransferMatrixFree<dim, Number, dealii::MemorySpace::Host> transfer;
1850+
1851+
template <typename Number2>
1852+
void
1853+
copy_to_host(
1854+
LinearAlgebra::distributed::Vector<Number2, dealii::MemorySpace::Host> &dst,
1855+
const LinearAlgebra::distributed::Vector<Number2, MemorySpace::Default>
1856+
&src) const
1857+
{
1858+
LinearAlgebra::ReadWriteVector<Number2> rw_vector(
1859+
src.get_partitioner()->locally_owned_range());
1860+
rw_vector.import_elements(src, VectorOperation::insert);
1861+
1862+
dst.reinit(src.get_partitioner());
1863+
dst.import_elements(rw_vector, VectorOperation::insert);
1864+
}
1865+
1866+
template <typename Number2>
1867+
void
1868+
copy_from_host(
1869+
LinearAlgebra::distributed::Vector<Number2, MemorySpace::Default> &dst,
1870+
const LinearAlgebra::distributed::Vector<Number2, dealii::MemorySpace::Host>
1871+
&src) const
1872+
{
1873+
LinearAlgebra::ReadWriteVector<Number2> rw_vector(
1874+
src.get_partitioner()->locally_owned_range());
1875+
rw_vector.import_elements(src, VectorOperation::insert);
1876+
1877+
if (dst.size() == 0)
1878+
dst.reinit(src.get_partitioner());
1879+
dst.import_elements(rw_vector, VectorOperation::insert);
1880+
}
1881+
};
1882+
1883+
17481884

17491885
template <int dim, typename Number, typename TransferType>
17501886
template <typename BlockVectorType2>

0 commit comments

Comments
 (0)