Skip to content

Commit 9d18cac

Browse files
committed
Optimize GNS core performance
1 parent 81690d2 commit 9d18cac

File tree

4 files changed

+127
-119
lines changed

4 files changed

+127
-119
lines changed

gns/learned_simulator.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def __init__(
7474

7575
self._device = device
7676

77+
# Optimized: Register boundary tensor as buffer for automatic device management
78+
# This avoids recreating the tensor on every forward pass
79+
self.register_buffer(
80+
'_boundary_tensor',
81+
torch.tensor(boundaries, dtype=torch.float32)
82+
)
83+
7784
def forward(self):
7885
"""Forward hook runs on class instantiation"""
7986
pass
@@ -95,12 +102,11 @@ def _compute_graph_connectivity(
95102
add_self_edges: Boolean flag to include self edge (default: True)
96103
"""
97104
# Specify examples id for particles
98-
batch_ids = torch.cat(
99-
[
100-
torch.LongTensor([i for _ in range(n)])
101-
for i, n in enumerate(nparticles_per_example)
102-
]
103-
).to(self._device)
105+
# Optimized: Use repeat_interleave instead of list comprehension + cat
106+
batch_ids = torch.repeat_interleave(
107+
torch.arange(len(nparticles_per_example), device=self._device, dtype=torch.long),
108+
nparticles_per_example
109+
)
104110

105111
# radius_graph accepts r < radius not r <= radius
106112
# A torch tensor list of source and target nodes with shape (2, nedges)
@@ -161,11 +167,9 @@ def _encoder_preprocessor(
161167
# Normalized clipped distances to lower and upper boundaries.
162168
# boundaries are an array of shape [num_dimensions, 2], where the second
163169
# axis, provides the lower/upper boundaries.
164-
boundaries = (
165-
torch.tensor(self._boundaries, requires_grad=False).float().to(self._device)
166-
)
167-
distance_to_lower_boundary = most_recent_position - boundaries[:, 0][None]
168-
distance_to_upper_boundary = boundaries[:, 1][None] - most_recent_position
170+
# Optimized: Use pre-computed boundary tensor buffer
171+
distance_to_lower_boundary = most_recent_position - self._boundary_tensor[:, 0][None]
172+
distance_to_upper_boundary = self._boundary_tensor[:, 1][None] - most_recent_position
169173
distance_to_boundaries = torch.cat(
170174
[distance_to_lower_boundary, distance_to_upper_boundary], dim=1
171175
)
@@ -193,28 +197,19 @@ def _encoder_preprocessor(
193197
# 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property
194198

195199
# Collect edge features.
196-
edge_features = []
197-
198-
# Relative displacement and distances normalized to radius
199-
# with shape (nedges, 2)
200-
# normalized_relative_displacements = (
201-
# torch.gather(most_recent_position, 0, senders) -
202-
# torch.gather(most_recent_position, 0, receivers)
203-
# ) / self._connectivity_radius
204-
normalized_relative_displacements = (
205-
most_recent_position[senders, :] - most_recent_position[receivers, :]
206-
) / self._connectivity_radius
207-
208-
# Add relative displacement between two particles as an edge feature
209-
# with shape (nparticles, ndim)
210-
edge_features.append(normalized_relative_displacements)
211-
212-
# Add relative distance between 2 particles with shape (nparticles, 1)
213-
# Edge features has a final shape of (nparticles, ndim + 1)
214-
normalized_relative_distances = torch.norm(
215-
normalized_relative_displacements, dim=-1, keepdim=True
216-
)
217-
edge_features.append(normalized_relative_distances)
200+
# Optimized: Compute displacement and distance together to reduce indexing operations
201+
sender_pos = most_recent_position[senders, :]
202+
receiver_pos = most_recent_position[receivers, :]
203+
relative_displacements = sender_pos - receiver_pos
204+
205+
# Compute distance before normalization for numerical stability
206+
relative_distances = torch.norm(relative_displacements, dim=-1, keepdim=True)
207+
208+
# Normalize both by connectivity radius
209+
normalized_relative_displacements = relative_displacements / self._connectivity_radius
210+
normalized_relative_distances = relative_distances / self._connectivity_radius
211+
212+
edge_features = [normalized_relative_displacements, normalized_relative_distances]
218213

219214
return (
220215
torch.cat(node_features, dim=-1),

gns/noise_utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,27 @@ def get_random_walk_noise_for_position_sequence(
2121
# so to keep `std_last_step` fixed, we apply at each step:
2222
# std_each_step `std_last_step / np.sqrt(num_input_velocities)`
2323
num_velocities = velocity_sequence.shape[1]
24-
velocity_sequence_noise = torch.randn(list(velocity_sequence.shape)) * (
25-
noise_std_last_step / num_velocities**0.5
26-
)
24+
25+
# Optimized: Create noise directly on same device as input
26+
velocity_sequence_noise = torch.randn(
27+
velocity_sequence.shape,
28+
device=position_sequence.device,
29+
dtype=position_sequence.dtype
30+
) * (noise_std_last_step / num_velocities**0.5)
2731

2832
# Apply the random walk.
2933
velocity_sequence_noise = torch.cumsum(velocity_sequence_noise, dim=1)
3034

3135
# Integrate the noise in the velocity to the positions, assuming
32-
# an Euler intergrator and a dt = 1, and adding no noise to the very first
36+
# an Euler integrator and a dt = 1, and adding no noise to the very first
3337
# position (since that will only be used to calculate the first position
3438
# change).
35-
position_sequence_noise = torch.cat(
36-
[
37-
torch.zeros_like(velocity_sequence_noise[:, 0:1]),
38-
torch.cumsum(velocity_sequence_noise, dim=1),
39-
],
40-
dim=1,
39+
# Optimized: Pre-allocate on correct device
40+
position_sequence_noise = torch.zeros(
41+
(velocity_sequence.shape[0], velocity_sequence.shape[1] + 1, velocity_sequence.shape[2]),
42+
device=position_sequence.device,
43+
dtype=position_sequence.dtype
4144
)
45+
position_sequence_noise[:, 1:] = torch.cumsum(velocity_sequence_noise, dim=1)
4246

4347
return position_sequence_noise

gns/particle_data_loader.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -131,41 +131,63 @@ def get_num_features(self):
131131

132132

133133
def collate_fn_sample(batch):
134+
"""Optimized collation function with pre-allocation and minimal copies."""
134135
features, labels = zip(*batch)
135136

136-
position_list = []
137-
particle_type_list = []
138-
material_property_list = []
139-
n_particles_per_example_list = []
140-
141-
for feature in features:
142-
position_list.append(feature[0])
143-
particle_type_list.append(feature[1])
144-
if len(feature) == 4: # If material property is present
145-
material_property_list.append(feature[2])
146-
n_particles_per_example_list.append(feature[3])
137+
# Pre-calculate total particles to avoid reallocation
138+
total_particles = sum(f[0].shape[0] for f in features)
139+
batch_size = len(features)
140+
has_material = len(features[0]) == 4
141+
142+
# Get dimensions from first sample
143+
seq_len = features[0][0].shape[1]
144+
dim = features[0][0].shape[2]
145+
146+
# Pre-allocate tensors with pinned memory for faster GPU transfer
147+
positions = torch.empty((total_particles, seq_len, dim),
148+
dtype=torch.float32, pin_memory=True)
149+
particle_types = torch.empty(total_particles,
150+
dtype=torch.long, pin_memory=True)
151+
n_particles = torch.empty(batch_size,
152+
dtype=torch.long, pin_memory=True)
153+
154+
if has_material:
155+
materials = torch.empty(total_particles,
156+
dtype=torch.float32, pin_memory=True)
157+
158+
# Fill pre-allocated tensors (single copy from numpy)
159+
offset = 0
160+
for i, feature in enumerate(features):
161+
n_part = feature[0].shape[0]
162+
163+
# Direct numpy-to-torch copy
164+
positions[offset:offset+n_part] = torch.from_numpy(feature[0])
165+
particle_types[offset:offset+n_part] = torch.from_numpy(feature[1])
166+
167+
if has_material:
168+
materials[offset:offset+n_part] = torch.from_numpy(feature[2])
169+
n_particles[i] = feature[3]
147170
else:
148-
n_particles_per_example_list.append(feature[2])
149-
150-
collated_features = (
151-
torch.tensor(np.vstack(position_list)).to(torch.float32).contiguous(),
152-
torch.tensor(np.concatenate(particle_type_list)).contiguous(),
153-
torch.tensor(n_particles_per_example_list).contiguous(),
154-
)
155-
156-
if material_property_list:
157-
material_property_tensor = (
158-
torch.tensor(np.concatenate(material_property_list))
159-
.to(torch.float32)
160-
.contiguous()
161-
)
162-
collated_features = (
163-
collated_features[:2] + (material_property_tensor,) + collated_features[2:]
164-
)
171+
n_particles[i] = feature[2]
165172

166-
collated_labels = torch.tensor(np.vstack(labels)).to(torch.float32).contiguous()
173+
offset += n_part
167174

168-
return collated_features, collated_labels
175+
# Build output tuple
176+
if has_material:
177+
collated_features = (positions, particle_types, materials, n_particles)
178+
else:
179+
collated_features = (positions, particle_types, n_particles)
180+
181+
# Labels - same optimization
182+
labels_tensor = torch.empty((total_particles, dim),
183+
dtype=torch.float32, pin_memory=True)
184+
offset = 0
185+
for label in labels:
186+
n_part = label.shape[0]
187+
labels_tensor[offset:offset+n_part] = torch.from_numpy(label)
188+
offset += n_part
189+
190+
return collated_features, labels_tensor
169191

170192

171193
def collate_fn_trajectory(batch):

gns/train.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,20 @@ def rollout(
5353
initial_positions = position[:, : cfg.data.input_sequence_length]
5454
ground_truth_positions = position[:, cfg.data.input_sequence_length :]
5555

56-
current_positions = initial_positions
57-
predictions = []
56+
current_positions = initial_positions.clone()
57+
58+
# Pre-allocate predictions tensor to avoid memory fragmentation
59+
n_particles = position.shape[0]
60+
dim = position.shape[-1]
61+
predictions = torch.zeros(
62+
(nsteps, n_particles, dim),
63+
device=device,
64+
dtype=position.dtype
65+
)
66+
67+
# Pre-compute kinematic mask once (static for entire rollout)
68+
kinematic_mask = (particle_types == cfg.data.kinematic_particle_id).bool()
69+
kinematic_mask_expanded = kinematic_mask[:, None].expand(-1, dim)
5870

5971
for step in tqdm(range(nsteps), total=nsteps):
6072
# Get next position with shape (nnodes, dim)
@@ -66,29 +78,17 @@ def rollout(
6678
)
6779

6880
# Update kinematic particles from prescribed trajectory.
69-
kinematic_mask = (
70-
(particle_types == cfg.data.kinematic_particle_id)
71-
.clone()
72-
.detach()
73-
.to(device)
74-
)
7581
next_position_ground_truth = ground_truth_positions[:, step]
76-
kinematic_mask = kinematic_mask.bool()[:, None].expand(
77-
-1, current_positions.shape[-1]
78-
)
7982
next_position = torch.where(
80-
kinematic_mask, next_position_ground_truth, next_position
83+
kinematic_mask_expanded, next_position_ground_truth, next_position
8184
)
82-
predictions.append(next_position)
8385

84-
# Shift `current_positions`, removing the oldest position in the sequence
85-
# and appending the next position at the end.
86-
current_positions = torch.cat(
87-
[current_positions[:, 1:], next_position[:, None, :]], dim=1
88-
)
86+
# Store prediction in pre-allocated tensor
87+
predictions[step] = next_position
8988

90-
# Predictions with shape (time, nnodes, dim)
91-
predictions = torch.stack(predictions)
89+
# Shift `current_positions` in-place
90+
current_positions[:, :-1] = current_positions[:, 1:].clone()
91+
current_positions[:, -1] = next_position
9292
ground_truth_positions = ground_truth_positions.permute(1, 0, 2)
9393

9494
loss = (predictions - ground_truth_positions) ** 2
@@ -577,41 +577,28 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
577577
labels,
578578
) = prepare_data(example, device_id)
579579

580-
n_particles_per_example = n_particles_per_example.to(device_id)
581-
labels = labels.to(device_id)
582-
583-
sampled_noise = (
584-
noise_utils.get_random_walk_noise_for_position_sequence(
585-
position, noise_std_last_step=cfg.data.noise_std
586-
).to(device_id)
587-
)
588-
non_kinematic_mask = (
589-
(particle_type != cfg.data.kinematic_particle_id)
590-
.clone()
591-
.detach()
592-
.to(device_id)
580+
# Optimized: Data already on device_id from prepare_data, no need to transfer again
581+
# Noise is now created directly on correct device (see noise_utils.py optimization)
582+
sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
583+
position, noise_std_last_step=cfg.data.noise_std
593584
)
585+
# Optimized: Comparison already creates new tensor, no need for clone/detach
586+
non_kinematic_mask = (particle_type != cfg.data.kinematic_particle_id)
594587
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)
595588

596-
device_or_rank = rank if device == torch.device("cuda") else device
597589
predict_fn = (
598590
simulator.module.predict_accelerations
599591
if use_dist
600592
else simulator.predict_accelerations
601593
)
594+
# Optimized: All tensors already on correct device, no transfers needed
602595
pred_acc, target_acc = predict_fn(
603-
next_positions=labels.to(device_or_rank),
604-
position_sequence_noise=sampled_noise.to(device_or_rank),
605-
position_sequence=position.to(device_or_rank),
606-
nparticles_per_example=n_particles_per_example.to(
607-
device_or_rank
608-
),
609-
particle_types=particle_type.to(device_or_rank),
610-
material_property=(
611-
material_property.to(device_or_rank)
612-
if n_features == 3
613-
else None
614-
),
596+
next_positions=labels,
597+
position_sequence_noise=sampled_noise,
598+
position_sequence=position,
599+
nparticles_per_example=n_particles_per_example,
600+
particle_types=particle_type,
601+
material_property=material_property if n_features == 3 else None,
615602
)
616603

617604
if (

0 commit comments

Comments
 (0)