Skip to content

Commit 0054d1a

Browse files
committed
fix some typings
1 parent 5e6d621 commit 0054d1a

File tree

6 files changed

+81
-76
lines changed

6 files changed

+81
-76
lines changed

openglider/glider/parametric/table/attachment_points.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from openglider.glider.cell.attachment_point import CellAttachmentPoint
1212
from openglider.glider.cell.cell import Cell
13-
from openglider.glider.curve import GliderCurveType
1413
from openglider.glider.parametric.table.base import CellTable, Keyword, RibTable, dto
1514
from openglider.glider.parametric.table.base.parser import Parser
1615
from openglider.glider.rib.attachment_point import AttachmentPoint
@@ -25,7 +24,7 @@
2524

2625
logger = logging.getLogger(__name__)
2726

28-
class ATP(dto.DTO):
27+
class ATP(dto.DTO[AttachmentPoint]):
2928
name: str
3029
rib_pos: Percentage
3130
force: float | euklid.vector.Vector3D
@@ -92,7 +91,7 @@ class AttachmentPointTable(RibTable):
9291
#"AHP": ATP,
9392
}
9493

95-
def get_element(self, row: int, keyword: str, data: list[Any], resolvers: list[Parser]=None, rib: Rib=None, **kwargs: Any) -> AttachmentPoint:
94+
def get_element(self, row: int, keyword: str, data: list[Any], resolvers: list[Parser] | None=None, rib: Rib | None=None, **kwargs: Any) -> AttachmentPoint:
9695
# rib_no, rib_pos, cell_pos, force, name, is_cell
9796
force = data[2]
9897

@@ -123,7 +122,8 @@ def update_columns(keyword: str, data_length: int, force_position: int) -> None:
123122
if name in forces:
124123
force = forces[name]
125124
try:
126-
if isinstance(force, float):
125+
# TODO: why?
126+
if isinstance(force, (float, int)):
127127
raise TypeError()
128128
column[row, force_position] = str(list(force))
129129
except TypeError:
@@ -157,7 +157,7 @@ class CellAttachmentPointTable(CellTable):
157157
"ATPDIFF": Keyword([("name", str), ("cell_pos", float), ("rib_pos", float), ("force", Union[float, str]), ("offset", float)], target_cls=CellAttachmentPoint)
158158
}
159159

160-
def get_element(self, row: int, keyword: str, data: list[Any], resolvers: list[Parser], cell: Cell=None, **kwargs: Any) -> CellAttachmentPoint:
160+
def get_element(self, row: int, keyword: str, data: list[Any], resolvers: list[Parser], cell: Cell | None=None, **kwargs: Any) -> CellAttachmentPoint: # type: ignore
161161
force = data[3]
162162

163163
if isinstance(force, str):
@@ -171,7 +171,8 @@ def get_element(self, row: int, keyword: str, data: list[Any], resolvers: list[P
171171
if len(data) > 4:
172172
offset = resolvers[row].parse(data[4])
173173

174-
node.offset = offset
174+
if offset is not None:
175+
node.offset = Length(offset)
175176

176177
return node
177178

openglider/glider/parametric/table/base/dto.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,22 @@ def from_value(cls, value: TupleType) -> Self:
3737
@pydantic.model_validator(mode="before")
3838
@classmethod
3939
def _validate(cls, v: Any) -> dict[str, Any] | Self:
40-
if isinstance(v, tuple) and len(v) == 2:
41-
return {
42-
"first": v[0],
43-
"second": v[1]
44-
}
45-
else:
46-
return {
47-
"first": v,
48-
"second": v
49-
}
50-
40+
if isinstance(v, tuple):
41+
v_tuple = typing.cast(tuple[Any, Any], v)
42+
if len(v_tuple) == 2:
43+
return {
44+
"first": v_tuple[0],
45+
"second": v_tuple[1]
46+
}
47+
return {
48+
"first": v,
49+
"second": v
50+
}
5151
class SingleCellTuple(CellTuple[TupleType], Generic[TupleType]):
5252
index_offset: ClassVar[tuple[int, int]] = (0, 0)
5353

5454

55-
_type_cache: dict[type[DTO], list[tuple[str, str]]] = {}
55+
_type_cache: dict[type[DTO[Any]], list[tuple[str, str]]] = {}
5656

5757
class DTO(BaseModel, Generic[ReturnType], abc.ABC):
5858
model_config = pydantic.ConfigDict(
@@ -66,12 +66,12 @@ def get_object(self) -> ReturnType:
6666
raise NotImplementedError
6767

6868
@staticmethod
69-
def _get_type_string(type_: type | None) -> str:
69+
def _get_type_string(type_: type | types.UnionType | None) -> str:
7070
assert type_ is not None
7171

7272
if isinstance(type_, types.UnionType):
73-
names = []
74-
for subtype in type_.__args__:
73+
names: list[str] = []
74+
for subtype in typing.get_args(type_):
7575
names.append(subtype.__name__)
7676

7777
return " | ".join(names)
@@ -87,10 +87,10 @@ def _get_type_string(type_: type | None) -> str:
8787
return type_.__name__
8888

8989
@staticmethod
90-
def _is_cell_tuple(type: Any) -> CellTuple | None:
90+
def check_is_cell_tuple(type_: Any) -> type[CellTuple[Any]] | None:
9191
try:
92-
if issubclass(type, CellTuple):
93-
return type
92+
if isinstance(type_, type) and issubclass(type_, CellTuple):
93+
return typing.cast(type[CellTuple[Any]], type_)
9494
except TypeError:
9595
pass
9696

@@ -99,12 +99,12 @@ def _is_cell_tuple(type: Any) -> CellTuple | None:
9999
@classmethod
100100
def describe(cls) -> list[tuple[str, str]]:
101101
if cls not in _type_cache:
102-
result = []
102+
result: list[tuple[str, str]] = []
103103
for field_name, field in cls.model_fields.items():
104-
is_cell_tuple = cls._is_cell_tuple(field.annotation)
104+
is_cell_tuple = cls.check_is_cell_tuple(field.annotation)
105105

106106
if is_cell_tuple:
107-
inner_type = is_cell_tuple.__fields__["first"].annotation
107+
inner_type = is_cell_tuple.model_fields["first"].annotation
108108
inner_type_str = cls._get_type_string(inner_type)
109109

110110
if sum(is_cell_tuple.index_offset) > 0:

openglider/glider/parametric/table/base/table.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import enum
22
import logging
3-
import sys
43
import typing
54
from typing import Any, Generic, TypeVar
65

@@ -24,10 +23,10 @@ class TableType(enum.Enum):
2423

2524
class ElementTable(Generic[ElementType]):
2625
table_type: TableType = TableType.general
27-
keywords: dict[str, Keyword] = {}
28-
dtos: dict[str, type[DTO]] = {}
26+
keywords: dict[str, Keyword[Any]] = {}
27+
dtos: dict[str, type[DTO[Any]]] = {}
2928

30-
def __init__(self, table: Table=None, migrate_header: bool=False):
29+
def __init__(self, table: Table | None=None, migrate_header: bool=False):
3130
self.table = Table()
3231
if table is not None:
3332
if migrate_header:
@@ -36,18 +35,17 @@ def __init__(self, table: Table=None, migrate_header: bool=False):
3635
else:
3736
_table = table
3837

39-
if _table is not None:
40-
def add_data(keyword: str, data_length: int) -> None:
41-
for column in self.get_columns(_table, keyword, data_length):
42-
self.table.append_right(column)
38+
def add_data(keyword: str, data_length: int) -> None:
39+
for column in self.get_columns(_table, keyword, data_length):
40+
self.table.append_right(column)
4341

44-
for keyword in self.keywords:
45-
data_length = self.keywords[keyword].attribute_length
46-
add_data(keyword, data_length)
42+
for keyword in self.keywords:
43+
data_length = self.keywords[keyword].attribute_length
44+
add_data(keyword, data_length)
4745

48-
for dto in self.dtos:
49-
data_length = self.dtos[dto].column_length()
50-
add_data(dto, data_length)
46+
for dto in self.dtos:
47+
data_length = self.dtos[dto].column_length()
48+
add_data(dto, data_length)
5149

5250
def __json__(self) -> dict[str, Any]:
5351
return {
@@ -56,7 +54,7 @@ def __json__(self) -> dict[str, Any]:
5654

5755
@classmethod
5856
def get_columns(cls, table: Table, keyword: str, data_length: int) -> list[Table]:
59-
columns = []
57+
columns: list[Table] = []
6058
column = 0
6159

6260
if keyword in cls.keywords:
@@ -69,6 +67,8 @@ def get_columns(cls, table: Table, keyword: str, data_length: int) -> list[Table
6967
header[0, 0] = keyword
7068
for i, (field_name, field_type) in enumerate(types):
7169
header[1, i] = f"{field_name}: {field_type}"
70+
else:
71+
raise ValueError(f"unknown keyword {keyword}")
7272

7373
while column < table.num_columns:
7474
if table[0, column] == keyword:
@@ -85,7 +85,7 @@ def get_columns(cls, table: Table, keyword: str, data_length: int) -> list[Table
8585

8686
def get(self, row_no: int, keywords: list[str] | None=None, **kwargs: Any) -> list[ElementType]:
8787
row_no += 2 # skip header line
88-
elements = []
88+
elements: list[ElementType] = []
8989

9090
for keyword in list(self.keywords.keys()) + list(self.dtos.keys()):
9191
if keyword in self.keywords:
@@ -136,14 +136,14 @@ def get_one(self, row_no: int, keywords: list[str] | None=None, **kwargs: Any) -
136136

137137
return None
138138

139-
def _prepare_dto_data(self, row: int, dto: type[DTO], data: list[Any], resolvers: list[Parser]) -> dict[str, Any]:
139+
def _prepare_dto_data(self, row: int, dto: type[DTO[Any]], data: list[Any], resolvers: list[Parser]) -> dict[str, Any]:
140140
fields = dto.model_fields.items()
141141

142142
dct: dict[str, Any] = {}
143143
index = 0
144144

145145
for field_name, field in fields:
146-
if tuple_type := dto._is_cell_tuple(field.annotation):
146+
if tuple_type := dto.check_is_cell_tuple(field.annotation):
147147
offset1, offset2 = tuple_type.index_offset
148148
dct[field_name] = (
149149
resolvers[row].parse(data[index+offset1]),
@@ -176,11 +176,11 @@ def get_element(self, row: int, keyword: str, data: list[typing.Any], **kwargs:
176176
raise ValueError()
177177

178178
def _repr_html_(self) -> str:
179-
return self.table._repr_html_()
179+
return self.table._repr_html_() # type: ignore
180180

181181

182-
class CellTable(ElementTable):
182+
class CellTable(ElementTable[Any]):
183183
table_type = TableType.cell
184184

185-
class RibTable(ElementTable):
185+
class RibTable(ElementTable[Any]):
186186
table_type = TableType.rib

openglider/glider/rib/attachment_point.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class RoundReinforcement(BaseModel):
2929
def get_3d(self, rib: Rib, num_points: int=10) -> list[euklid.vector.PolyLine3D]:
3030
# create circle with center on the point
3131
polygons = self.get_flattened(rib, num_points=num_points)
32-
aligned_polygons = []
32+
aligned_polygons: list[euklid.vector.PolyLine3D] = []
3333
for polygon in polygons:
3434

3535
aligned_polygons.append(rib.align_all(polygon, scale=False))
@@ -61,7 +61,7 @@ class AttachmentPoint(Node):
6161
protoloops: int = 0
6262
protoloop_distance: Percentage | Length = Percentage("2%")
6363

64-
re_name: ClassVar[re.Pattern] = re.compile(r"^(?P<n>[0-9]+_)?([A-Za-z]+)([0-9]+)")
64+
re_name: ClassVar[re.Pattern[Any]] = re.compile(r"^(?P<n>[0-9]+_)?([A-Za-z]+)([0-9]+)")
6565

6666
def __repr__(self) -> str:
6767
return f"<{self.__class__.__name__}: '{self.name}' ({self.rib_pos})>"
@@ -76,7 +76,7 @@ def __json__(self) -> dict[str, Any]:
7676
}
7777

7878
@classmethod
79-
def __from_json__(self, **data: Any) -> AttachmentPoint:
79+
def __from_json__(cls, **data: Any) -> AttachmentPoint:
8080
data["force"] = euklid.vector.Vector3D(data["force"])
8181
return AttachmentPoint(**data)
8282

@@ -88,7 +88,7 @@ def get_x_values(self, rib: Rib) -> list[float]:
8888

8989
if self.protoloops:
9090
hull = rib.get_hull()
91-
ik_start = hull.get_ik(self.rib_pos)
91+
ik_start = hull.get_ik(self.rib_pos.si)
9292

9393
for i in range(self.protoloops):
9494
diff = (i+1) * self.protoloop_distance
@@ -104,11 +104,9 @@ def get_x_values(self, rib: Rib) -> list[float]:
104104

105105
return positions
106106

107-
@classmethod
108-
def calculate_force_rib_aligned(self, rib: Rib, force: float | None=None) -> euklid.vector.Vector3D:
109-
if force is None:
110-
force = self.force.length()
111-
return rib.rotation_matrix.apply([0, force, 0])
107+
@staticmethod
108+
def calculate_force_rib_aligned(rib: Rib, force: float) -> euklid.vector.Vector3D:
109+
return rib.rotation_matrix.apply(euklid.vector.Vector3D([0, force, 0]))
112110

113111
def get_position(self, rib: Rib) -> euklid.vector.Vector3D:
114112
hull = rib.get_hull()

openglider/lines/lineset.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import math
77
import os
88
import re
9-
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
9+
from typing import TYPE_CHECKING, Any, TypeAlias
1010
from collections.abc import Callable
1111
from functools import cmp_to_key
1212

@@ -58,8 +58,7 @@ def get_length(self) -> float:
5858
return length
5959

6060

61-
T = TypeVar('T')
62-
LineTreePart: TypeAlias = tuple[Line, list[T]]
61+
LineTreePart: TypeAlias = tuple[Line, list["LineTreePart"]]
6362

6463
class LineSet:
6564
"""
@@ -69,7 +68,7 @@ class LineSet:
6968
knot_corrections = KnotCorrections.read_csv(os.path.join(os.path.dirname(__file__), "knots.csv"))
7069
mat: SagMatrix
7170

72-
def __init__(self, lines: list[Line], v_inf: euklid.vector.Vector3D=None):
71+
def __init__(self, lines: list[Line], v_inf: euklid.vector.Vector3D | None=None):
7372
self._v_inf = v_inf or euklid.vector.Vector3D([0,0,0])
7473
self.lines = lines or []
7574

@@ -101,7 +100,7 @@ def __json__(self) -> dict[str, Any]:
101100

102101
@classmethod
103102
def __from_json__(cls, lines: list[dict[str, Any]], nodes: list[Node], v_inf: euklid.vector.Vector3D) -> LineSet:
104-
lines_new = []
103+
lines_new: list[Line] = []
105104
for line in lines:
106105
if isinstance(line["upper_node"], int):
107106
line["upper_node"] = nodes[line["upper_node"]]
@@ -137,8 +136,9 @@ def uppermost_lines(self) -> list[Line]:
137136

138137
@property
139138
def nodes(self) -> list[Node]:
140-
nodes = set()
139+
nodes: set[Node] = set()
141140
for line in self.lines:
141+
# Collect unique Node instances from each line
142142
nodes.add(line.upper_node)
143143
nodes.add(line.lower_node)
144144
return list(nodes)
@@ -207,7 +207,7 @@ def recursive_count_floors(node: Node) -> int:
207207

208208
return {n: recursive_count_floors(n) for n in self.lower_attachment_points}
209209

210-
def get_lines_by_floor(self, target_floor: int=0, node: Node=None, en_style: bool=True) -> list[Line]:
210+
def get_lines_by_floor(self, target_floor: int=0, node: Node | None=None, en_style: bool=True) -> list[Line]:
211211
"""
212212
starting from node: walk up "target_floor" floors and return all the lines.
213213
@@ -229,8 +229,8 @@ def recursive_level(node: Node, current_level: int) -> list[Line]:
229229

230230
return recursive_level(node, 0)
231231

232-
def get_floor_strength(self, node: Node=None) -> list[float]:
233-
strength_list = []
232+
def get_floor_strength(self, node: Node | None=None) -> list[float]:
233+
strength_list: list[float] = []
234234
node = node or self.get_main_attachment_point()
235235
for i in range(self.floors[node]):
236236
lines = self.get_lines_by_floor(i, node, en_style=True)
@@ -637,7 +637,7 @@ def sort_lines(self, lines: list[Line] | None=None, x_factor: float=10., by_name
637637
matches = {line.name: re_name.match(line.name) for line in lines_new}
638638

639639
if all(matches.values()):
640-
line_values = {}
640+
line_values: dict[str, tuple[float, int, int]] = {}
641641
for name, match in matches.items():
642642
if match is None:
643643
raise ValueError(f"this is unreachable")
@@ -888,15 +888,15 @@ def get_checklength(line: Line, upper_lines: Any) -> list[tuple[str, float]]:
888888
if not len(upper_lines):
889889
return [(line.upper_node.name, line_length)]
890890
else:
891-
lengths = []
891+
lengths: list[tuple[str, float]] = []
892892
for upper in upper_lines:
893893
lengths += get_checklength(*upper)
894894

895895
return [
896896
(name, length + line_length) for name, length in lengths
897897
]
898898

899-
checklength_values = []
899+
checklength_values: list[tuple[str, float]] = []
900900
for line, upper_line in self.create_tree():
901901
checklength_values += get_checklength(line, upper_line)
902902

0 commit comments

Comments
 (0)