Skip to content

Commit f93d017

Browse files
committed
add types standalone module
1 parent 504c1ed commit f93d017

File tree

6 files changed

+52
-10
lines changed

6 files changed

+52
-10
lines changed

src/torchlensmaker/__init__.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818

1919
import torch
2020

21+
#######
22+
# Types
23+
#######
24+
25+
from torchlensmaker.types import (
26+
ScalarTensor,
27+
BatchTensor,
28+
Batch2DTensor,
29+
Batch3DTensor,
30+
BatchNDTensor,
31+
HomMatrix2D,
32+
HomMatrix3D,
33+
HomMatrix,
34+
)
35+
2136
######
2237
# Core
2338
######
@@ -51,9 +66,6 @@
5166
############
5267

5368
from torchlensmaker.kinematics.homogeneous_geometry import (
54-
HomMatrix2D,
55-
HomMatrix3D,
56-
HomMatrix,
5769
transform_points,
5870
transform_vectors,
5971
hom_identity,

src/torchlensmaker/kinematics/homogeneous_geometry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
from torchlensmaker.core.rot3d import euler_angles_to_matrix
2424

25-
HomMatrix2D: TypeAlias = Float[torch.Tensor, "3 3"]
26-
HomMatrix3D: TypeAlias = Float[torch.Tensor, "4 4"]
27-
HomMatrix: TypeAlias = HomMatrix2D | HomMatrix3D
25+
from torchlensmaker.types import HomMatrix2D, HomMatrix3D, HomMatrix
2826

2927

3028
def hom_matrix_2d(M: Float[torch.Tensor, "2 2"]) -> HomMatrix2D:

src/torchlensmaker/kinematics/kinematics_elements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torchlensmaker.core.tensor_manip import to_tensor, init_param, expand_bool_tuple
2424
from torchlensmaker.optical_data import OpticalData
2525

26-
from .homogeneous_geometry import (
26+
from torchlensmaker.types import (
2727
HomMatrix2D,
2828
HomMatrix3D,
2929
HomMatrix,

src/torchlensmaker/kinematics/kinematics_kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import torch
1818
from jaxtyping import Float
1919

20-
from .homogeneous_geometry import (
20+
from torchlensmaker.types import (
2121
HomMatrix2D,
2222
HomMatrix3D,
23+
)
24+
25+
from .homogeneous_geometry import (
2326
hom_rotate_2d,
2427
hom_rotate_3d,
2528
hom_translate_2d,

src/torchlensmaker/kinematics/tests/test_kinematics_elements.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
import torch
2222
import torch.nn as nn
2323

24+
from torchlensmaker.types import HomMatrix2D, HomMatrix3D
2425

2526
from torchlensmaker.kinematics.homogeneous_geometry import (
2627
hom_identity_2d,
2728
hom_identity_3d,
28-
HomMatrix2D,
29-
HomMatrix3D,
3029
)
3130

3231
from torchlensmaker.kinematics.kinematics_elements import (

src/torchlensmaker/types.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# This file is part of Torch Lens Maker
2+
# Copyright (C) 2024-present Victor Poughon
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU General Public License as published by
6+
# the Free Software Foundation, either version 3 of the License, or
7+
# (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
16+
17+
18+
from typing import TypeAlias
19+
from jaxtyping import Float
20+
import torch
21+
22+
ScalarTensor: TypeAlias = Float[torch.Tensor, ""]
23+
BatchTensor: TypeAlias = Float[torch.Tensor, "..."]
24+
Batch2DTensor: TypeAlias = Float[torch.Tensor, "... 2"]
25+
Batch3DTensor: TypeAlias = Float[torch.Tensor, "... 3"]
26+
BatchNDTensor: TypeAlias = Float[torch.Tensor, "... D"]
27+
28+
HomMatrix2D: TypeAlias = Float[torch.Tensor, "3 3"]
29+
HomMatrix3D: TypeAlias = Float[torch.Tensor, "4 4"]
30+
HomMatrix: TypeAlias = HomMatrix2D | HomMatrix3D

0 commit comments

Comments
 (0)