Skip to content

Commit 892c174

Browse files
committed
Refactor matrix multiply for compact code style
1 parent 01c3364 commit 892c174

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

IntroductionToRobotics/matrix_multiply.inl

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
template<typename T, size_t M, size_t N, size_t P>
2-
requires is_matrix_multiplicable_type<T>
3-
INLINE constexpr auto matrix_multiply(
4-
const std::array<std::array<T, N>, M>& A, const std::array<std::array<T, P>, N>& B
5-
) -> std::array<std::array<T, P>, M>
6-
{
71
/**
82
* @file matrix_multiply.inl
93
* @author Xuhua Huang
@@ -20,15 +14,18 @@ INLINE constexpr auto matrix_multiply(
2014
#define MATRIX_MULTIPLY_INL
2115

2216
namespace robotics {
17+
18+
template <typename T, size_t M, size_t N, size_t P>
19+
requires is_matrix_multipliable_type<T>
20+
constexpr inline auto
21+
matrix_multiply(const std::array<std::array<T, N>, M>& A, const std::array<std::array<T, P>, N>& B)
22+
-> std::array<std::array<T, P>, M> {
2323
std::array<std::array<T, P>, M> result{};
2424

25-
for (size_t i = 0; i < M; ++i)
26-
{
27-
for (size_t j = 0; j < P; ++j)
28-
{
25+
for (size_t i = 0; i < M; ++i) {
26+
for (size_t j = 0; j < P; ++j) {
2927
result[i][j] = 0;
30-
for (size_t k = 0; k < N; ++k)
31-
{
28+
for (size_t k = 0; k < N; ++k) {
3229
result[i][j] += A[i][k] * B[k][j];
3330
}
3431
}
@@ -37,45 +34,46 @@ namespace robotics {
3734
return result;
3835
}
3936

40-
template<typename T, std::size_t N>
41-
requires std::is_arithmetic<T>::value
42-
INLINE std::ostream& operator<<(std::ostream& os, const std::array<T, N>& A)
43-
{
37+
/// @brief Overloaded stream insertion operator for std::array.
38+
/// @tparam T The type of the elements in the std::array.
39+
/// @tparam N The size of the std::array.
40+
/// @param os The output stream to write to.
41+
/// @param A The std::array to be printed.
42+
/// @return std::ostream& The output stream after writing the std::array.
43+
/// @note This function allows the elements of a std::array to be printed to an output stream.
44+
/// The elements are enclosed in square brackets and separated by commas.
45+
template <typename T, std::size_t N>
46+
requires std::is_arithmetic_v<T>
47+
inline std::ostream& operator<<(std::ostream& os, const std::array<T, N>& A) {
4448
os << "[";
45-
for (std::size_t i = 0; i < A.size(); ++i)
46-
{
47-
std::cout << A[i];
48-
if (i < A.size() - 1) [[likely]]
49-
{
50-
os << ", ";
51-
}
52-
else [[unlikely]]
53-
{
54-
continue;
49+
if constexpr (N > 1) {
50+
for (std::size_t i = 0; i < N - 1; i++) {
51+
os << A[i] << ", ";
5552
}
5653
}
57-
os << "]";
54+
os << A[N - 1] << "]";
5855
return os;
5956
}
6057

61-
template<typename T, std::size_t M, std::size_t N>
62-
requires std::is_arithmetic<T>::value
63-
INLINE std::ostream& operator<<(std::ostream& os, const std::array<std::array<T, N>, M>& A)
64-
{
58+
/// @brief Overloaded stream insertion operator for printing a 2D array (matrix) to an output stream.
59+
/// @tparam T The type of the elements in the matrix. Must be an arithmetic type.
60+
/// @tparam M The number of rows in the matrix.
61+
/// @tparam N The number of columns in the matrix.
62+
/// @param os The output stream to write the matrix to.
63+
/// @param A The 2D array (matrix) to be printed.
64+
/// @return The output stream after writing the matrix.
65+
/// @note This function prints the elements of a 2D array in a matrix format, enclosed in square brackets.
66+
/// Each row of the matrix is printed on a new line, separated by commas.
67+
template <typename T, std::size_t M, std::size_t N>
68+
requires std::is_arithmetic_v<T>
69+
inline std::ostream& operator<<(std::ostream& os, const std::array<std::array<T, N>, M>& A) {
6570
os << "[";
66-
for (std::size_t i = 0; i < M; ++i)
67-
{
68-
std::cout << A[i];
69-
if (i < M - 1) [[likely]]
70-
{
71-
os << ",\n ";
72-
}
73-
else [[unlikely]]
74-
{
75-
continue;
71+
if constexpr (M > 1) {
72+
for (std::size_t i = 0; i < M - 1; ++i) {
73+
os << A[i] << ",\n ";
7674
}
7775
}
78-
os << "]";
76+
os << A[M - 1] << "]";
7977
return os;
8078
}
8179

0 commit comments

Comments
 (0)