Skip to content

Commit d2216f7

Browse files
committed
import and typing improvements
1 parent 9ef8e40 commit d2216f7

File tree

9 files changed

+124
-39
lines changed

9 files changed

+124
-39
lines changed

src/torchlensmaker/__init__.py

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,51 +14,130 @@
1414
# You should have received a copy of the GNU General Public License
1515
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1616

17+
# ruff: noqa: F401
18+
19+
######
1720
# Core
18-
from torchlensmaker.core.physics import *
19-
from torchlensmaker.core.outline import *
20-
from torchlensmaker.surfaces.sphere_r import *
21-
from torchlensmaker.core.sag_functions import *
22-
from torchlensmaker.core.transforms import *
21+
######
22+
23+
from torchlensmaker.core.physics import reflection, refraction
24+
from torchlensmaker.core.transforms import (
25+
TransformBase,
26+
LinearTransform,
27+
ComposeTransform,
28+
TranslateTransform,
29+
IdentityTransform,
30+
)
2331
from torchlensmaker.core.intersect import *
2432
from torchlensmaker.core.full_forward import *
2533
from torchlensmaker.core.collision_detection import *
2634
from torchlensmaker.core.parameter import *
2735
from torchlensmaker.core.geometry import *
36+
from torchlensmaker.core.sag_functions import (
37+
Spherical,
38+
Parabolic,
39+
Aspheric,
40+
XYPolynomial,
41+
Conical,
42+
SagSum,
43+
SagFunction,
44+
)
45+
# from torchlensmaker.core.outline import *
2846

47+
##########
2948
# Surfaces
30-
from torchlensmaker.surfaces.conics import *
31-
from torchlensmaker.surfaces.sphere_r import *
32-
from torchlensmaker.surfaces.implicit_surface import *
33-
from torchlensmaker.surfaces.local_surface import *
34-
from torchlensmaker.surfaces.plane import *
35-
from torchlensmaker.surfaces.implicit_cylinder import *
36-
37-
# Optics
49+
##########
50+
51+
from torchlensmaker.surfaces.sphere_r import SphereR
52+
from torchlensmaker.surfaces.conics import Sphere, Parabola, Conic, Asphere
53+
from torchlensmaker.surfaces.implicit_surface import ImplicitSurface
54+
from torchlensmaker.surfaces.local_surface import LocalSurface
55+
from torchlensmaker.surfaces.plane import Plane, CircularPlane, SquarePlane
56+
from torchlensmaker.surfaces.implicit_cylinder import ImplicitCylinder
57+
from torchlensmaker.surfaces.sag_surface import SagSurface
58+
59+
##################
60+
# Optical elements
61+
##################
62+
3863
from torchlensmaker.elements.sequential import Sequential
39-
from torchlensmaker.elements.utils import *
40-
from torchlensmaker.elements.kinematics import *
41-
from torchlensmaker.elements.optical_surfaces import *
42-
from torchlensmaker.elements.light_sources import *
43-
44-
from torchlensmaker.optical_data import *
45-
from torchlensmaker.lenses import *
46-
from torchlensmaker.materials import *
64+
from torchlensmaker.elements.utils import MixedDim
65+
from torchlensmaker.elements.kinematics import (
66+
SubChain,
67+
AbsoluteTransform,
68+
RelativeTransform,
69+
Gap,
70+
Rotate3D,
71+
Rotate2D,
72+
Translate2D,
73+
Translate3D,
74+
Rotate,
75+
Translate,
76+
)
77+
from torchlensmaker.elements.optical_surfaces import (
78+
CollisionSurface,
79+
ReflectiveSurface,
80+
RefractiveSurface,
81+
Aperture,
82+
FocalPoint,
83+
ImagePlane,
84+
linear_magnification,
85+
)
86+
from torchlensmaker.elements.light_sources import (
87+
LightSourceBase,
88+
RaySource,
89+
PointSourceAtInfinity,
90+
PointSource,
91+
ObjectAtInfinity,
92+
Object,
93+
Wavelength,
94+
)
95+
96+
# Top level stuff - to be reorganized
97+
from torchlensmaker.optical_data import OpticalData, default_input
98+
from torchlensmaker.lenses import LensBase, BiLens, Lens, PlanoLens
99+
from torchlensmaker.materials import (
100+
MaterialModel,
101+
NonDispersiveMaterial,
102+
CauchyMaterial,
103+
SellmeierMaterial,
104+
)
105+
106+
##########
107+
# Sampling
108+
##########
109+
47110
from torchlensmaker.sampling import *
48111

112+
##############
49113
# Optimization
50-
from torchlensmaker.optimize import *
114+
##############
115+
116+
import torch.optim as optim
117+
from torchlensmaker.optimize import (
118+
optimize,
119+
OptimizationRecord,
120+
plot_optimization_record,
121+
)
51122

123+
########
52124
# Viewer
125+
########
126+
53127
import torchlensmaker.viewer.tlmviewer as viewer
54128
from torchlensmaker.viewer.render_sequence import *
55129

56-
130+
##########
57131
# Analysis
132+
##########
133+
58134
from torchlensmaker.analysis.plot_magnification import plot_magnification
59135
from torchlensmaker.analysis.plot_material_model import plot_material_models
60136
from torchlensmaker.analysis.spot_diagram import spot_diagram
61137

138+
##################
62139
# Export build123d
140+
##################
141+
63142
import torchlensmaker.export_build123d as export
64143
from torchlensmaker.export_build123d import show_part

src/torchlensmaker/core/collision_detection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
from __future__ import annotations
1818

1919
import torch
20-
import math
21-
from functools import partial
2220

2321
from dataclasses import dataclass
2422

25-
from typing import TYPE_CHECKING, Optional, Callable, Any
23+
from typing import TYPE_CHECKING, Optional, Any
2624

2725
if TYPE_CHECKING:
2826
from torchlensmaker.surfaces.implicit_cylinder import ImplicitSurface

src/torchlensmaker/core/tensorframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get(self, names: str | list[str]) -> torch.Tensor:
5050
else:
5151
idx = [self.columns.index(n) for n in names]
5252
return self.data[:, idx]
53-
except:
53+
except ValueError:
5454
raise KeyError(f"TensorFrame doesn't have column(s): {names}")
5555

5656
def masked(self, mask: torch.Tensor) -> TensorFrame:

src/torchlensmaker/core/transforms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import functools
1919
from collections.abc import Sequence
2020

21-
from torchlensmaker.core.rot2d import rotation_matrix_2D
22-
from torchlensmaker.core.rot3d import euler_angles_to_matrix
2321

2422
# for shorter type annotations
2523
Tensor = torch.Tensor

src/torchlensmaker/elements/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import torch.nn as nn
1818
from torchlensmaker.optical_data import OpticalData
1919

20-
from typing import Any
21-
2220

2321
class Marker(nn.Module):
2422
"WIP"

src/torchlensmaker/export_build123d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torchlensmaker.elements.sequential import Sequential
2222

2323
import os
24-
import torch.nn as nn
2524
import build123d as bd
2625
from os.path import join
2726

src/torchlensmaker/sampling/samplers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
import torch.nn as nn
1919
import numpy as np
2020

21-
from typing import Any, Sequence
22-
23-
Tensor = torch.Tensor
24-
2521
from torchlensmaker.core.tensor_manip import to_tensor
22+
from typing import Any, Sequence, TypeAlias
23+
24+
Tensor: TypeAlias = torch.Tensor
2625

2726

2827
class Sampler:

src/torchlensmaker/testing/check_local_collide.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def check_local_collide(
2424
) -> None:
2525
"Call surface.local_collide() and performs tests on the output"
2626

27-
dim = P.shape[-1]
28-
2927
# Check that rays are the correct dtype
3028
assert P.dtype == surface.dtype
3129
assert V.dtype == surface.dtype
@@ -63,7 +61,6 @@ def check_local_collide(
6361
)
6462

6563
if isinstance(surface, tlm.ImplicitSurface):
66-
N = sum(P.shape[:-1])
6764
rmse = surface.rmse(local_points)
6865
else:
6966
rmse = None

tests/test_tensorframe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23
from torchlensmaker.core.tensorframe import TensorFrame
34

@@ -89,3 +90,19 @@ def test_stack() -> None:
8990

9091
assert torch.all(tf3.data == torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4]]))
9192
assert tf3.columns == tf1.columns == tf2.columns
93+
94+
95+
def test_missing_column() -> None:
96+
N = 2
97+
tf1 = TensorFrame(
98+
torch.column_stack(
99+
(
100+
torch.full((N,), 1),
101+
torch.full((N,), 2),
102+
)
103+
),
104+
["a", "b"],
105+
)
106+
107+
with pytest.raises(KeyError):
108+
tf1.get("c")

0 commit comments

Comments
 (0)