Skip to content

Commit b95c19c

Browse files
Add GridNodeIndex and GridData for MPM (#22295)
1 parent fe6caf8 commit b95c19c

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

multibody/mpm/BUILD.bazel

+19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ drake_cc_package_library(
1515
visibility = ["//visibility:public"],
1616
deps = [
1717
":bspline_weights",
18+
":grid_data",
1819
],
1920
)
2021

@@ -32,6 +33,17 @@ drake_cc_library(
3233
],
3334
)
3435

36+
drake_cc_library(
37+
name = "grid_data",
38+
hdrs = [
39+
"grid_data.h",
40+
],
41+
deps = [
42+
"//common:bit_cast",
43+
"//common:essential",
44+
],
45+
)
46+
3547
# TODO(xuchenhan-tri): when we enable SPGrid in our releases, we also need to
3648
# install its license file in drake/tools/workspace/BUILD.bazel.
3749

@@ -56,6 +68,13 @@ drake_cc_googletest(
5668
],
5769
)
5870

71+
drake_cc_googletest(
72+
name = "grid_data_test",
73+
deps = [
74+
":grid_data",
75+
],
76+
)
77+
5978
drake_cc_googletest_linux_only(
6079
name = "spgrid_test",
6180
deps = [

multibody/mpm/grid_data.h

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
#include "drake/common/bit_cast.h"
6+
#include "drake/common/drake_assert.h"
7+
#include "drake/common/drake_copyable.h"
8+
#include "drake/common/eigen_types.h"
9+
10+
namespace drake {
11+
namespace multibody {
12+
namespace mpm {
13+
namespace internal {
14+
15+
/* This class is a lightweight wrapper around an integer type (`int32_t` or
16+
`int64_t`) used to differentiate between active grid node indices, inactive
17+
states, and a special flag. An IndexOrFlag can be in exactly one of the
18+
following states:
19+
20+
1. Active index: A non-negative integer representing the index of a grid node.
21+
2. Inactive state (the default state): neither an index nor a flag.
22+
3. The `flag` state: A special state used to mark grid nodes for deferred
23+
processing. The flag state is intended as a marker that must only be set
24+
from the inactive state (not from an active index) to maintain consistency.
25+
26+
Transitions between states are as follows:
27+
- Any state can become inactive.
28+
- The inactive state can become flag state.
29+
- The inactive state can become active (with a non-negative index).
30+
- Active cannot directly become flag state (must go inactive first).
31+
32+
An IndexOrFlag object is guaranteed to have size equal to its template
33+
parameter T. In addition, the inactive state is guaranteed to be represented
34+
with the value -1.
35+
36+
@tparam T The integer type for the index. Must be `int32_t` or `int64_t`. */
37+
template <typename T>
38+
class IndexOrFlag {
39+
public:
40+
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(IndexOrFlag);
41+
42+
static_assert(std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>,
43+
"T must be int32_t or int64_t.");
44+
45+
/* Default constructor initializes the object to the inactive state. */
46+
constexpr IndexOrFlag() = default;
47+
48+
/* Constructor for an active index.
49+
@pre index >= 0 */
50+
explicit constexpr IndexOrFlag(T index) { set_index(index); }
51+
52+
/* Sets the index to the given value, which must be non-negative, thereby
53+
making `this` active.
54+
@pre index >= 0 */
55+
void set_index(T index) {
56+
DRAKE_ASSERT(index >= 0);
57+
value_ = index;
58+
}
59+
60+
/* Sets `this` to the inactive state. */
61+
void set_inactive() { value_ = kInactive; }
62+
63+
/* Sets `this` to the flag state.
64+
@pre !is_index() (i.e., must currently be inactive) */
65+
void set_flag() {
66+
DRAKE_ASSERT(!is_index());
67+
value_ = kFlag;
68+
}
69+
70+
/* Returns true iff `this` is an active index (i.e., a non-negative integer).
71+
*/
72+
constexpr bool is_index() const { return value_ >= 0; }
73+
74+
/* Returns true iff `this` is in the inactive state. */
75+
constexpr bool is_inactive() const { return value_ == kInactive; }
76+
77+
/* Returns true iff `this` is in the flag state. */
78+
constexpr bool is_flag() const { return value_ == kFlag; }
79+
80+
/* Returns the index value.
81+
@pre is_index() == true; */
82+
constexpr T index() const {
83+
DRAKE_ASSERT(is_index());
84+
return value_;
85+
}
86+
87+
private:
88+
/* Note that the enum values are not arbitrary; kInactive must be -1 as in the
89+
class documentation. */
90+
enum : T { kInactive = -1, kFlag = -2 };
91+
/* We encode all information in `value_` to satisfy the size requirement laid
92+
out in the class documentation. */
93+
T value_{kInactive};
94+
};
95+
96+
/* GridData stores data at a single grid node of SparseGrid.
97+
98+
It contains the mass and velocity of the node along with a scratch space for
99+
temporary storage and an index to the node.
100+
101+
It's important to be conscious of the size of GridData since the MPM algorithm
102+
is usually memory-bound. We carefully pack GridData to be a power of 2 to work
103+
with SPGrid, which automatically packs the data to a power of 2.
104+
105+
@tparam T double or float. */
106+
template <typename T>
107+
struct GridData {
108+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
109+
"T must be float or double.");
110+
111+
/* Resets `this` GridData to its default state where all floating point values
112+
are set to NAN and the index is inactive. */
113+
void reset() { *this = {}; }
114+
115+
/* Returns true iff `this` GridData is bit-wise equal to `other`. */
116+
bool operator==(const GridData<T>& other) const {
117+
return std::memcmp(this, &other, sizeof(GridData<T>)) == 0;
118+
}
119+
120+
Vector3<T> v{Vector3<T>::Constant(nan_with_all_bits_set())};
121+
T m{nan_with_all_bits_set()};
122+
Vector3<T> scratch{Vector3<T>::Constant(nan_with_all_bits_set())};
123+
std::conditional_t<std::is_same_v<T, float>, IndexOrFlag<int32_t>,
124+
IndexOrFlag<int64_t>>
125+
index_or_flag{};
126+
127+
private:
128+
/* Returns a floating point NaN value with all bits set to one. This choice
129+
makes the reset() function more efficient. In particlar, it allows the
130+
generated machine code to memset all bits to 1 instead of handling each field
131+
individually. */
132+
static T nan_with_all_bits_set() {
133+
using IntType =
134+
std::conditional_t<std::is_same_v<T, float>, int32_t, int64_t>;
135+
constexpr IntType kAllBitsOn = -1;
136+
return drake::internal::bit_cast<T>(kAllBitsOn);
137+
}
138+
};
139+
140+
/* With T = float, GridData is expected to be 32 bytes. With T = double,
141+
GridData is expected to be 64 bytes. We enforce these sizes at compile time
142+
with static_assert, so that if future changes to this code, compiler alignment,
143+
or Eigen alignment rules cause a size shift, it will be caught early. */
144+
static_assert(sizeof(GridData<float>) == 32,
145+
"Unexpected size for GridData<float>.");
146+
static_assert(sizeof(GridData<double>) == 64,
147+
"Unexpected size for GridData<double>.");
148+
149+
} // namespace internal
150+
} // namespace mpm
151+
} // namespace multibody
152+
} // namespace drake

multibody/mpm/test/grid_data_test.cc

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include "drake/multibody/mpm/grid_data.h"
2+
3+
#include <gtest/gtest.h>
4+
5+
namespace drake {
6+
namespace multibody {
7+
namespace mpm {
8+
namespace internal {
9+
namespace {
10+
11+
using IndexTypes = ::testing::Types<int32_t, int64_t>;
12+
13+
template <typename T>
14+
class IndexOrFlagTest : public ::testing::Test {};
15+
16+
TYPED_TEST_SUITE(IndexOrFlagTest, IndexTypes);
17+
18+
TYPED_TEST(IndexOrFlagTest, Basic) {
19+
using T = TypeParam;
20+
IndexOrFlag<T> dut;
21+
EXPECT_FALSE(dut.is_index());
22+
EXPECT_FALSE(dut.is_flag());
23+
EXPECT_TRUE(dut.is_inactive());
24+
25+
dut.set_index(123);
26+
EXPECT_EQ(dut.index(), 123);
27+
dut.set_inactive();
28+
EXPECT_TRUE(dut.is_inactive());
29+
dut.set_flag();
30+
EXPECT_TRUE(dut.is_flag());
31+
/* Setting flag twice is allowed. */
32+
dut.set_flag();
33+
EXPECT_TRUE(dut.is_flag());
34+
}
35+
36+
TYPED_TEST(IndexOrFlagTest, StateTransition) {
37+
using T = TypeParam;
38+
IndexOrFlag<T> dut(123);
39+
EXPECT_TRUE(dut.is_index());
40+
EXPECT_FALSE(dut.is_flag());
41+
EXPECT_FALSE(dut.is_inactive());
42+
43+
/* Active -> Inactive */
44+
dut.set_inactive();
45+
EXPECT_FALSE(dut.is_index());
46+
EXPECT_FALSE(dut.is_flag());
47+
EXPECT_TRUE(dut.is_inactive());
48+
49+
/* Inactive -> Flag */
50+
dut.set_flag();
51+
EXPECT_FALSE(dut.is_index());
52+
EXPECT_TRUE(dut.is_flag());
53+
EXPECT_FALSE(dut.is_inactive());
54+
55+
/* Flag -> Inactive */
56+
dut.set_inactive();
57+
EXPECT_FALSE(dut.is_index());
58+
EXPECT_FALSE(dut.is_flag());
59+
EXPECT_TRUE(dut.is_inactive());
60+
61+
/* Inactive -> Active */
62+
dut.set_index(123);
63+
EXPECT_TRUE(dut.is_index());
64+
EXPECT_FALSE(dut.is_flag());
65+
EXPECT_FALSE(dut.is_inactive());
66+
EXPECT_EQ(dut.index(), 123);
67+
68+
/* Additional scenario: Flag -> Active */
69+
IndexOrFlag<T> another_dut;
70+
another_dut.set_flag();
71+
another_dut.set_index(123);
72+
EXPECT_TRUE(another_dut.is_index());
73+
EXPECT_EQ(another_dut.index(), 123);
74+
EXPECT_FALSE(another_dut.is_flag());
75+
EXPECT_FALSE(another_dut.is_inactive());
76+
}
77+
78+
using FloatingPointTypes = ::testing::Types<float, double>;
79+
80+
template <typename T>
81+
class GridDataTest : public ::testing::Test {};
82+
83+
TYPED_TEST_SUITE(GridDataTest, FloatingPointTypes);
84+
85+
TYPED_TEST(GridDataTest, Reset) {
86+
using T = TypeParam;
87+
GridData<T> data;
88+
data.index_or_flag.set_index(123);
89+
data.scratch = Vector3<T>::Ones();
90+
data.v = Vector3<T>::Ones();
91+
data.m = 1;
92+
93+
data.reset();
94+
EXPECT_TRUE(data.index_or_flag.is_inactive());
95+
EXPECT_NE(data.scratch, data.scratch);
96+
EXPECT_NE(data.v, data.v);
97+
EXPECT_TRUE(std::isnan(data.m));
98+
}
99+
100+
TYPED_TEST(GridDataTest, Equality) {
101+
using T = TypeParam;
102+
GridData<T> data1;
103+
data1.index_or_flag.set_index(123);
104+
data1.scratch = Vector3<T>::Ones();
105+
data1.v = Vector3<T>::Ones();
106+
data1.m = 1;
107+
108+
GridData<T> data2;
109+
data2.index_or_flag.set_index(123);
110+
data2.scratch = Vector3<T>::Ones();
111+
data2.v = Vector3<T>::Ones();
112+
data2.m = 1;
113+
114+
EXPECT_EQ(data1, data2);
115+
data2.index_or_flag.set_inactive();
116+
EXPECT_NE(data1, data2);
117+
data1.index_or_flag.set_inactive();
118+
EXPECT_EQ(data1, data2);
119+
data1.reset();
120+
EXPECT_NE(data1, data2);
121+
data2.reset();
122+
EXPECT_EQ(data1, data2);
123+
}
124+
125+
} // namespace
126+
} // namespace internal
127+
} // namespace mpm
128+
} // namespace multibody
129+
} // namespace drake

0 commit comments

Comments
 (0)