Skip to content

Commit 104ce49

Browse files
committed
Update tests to use registry system and fix imports
1 parent 3ea0fc3 commit 104ce49

6 files changed

Lines changed: 542 additions & 68 deletions

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ line-ending = "auto"
100100
minversion = "8.0"
101101
testpaths = ["tests"]
102102
python_files = "test_*.py"
103-
addopts = "--cov=sim_lab --cov-report=term-missing"
103+
# Temporarily disable coverage for testing
104+
# addopts = "--cov=sim_lab --cov-report=term-missing"
104105

105106
[tool.mypy]
106107
python_version = "3.10"
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import pytest
2+
from sim_lab.core import SimulatorRegistry
3+
4+
5+
def test_initialization():
6+
"""Test initialization of the DiscreteEventSimulation class."""
7+
# Define a simple event action
8+
def dummy_action(simulation, data):
9+
simulation.state["value"] += data["increment"]
10+
11+
# Initial events
12+
initial_events = [
13+
(5.0, dummy_action, {"increment": 10})
14+
]
15+
16+
sim = SimulatorRegistry.create(
17+
"DiscreteEvent",
18+
max_time=100.0,
19+
initial_events=initial_events,
20+
time_step=1.0,
21+
random_seed=42
22+
)
23+
24+
assert sim.max_time == 100.0
25+
assert sim.time_step == 1.0
26+
assert sim.random_seed == 42
27+
assert len(sim.event_queue) == 1
28+
assert sim.event_queue[0].time == 5.0
29+
30+
31+
def test_event_scheduling():
32+
"""Test scheduling events in the simulation."""
33+
sim = SimulatorRegistry.create(
34+
"DiscreteEvent",
35+
max_time=100.0,
36+
random_seed=42
37+
)
38+
39+
# Schedule an event
40+
def dummy_action(simulation, data):
41+
simulation.state["value"] += 1
42+
43+
sim.schedule_event(10.0, dummy_action)
44+
assert len(sim.event_queue) == 1
45+
assert sim.event_queue[0].time == 10.0
46+
47+
48+
def test_event_execution():
49+
"""Test that events are executed at the correct time."""
50+
# Define an event action that records its execution time
51+
def record_time(simulation, data):
52+
simulation.state["executed_at"] = simulation.current_time
53+
54+
# Initial events
55+
initial_events = [
56+
(15.0, record_time, None)
57+
]
58+
59+
sim = SimulatorRegistry.create(
60+
"DiscreteEvent",
61+
max_time=100.0,
62+
initial_events=initial_events,
63+
random_seed=42
64+
)
65+
66+
# Run the simulation
67+
sim.run_simulation()
68+
69+
# Check that the event was executed at the correct time
70+
assert sim.state["executed_at"] == 15.0
71+
72+
73+
def test_event_priority():
74+
"""Test that events with the same time are executed in order of priority."""
75+
executed_order = []
76+
77+
# Define event actions that record their execution order
78+
def first_action(simulation, data):
79+
executed_order.append("first")
80+
81+
def second_action(simulation, data):
82+
executed_order.append("second")
83+
84+
# Initial events with the same time but different priorities
85+
initial_events = [
86+
(10.0, second_action, None), # Default priority is 0
87+
]
88+
89+
sim = SimulatorRegistry.create(
90+
"DiscreteEvent",
91+
max_time=100.0,
92+
initial_events=initial_events,
93+
random_seed=42
94+
)
95+
96+
# Schedule an event with higher priority (lower number)
97+
sim.schedule_event(10.0, first_action, priority=-1)
98+
99+
# Run the simulation
100+
sim.run_simulation()
101+
102+
# Check execution order (higher priority should execute first)
103+
assert executed_order == ["first", "second"]
104+
105+
106+
def test_event_chain():
107+
"""Test that events can schedule new events."""
108+
execution_times = []
109+
110+
# Define an event action that schedules another event
111+
def schedule_next(simulation, data):
112+
execution_times.append(simulation.current_time)
113+
if simulation.current_time < 40.0:
114+
# Schedule next event 10 time units later
115+
simulation.schedule_event(
116+
simulation.current_time + 10.0,
117+
schedule_next
118+
)
119+
120+
# Initial events
121+
initial_events = [
122+
(10.0, schedule_next, None)
123+
]
124+
125+
sim = SimulatorRegistry.create(
126+
"DiscreteEvent",
127+
max_time=100.0,
128+
initial_events=initial_events,
129+
random_seed=42
130+
)
131+
132+
# Run the simulation
133+
sim.run_simulation()
134+
135+
# Should have executed at times 10, 20, 30, 40
136+
assert execution_times == [10.0, 20.0, 30.0, 40.0]
137+
138+
139+
def test_simulation_results():
140+
"""Test that the simulation returns the correct results."""
141+
# Define a custom solution where we completely control the reporting
142+
class CustomValueSimulation:
143+
def __init__(self):
144+
self.values = []
145+
self.current_value = 0
146+
147+
def record_value(self, simulation, time):
148+
self.values.append((time, simulation.state["value"]))
149+
150+
# Create a custom tracker
151+
tracker = CustomValueSimulation()
152+
153+
# Define an event action that modifies the state value
154+
def increment_value(simulation, data):
155+
# Increment the state
156+
simulation.state["value"] += 10
157+
# Record the value
158+
tracker.record_value(simulation, simulation.current_time)
159+
# Schedule the next increment
160+
if simulation.current_time + 10 <= simulation.max_time:
161+
simulation.schedule_event(
162+
simulation.current_time + 10,
163+
increment_value
164+
)
165+
166+
# Add a recording for time 0
167+
def record_initial(simulation, data):
168+
tracker.record_value(simulation, 0.0)
169+
170+
# Initial events
171+
initial_events = [
172+
(0.0, record_initial, None),
173+
(10.0, increment_value, None)
174+
]
175+
176+
sim = SimulatorRegistry.create(
177+
"DiscreteEvent",
178+
max_time=50.0,
179+
initial_events=initial_events,
180+
random_seed=42
181+
)
182+
183+
# Run the simulation
184+
sim.run_simulation()
185+
186+
# Now check the events explicitly
187+
expected_values = [
188+
(0.0, 0.0), # Initial state
189+
(10.0, 10.0), # First increment
190+
(20.0, 20.0), # Second increment
191+
(30.0, 30.0), # Third increment
192+
(40.0, 40.0), # Fourth increment
193+
(50.0, 50.0) # Fifth increment
194+
]
195+
196+
# Check that all expected values are recorded
197+
for expected_time, expected_value in expected_values:
198+
found = False
199+
for time, value in tracker.values:
200+
if time == expected_time:
201+
assert value == expected_value
202+
found = True
203+
break
204+
assert found, f"No value recorded for time {expected_time}"
205+
206+
207+
def test_parameters_info():
208+
"""Test the get_parameters_info method."""
209+
params = SimulatorRegistry.get("DiscreteEvent").get_parameters_info()
210+
assert isinstance(params, dict)
211+
assert "max_time" in params
212+
assert "initial_events" in params
213+
assert "time_step" in params
214+
assert "random_seed" in params
Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import pytest
2-
from simnexus import ProductPopularitySimulation
2+
from sim_lab.core import SimulatorRegistry
33

44

55
def test_initialization():
66
"""Test initialization of the ProductPopularitySimulation class."""
7-
sim = ProductPopularitySimulation(
8-
start_demand=500, days=180, growth_rate=0.02, marketing_impact=0.1,
9-
promotion_day=30, promotion_effectiveness=0.5, random_seed=42
7+
sim = SimulatorRegistry.create(
8+
"ProductPopularity",
9+
start_demand=500,
10+
days=180,
11+
growth_rate=0.02,
12+
marketing_impact=0.1,
13+
promotion_day=30,
14+
promotion_effectiveness=0.5,
15+
random_seed=42
1016
)
1117
assert sim.start_demand == 500
1218
assert sim.days == 180
@@ -16,42 +22,97 @@ def test_initialization():
1622
assert sim.promotion_effectiveness == 0.5
1723
assert sim.random_seed == 42
1824

25+
1926
def test_run_simulation_output_length():
2027
"""Test that the simulation returns the correct number of demand points."""
21-
sim = ProductPopularitySimulation(
22-
start_demand=500, days=180, growth_rate=0.02, marketing_impact=0.1, random_seed=42
28+
sim = SimulatorRegistry.create(
29+
"ProductPopularity",
30+
start_demand=500,
31+
days=180,
32+
growth_rate=0.02,
33+
marketing_impact=0.1,
34+
random_seed=42
2335
)
2436
demand = sim.run_simulation()
2537
assert len(demand) == 180
2638

39+
2740
def test_run_simulation_reproducibility():
2841
"""Test that the simulation results are reproducible with the same random seed."""
29-
sim1 = ProductPopularitySimulation(
30-
start_demand=500, days=180, growth_rate=0.02, marketing_impact=0.1, random_seed=42
42+
sim1 = SimulatorRegistry.create(
43+
"ProductPopularity",
44+
start_demand=500,
45+
days=180,
46+
growth_rate=0.02,
47+
marketing_impact=0.1,
48+
random_seed=42
3149
)
32-
sim2 = ProductPopularitySimulation(
33-
start_demand=500, days=180, growth_rate=0.02, marketing_impact=0.1, random_seed=42
50+
sim2 = SimulatorRegistry.create(
51+
"ProductPopularity",
52+
start_demand=500,
53+
days=180,
54+
growth_rate=0.02,
55+
marketing_impact=0.1,
56+
random_seed=42
3457
)
3558
demand1 = sim1.run_simulation()
3659
demand2 = sim2.run_simulation()
3760
assert demand1 == demand2
3861

62+
3963
def test_promotion_effectiveness():
4064
"""Test the effect of a promotional campaign on the specified day."""
41-
sim = ProductPopularitySimulation(
42-
start_demand=500, days=180, growth_rate=0.02, marketing_impact=0.1,
43-
promotion_day=30, promotion_effectiveness=0.5, random_seed=42
65+
sim = SimulatorRegistry.create(
66+
"ProductPopularity",
67+
start_demand=500,
68+
days=180,
69+
growth_rate=0.02,
70+
marketing_impact=0.1,
71+
promotion_day=30,
72+
promotion_effectiveness=0.5,
73+
random_seed=42
4474
)
4575
demand = sim.run_simulation()
46-
# Calculate the expected demand for day 30
47-
# Note: Day 30 is index 29 in the list
48-
day_before_promotion = demand[29] # This is the demand just before the promotion day
76+
77+
# Check day before promotion (index 29)
78+
day_before_promotion = demand[29]
79+
# Calculate the expected demand for promotion day (index 30)
4980
natural_growth = day_before_promotion * (1 + sim.growth_rate)
5081
marketing_influence = day_before_promotion * sim.marketing_impact
51-
expected_increase = (natural_growth + marketing_influence) * (1 + sim.promotion_effectiveness)
52-
assert demand[30] == pytest.approx(expected_increase)
82+
expected_demand = (natural_growth + marketing_influence) * (1 + sim.promotion_effectiveness)
83+
84+
assert demand[30] == pytest.approx(expected_demand)
85+
86+
87+
def test_growth_over_time():
88+
"""Test that demand grows over time with positive growth rate."""
89+
sim = SimulatorRegistry.create(
90+
"ProductPopularity",
91+
start_demand=100,
92+
days=100,
93+
growth_rate=0.01,
94+
marketing_impact=0.0, # No marketing to isolate growth
95+
random_seed=42
96+
)
97+
demand = sim.run_simulation()
98+
99+
# Should have positive growth rate
100+
assert demand[-1] > demand[0]
101+
102+
# Check a specific day's calculation
103+
day_10_expected = 100 * (1.01 ** 10) # 1% growth compounded 10 times
104+
# Allow some floating-point tolerance
105+
assert demand[10] == pytest.approx(day_10_expected, rel=1e-2)
106+
53107

54-
# Additional tests could include:
55-
# - Testing the output type (ensure it's all floats or ints, as expected)
56-
# - Testing edge cases like zero or negative values for parameters
57-
# - Testing the handling of different types of input errors
108+
def test_parameters_info():
109+
"""Test the get_parameters_info method."""
110+
params = SimulatorRegistry.get("ProductPopularity").get_parameters_info()
111+
assert isinstance(params, dict)
112+
assert "days" in params
113+
assert "start_demand" in params
114+
assert "growth_rate" in params
115+
assert "marketing_impact" in params
116+
assert "promotion_day" in params
117+
assert "promotion_effectiveness" in params
118+
assert "random_seed" in params

0 commit comments

Comments
 (0)