1- # mypy: ignore-errors
21"""
32==========================
43Vivarium Testing Utilities
76Utility functions and classes to make testing ``vivarium`` components easier.
87
98"""
9+ from __future__ import annotations
1010
11+ from collections .abc import Callable , Sequence
12+ from datetime import datetime
1113from itertools import product
1214from pathlib import Path
1315from typing import Any
1618import pandas as pd
1719
1820from vivarium import Component
19- from vivarium .framework import randomness
2021from vivarium .framework .engine import Builder
2122from vivarium .framework .event import Event
2223from vivarium .framework .population import SimulantData
2324from vivarium .framework .randomness .index_map import IndexMap
25+ from vivarium .framework .randomness .stream import RandomnessStream
26+ from vivarium .types import ClockStepSize , ClockTime
2427
2528
2629class NonCRNTestPopulation (Component ):
@@ -64,7 +67,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
6467
6568 def on_time_step (self , event : Event ) -> None :
6669 population = self .population_view .get (event .index , query = "alive == 'alive'" )
67- population ["age" ] += event .step_size / pd .Timedelta (days = 365 )
70+ # This component won't work if event.step_size is an int
71+ if not isinstance (event .step_size , int ):
72+ population ["age" ] += event .step_size / pd .Timedelta (days = 365 )
6873 self .population_view .update (population )
6974
7075
@@ -85,7 +90,11 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
8590 )
8691 age_draw = self .age_randomness .get_draw (pop_data .index )
8792 if age_start == age_end :
88- age = age_draw * (pop_data .creation_window / pd .Timedelta (days = 365 )) + age_start
93+ # This component won't work if creation window is an int
94+ if not isinstance (pop_data .creation_window , int ):
95+ age = (
96+ age_draw * (pop_data .creation_window / pd .Timedelta (days = 365 )) + age_start
97+ )
8998 else :
9099 age = age_draw * (age_end - age_start ) + age_start
91100
@@ -104,7 +113,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
104113 self .population_view .update (population )
105114
106115
107- def _build_population (core_population , location , randomness_stream ):
116+ def _build_population (
117+ core_population : pd .DataFrame , location : str , randomness_stream : RandomnessStream
118+ ) -> pd .DataFrame :
108119 index = core_population .index
109120
110121 population = pd .DataFrame (
@@ -124,13 +135,20 @@ def _build_population(core_population, location, randomness_stream):
124135
125136
126137def _non_crn_build_population (
127- index , age_start , age_end , location , creation_time , creation_window , randomness_stream
128- ):
138+ index : pd .Index [int ],
139+ age_start : float ,
140+ age_end : float ,
141+ location : str ,
142+ creation_time : ClockTime ,
143+ creation_window : ClockStepSize ,
144+ randomness_stream : RandomnessStream ,
145+ ) -> pd .DataFrame :
129146 if age_start == age_end :
130- age = (
131- randomness_stream .get_draw (index ) * (creation_window / pd .Timedelta (days = 365 ))
132- + age_start
133- )
147+ if not isinstance (creation_window , int ):
148+ age = (
149+ randomness_stream .get_draw (index ) * (creation_window / pd .Timedelta (days = 365 ))
150+ + age_start
151+ )
134152 else :
135153 age = randomness_stream .get_draw (index ) * (age_end - age_start ) + age_start
136154
@@ -152,12 +170,12 @@ def _non_crn_build_population(
152170
153171def build_table (
154172 value : Any ,
155- parameter_columns : dict = {
173+ parameter_columns : dict [ str , Sequence [ int ]] = {
156174 "age" : (0 , 125 ),
157175 "year" : (1990 , 2020 ),
158176 },
159- key_columns : dict = {"sex" : ("Female" , "Male" )},
160- value_columns : list = ["value" ],
177+ key_columns : dict [ str , Sequence [ Any ]] = {"sex" : ("Female" , "Male" )},
178+ value_columns : list [ str ] = ["value" ],
161179) -> pd .DataFrame :
162180 """
163181
@@ -191,7 +209,7 @@ def build_table(
191209 }
192210 # Build out dict of items we will need cartesian product of to make dataframe
193211 product_dict = dict (range_parameter_product )
194- product_dict .update (key_columns )
212+ product_dict .update (key_columns ) # type: ignore [arg-type]
195213 products = product (* product_dict .values ())
196214
197215 rows = []
@@ -212,10 +230,12 @@ def build_table(
212230 # Transform parameter column values
213231 parameter_columns_index_values = item [: len (parameter_columns )]
214232 # Create intervals for parameter columns. Example year, year+1 for year_start and year_end
215- parameter_columns_index_values = [
233+ unpacked_parameter_columns_index_values : list [ Any ] = [
216234 v for val in parameter_columns_index_values for v in (val , val + 1 )
217235 ]
218- rows .append (parameter_columns_index_values + key_columns_index_values + r_values )
236+ rows .append (
237+ unpacked_parameter_columns_index_values + key_columns_index_values + r_values
238+ )
219239
220240 # Make list of parameter column names
221241 parameter_column_names = [
@@ -228,34 +248,13 @@ def build_table(
228248 )
229249
230250
231- def make_dummy_column (name , initial_value ):
232- class DummyColumnMaker :
233- @property
234- def name (self ):
235- return "dummy_column_maker"
236-
237- def setup (self , builder ):
238- self .population_view = builder .population .get_view (name )
239- builder .population .initializes_simulants (self .make_column , creates_columns = name )
240-
241- def make_column (self , pop_data ):
242- self .population_view .update (
243- pd .Series (initial_value , index = pop_data .index , name = name )
244- )
245-
246- def __repr__ (self ):
247- return f"dummy_column(name={ name } , initial_value={ initial_value } )"
248-
249- return DummyColumnMaker ()
250-
251-
252251def get_randomness (
253- key = "test" ,
254- clock = lambda : pd .Timestamp (1990 , 7 , 2 ),
255- seed = 12345 ,
256- initializes_crn_attributes = False ,
257- ):
258- return randomness . RandomnessStream (
252+ key : str = "test" ,
253+ clock : Callable [[], pd . Timestamp | datetime | int ] = lambda : pd .Timestamp (1990 , 7 , 2 ),
254+ seed : int = 12345 ,
255+ initializes_crn_attributes : bool = False ,
256+ ) -> RandomnessStream :
257+ return RandomnessStream (
259258 key ,
260259 clock ,
261260 seed = seed ,
@@ -264,10 +263,5 @@ def get_randomness(
264263 )
265264
266265
267- def reset_mocks (mocks ):
268- for mock in mocks :
269- mock .reset_mock ()
270-
271-
272266def metadata (file_path : str , layer : str = "override" ) -> dict [str , str ]:
273267 return {"layer" : layer , "source" : str (Path (file_path ).resolve ())}
0 commit comments