Skip to content

Commit 2713510

Browse files
authored
Merge pull request #308 from Hespe/fix-clone-warnings
Fix warnings in test suite
2 parents b92e3c9 + 286d09b commit 2713510

File tree

5 files changed

+7
-10
lines changed

5 files changed

+7
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
4747
- Fix plotting for segments that contain tensors with `require_grad=True` (see #288) (@hespe)
4848
- Fix bug where `Element.length` could not be set as a `torch.nn.Parameter` (see #301) (@jank324, @hespe)
4949
- Fix registration of `torch.nn.Parameter` at initilization for elements and beams (see #303) (@hespe)
50+
- Fix warnings about NumPy deprecations and unintentional tensor clones (see #308) (@hespe)
5051

5152
### 🐆 Other
5253

cheetah/accelerator/dipole.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal, Optional
22

33
import matplotlib.pyplot as plt
4-
import numpy as np
54
import torch
65
from matplotlib.patches import Rectangle
76
from scipy.constants import physical_constants
@@ -104,14 +103,14 @@ def __init__(
104103
if fringe_integral is not None:
105104
self.fringe_integral = torch.as_tensor(fringe_integral, **factory_kwargs)
106105
self.fringe_integral_exit = (
107-
torch.tensor(fringe_integral_exit, **factory_kwargs)
106+
torch.as_tensor(fringe_integral_exit, **factory_kwargs)
108107
if fringe_integral_exit is not None
109108
else self.fringe_integral
110109
)
111110
if gap is not None:
112111
self.gap = torch.as_tensor(gap, **factory_kwargs)
113112
self.gap_exit = (
114-
torch.tensor(gap_exit, **factory_kwargs)
113+
torch.as_tensor(gap_exit, **factory_kwargs)
115114
if gap_exit is not None
116115
else self.gap
117116
)
@@ -493,7 +492,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
493492
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle
494493

495494
alpha = 1 if self.is_active else 0.2
496-
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
495+
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)
497496

498497
patch = Rectangle(
499498
(plot_s, 0), plot_length, height, color="tab:green", alpha=alpha, zorder=2

cheetah/accelerator/horizontal_corrector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional
22

33
import matplotlib.pyplot as plt
4-
import numpy as np
54
import torch
65
from matplotlib.patches import Rectangle
76

@@ -88,7 +87,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
8887
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle
8988

9089
alpha = 1 if self.is_active else 0.2
91-
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
90+
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)
9291

9392
patch = Rectangle(
9493
(plot_s, 0), plot_length, height, color="tab:blue", alpha=alpha, zorder=2

cheetah/accelerator/quadrupole.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal, Optional
22

33
import matplotlib.pyplot as plt
4-
import numpy as np
54
import torch
65
from matplotlib.patches import Rectangle
76
from scipy.constants import physical_constants
@@ -215,7 +214,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
215214
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
216215

217216
alpha = 1 if self.is_active else 0.2
218-
height = 0.8 * (np.sign(plot_k1) if self.is_active else 1)
217+
height = 0.8 * (torch.sign(plot_k1) if self.is_active else 1)
219218
patch = Rectangle(
220219
(plot_s, 0), plot_length, height, color="tab:red", alpha=alpha, zorder=2
221220
)

cheetah/accelerator/vertical_corrector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional
22

33
import matplotlib.pyplot as plt
4-
import numpy as np
54
import torch
65
from matplotlib.patches import Rectangle
76
from scipy.constants import physical_constants
@@ -91,7 +90,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
9190
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle
9291

9392
alpha = 1 if self.is_active else 0.2
94-
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
93+
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)
9594

9695
patch = Rectangle(
9796
(plot_s, 0), plot_length, height, color="tab:cyan", alpha=alpha, zorder=2

0 commit comments

Comments
 (0)