Skip to content

Commit 653a791

Browse files
committed
Refactor caching logic to use functools lru_cache, refactor reference spiral generation to improve performance, convert SpiralConfig to frozen dataclass for immutability necessary for caching, and fix center_spiral to modify a copied SpiralConfig instead of the original
1 parent be828ab commit 653a791

5 files changed

Lines changed: 91 additions & 294 deletions

File tree

src/graphomotor/core/config.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,31 @@
22

33
import dataclasses
44
import logging
5-
import warnings
5+
from typing import Any
66

77
import numpy as np
88

99

10-
@dataclasses.dataclass
10+
def get_logger() -> logging.Logger:
11+
"""Get the Graphomotor logger."""
12+
logger = logging.getLogger("graphomotor")
13+
if logger.handlers:
14+
return logger
15+
logger.setLevel(logging.INFO)
16+
formatter = logging.Formatter(
17+
"%(asctime)s - %(name)s - %(levelname)s - "
18+
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
19+
)
20+
handler = logging.StreamHandler()
21+
handler.setFormatter(formatter)
22+
logger.addHandler(handler)
23+
return logger
24+
25+
26+
logger = get_logger()
27+
28+
29+
@dataclasses.dataclass(frozen=True)
1130
class SpiralConfig:
1231
"""Class for the parameters of anticipated spiral drawing."""
1332

@@ -29,32 +48,17 @@ def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig
2948
Returns:
3049
SpiralConfig instance with updated parameters.
3150
"""
32-
config = cls()
51+
valid_params = {f.name for f in cls.__dataclass_fields__.values()}
52+
filtered_params: dict[str, Any] = {}
53+
3354
for key, value in config_dict.items():
34-
if hasattr(config, key):
35-
setattr(config, key, value)
55+
if key in valid_params:
56+
filtered_params[key] = value
3657
else:
37-
valid_params = ", ".join(
38-
f.name for f in cls.__dataclass_fields__.values()
39-
)
40-
warnings.warn(
58+
valid_param_names = ", ".join(valid_params)
59+
logger.warning(
4160
f"Unknown configuration parameters will be ignored: {key}. "
42-
f"Valid parameters are: {valid_params}"
61+
f"Valid parameters are: {valid_param_names}"
4362
)
44-
return config
45-
4663

47-
def get_logger() -> logging.Logger:
48-
"""Get the Graphomotor logger."""
49-
logger = logging.getLogger("graphomotor")
50-
if logger.handlers:
51-
return logger
52-
logger.setLevel(logging.INFO)
53-
formatter = logging.Formatter(
54-
"%(asctime)s - %(name)s - %(levelname)s - "
55-
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
56-
)
57-
handler = logging.StreamHandler()
58-
handler.setFormatter(formatter)
59-
logger.addHandler(handler)
60-
return logger
64+
return cls(**filtered_params)

src/graphomotor/utils/center_spiral.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ def center_spiral(spiral):
2727
spiral_config = config.SpiralConfig()
2828

2929
if isinstance(spiral, models.Spiral):
30-
spiral.data["x"] -= spiral_config.center_x
31-
spiral.data["y"] -= spiral_config.center_y
32-
return spiral
30+
centered_spiral = models.Spiral(
31+
data=spiral.data.copy(), metadata=spiral.metadata.copy()
32+
)
33+
centered_spiral.data["x"] -= spiral_config.center_x
34+
centered_spiral.data["y"] -= spiral_config.center_y
35+
return centered_spiral
3336
elif isinstance(spiral, np.ndarray):
34-
spiral[:, 0] -= spiral_config.center_x
35-
spiral[:, 1] -= spiral_config.center_y
36-
return spiral
37+
centered_spiral = spiral.copy()
38+
centered_spiral[:, 0] -= spiral_config.center_x
39+
centered_spiral[:, 1] -= spiral_config.center_y
40+
return centered_spiral
3741
else:
3842
raise TypeError(
3943
f"Expected models.Spiral or np.ndarray, got {type(spiral).__name__}"
Lines changed: 43 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
"""Utility functions for generating an equidistant reference spiral."""
22

3-
import hashlib
4-
import pathlib
3+
import functools
54

65
import numpy as np
76
from scipy import integrate, optimize
87

98
from graphomotor.core import config
109

11-
logger = config.get_logger()
12-
1310

1411
def _arc_length_integrand(t: float, spiral_config: config.SpiralConfig) -> float:
1512
"""Calculate the differential arc length at angle t for an Archimedean spiral.
1613
1714
Args:
1815
t: Angle parameter.
19-
spiral_config: Configuration parameters for the spiral.
16+
spiral_config: Spiral configuration.
2017
2118
Returns:
2219
Differential arc length value.
@@ -25,138 +22,50 @@ def _arc_length_integrand(t: float, spiral_config: config.SpiralConfig) -> float
2522
return np.sqrt(r_t**2 + spiral_config.growth_rate**2)
2623

2724

28-
def _calculate_arc_length(theta: float, spiral_config: config.SpiralConfig) -> float:
29-
"""Calculate the arc length of the spiral from start_angle to theta.
25+
def _calculate_arc_length_between(
26+
theta_start: float, theta_end: float, spiral_config: config.SpiralConfig
27+
) -> float:
28+
"""Calculate the arc length of the spiral between two theta values.
3029
3130
Args:
32-
theta: The angle in radians.
33-
spiral_config: Configuration parameters for the spiral.
31+
theta_start: Starting angle in radians.
32+
theta_end: Ending angle in radians.
33+
spiral_config: Spiral configuration.
3434
3535
Returns:
36-
The arc length of the spiral from start_angle to theta.
36+
The arc length of the spiral from theta_start to theta_end.
3737
"""
3838
return integrate.quad(
3939
lambda t: _arc_length_integrand(t, spiral_config),
40-
spiral_config.start_angle,
41-
theta,
40+
theta_start,
41+
theta_end,
4242
)[0]
4343

4444

45-
def _find_theta_for_arc_length(
46-
target_arc_length: float, spiral_config: config.SpiralConfig
45+
def _find_theta_for_incremental_arc_length(
46+
target_increment: float,
47+
current_theta: float,
48+
spiral_config: config.SpiralConfig,
4749
) -> float:
48-
"""Find the theta value for a given arc length using numerical root finding.
50+
"""Find the theta value for a given incremental arc length from current position.
4951
5052
Args:
51-
target_arc_length: Target arc length.
52-
spiral_config: Configuration parameters for the spiral.
53+
target_increment: Target arc length increment from current position.
54+
current_theta: Current theta position.
55+
spiral_config: Spiral configuration.
5356
5457
Returns:
55-
Angle theta corresponding to the arc length.
58+
Angle theta corresponding to the target cumulative arc length.
5659
"""
5760
solution = optimize.root_scalar(
58-
lambda theta: _calculate_arc_length(theta, spiral_config) - target_arc_length,
59-
bracket=[spiral_config.start_angle, spiral_config.end_angle],
61+
lambda theta: _calculate_arc_length_between(current_theta, theta, spiral_config)
62+
- target_increment,
63+
bracket=(current_theta, spiral_config.end_angle),
6064
)
6165
return solution.root
6266

6367

64-
def _get_spiral_cache_key(spiral_config: config.SpiralConfig) -> str:
65-
"""Generate a cache key based on spiral configuration parameters.
66-
67-
Args:
68-
spiral_config: Configuration parameters for the spiral.
69-
70-
Returns:
71-
Hash string representing the configuration.
72-
"""
73-
config_str = (
74-
f"{spiral_config.center_x}_{spiral_config.center_y}_"
75-
f"{spiral_config.start_radius}_{spiral_config.growth_rate}_"
76-
f"{spiral_config.start_angle}_{spiral_config.end_angle}_"
77-
f"{spiral_config.num_points}"
78-
)
79-
return hashlib.md5(config_str.encode()).hexdigest()
80-
81-
82-
def _get_cache_path(spiral_config: config.SpiralConfig) -> pathlib.Path:
83-
"""Get the cache file path for a given spiral configuration.
84-
85-
Args:
86-
spiral_config: Configuration parameters for the spiral.
87-
88-
Returns:
89-
Path to the cache file.
90-
"""
91-
cache_key = _get_spiral_cache_key(spiral_config)
92-
package_cache_dir = pathlib.Path(__file__).parent.parent / "cache"
93-
94-
try:
95-
package_cache_dir.mkdir(parents=True, exist_ok=True)
96-
test_file = package_cache_dir / ".write_test"
97-
test_file.touch()
98-
test_file.unlink()
99-
except (PermissionError, OSError):
100-
logger.warning(
101-
"Package cache directory is not writable. "
102-
"Cannot save reference spiral to cache."
103-
)
104-
105-
return package_cache_dir / f"reference_spiral_{cache_key}.npy"
106-
107-
108-
def _load_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray | None:
109-
"""Load a pre-computed reference spiral from disk.
110-
111-
Args:
112-
spiral_config: Configuration parameters for the spiral.
113-
114-
Returns:
115-
Reference spiral array if found, None otherwise.
116-
"""
117-
cache_path = _get_cache_path(spiral_config)
118-
119-
if cache_path.exists():
120-
try:
121-
spiral = np.load(cache_path)
122-
logger.info(f"Loaded pre-computed reference spiral from {cache_path}")
123-
return spiral
124-
except Exception as e:
125-
logger.warning(f"Error loading cached spiral from {cache_path}: {e}")
126-
return None
127-
128-
return None
129-
130-
131-
def _compute_reference_spiral(
132-
spiral_config: config.SpiralConfig,
133-
) -> np.ndarray:
134-
"""Generate a reference spiral using numerical computation.
135-
136-
This is the computation-heavy implementation that performs numerical integration and
137-
root finding to create equidistant points along the spiral.
138-
139-
Args:
140-
spiral_config: Configuration parameters for the spiral.
141-
142-
Returns:
143-
Array with shape (N, 2) containing Cartesian coordinates of the spiral points.
144-
"""
145-
total_arc_length = _calculate_arc_length(spiral_config.end_angle, spiral_config)
146-
147-
arc_length_values = np.linspace(0, total_arc_length, spiral_config.num_points)
148-
149-
theta_values = np.array(
150-
[_find_theta_for_arc_length(s, spiral_config) for s in arc_length_values]
151-
)
152-
153-
r_values = spiral_config.start_radius + spiral_config.growth_rate * theta_values
154-
x_values = spiral_config.center_x + r_values * np.cos(theta_values)
155-
y_values = spiral_config.center_y + r_values * np.sin(theta_values)
156-
157-
return np.column_stack((x_values, y_values))
158-
159-
68+
@functools.lru_cache(maxsize=48)
16069
def generate_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray:
16170
"""Generate a reference spiral with equidistant points along its arc length.
16271
@@ -183,8 +92,7 @@ def generate_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray:
18392
- Cartesian coordinates: x = cx + r·cos(θ), y = cy + r·sin(θ)
18493
18594
Parameters are defined in the SpiralConfig class:
186-
- Center coordinates: (cx, cy) = (spiral_config.center_x,
187-
spiral_config.center_y)
95+
- Center coordinates: cx, cy = spiral_config.center_x, spiral_config.center_y
18896
- Start radius: a = spiral_config.start_radius
18997
- Growth rate: b = spiral_config.growth_rate
19098
- Total rotation: θ = spiral_config.end_angle - spiral_config.start_angle
@@ -196,17 +104,24 @@ def generate_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray:
196104
Returns:
197105
Array with shape (N, 2) containing Cartesian coordinates of the spiral points.
198106
"""
199-
cached_spiral = _load_reference_spiral(spiral_config)
200-
if cached_spiral is not None:
201-
return cached_spiral
107+
total_arc_length = _calculate_arc_length_between(
108+
spiral_config.start_angle, spiral_config.end_angle, spiral_config
109+
)
202110

203-
logger.info("No cached reference spiral found, generating new reference spiral...")
204-
spiral = _compute_reference_spiral(spiral_config)
111+
arc_length_increment = total_arc_length / (spiral_config.num_points - 1)
205112

206-
cache_path = _get_cache_path(spiral_config)
207-
cache_path.parent.mkdir(parents=True, exist_ok=True)
113+
theta_values = np.zeros(spiral_config.num_points)
114+
theta_values[0] = spiral_config.start_angle
208115

209-
logger.info(f"Saving generated reference spiral to cache: {cache_path}")
210-
np.save(cache_path, spiral)
116+
for i in range(1, spiral_config.num_points):
117+
theta_values[i] = _find_theta_for_incremental_arc_length(
118+
arc_length_increment,
119+
theta_values[i - 1],
120+
spiral_config,
121+
)
211122

212-
return spiral
123+
r_values = spiral_config.start_radius + spiral_config.growth_rate * theta_values
124+
x_values = spiral_config.center_x + r_values * np.cos(theta_values)
125+
y_values = spiral_config.center_y + r_values * np.sin(theta_values)
126+
127+
return np.column_stack((x_values, y_values))

tests/unit/test_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
def test_spiral_config_add_custom_params_valid(
3535
custom_params: dict[str, int | float],
3636
expected_params: dict[str, int | float],
37-
recwarn: pytest.WarningsRecorder,
37+
caplog: pytest.LogCaptureFixture,
3838
) -> None:
3939
"""Test that SpiralConfig.add_custom_params correctly sets parameter values."""
4040
spiral_config = config.SpiralConfig.add_custom_params(custom_params)
4141

4242
for key, value in expected_params.items():
4343
assert getattr(spiral_config, key) == value
44-
assert len(recwarn) == 0
44+
assert len(caplog.records) == 0
4545

4646

4747
@pytest.mark.parametrize(
@@ -71,17 +71,17 @@ def test_spiral_config_add_custom_params_warnings(
7171
custom_params: dict[str, int | float],
7272
expected_params: dict[str, int | float],
7373
expected_warnings: list[str],
74-
recwarn: pytest.WarningsRecorder,
74+
caplog: pytest.LogCaptureFixture,
7575
) -> None:
7676
"""Test that SpiralConfig.add_custom_params issues warnings appropriately."""
7777
spiral_config = config.SpiralConfig.add_custom_params(custom_params)
7878

79-
assert len(recwarn) == len(expected_warnings)
79+
assert len(caplog.records) == len(expected_warnings)
8080
for key, value in expected_params.items():
8181
assert getattr(spiral_config, key) == value
8282
for i, param in enumerate(expected_warnings):
8383
assert f"Unknown configuration parameters will be ignored: {param}" in str(
84-
recwarn[i].message
84+
caplog.records[i].message
8585
)
8686

8787

0 commit comments

Comments
 (0)