diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 07d8fffc2..3551e9967 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -24,7 +24,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.11" + python-version: "3.13" - name: Install dependencies run: | python --version diff --git a/.readthedocs.yml b/.readthedocs.yml index 96b012aed..43bcd5186 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -11,7 +11,7 @@ sphinx: build: os: ubuntu-22.04 tools: - python: "3.11" + python: "3.13" python: install: # This runs pip install .[docs] from the project root. diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c952187a0..4c5c7ecc3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,85 @@ +**4.0.0 - TBD TBD TBD** +----------------------- + +This release introduces a major refactor of the population management system as well +as various other miscellaneous changes. + +Breaking changes +---------------- + +Population management system refactor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Interactive context: 'get_population()' will now error if requesting an attribute that doesn't exist. +- Population views: Replace subviews and 'get()' method with 'get_attributes()', + 'get_attribute_frame()', and 'get_private_columns()'. + + - You must now explicitly request which attributes you want to retrieve. + - Write access (via the 'update()' method) is now restricted to private columns + created by the component the view is attached to. + +- Population views: Remove support for population view default queries. +- Population interface: Replace the 'tracked' column and corresponding auto-filter + logic with a new 'register_tracked_query()' method and 'include_untracked' argument + when getting attributes or private columns from a population view. +- Population interface: Require explicit initializer method registration instead + of inferring from methods named 'on_initialize_simulants()'. Supports multiple + initializer methods per component. + + - Remove columns_created, columns_required, and initialization_requirements properties throughout. + - Changed the names of all initializer methods (no longer 'on_initialize_simulants'). + +- Population manager: 'get_population()' now requires an explicit attribute request ("all" is allowed). +- Stop returning AttributePipelines (previously Pipelines) when registering them. + +Miscellaneous +~~~~~~~~~~~~~ + +- Split managers and their corresponding interfaces into separate modules. +- Replace 'requires_columns' and 'requires_values' arguments with 'requires_attributes' throughout. +- Replace 'dependencies' arguments with 'required_resources' throughout. +- Change default behavior of state machine 'allow_self_transition' to True. + +Major changes +------------- + +- Replace all Pipelines with AttributePipelines throughout. +- Support attribute names as source and/or modifiers to AttributePipelines. +- Add support for python 3.12 and 3.13 + +Other changes +------------- + +Population management system refactor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- InteractiveContext: Allow specific column request to 'get_population()'. +- InteractiveContext: Implement new 'get_columns()' method to get all attribute names. +- Population views: Implement new 'get_filtered_index()' method. +- Stop using population views inappropriately when using individualized clocks. +- Implement 'skip_post_processor' argument to population view 'get_attributes()' + and population manager 'get_population()' methods. +- Ensure Pipeline 'union_post_processor' always returns a Series or DataFrame. +- Change 'alive' string column to 'is_alive' boolean column in disease model example and various tests. +- Make Mortality a sub-component of BasePopulation component in disease model example. +- Update documentation. + +Miscellaneous +~~~~~~~~~~~~~ + +- LookupTables: Improve type hinting. +- LookupTables: Register tables as resources. +- LookupTables: Warn if unused tables are registered. +- Properly set Pipeline source dependencies. +- Infer component when creating a Resource. +- Clean up Resource registration. +- Clean up ComponentManager. +- Only get random draws for non-0 and non-1 probabilities when calling 'filter_for_probability()'. +- Fix mypy error in setup.py. +- Fix mypy errors in disease model example. +- Create a Component.logger property for better type-checking. +- Removed unnecessary InitializerComponentSet class. + **3.6.6 - 01/06/26** - Fail deployment if changelog date is not current date diff --git a/README.rst b/README.rst index 803fdba5f..71b63a6ee 100644 --- a/README.rst +++ b/README.rst @@ -19,7 +19,7 @@ Vivarium Vivarium is a simulation framework written using standard scientific Python tools. -**Vivarium requires Python 3.8-3.11 to run** +**Vivarium requires Python 3.10-3.13 to run** You can install ``vivarium`` from PyPI with pip: @@ -31,7 +31,7 @@ or build it from source with ``> cd vivarium`` - ``> conda create -n ENVIRONMENT_NAME python=3.11`` + ``> conda create -n ENVIRONMENT_NAME python=3.12`` ``> pip install -e .[dev]`` diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 0dd933fa2..edbb95deb 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -47,9 +47,13 @@ py:exc ResultsConfigurationError py:exc vivarium.framework.results.exceptions.ResultsConfigurationError py:class vivarium.framework.plugins.M py:class vivarium.framework.plugins.I -py:class vivarium.framework.utilities.T +py:class vivarium.framework.utilities.TimeValue +py:class T +py:class vivarium.framework.components.manager.T +py:class vivarium.framework.components.manager.C # layered_config_tree +py:class LayeredConfigTree py:class layered_config_tree.main.LayeredConfigTree py:exc layered_config_tree.exceptions.ConfigurationError @@ -69,5 +73,8 @@ py:class Logger py:class Path py:class LookupTableData +# pathlib internals +py:class pathlib._local.Path + # typing py:data typing.Union diff --git a/docs/source/api_reference/framework/event.rst b/docs/source/api_reference/framework/event.rst deleted file mode 100644 index 05fdc115c..000000000 --- a/docs/source/api_reference/framework/event.rst +++ /dev/null @@ -1 +0,0 @@ -.. automodule:: vivarium.framework.event diff --git a/docs/source/api_reference/framework/event/index.rst b/docs/source/api_reference/framework/event/index.rst new file mode 100644 index 000000000..53d5795e4 --- /dev/null +++ b/docs/source/api_reference/framework/event/index.rst @@ -0,0 +1,11 @@ +================ +Event Management +================ + +.. automodule:: vivarium.framework.event + +.. toctree:: + :maxdepth: 1 + :glob: + + * \ No newline at end of file diff --git a/docs/source/api_reference/framework/event/interface.rst b/docs/source/api_reference/framework/event/interface.rst new file mode 100644 index 000000000..8215449e6 --- /dev/null +++ b/docs/source/api_reference/framework/event/interface.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.event.interface diff --git a/docs/source/api_reference/framework/event/manager.rst b/docs/source/api_reference/framework/event/manager.rst new file mode 100644 index 000000000..474e6079c --- /dev/null +++ b/docs/source/api_reference/framework/event/manager.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.event.manager diff --git a/docs/source/api_reference/framework/logging/index.rst b/docs/source/api_reference/framework/logging/index.rst index ceb142603..a1c33dc75 100644 --- a/docs/source/api_reference/framework/logging/index.rst +++ b/docs/source/api_reference/framework/logging/index.rst @@ -1,3 +1,7 @@ +================== +Logging Management +================== + .. automodule:: vivarium.framework.logging .. toctree:: diff --git a/docs/source/api_reference/framework/lookup/index.rst b/docs/source/api_reference/framework/lookup/index.rst index ca389d623..78861709c 100644 --- a/docs/source/api_reference/framework/lookup/index.rst +++ b/docs/source/api_reference/framework/lookup/index.rst @@ -1,3 +1,7 @@ +======================= +Lookup Table Management +======================= + .. automodule:: vivarium.framework.lookup .. toctree:: diff --git a/docs/source/api_reference/framework/population/interface.rst b/docs/source/api_reference/framework/population/interface.rst new file mode 100644 index 000000000..bc01c850b --- /dev/null +++ b/docs/source/api_reference/framework/population/interface.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.population.interface diff --git a/docs/source/api_reference/framework/population/utilities.rst b/docs/source/api_reference/framework/population/utilities.rst new file mode 100644 index 000000000..a54575522 --- /dev/null +++ b/docs/source/api_reference/framework/population/utilities.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.population.utilities \ No newline at end of file diff --git a/docs/source/api_reference/framework/time.rst b/docs/source/api_reference/framework/time.rst deleted file mode 100644 index 3008907c7..000000000 --- a/docs/source/api_reference/framework/time.rst +++ /dev/null @@ -1 +0,0 @@ -.. automodule:: vivarium.framework.time diff --git a/docs/source/api_reference/framework/time/index.rst b/docs/source/api_reference/framework/time/index.rst new file mode 100644 index 000000000..4179e9002 --- /dev/null +++ b/docs/source/api_reference/framework/time/index.rst @@ -0,0 +1,11 @@ +=============== +Time Management +=============== + +.. automodule:: vivarium.framework.time + +.. toctree:: + :maxdepth: 1 + :glob: + + * \ No newline at end of file diff --git a/docs/source/api_reference/framework/time/interface.rst b/docs/source/api_reference/framework/time/interface.rst new file mode 100644 index 000000000..d8390ac3a --- /dev/null +++ b/docs/source/api_reference/framework/time/interface.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.time.interface diff --git a/docs/source/api_reference/framework/time/manager.rst b/docs/source/api_reference/framework/time/manager.rst new file mode 100644 index 000000000..17b4eabb2 --- /dev/null +++ b/docs/source/api_reference/framework/time/manager.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.time.manager diff --git a/docs/source/api_reference/framework/values/interface.rst b/docs/source/api_reference/framework/values/interface.rst new file mode 100644 index 000000000..7f96119dd --- /dev/null +++ b/docs/source/api_reference/framework/values/interface.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.interface \ No newline at end of file diff --git a/docs/source/concepts/builder.rst b/docs/source/concepts/builder.rst index 8e1ed65b1..ae2524436 100644 --- a/docs/source/concepts/builder.rst +++ b/docs/source/concepts/builder.rst @@ -24,10 +24,10 @@ they register for services and provide information about their structure. For ex a component needing to leverage the simulation clock and step size to determine a numerical effect to apply on each time step, will get the simulation clock and step size though the Builder and will register -method(s) to apply the effect (e.g., via :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_modifier`). -Another component, needing to initialize state for simulants at before the -simulation begin, might call :meth:`vivarium.framework.population.manager.PopulationInterface.initializes_simulants` in its setup -method to register method(s) that setup the additional state. +method(s) to apply the effect (e.g., via :meth:`vivarium.framework.values.interface.ValuesInterface.register_value_modifier`). +Another component, needing to initialize state for simulants before the +simulation begins, might call :meth:`vivarium.framework.population.interface.PopulationInterface.register_initializer` +in its setup method to register method(s) that set up the additional state. Outline diff --git a/docs/source/concepts/crn.rst b/docs/source/concepts/crn.rst index f5f18bf8d..ceb4b52a6 100644 --- a/docs/source/concepts/crn.rst +++ b/docs/source/concepts/crn.rst @@ -188,30 +188,31 @@ for this decision point of whether to move left or not. Here's how we'd do it: import pandas as pd - class MoveLeft: + from vivarium import Component + + class MoveLeft(Component): @property def name(self): return 'move_left' def setup(self, builder): - self.randomness = builder.randomness.get_stream('move_left') - - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=['location']) - - self.population_view = builder.population.get_view(['location']) + self.randomness = builder.randomness.get_stream(self.name) - builder.event.register_listener('time_step', self.on_time_step) + builder.population.register_initializer( + initializer=self.initialize_location, + columns='location', + required_resources=[self.randomness], + ) - def on_initialize_simulants(self, pop_data): + def initialize_location(self, pop_data): # all simulants start at position 10 self.population_view.update(pd.Series(10, index=pop_data.index, name='location')) def on_time_step(self, event): # with probability 0.5 simulants move to the left 1 position to_move_index = self.randomness.filter_for_probability(event.index, pd.Series(0.5, index=event.index)) - moved_locations = self.population_view.get(to_move_index).location - 1 + moved_locations = self.population_view.get_attributes(to_move_index, "location") - 1 self.population_view.update(moved_locations) diff --git a/docs/source/concepts/event.rst b/docs/source/concepts/event.rst index eb8d2edf0..1ce032dda 100644 --- a/docs/source/concepts/event.rst +++ b/docs/source/concepts/event.rst @@ -17,9 +17,9 @@ a means of coordinating across various components in a simulation. What is an Event? ----------------- -An :class:`Event ` is a simple container for a -group of attributes that provide all the necessary information to respond to -the event. Events have the following attributes: +An :class:`~ ` is a simple container for a +group of class attributes that provide all the necessary information to respond to +the event. Each Event contains the following: .. list-table:: **Event Attributes** :header-rows: 1 @@ -28,7 +28,7 @@ the event. Events have the following attributes: * - Name - Description * - | index - - | An index into the population table that contains all + - | An index into the population state table that contains all | individuals that may respond to the event. * - | time - | The time at which the event will resolve. The current simulation @@ -89,14 +89,14 @@ these phases. Interacting with Events ----------------------- -The :class:`EventInterface ` is +The :class:`~ ` is available off the :ref:`Builder ` and provides two options for interacting with the event system: -1. :func:`register_listener ` to add a +1. :func:`~ ` to add a listener to a given event to be called on emission -2. :func:`get_emitter ` +2. :func:`~ ` to retrieve a callable emitter for a given event Although methods for both getting emitters and registering listeners are @@ -109,9 +109,9 @@ Registering Listeners In order to register a listener to an event to respond when that event is emitted, we can use the -:func:`register_listener `. The listener +:func:`~ `. The listener itself should be a callable function that takes as its only argument -the :class:`Event ` that is emitted. +the :class:`~ ` that is emitted. Suppose we wish to track how many simulants are affected each time step. We could do this by creating an observer component with an ``on_time_step`` method @@ -144,11 +144,18 @@ another row to our dataframe tracking the number of affected simulants. simulation at the beginning of the next time step should only depend on the current state of the system. +.. note:: + + If a new component is being created that inherits from :class:`vivarium.component.Component`, + listeners are registered automatically if the component defines methods named + ``on_``, where ```` is one of the lifecycle names + (e.g. ``time_step``, ``collect_metrics``, etc.). + Emitting Events +++++++++++++++ -The :func:`get_emitter ` +The :func:`~ ` provides a way to get a callable emitter for a given named event. It can be used as follows: diff --git a/docs/source/concepts/lifecycle.rst b/docs/source/concepts/lifecycle.rst index c9483a0d9..340659df8 100644 --- a/docs/source/concepts/lifecycle.rst +++ b/docs/source/concepts/lifecycle.rst @@ -115,7 +115,7 @@ the simulation components will have their ``setup`` method called with the simulation :ref:`builder ` as an argument. The builder allows the components to request services like :ref:`randomness ` or views into the -:term:`population state table ` or to register themselves +:term:`state table ` or to register themselves with various simulation subsystems. Setting up components may also involve loading data, registering or getting :ref:`pipelines `, creating :ref:`lookup tables `, and registering @@ -173,7 +173,7 @@ the state transitions by emitting a series of events for each simulation outputs. By listening for these events, individual components can perform actions, -including manipulating the :ref:`state table `. This +including manipulating the :ref:`population state table `. This sequence of events is repeated until the simulation clock passes the simulation end time. diff --git a/docs/source/concepts/lookup.rst b/docs/source/concepts/lookup.rst index fb86488e3..95975f9b8 100644 --- a/docs/source/concepts/lookup.rst +++ b/docs/source/concepts/lookup.rst @@ -41,55 +41,32 @@ population in a simulation. The lookup table system is built in layers. At the top is the :class:`Lookup Table ` object which is responsible for providing a uniform interface to the user regardless -of the underlying implementation. From the user's perspective, it takes in -a data set or scalar value on initialization and then lets them query against -that data with a population index. - -The next layer is selected at initialization time based on the type of data -provided. The :class:`Lookup Table ` -picks a :class:`ScalarTable ` -if a single value is provided as the data, a -:class:`CategoricalTable ` if a -:class:`pandas.DataFrame` with only categorical variables is provided as the -data, and a :class:`InterpolatedTable ` -if a :class:`pandas.DataFrame` which has at least one continuous variable is -provided as the data. +of the underlying data. From the user's perspective, it takes in a data set +or scalar value on initialization and then lets them query against that data +with a population index. + +At initialization time, the +:class:`Lookup Table ` examines the +provided data and configures itself accordingly. If the data is a scalar value +(or list/tuple of scalars), the table simply broadcasts those values over the +population index when called. If the data is a :class:`pandas.DataFrame`, the +table delegates to an +:class:`Interpolation ` +object that handles both categorical and continuous parameter lookups. The +:class:`Interpolation ` +groups the data by any categorical (key) columns and then, for each group, +finds the correct bin for any continuous parameters. Tables with only +categorical parameters are simply the special case where there are no +continuous parameters to bin on. .. note:: - The :class:`InterpolatedTable ` - is a misnomer here. It confuses the data handling strategy with the - underlying data representation. A better name would be ``BinnedDataTable`` - to indicate that it wraps data where the continuous parameters are - represented by bin edges in the provided data. This would allow us - to easily think about and extend the lookup system to wrap data where the - continuous parameters are represented by points and to tables where all - parameters are categorical. - -If the underlying data is a single value or consists only of categorical variables, -this is the last layer of abstraction. The -:class:`ScalarTable ` and -:class:`CategoricalTable ` each -have only one reasonable strategy which is to broadcast the value over the -population index. If we have continuous variables and therefore an -:class:`InterpolatedTable `, -there are additional layers to the lookup system to allow the user to -control the strategy for turning the population index into values based on -the data. The -:class:`InterpolatedTable ` -is then responsible for turning the population index into a set of -attributes relevant to the value production based on the structure of -the input data and then providing those attributes to the value production -strategy. - -.. note:: - - I'm being careful with language here. We have objects named - ``Interpolation`` and ``InterpolatedTable`` though the operation they - perform is actually disaggregation. If we extend the system to - work with point estimates for the continuous parameters, then - interpolation would appropriately describe what we do. Both are - value production strategies based on the structure of the input data. + The ``Interpolation`` name is somewhat of a misnomer. For order 0 + (the only currently supported order), the operation is really + disaggregation -- finding the correct bin a value belongs to rather + than interpolating between points. If the system is extended to work + with point estimates for continuous parameters, then interpolation + would appropriately describe the operation. More information about the value production strategies can be found in :ref:`here `. @@ -116,7 +93,7 @@ corresponding column names which are used to query an internal when the table itself is called. This means the lookup table only needs to be called with a population index -- it gathers the population information it needs itself. It also means the data must be available in the -:term:`population state table ` with the same column name. +:term:`population state table ` with the same column name. In the table below is an example of (unrealistic) data that could be used to create a lookup table for a quantity of interest about a population, @@ -145,10 +122,10 @@ Female 60 100 27 Constructing Lookup Tables from a Component ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Components can register lookup tables to be built by specifying -a ``data_sources`` block in their :attr:`~vivarium.component.Component.configuration_defaults` property. -As a basic example, DiseaseModel in ``vivarium_public_health`` has the following -``data_sources`` configuration: +Components can build lookup tables as needed via the :meth:`~vivarium.component.Component.build_lookup_table` +method which will refer to the ``data_sources`` block in the component's +:attr:`~vivarium.component.Component.configuration_defaults` property. As a basic example, +DiseaseModel in ``vivarium_public_health`` has the following ``data_sources`` configuration: .. code-block:: python @@ -162,25 +139,14 @@ As a basic example, DiseaseModel in ``vivarium_public_health`` has the following }, } -which specifies a single lookup table named -``cause_specific_mortality_rate`` whose data is provided by the component's +which specifies that when building a lookup table named +``cause_specific_mortality_rate``, the data should be provided by the component's ``load_cause_specific_mortality_rate`` method. Each entry in ``data_sources`` maps a table name to a data source from one of -several supported types (see `Data Source Types`_). Barring edge cases (see -`Limitations and When to Override`_), one should specify all of a component's -lookup tables via ``data_sources``, instead of accessing the builder's lookup -interface directly. - -When a component configures ``data_sources``, the base -:class:`~vivarium.component.Component` class automatically builds -the lookup tables before the component's :meth:`~vivarium.component.Component.setup` method is called. The -resulting tables are stored in the component's :attr:`~vivarium.component.Component.lookup_tables` dictionary, -keyed by the name specified in ``data_sources``. - -This approach separates the *what* (which tables to build and where to get data) from the -*how* (the mechanics of table construction), making components easier to -write and configure. It also allows users to override data sources in model specification files +several supported types (see `Data Source Types`_). + +This approach allows users to override data sources in model specification files without modifying component code. Following the example above, a model specification could adjust the ``cause_specific_mortality_rate`` data source to point to different data or a scalar value: @@ -234,8 +200,8 @@ When building a lookup table from a :class:`pandas.DataFrame` using ``data_sourc the component automatically determines key columns, parameter columns, and value columns based on the data structure: -- **Value columns** are assumed by the structure of the artifact to be ``["value"]``. In principle, - this could be configured by implementing a custom :class:`~vivarium.framework.artifact.manager.ArtifactManager`. +- **Value columns** can be provided as an argument to :meth:`~vivarium.component.Component.build_lookup_table` + If value columns are not provided, it will default to ``"value"``. - **Parameter columns** are detected by finding columns ending in ``_start`` that have corresponding ``_end`` columns (e.g., ``age_start``/``age_end``). - **Key columns** are all remaining columns that are neither value columns @@ -290,35 +256,6 @@ code: # Or point to different artifact data life_expectancy: "alternative.life_expectancy.data" -Limitations and When to Override -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The automatic ``data_sources`` mechanism works well for straightforward cases, -but some situations require overriding the :meth:`~vivarium.component.Component.build_all_lookup_tables` method: - -**Non-standard value columns:** - The component defaults to ``["value"]`` as the value column name. If your - data has differently named value columns or multiple value columns, you - must call :meth:`~vivarium.component.Component.build_lookup_table` directly with explicit - ``value_columns``. - -**Complex data transformations:** - When data requires transformation before building tables (e.g., pivoting, - computing derived parameters, combining multiple data sources), override - :meth:`~vivarium.component.Component.build_all_lookup_tables` to perform the transformation first. - -**Delegation to sub-components:** - When lookup tables should be built by sub-components rather than the - parent component, override :meth:`~vivarium.component.Component.build_all_lookup_tables` to skip the - default behavior. - -Examples of these patterns can be found in `vivarium_public_health `_: - -- `RateTransition `_ and `DiseaseState `_ in `vivarium_public_health.disease `_ - demonstrate the basic ``data_sources`` pattern with various data source types. -- ``Risk`` in ``vivarium_public_health.risks`` overrides :meth:`~vivarium.component.Component.build_all_lookup_tables` - to delegate table construction to its exposure distribution sub-component. - Using the Lookup Interface Directly ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -336,11 +273,15 @@ integrating a lookup table into a :term:`component `, which is primar how they are used. Assuming you have a valid simulation object named ``sim`` and the data from the above table in a :class:`pandas.DataFrame` named ``data``, you can construct a lookup table in the following way, using the interface from the builder. +You don't have to provide a name for the table, but it is recommended to do so for clarity +and for ease of debugging. If you don't provide value column names, it will default to +``"value"``. + .. code-block:: python # value_columns implicitly set to remaining columns - > bmi = sim.builder.lookup.build_table(data, key_columns=['sex'], parameter_columns=['age']) + > bmi = sim.builder.lookup.build_table(data, name="bmi") > population = sim.get_population() > bmi(population.index).head() # returns BMI values for the population @@ -356,7 +297,7 @@ can construct a lookup table in the following way, using the interface from the Constructing a lookup table currently requires your data meet specific conditions. These are a consequence of the method the lookup table uses to arrive at the correct data. Specifically, your parameter columns must - represent bins and they must overlap. + represent bins and they must not overlap or have gaps. Estimating Unknown Values ------------------------- diff --git a/docs/source/concepts/population.rst b/docs/source/concepts/population.rst index d73bbff27..c6706846f 100644 --- a/docs/source/concepts/population.rst +++ b/docs/source/concepts/population.rst @@ -10,53 +10,180 @@ Population Management :backlinks: none Since ``Vivarium`` is an agent-based simulation framework, managing a group of -:term:`simulants ` and their attributes is a critical task. +:term:`simulants ` and their :term:`attributes ` is a critical task. Fundamentally, to run a simulation we need to be able to create new simulates, update their state attributes, and facilitate access to their state so that :term:`components ` in the simulation can do interesting things based on it. The tooling to support working with our simulant population is called the population management system. -The State Table ---------------- +The Population State Table +-------------------------- The core representation of simulants and their state information in ``Vivarium`` -is a :class:`pandas.DataFrame` known as the state table. Under this -representation rows represent simulants while columns correspond to state -attributes like age, sex or systolic blood pressure. These columns represent one -of several important resources within ``Vivarium`` that other components can -draw on. Each of the actions we need to be able to take correspond to a -manipulation of this state table. The addition of new simulants is the creation -of rows, the creation of new state attributes is the creation of columns, and -the reading and updating of state is reading and updating the dataframe itself. +is a dynamically-generated :class:`pandas.DataFrame` known as the population state +table (or just "state table"). Under this representation, rows represent simulants +while columns correspond to attributes like age, sex or systolic blood pressure. +These columns represent one of several important resources within ``Vivarium`` that +other components can draw on. Each of the actions we need to be able to take correspond +to a manipulation of this state table. The addition of new simulants is the creation +of rows, the creation of new attributes is the creation of columns, and the reading +and updating of state is reading and updating the dataframe itself. <> +Attributes +---------- + +Attributes are the fundamental characteristics of a population and are represented +by columns in the population state table. They are a particular type of :term:`values ` +that are produced by on-demand by :term:`attribute pipelines `. +When a component requires the state table (or some subset of it), each attribute +requested is calculated via its corresponding attribute pipeline and returned in +tabular form. For example, when a component requests the entire population's age, +the "age" attribute pipeline calculates the age of all simulants and returns a +``pandas.Series`` of age values. + +.. note:: + The population system is distinct from the :ref:`values system documentation ` + although they are intimately related. While the values system is responsible for + populating the columns of the state table with attributes, the population system + is responsible for managing and providing access to said state table. + Population Views ---------------- -The population manager holds the state table directly and tightly controls read -and write access to it through a structure it provides known as a population -view. A population view itself represents a subset of columns and rows from the -state table. Through a view, components can read, update, or, under the right -circumstances, create new state in the state table. - -Views are created on-demand for components in a simulation by specifying a set -of columns and an optional query string to the population manager interface. The -columns dictate the subset of the state table that is viewable and modifiable -while the query string is a filter on the simulants returned. The view itself is -callable and accepts an index, which is the simulants to be viewed. It also -provides an update method that accepts a dataframe and will replace values in -the state table according to column and index. Only the columns that the view -was created with can be updated in this way. The only exception is at simulant -initialization time, when initial state must be created. - -Population views can themselves create subviews through the subview method. This -generates a new population view that is constrained by it's parents columns and -query string in addition to whatever arguments it is passed, with the -requirement that it's columns must be a subset of its parent view's. - -<> +As mentioned above, columns in the state table are dynamically generated via attribute +pipelines as needed. The population manager holds this logic and tightly controls +read and write access to it through a structure it provides known as a "population +view". A population view itself provides access to a subset of columns and rows +from the state table as well as any :term:`private columns ` created +by the component the view is attached to. Through a view, components can read, update, +or, under the right circumstances, create new state in the state table. + +Views are created for components in a simulation by specifying the component +needing it and an optional query string to the population manager interface. All +attributes are then viewable and the query string filters the simulants returned. +And as noted above, they also have read and write access to all private columns +created by the component they are attached to. This is how one might update the +source data for attributes, e.g. updating all simulants' ages on every time step. + +There are several methods on a population view that facilitate working with the +state table, including ones to get the population index, attributes, or private +columns. There is also an :meth:`~vivarium.framework.population.population_view.PopulationView.update` +method that accepts a dataframe and replaces private column data according to column +and index. This method is also used at simulant initialization time to create initial +state. + +Filtering Simulants ++++++++++++++++++++ + +There are two types of filtering that can be applied when using a population view +to get attributes or private columns. + +First, a ``query`` argument can be passed in to any of the population view's +:meth:`~vivarium.framework.population.population_view.PopulationView.get_attributes`, +:meth:`~vivarium.framework.population.population_view.PopulationView.get_attribute_frame`, +:meth:`~vivarium.framework.population.population_view.PopulationView.get_private_columns`, or +:meth:`~vivarium.framework.population.population_view.PopulationView.get_filtered_index` +to filter the simulants returned for that specific call. + +Second, if any components have registered an untracking query, untracked simulants +will be automatically filtered out. There is an optional ``include_untracked`` argument +that defaults to False that can be used to bypass the untracked filtering if desired. + +.. note:: + + **Combining Queries** + All types of queries are combined using the logical AND operator. Be sure to + set up your query strings accordingly. + +Untracking Simulants +++++++++++++++++++++ + +As mentioned above, there is a ``Vivarium`` concept of untracking simulants. Untracking +a simulant allows for automatic filtering of those simulants from population views +so that components can ignore them. This is useful to reduce computational overhead +when simulants are no longer relevant to the simulation, e.g. deceased individuals +or those who have aged beyond the scope of interest. A component can register a +tracked query via :meth:`vivarium.framework.population.interface.PopulationInterface.register_tracked_query`. + +.. note:: + + **Tracked Queries and Including Untracked Simulants** + When a component wants to register a query to be used for filtering out untracked + simulants, it registers the *tracked* query, i.e. the query that defines which + simulants should be *kept*. This can perhaps be a bit confusing since we then + decide to include or exclude *untracked* simulants when using population views. + Despite this potential source of confusion, we feel it's more intuitive to think + about the query in terms of who to keep and then the population view call in + terms of who to exclude. + + For example, if a component wants to untrack simulants whohave died, it would + register ``is_alive == True`` as a tracked query which tells ``Vivarium`` to + **keep** simulants who are alive (and, conversely, filter out those who are not). + Then, when using a population view to data, we can decide whether or not to + include untracked simulants or not (i.e. deceased ones). + +Private Columns +--------------- + +We have mentioned private columns a few times now, but what exactly are they and +how do they differ from attributes (which can be thought of as "public" columns +in the state table)? To start, keep in mind that attributes are produced by attribute +pipelines and *all* pipelines - attribute or otherwise - require a source of data +to operate on. One of the things that an attribute pipeline's source can be is +a column of data. All such attribute pipeline source columns are maintained in a +:class:`pandas.DataFrame` attached to the population manager, *but are only accessible +by population views attached to the component that created the source data in the +first place*. These columns are thus referred to as private columns. + +Creating and Updating Private Columns ++++++++++++++++++++++++++++++++++++++ + +To create a private column to be used as a source for an attribute pipeline, a component +must register initializer methods during its setup. Any columns that are created +and passed to the population view's ``update`` method within these methods will +be automatically registered as private columns for that component. The corresponding +attribute pipelines will be registered automatically as well. + +To update private column data over the course of a simulation, a component can use +the same population view ``update`` method as needed. + +.. note:: + + **Private columns vs attributes** + The distinction between private columns and attributes can be confusing. It's + important to remember that attributes are dynamically calculated as needed + (via attribute pipelines) and are readable by all components (via population + views). Private columns, on the other hand, are static data stored in the population + manager that are only readable and writable by the component that created them + and serve as the source for their corresponding attributes. + + Private column data can be updated as needed by the owning component. These + updates are then reflected in the attributes calculated from them the next time + they are requested. For example, a component that creates an "age" private column + (and thus and "age" attribute) instantiates the starting ages for all simulants + at the start of the simulation. At each time step, the component can then update + the private column by incrementing all ages by the duration of the time step. + The next time any component then requests the "age" attribute, the updated ages + will be returned since the source data was update. + +Creating Attributes +------------------- + +There are two ways to create attributes. The first, as described above, is to have +a component register an initializer method during its setup phase which creates +a private column. This private column will act as the source of data for its corresponding +attribute pipeline which is automatically registered as well. For example, if a +component creates an "age" private column, and "age" attribute pipeline will be +automatically registered and so the "age" attribute will be available for use by +all components. + +Not all attributes use a private column as their source, however. A component can +also register an attribute pipeline explicitly during its setup phase by calling +the values manager interface's :meth:`~vivarium.framework.values.interface.ValuesInterface.register_attribute_producer` +or :meth:`~vivarium.framework.values.interface.ValuesInterface.register_rate_producer` methods. Creating Simulants ------------------ @@ -80,7 +207,8 @@ initialization state during the setup phase, and the main event loop. The simulant creator function first adds rows to the state table. It then loops through a set of functions that have been registered to it as population -initializers via `initializes_simulants`, passing in the index of the newly -created simulants. These functions generally proceed by using population views -to dictate the state of the newly created simulants they are responsible for. -It is the only time creating columns in the state table is acceptable. +initializers via :meth:`~vivarium.framework.population.interface.PopulationInterface.register_initializer`, +passing in the index of the newly created simulants. These functions generally proceed +by using population views to dictate the state of the newly created simulants they +are responsible for. It is the only time creating columns in the state table is +acceptable. diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst index 03f9d5d3a..271824af7 100644 --- a/docs/source/concepts/results.rst +++ b/docs/source/concepts/results.rst @@ -83,15 +83,11 @@ to the existing number of people who have died from previous time steps. } } - @property - def columns_required(self) -> list[str] | None: - return ["age", "alive"] - def register_observations(self, builder: Builder) -> None: builder.results.register_adding_observation( name="total_population_dead", - requires_columns=["alive"], - pop_filter='alive == "dead"', + requires_attributes=["is_alive"], + pop_filter='is_alive == True', ) And here is an example of how you might create an observer that records new @@ -123,7 +119,7 @@ as well as adds a new one ("birth_date"). f"and previous_pregnancy == 'pregnant' " f"and pregnancy == 'parturition'" ), - requires_columns=self.COLUMNS, + requires_attributes=self.COLUMNS, results_formatter=self.format, ) @@ -180,10 +176,10 @@ Here is an example of how you might register a "current_year" and "sex" as strat [str(year) for year in range(self.start_year, self.end_year + 1)], mapper=self.map_year, is_vectorized=True, - requires_columns=["current_time"], + requires_attributes=["current_time"], ) builder.results.register_stratification( - "sex", ["Female", "Male"], requires_columns=["sex"] + "sex", ["Female", "Male"], requires_attributes=["sex"] ) ########### @@ -215,7 +211,7 @@ Here is an example of how you might register a "current_year" and "sex" as strat method. Just because you've *registered* a stratification doesn't mean that the results will -actually *use* it. In order to use the stratification, you must add it to the +actually *use* it. In order to use the stratification, you can add it to the :ref:`model specification ` configuration block using the "stratification" key. You can provide "default" stratifications which will be used by all observations as well as observation-specific "include" and @@ -239,6 +235,15 @@ observations and then customize "deaths" observations to also include .. note:: All stratifications must be included as a list, even if there is only one. +Another way to include and exclude stratifications from different observations +(besides via the model specification as shown above) is to provide them via the +``additional_stratifications`` and ``excluded_stratifications`` arguments when +registering an observation. Note that not all observation registration methods +support these arguments, e.g. registering an unstratified observation by definition +does not support stratifying results and so the +:meth:`~ ` +method does not support these arguments. + Excluding Categories from Results +++++++++++++++++++++++++++++++++ @@ -264,6 +269,9 @@ For example, to exclude "stillbirth" as a pregnancy outcome during results proce excluded_categories: pregnancy_outcome: ['stillbirth'] +Alternatively, categories can be excluded from results stratifications by providing +an "excluded_categories" argument when registering a stratification. + Observers --------- @@ -315,9 +323,11 @@ abstract base class, which contains the common attributes between observation ty * - | :attr:`name ` - | Name of the observation. It will also be the name of the output results file | for this particular observation. - * - | :attr:`pop_filter ` - - | A Pandas query filter string to filter the population down to the simulants - | who should be considered for the observation. + * - | :attr:`population_filter ` + - | A named tuple of population filtering details. The first item is a Pandas + | query string to filter the population down to the simulants who should be + | considered for the observation. The second item is a boolean indicating whether + | to include untracked simulants from the observation. * - | :attr:`when ` - | Name of the lifecycle phase the observation should happen. Valid values are: | "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". @@ -372,18 +382,16 @@ results of an observation: - Description * - | :attr:`name ` - | Name of the stratification. - * - | :attr:`sources ` - - | A list of the columns needed as input for the `mapper`. - * - | :attr:`requires_values ` - - | A list of value pipelines needed as input for the `mapper`. + * - | :attr:`requires_attributes ` + - | The population attributes needed as input for the `mapper`. * - | :attr:`categories ` - | Exhaustive list of all possible stratification values. * - | :attr:`excluded_categories ` - | List of possible stratification values to exclude from results processing. | If None (the default), will use exclusions as defined in the configuration. * - | :attr:`mapper ` - - | A callable that maps the columns and value pipelines specified by the - | `requires_columns` and `requires_values` arguments to the stratification + - | A callable that maps the population attributes specified by the + | `requires_attributes` argument to the stratification | categories. It can either map the entire population or an individual | simulant. A simulation will fail if the `mapper` ever produces an invalid | value. @@ -394,7 +402,7 @@ results of an observation: Each **Stratification** also contains the :meth:`stratify ` method which is called at each :ref:`event ` and :ref:`time step ` -to use the **mapper** to map values in the **sources** columns to **categories** +to use the **mapper** to map values in the **requires_attributes** columns to **categories** (excluding any categories specified in **excluded_categories**). .. note:: diff --git a/docs/source/concepts/time.rst b/docs/source/concepts/time.rst index 077e47440..ff731bfca 100644 --- a/docs/source/concepts/time.rst +++ b/docs/source/concepts/time.rst @@ -11,44 +11,66 @@ Thinking about Time in the Simulation The Simulation Clock -------------------- -The :class:`SimulationClock ` plugin manages the progression of time throughout the simulation. -Fundamentally, that means it keeps track of the current time (beginning at the *start time*), provides -a mechanism to advance the simulation time by some duration (the *step size*), and determines when -the simulation is complete via a configured *end time*. The simplest -implementation of a Clock is the :class:`SimpleClock ` object, which is little more -than an integer counter that is incremented by a fixed step size until it reaches the -end time. Modeling real-world events is often dependent on data that are tied to particular years or rates, where the -desired step size is not necessarily known in advance. Therefore, it is common to use the :class:`DateTimeClock `, -which uses datetime-like objects (specifically :class:`~pandas.Timestamp` and :class:`~pandas.Timedelta`) as the temporal units. The DateTimeClock -can more easily facilitate the conversion rates to particular increments of time. +The :class:`~vivarium.framework.time.manager.SimulationClock` plugin manages +the progression of time throughout the simulation. Fundamentally, that means it +keeps track of the current time (beginning at the *start time*), provides a mechanism +to advance the simulation time by some duration (the *step size*), and determines +when the simulation is complete via a configured *end time*. The simplest implementation +of a Clock is the :class:`~vivarium.framework.time.manager.SimpleClock` object, +which is little more than an integer counter that is incremented by a fixed step +size until it reaches the end time. Modeling real-world events is often dependent +on data that are tied to particular years or rates, where the desired step size +is not necessarily known in advance. Therefore, it is common to use the +:class:`~vivarium.framework.time.manager.DateTimeClock`, which uses datetime-like +objects (specifically :class:`~pandas.Timestamp` and :class:`~pandas.Timedelta`) +as the temporal units. The DateTimeClock can more easily facilitate the conversion +rates to particular increments of time. Event Times ----------- -Discrete time simulations assume that all changes to a simulant's state vector happen at the -end of the time step, that is, the current clock time *plus* the step size. :mod:`vivarium` explicates this important distinction -and labels this quantity the *event time*. `Events ` that correspond to (potential) state changes are mediated through the -:class:`Event Manager `, which propagates events to :ref:`components ` subscribed to them during particuar phases of the simulation lifecycle. -The Event Manager uses the event time when calculating time-related outcomes, for example, age- or year-dependent rates of morbidity and mortality. +Discrete time simulations assume that all changes to a simulant's state vector happen +at the end of the time step, that is, the current clock time *plus* the step size. +:mod:`vivarium` explicates this important distinction and labels this quantity the +*event time*. `Events ` that correspond to (potential) state changes +are mediated through the :class:`Event Manager `, +which propagates events to :ref:`components ` subscribed to +them during particuar phases of the simulation lifecycle. The Event Manager uses +the event time when calculating time-related outcomes, for example, age- or year-dependent +rates of morbidity and mortality. Time Interface -------------- -The Time plugin provides, via the :ref:`Builder `, an :class:`interface ` to access several clock methods that might be needed -by other managers or components. In particular, components can access the current time and step size (and, implicitly, the event time). +The Time plugin provides, via the :ref:`Builder `, an +:class:`interface ` to access several clock +methods that might be needed by other managers or components. In particular, components +can access the current time and step size (and, implicitly, the event time). Individual Clocks ----------------- -:mod:`vivarium` also allows one to update simulants asynchronously with different frequencies depending on their state information. -For example, a component that simulates the progression of a disease might need to update the state of each -simulant more frequently when infected than when in remission. The basic method is to give each simulant its own distinct clock time and step size instead of one global clock. -A simulant's *next event time*, that is, the sum of its clock time and step size, is when it is scheduled to be updated. -Currently, the :mod:`vivarium` still incorporates a global clock, which determines the start, end, and minimal step size of the simulation. The minimum step -size is the smallest value that a simulant's step size can take, and therefore determines the minimum duration by which the simulation can advance in a single iteration. -However, global step size changes from iteration to iteration and can be larger than the minimum step size. In each iteration of the simulation, the global clock is advanced to the earliest time in which some simulant is scheduled to be updated. -Simulants that are not scheduled to be updated in a particular iteration are simply excluded from the relevant events as propagated by the Event Manager. -In effect, if there are no simulants to be updated in a duration comprising several minimum timesteps, those "minimum timesteps" are skipped. +:mod:`vivarium` also allows one to update simulants asynchronously with different +frequencies depending on their state information. For example, a component that +simulates the progression of a disease might need to update the state of each simulant +more frequently when infected than when in remission. The basic method is to give +each simulant its own distinct clock time and step size instead of one global clock. +A simulant's *next event time*, that is, the sum of its clock time and step size, +is when it is scheduled to be updated. Currently, :mod:`vivarium` still incorporates +a global clock, which determines the start, end, and minimal step size of the simulation. +The minimum step size is the smallest value that a simulant's step size can take, +and therefore determines the minimum duration by which the simulation can advance +in a single iteration. However, global step size changes from iteration to iteration +and can be larger than the minimum step size. In each iteration of the simulation, +the global clock is advanced to the earliest time in which some simulant is scheduled +to be updated. Simulants that are not scheduled to be updated in a particular iteration +are simply excluded from the relevant events as propagated by the Event Manager. +In effect, if there are no simulants to be updated in a duration comprising several +minimum timesteps, those "minimum timesteps" are skipped. -The Time Interface provides a method to modify a simulant's step size based on some criteria, :func:`builder.time.register_step_size_modifier() `. -If there are multiple modifiers to the same simulant simultaneously, the time manager chooses the smallest one (bounded by the global minimum step size). -If a simulant has no step modifier, it is given a default value, either the global minimum or another optionally configurable value, the *standard* step size, -in the case that we want the "background" update frequency to be larger than the minimium size. -If *no* simulants have a step modifier, then the simulation behaves as if there were no individual clocks, reverting to the global clock. \ No newline at end of file +The Time Interface provides a method to modify a simulant's step size based on some +criteria, :func:`builder.time.register_step_size_modifier() `. +If there are multiple modifiers to the same simulant simultaneously, the time manager +chooses the smallest one (bounded by the global minimum step size). If a simulant +has no step modifier, it is given a default value, either the global minimum or +another optionally configurable value, the *standard* step size, in the case that +we want the "background" update frequency to be larger than the minimium size. If +*no* simulants have a step modifier, then the simulation behaves as if there were +no individual clocks, reverting to the global clock. \ No newline at end of file diff --git a/docs/source/concepts/values.rst b/docs/source/concepts/values.rst index 9def2fa36..85e6f4296 100644 --- a/docs/source/concepts/values.rst +++ b/docs/source/concepts/values.rst @@ -4,18 +4,29 @@ The Values System ================= -The values system provides an interface to an alternative representation of -:term:`state ` in the simulation: pipelines. -:class:`Pipelines ` are dynamically -calculated values that can be constructed across multiple -:ref:`components `. This ability for multiple -components to together compose a single value is the biggest advantage -pipelines provide over the standard state representation of the population -state table. +The values system provides an interface for working with dynamically calculated +:term:`values ` that can be constructed across multiple :ref:`components `. +A "value" is extremely general in this context and can be anything at all. However, +by far the most common types of values used in a ``Vivarium`` simulation are +:term:`attributes `. Attributes are simulant-specific characteristics +and are stored in the :term:`population state table `. +Examples of attributes include things such as age, sex, blood pressure, and any +other type of information of interest that describes each simulant. .. note:: - **You should use the values system when you have a value that must be - composed across multiple components.** + The values system is distinct from the :ref:`population management system ` + although they are intimately related. While the population system is responsible + for managing and providing access to the state table, the values system is responsible + for populating the columns of said state table with attributes. + +As mentioned above, values and attributes are dynamically calculated as needed throughout +the simulation. The most prominent example is any time during the simulation a component +requests some information from the state table, the desired attributes are calculated. +The producers of these values - the things that actually do the calcuations - are +:class:`pipelines ` for generic values +and :class:`attribute pipelines ` +for attributes. + .. contents:: :depth: 2 @@ -30,15 +41,15 @@ We can visualize a pipeline as the following: .. image:: ../images/pipeline.jpg At the left, we have the original **source** of the pipeline. This is a -callable registered by a single component that returns a dataframe. To this +callable registered by a single component that can return anything. To this source, additional components can register **modifiers**. These modifiers are -also callables that return dataframes. +also callables that can return anything. The source and modifiers are composed into a single value by the **combiner** with which the pipeline is registered. The combiner is also a callable that -returns a dataframe - it is the function that dictates how the dataframe -produced by the source and the dataframes produced by the modifiers will be -combined into a single dataframe. The combiner also determines the required +can return anything - it is the function that dictates how the value +produced by the source and the values produced by the modifiers will be +combined into a single value. The combiner also determines the required signatures of modifiers in relation to the source. The values system provides three options for combiners, detailed in the following table. @@ -61,7 +72,7 @@ three options for combiners, detailed in the following table. - | Modifiers should have the same signature as the source. Pipelines may also optionally be registered with a **postprocessor**. This is -a callable that returns a dataframe that will be called on the output of the +a callable that can return anything that will be called on the output of the combiner to do some postprocessing. .. list-table:: **Pipeline Post-processors** @@ -80,17 +91,31 @@ combiner to do some postprocessing. | union of the underlying sample space -The values system also inverts the direction of control from information that -is stored in the state table. Components that update columns in the state -table can be seen as "pushing" that information out. Pipelines, however, are -"pulled" on by components, often components that did not play any part in the -construction of the pipeline value. +What are attribute pipelines? +----------------------------- + +An attribute pipeline is a specific type of pipeline whose calculated value is an +attribute, i.e. a simulant-specific characterstic stored in the population state +table such as age, sex, or body-mass index. Attribute pipelines differ from generic +pipelines in that they (and their sources and postprocessors) must accept an index +representing the population of interest and return data in tabular form (i.e. a +``pandas.DataFrame`` or ``pandas.Series``) with the same index. + +By far most pipelines used in ``Vivarium`` simulations are attribute pipelines. + +.. note:: + + Note that the values system inverts the direction of control from information that + is stored as private columns in the population manager. Components that update private + columns via a population view can be seen as "pushing" that information out. Pipelines, + however, are "pulled" on by components, often components that did not play any part + in the construction of the pipeline value. How to use pipelines -------------------- -The values system provides four interface methods, available off the +The values system provides a handful of interface methods, available off the :ref:`builder ` during setup. .. list-table:: **Values System Interface Methods** @@ -99,27 +124,42 @@ The values system provides four interface methods, available off the * - Method - Description - * - | :meth:`register_value_producer ` - - | Register a new pipeline with the values system. Provide a name for the + * - | :meth:`register_value_producer ` + - | Registers a new pipeline with the values system. Provide a name for the | pipeline and a source. Optionally provide a combiner (defaults to - | the replace combiner) and a postprocessor. Provide dependencies (see note). - * - | :meth:`register_rate_producer ` - - | A special case of :meth:`register_value_producer ` + | the replace combiner) and a postprocessor. Provide required resources (see note). + * - | :meth:`register_value_producer ` + - | Registers a new attribute pipeline with the values system. Provide a name + | for the attribute pipeline and a source. Optionally provide a combiner + | (defaults to the replace combiner) and a postprocessor. Provide required + | resources (see note). + * - | :meth:`register_rate_producer ` + - | A special case of :meth:`register_attribute_producer ` | for rates specifically. | Provide a name for the pipeline and a source and the values system will - | automatically use the rescale postprocessor. Provide dependencies (see note). - * - | :meth:`register_value_modifier ` - - | Register a modifier to a pipeline. Provide a name for the pipeline to - | modify and a modifier callable. Provide dependencies (see note). - * - | :meth:`get_value ` - - | Retrieve a reference to the pipeline with the given name. + | automatically use the rescale postprocessor. Provide required resources (see note). + * - | :meth:`register_value_modifier ` + - | Registers a modifier to a pipeline. Provide a name for the pipeline to + | modify and a modifier callable. Provide required resources (see note). + * - | :meth:`register_value_modifier ` + - | Registers a modifier to an attribute pipeline. Provide a name for the attribute + | pipeline to modify and a modifier callable or name of an attribute pipeline + | that does the modifying. Provide required resources (see note). + * - | :meth:`get_value ` + - | Retrieves the pipeline with the given name. + * - | :meth:`get_value ` + - | Retrieves a callable that in turn gets a dictionary of all attribute pipelines + | registered with the values system. This method is intended to be used only + | by backend managers as needed. Components should not need direct access + | to attribute pipelines as attributes are obtained via population views. .. note:: - The registration methods for the values system require dependencies be + The registration methods for the values system require that any required resources be specified in order for the :ref:`resource manager ` to - properly order and manage dependencies. These dependencies are the state - table columns, other pipelines, and randomness streams that the source or - modifier callable uses in producing the dataframe it returns. + properly order and manage dependencies. These required resources must include + any private columns, other pipelines or attribute pipelines, :ref:`randomness streams `, + and :ref:`lookup tables ` that the source or modifier callables + use in producing the value it returns. For a view of the values system in action, see the diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 55cb358e9..21fb3a8fc 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -7,15 +7,21 @@ Glossary .. glossary:: Attribute - A variable associated with each :term:`simulant `. For - example, each simulant may have an attribute to describe their age or - position. + A specific type of :term:`value ` that is associated with each + :term:`simulant `. For example, each simulant may have an attribute + to describe their age or position. These are computed by :term:`attribute pipelines ` + and stored in the :term:`population state table `. + + Attribute Pipeline + A specific type of :term:`pipeline ` whose callable must accept + a :term:`simulant ` index and return tabular data of + :term:`attributes ` for those simulants. Component Any self-contained piece of code that can be plugged into the simulation to add some functionality. In ``vivarium`` we typically think of components as encapsulating and managing some behavior or - attributes of the :term:`simulants `. + :term:`attributes ` of the :term:`simulants `. Configuration The configuration is a set of parameters a user can modify to @@ -23,8 +29,7 @@ Glossary themselves may provide defaults for several configuration parameters. Metrics - A special :term:`pipeline ` in the simulation that produces - the simulation outputs. + Simulation outputs, often used interchangeably with "results". Model Specification A complete description of a ``vivarium`` model. This description @@ -36,22 +41,36 @@ Glossary Pipeline A ``vivarium`` value pipeline. A pipeline is a framework tool that allows users to dynamically construct and share data across several - :term:`components `. + :term:`components `. See also: :term:`Attribute Pipeline`. Plugin A plugin is a python class intended to add additional functionality to the core framework. Unlike a normal :term:`component ` - which adds new behaviors and attributes to + which adds new behaviors and :term:`attributes ` to :term:`simulants `, a plugin adds new services to the framework itself. Examples might include a new simulation clock, a results handling service, or a logging service. - Simulant - An individual or agent. One member of the population being simulated. + Population State Table + The core representation of the population in the simulation. Often referred + to as simply the "state table", it consists of a row for each :term:`simulant ` + and a column for each :term:`attribute `. - State Table - The core representation of the population in the simulation. The state - table consists of a row for each :term:`simulant ` and a - column for each :term:`attribute ` the simulant possesses. + Private Column + The source of a corresponding :term:`attribute pipeline `. + Only the :term:`component ` that creates a private column has + read or write access to it. Note that not all attribute pipelines require + private columns as a source, but any private columns created are automatically + paired with an attribute pipeline of the same name. + Public Column + A column in the :term:`population state table ` that + is generated by an :term:`attribute pipeline ` and is + read-accessible by all :term:`components `. + + Simulant + An individual or agent. One member of the population being simulated. + Value + A dynamic variable in the simulation that is computed by a :term:`pipeline ` + and can be shared across multiple :term:`components `. diff --git a/docs/source/tutorials/boids.rst b/docs/source/tutorials/boids.rst index 10f04cf4d..c686a6e2e 100644 --- a/docs/source/tutorials/boids.rst +++ b/docs/source/tutorials/boids.rst @@ -47,8 +47,8 @@ Imports +++++++ .. literalinclude:: ../../../src/vivarium/examples/boids/population.py - :lines: 1-6 - :linenos: + :start-after: # docs-start: imports + :end-before: # docs-end: imports `NumPy `_ is a library for doing high performance numerical computing in Python. `pandas `_ is a set @@ -78,40 +78,13 @@ configuration information. Components typically expose the values they use in the ``configuration_defaults`` class attribute. .. literalinclude:: ../../../src/vivarium/examples/boids/population.py - :lines: 13-17 + :start-after: # docs-start: configuration_defaults + :end-before: # docs-end: configuration_defaults :dedent: 4 - :linenos: - :lineno-start: 13 We'll talk more about configuration information later. For now observe that we're exposing a set of possible colors for our boids. -The ``columns_created`` property -++++++++++++++++++++++++++++++++ - -.. literalinclude:: ../../../src/vivarium/examples/boids/population.py - :lines: 18 - :dedent: 4 - :linenos: - :lineno-start: 18 - -The ``columns_created`` property tells Vivarium what columns (or "attributes") -the component will add to the population table. -See the next section for where we actually create these columns. - -.. note:: - - **The Population Table** - - When we talk about columns in the context of Vivarium, we are typically - talking about the simulant :term:`attributes `. Vivarium - represents the population of simulants as a single `pandas DataFrame`__. - We think of each simulant as a row in this table and each column as an - attribute of the simulants. - - __ https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html - - The ``setup`` method ++++++++++++++++++++ @@ -123,13 +96,13 @@ the setup method on components and providing the builder to them. We'll explore these tools that the builder provides in detail as we go. .. literalinclude:: ../../../src/vivarium/examples/boids/population.py - :lines: 24-25 + :start-after: # docs-start: setup + :end-before: # docs-end: setup :dedent: 4 - :linenos: - :lineno-start: 24 -Our setup method is pretty simple: we just save the configured colors for later use. -The component is accessing the subsection of the configuration that it cares about. +Our setup method is pretty simple: we just save the configured colors for later use, +get a randomness stream, and register some private columns. Regarding the colors, +note that the component is accessing the subsection of the configuration that it cares about. The full simulation configuration is available from the builder as ``builder.configuration``. You can treat the configuration object just like a nested python @@ -138,47 +111,57 @@ that's been extended to support dot-style attribute access. Our access here mirrors what's in the ``configuration_defaults`` at the top of the class definition. -The ``on_initialize_simulants`` method -++++++++++++++++++++++++++++++++++++++ +Note that the setup method is registering a method called ``initialize_population`` +as an initializer that will create two private columns (``color`` and ``entrance_time``) +and requires the randomness stream to do so. This tells Vivarium what columns +(or "attributes") the component will add to the population table and how. See the +next section for where we actually create these columns. -Finally we look at the ``on_initialize_simulants`` method, -which is automatically called by Vivarium when new simulants are -being initialized. -This is where we should initialize values in the ``columns_created`` -by this component. +.. note:: -.. literalinclude:: ../../../src/vivarium/examples/boids/population.py - :lines: 31-39 - :dedent: 4 - :linenos: - :lineno-start: 31 + **The Population State Table** + + When we talk about columns in the context of Vivarium, we are typically + talking about simulant :term:`attributes `. All attributes for all + simulants can be represented by a single `pandas DataFrame`__ where each simulant + is a row and each attribute is a column. This dataframe is referred to as the + :term:`population state table` (or simply "state table"). + + __ https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html -We see that like the ``setup`` method, ``on_initialize_simulants`` takes in a -special argument that we don't provide. This argument, ``pop_data`` is an -instance of :class:`~vivarium.framework.population.manager.SimulantData` containing a -handful of information useful when initializing simulants. +Initializers +++++++++++++ -The only two bits of information we need for now are the -``pop_data.index``, which supplies the index of the simulants to be -initialized, and the ``pop_data.creation_time`` which gives us a -representation (typically an ``int`` or :class:`pandas.Timestamp`) of the -simulation time when the simulant was generated. +Finally we look at the ``initialize_population`` method which was registered +as the one and only initializer method. Any registered initializers will be automatically +called by Vivarium when new simulants are being added to the simulation. + +We see that, like the ``setup`` method, initializer methods (``initialize_population`` +in this case) take in a special argument that we don't provide. This argument, +``pop_data``, is an instance of :class:`~vivarium.framework.population.manager.SimulantData` +containing a handful of information useful when initializing simulants. + +The only bits of information we need for now are the randomness stream we registered +as a required resource, the ``pop_data.index`` which supplies the index of the +simulants to be initialized, and the ``pop_data.creation_time`` which gives us a +representation of the simulation time when the simulant was generated (typically +an ``int`` or :class:`pandas.Timestamp`). .. note:: **The Population Index** - The population table we described before has an index that uniquely + The population state table we described before has an index that uniquely identifies each simulant. This index is used in several places in the simulation to look up information, calculate simulant-specific values, and update information about the simulants' state. -Using the population index, we generate a ``pandas.DataFrame`` on lines 32-38 -and fill it with the initial values of 'entrance_time' and 'color' for each -new simulant. Right now, this is just a table with data hanging out in our -simulation. To actually do something, we have to tell Vivarium's population -management system to update the underlying population table, which we do -on line 39. +Using the population index, we generate a ``pandas.DataFrame`` and fill it with +the initial values of 'entrance_time' and 'color' for each new simulant. However, +this new dataframe is just a table hanging out in memory - Vivarium knows nothing +about it and it cannot readily be used througout the simulation. To actually be +able to use the data as attributes, we need to tell Vivarium's population management +system to update the state table with the values. Putting it together +++++++++++++++++++ @@ -199,7 +182,7 @@ we can set up our simulation with the following code: ) # Peek at the population table - print(sim.get_population().head()) + print(sim.get_population(["entrance_time", "color"]).head()) .. testcode:: @@ -214,14 +197,16 @@ we can set up our simulation with the following code: logging_verbosity=0, ) -:: + print(sim.get_population(["entrance_time", "color"]).head()) + +.. testoutput:: - tracked entrance_time color - 0 True 2005-07-01 blue - 1 True 2005-07-01 red - 2 True 2005-07-01 red - 3 True 2005-07-01 red - 4 True 2005-07-01 red + entrance_time color + 0 2005-07-01 red + 1 2005-07-01 red + 2 2005-07-01 red + 3 2005-07-01 red + 4 2005-07-01 blue Movement @@ -234,50 +219,58 @@ It tracks the position and velocity of each boid, and creates an .. literalinclude:: ../../../src/vivarium/examples/boids/movement.py :caption: **File**: :file:`~/code/vivarium_examples/boids/movement.py` + :linenos: You'll notice that some parts of this component look very similar to our population component. Indeed, we can split up the responsibilities of initializing simulants over many different components. In Vivarium we tend to think of components as being -responsible for individual behaviors or :term:`attributes `. This +responsible for individual behaviors or attributes. This makes it very easy to build very complex models while only having to think about local pieces of it. However, there are also a few new Vivarium features on display in this component. We'll step through these in more detail. -Value pipelines -+++++++++++++++ +Attribute pipelines ++++++++++++++++++++ -A :term:`value pipeline ` is like a column in the population table, in that it contains information -about our simulants (boids, in this case). -The key difference is that it is not *stateful* -- each time it is accessed, its values are re-initialized +Each :term:`attribute pipeline ` creates a different column in the population +state table that contains information about our simulants (boids, in this case). +Importantly, these attribute pipelines are is not *stateful* -- each time one is +accessed in order to generate the state table, its values are re-initialized from scratch, instead of "remembering" what they were on the previous timestep. This makes it appropriate for modeling acceleration, because we only want a boid to accelerate due to forces acting on it *now*. You can find an overview of the values system :ref:`here `. -The Builder class exposes an additional property for working with value pipelines: +.. note:: + + Attribute pipelines are a special type of the more generic :term:`value pipeline `. + While attribute pipelines dynamically calculate attributes of simulants, value + pipelines can be used to calculate *anything*. By far the most common type of + value pipeline used in Vivarium simulations are attribute pipelines and so we + will not discuss the more general concept further in this tutorial. + +The Builder class exposes an additional property for working with attribute pipelines: :meth:`vivarium.framework.engine.Builder.value`. -We call the :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_producer` -method to register a new pipeline. +We call the :meth:`vivarium.framework.values.interface.ValuesInterface.register_attribute_producer` +method to register a new attribute pipeline as the producer of some attribute. .. literalinclude:: ../../../src/vivarium/examples/boids/movement.py - :lines: 32-34 + :start-after: # docs-start: register_attribute_producer + :end-before: # docs-end: register_attribute_producer :dedent: 4 - :linenos: - :lineno-start: 32 -This call provides a ``source`` function for our pipeline, which initializes the values. +This call provides a ``source`` function for our pipeline which initializes the values. In this case, the default is zero acceleration: .. literalinclude:: ../../../src/vivarium/examples/boids/movement.py - :lines: 40-41 + :start-after: # docs-start: base_acceleration + :end-before: # docs-end: base_acceleration :dedent: 4 - :linenos: - :lineno-start: 40 -This may seem pointless, since acceleration will always be zero. -Value pipelines have another feature we will see later: other components can *modify* +This may seem pointless since acceleration will always be zero. +Pipelines have another feature we will see later: other components can *modify* their values. We'll create components later in this tutorial that modify this pipeline to exert forces on our boids. @@ -285,29 +278,62 @@ exert forces on our boids. The ``on_time_step`` method +++++++++++++++++++++++++++ -This is a lifecycle method, much like ``on_initialize_simulants``. +This is a lifecycle method, much like any registered initializer methods. However, this method will be called on each step forward in time, not only when new simulants are initialized. -It can use values from pipelines and update the population table. -In this case, we change boids' velocity according to their acceleration, +One very common thing components do on each time step is read and update the population +state table. In this case, we change boids' velocity according to their acceleration, limit their velocity to a maximum, and update their position according to their velocity. -To get the values of a pipeline such as ``acceleration`` inside on_time_step, -we simply call that pipeline as a function, using ``event.index``, -which is the set of simulants affected by the event (in this case, all of them). +To get population attributes such as ``acceleration`` inside on_time_step, +we leverage a :class:`~vivarium.framework.population.population_view.PopulationView` +which provides a handful of methods designed to get when you need. In this case, +we call :meth:`~vivarium.framework.population.population_view.PopulationView.get_attribute_frame` +to get the acceleration attribute. We pass in the ``event.index`` which is the set +of simulants affected by the event (in this case, all of them). Note that there is +also available a :meth:`~vivarium.framework.population.population_view.PopulationView.get_attributes` +method which is similar to ``get_attribute_frame`` but can request multiple attributes +at once and does not necessarily return a dataframe. + +.. note:: + + **Population Views** + + A :class:`~vivarium.framework.population.population_view.PopulationView` is a + read/write interface to the population state table. It provides a number of + convenience methods for getting and setting attributes, private columns, + and other bits of information about the population. + +We also make a call to the population view's ``get_private_columns`` method to get +all the private columns created by this component (``x``, ``y``, ``vx``, and ``vy``). +A :term:`private column ` is one that acts as a *source* of an +attribute. + +.. note:: + + **Private Columns vs. Attributes** + + Knowing whether you need a private column or an attribute depends on context, + but when you need to update the state table as we're doing here, it's important + to understand that what you are really updating are the appropriate private columns + that act as the source of the attributes you care about. In this case, we want + to update the source data for position (``x`` and ``y``) and velocity (``vx`` + and ``vy``) attributes which are this component's private columns (i.e. this + component registered the initializer that created those columns). It just so + happens that in order to update these columns we *also* need the acceleration + attributes and so we retrieve that as well. .. literalinclude:: ../../../src/vivarium/examples/boids/movement.py - :lines: 61-85 + :start-after: # docs-start: on_time_step + :end-before: # docs-end: on_time_step :dedent: 4 - :linenos: - :lineno-start: 61 Putting it together +++++++++++++++++++ -Let's run the simulation with our new component and look again at the population table. +Let's run the simulation with our new component and look again at the state table. .. code-block:: python @@ -322,7 +348,7 @@ Let's run the simulation with our new component and look again at the population ) # Peek at the population table - print(sim.get_population().head()) + print(sim.get_population(["color", "x", "y", "vx", "vy"]).head()) .. testcode:: :hide: @@ -337,7 +363,10 @@ Let's run the simulation with our new component and look again at the population ) # Peek at the population table - print(sim.get_population().head()[["color", "x", "y", "vx", "vy"]]) + pop = sim.get_population(["color", "x", "y", "vx", "vy"]).head() + # flatten MultiIndex for display + pop.columns = pop.columns.get_level_values(0) + print(pop) .. testoutput:: @@ -357,8 +386,8 @@ but their velocity stay the same. sim.step() - # Peek at the population table - print(sim.get_population().head()) + # Peek at the population state table + print(sim.get_population(["color", "x", "y", "vx", "vy"]).head()) .. testcode:: :hide: @@ -373,8 +402,10 @@ but their velocity stay the same. ) sim.step() - # Peek at the population table - print(sim.get_population().head()[["color", "x", "y", "vx", "vy"]]) + pop = sim.get_population(["color", "x", "y", "vx", "vy"]).head() + # flatten MultiIndex for display + pop.columns = pop.columns.get_level_values(0) + print(pop) .. testoutput:: @@ -404,7 +435,8 @@ boids and maybe some arrows to indicated their velocity. .. literalinclude:: ../../../src/vivarium/examples/boids/visualization.py :caption: **File**: :file:`~/code/vivarium_examples/boids/visualization.py` - :lines: 1-17 + :start-after: # docs-start: plot_boids + :end-before: # docs-end: plot_boids We can then visualize our flock with @@ -450,19 +482,21 @@ __ https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.ht .. literalinclude:: ../../../src/vivarium/examples/boids/neighbors.py :caption: **File**: :file:`~/code/vivarium_examples/boids/neighbors.py` + :linenos: -This component creates a value pipeline called ``neighbors`` that other components -can use to access the neighbors of each boid. +This component creates an attribute pipeline called ``neighbors`` so that other +components can access the neighbors of each boid (remember that every attribute +pipeline corresponds to a column in the population state table). Note that the only thing it does in ``on_time_step`` is ``self.neighbors_calculated = False``. That's because we only want to calculate the neighbors once per time step. When the pipeline -is called, we can tell with ``self.neighbors_calculated`` whether we need to calculate them, +is called, we can tell with ``self.neighbors_calculated`` whether we need to calculate them or use our cached value in ``self._neighbors``. Swarming behavior ----------------- -Now we know which boids are each others' neighbors, but we're not doing anything +Now we know which boids are each others' neighbors but we're not doing anything with that information. We need to teach the boids to swarm! There are lots of potential swarming behaviors to play around with, all of which @@ -474,34 +508,29 @@ and we'll gloss over most of the calculations. We define a base class for all our forces, since they will have a lot in common. We won't get into the details of this class, but at a high level it uses the -neighbors pipeline to find all the pairs of boids that are neighbors, +``neighbors`` attribute to find all the pairs of boids that are neighbors, applies some force to (some of) those pairs, and limits that force to a maximum magnitude. .. literalinclude:: ../../../src/vivarium/examples/boids/forces.py :caption: **File**: :file:`~/code/vivarium_examples/boids/forces.py` - :lines: 1-113 - :linenos: + :start-after: # docs-start: force_base_class + :end-before: # docs-end: force_base_class -To access the value pipeline we created in the Neighbors component, we use -``builder.value.get_value`` in the setup method. Then, as we saw with the ``acceleration`` -pipeline, we simply call that pipeline as a function inside ``on_time_step`` to retrieve -its values for a specified index. -The major new Vivarium feature seen here is that of the **value modifier**, -which we register with :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_modifier`. -As the name suggests, this allows us to modify the values in a pipeline, -in this case adding the effect of a force to the values in the ``acceleration`` pipeline. -We register that the ``apply_force`` method will modify the acceleration values like so: +The major new Vivarium feature seen here is that of the **attribute modifier**, +which we register with :meth:`vivarium.framework.values.interface.ValuesInterface.register_attribute_modifier`. +As the name suggests, this allows us to modify attributes, +in this case adding the effect of a force to the ``acceleration`` attribute. +We register that the ``apply_force`` method as the modifier like so: .. literalinclude:: ../../../src/vivarium/examples/boids/forces.py :caption: **File**: :file:`~/code/vivarium_examples/boids/forces.py` - :lines: 35-38 + :start-after: # docs-start: register_acceleration_modifier + :end-before: # docs-end: register_acceleration_modifier :dedent: 4 - :linenos: - :lineno-start: 35 -Once we start adding components with these modifiers into our simulation, acceleration won't always be -zero anymore! +Once we start adding components with these modifiers into our simulation, acceleration +won't always be zero anymore! We then define our three forces using the ``Force`` base class. We won't step through what these mean in detail. @@ -512,9 +541,8 @@ parameter: the distance within which it should act. .. literalinclude:: ../../../src/vivarium/examples/boids/forces.py :caption: **File**: :file:`~/code/vivarium_examples/boids/forces.py` - :lines: 116-167 - :linenos: - :lineno-start: 116 + :start-after: # docs-start: concrete_force_classes + :end-before: # docs-end: concrete_force_classes For a quick test of our swarming behavior, let's add in these forces and check in on our boids after 100 steps: @@ -562,7 +590,8 @@ Add this method to ``visualization.py``: .. literalinclude:: ../../../src/vivarium/examples/boids/visualization.py :caption: **File**: :file:`~/code/vivarium_examples/boids/visualization.py` - :lines: 20-41 + :start-after: # docs-start: plot_boids_animated + :end-before: # docs-end: plot_boids_animated Then, try it out like so: diff --git a/docs/source/tutorials/disease_model.rst b/docs/source/tutorials/disease_model.rst index e5f201147..a108aecf6 100644 --- a/docs/source/tutorials/disease_model.rst +++ b/docs/source/tutorials/disease_model.rst @@ -46,16 +46,16 @@ of some of the more complex pieces/systems until later. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py :caption: **File**: :file:`~/code/vivarium/examples/disease_model/population.py` + :linenos: There are a lot of things here. Let's take them piece by piece. -*Note: docstrings are left out of the code snippets below.* - Imports +++++++ .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 1-8 + :start-after: # docs-start: imports + :end-before: # docs-end: imports It's typical to import all required objects at the top of each module. In this case, we are importing ``pandas`` and the Vivarium @@ -76,9 +76,6 @@ information in method signatures. BasePopulation Instantiation ++++++++++++++++++++++++++++ -.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 11 - We define a class called ``BasePopulation`` that inherits from the Vivarium :class:`Component `. This inheritance is what makes a class a proper Vivarium :term:`component` and all the affordances that @@ -88,9 +85,10 @@ Default Configuration +++++++++++++++++++++ .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 18-19, 25-32 + :start-after: # docs-start: configuration_defaults + :end-before: # docs-end: configuration_defaults -You'll see this sort of pattern repeated in many, many Vivarium components. +You'll see this sort of pattern repeated in many Vivarium components. We declare a configuration block as a property for components. Vivarium has a :doc:`cascading configuration system ` that @@ -106,16 +104,16 @@ initial population of simulants. It also notes that there is a `'population_size'` key. This key has a default value set by Vivarium's population management system. -Columns Created -+++++++++++++++ +Sub-components +++++++++++++++ + .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 34-36 + :start-after: # docs-start: sub_components + :end-before: # docs-end: sub_components -This property is a list of the columns that the component will create in the -population state table. The population management system uses information about -what columns are created by which components in order to determine what order to -call initializers defined in separate classes. We'll see what this means in -practice later. +This property is a list of components that are managed by this component. In this +case, we see that when ``BasePopulation`` is set up, it will also set up an instance +of the ``Mortality`` component. The ``__init__()`` method +++++++++++++++++++++++++ @@ -159,7 +157,7 @@ method on each component and calls that method with a to register listeners for ``'time_step'`` events. - ``builder.population`` : The population management system. Registers population initializers (functions that fill in initial state - information about simulants), give access to views of the simulation + information about simulants), gives access to views of the simulation state, and mediates updates to the simulation state. It also provides access to functionality for generating new simulants (e.g. via birth or migration), though we won't use that feature in this tutorial. @@ -176,28 +174,26 @@ method on each component and calls that method with a Let's step through the ``setup`` method and examine what's happening. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 43, 55-71 - :dedent: 4 - :linenos: + :start-after: # docs-start: setup + :end-before: # docs-end: setup -Line 2 simply grabs a copy of the simulation +To start, we simply grab a copy of the simulation :class:`configuration `. This is essentially a dictionary that supports ``.``-access notation. -Lines 4-18 interact with Vivarium's +The next handful of lines interact with Vivarium's :class:`randomness system `. Several things are happening here. -Lines 4-13 deal with the topic of :doc:`Common Random Numbers `, +First, we deal with the topic of :doc:`Common Random Numbers `, a variance reduction technique employed by the Vivarium framework to make it easier to perform counterfactual analysis. It's not important to have a full grasp of this system at this point. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 57-66 + :start-after: # docs-start: crn + :end-before: # docs-end: crn :dedent: 4 - :linenos: - :lineno-start: 4 .. note:: @@ -230,14 +226,13 @@ randomness system to let us know whether or not we care about using CRN. We'll explore this later when we're looking at running simulations with interventions. -Finally, we grab actual :class:`randomness streams ` +Next, we grab actual :class:`randomness streams ` from the framework. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 68-71 + :start-after: # docs-start: randomness + :end-before: # docs-end: randomness :dedent: 4 - :linenos: - :lineno-start: 15 ``get_stream`` is the only call most components make to the randomness system. The best way to think about randomness streams is as decision points in your @@ -258,45 +253,64 @@ initialization streams in a simulation. The ``'sex_randomness'`` is a much more typical example of how to interact with the randomness system - we are simply getting the stream. +Finally, we register two initializers with the population manager. This tells +vivarium which initializers to call when creating the component as well as +which columns (if any) each initializer is responsible for creating. + +.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py + :start-after: # docs-start: initializers + :end-before: # docs-end: initializers + :dedent: 4 + +Note that each initializer registration requires us to specify which columns +that initializer is creating (if any), the initializer method itself, and any +required resources. In this case, each initializer depends on a randomness stream. +The system will ensure that these required resources are set up before calling the +initializer methods. + **That was a lot of stuff** -As I mentioned at the top the population component is one of the more +As I mentioned at the top, the population component is one of the more complicated pieces of any simulation. It's not important to grasp everything right now. We'll see many of the same patterns repeated in the ``setup`` method of other components later. The unique things here are worth coming back to at a later point once you have more familiarity with the framework conventions. -The ``on_initialize_simulants`` method -++++++++++++++++++++++++++++++++++++++ +The initializers +++++++++++++++++ -The primary purpose of this method (for this class) is to generate the initial -population. Specifically, it will generate the 'age', 'sex', 'alive', and -'entrance_time' columns for the population table (recall that the ``columns_created`` -property dictates that this component will indeed create these columns). +Two methods are registered as initializers. The primary purpose of initializer methods +is to generate the initial population. Specifically (for this class), the +``initialize_entrance_time_and_age`` method is registered to create the 'entrance_time' +and 'age' columns (with ``age_randomness`` as a dependency) and the +``initialize_sex`` method is registered to create the 'sex' column (with +``sex_randomness`` as a dependency). .. note:: - **The Population Table** + **The Population State Table** When we talk about columns in the context of Vivarium, we are typically - talking about the simulant :term:`attributes `. Vivarium - represents the population of simulants as a single - :class:`pandas.DataFrame`. We think of each simulant as a row in this table - and each column as an attribute of the simulants. + talking about simulant :term:`attributes `. All attributes for all + simulants can be represented by a single ``pandas.DataFrame`` where each simulant + is a row and each attribute is a column. This dataframe is referred to as the + :term:`population state table` (or simply "state table"). As previously mentioned, this class is a proper Vivarium :term:`Component`. Among other things, this means that much of the setup happens automatically during the simulation's ``Setup`` :doc:`lifecycle phase `. There are several methods available to define for a component's setup, depending -on what you want to happen when: ``on_post_setup()``, ``on_initialize_simulants()`` -(this one), ``on_time_step_prepare()``, ``on_time_step()``, ``on_time_step_cleanup()``., +on what you want to happen when: ``on_post_setup()`` +``on_time_step_prepare()``, ``on_time_step()``, ``on_time_step_cleanup()``, ``on_collect_metrics()``, and ``on_simulation_end()``. The framework looks for any of these methods during the setup phase and calls them if they are defined. -The fact that this method is called ``on_initialize_simulants`` guarantees that -it will be called during the population initialization phase of the simulation. +Further, the framework calls any registered initializer methods during population +creation. The fact that ``initialize_entrance_time_and_age`` and ``initialize_sex`` +were registered guarantees that they will be called during the population initialization +phase of the simulation. -This initializer method is called by the population management whenever simulants +Initializer methods are called by the population manager whenever simulants are created. For our purposes, this happens only once at the very beginning of the simulation. Typically, we'd task another component with responsibility for managing other ways simulants might enter (we might, for instance, have a @@ -304,15 +318,18 @@ managing other ways simulants might enter (we might, for instance, have a our location of interest or a ``Fertility`` component that handles new simulants being born). -We'll take this method line by line as we did with ``setup``. +Let's inspect the two initializer methods line by line as we did with ``setup``. + +The ``initialize_entrance_time_and_age`` method +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 77, 102-132 + :start-after: # docs-start: initialize_entrance_time_and_age + :end-before: # docs-end: initialize_entrance_time_and_age :dedent: 4 - :linenos: First, we see that this method takes in a special argument that we don't provide. -This argument, ``pop_data`` is an instance of +This argument, ``pop_data``, is an instance of :class:`~vivarium.framework.population.manager.SimulantData` containing a handful of information useful when initializing simulants. @@ -338,10 +355,9 @@ property we specified an ``'age_start'`` and ``'age_end'``. Here we use these to generate the age distribution of our initial population. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 102-111 + :start-after: # docs-start: ages + :end-before: # docs-end: ages :dedent: 4 - :linenos: - :lineno-start: 2 We've built in support for two different kinds of populations based on the ``'age_start'`` and ``'age_end'`` specified in the configuration. If we get @@ -364,11 +380,18 @@ in the index. **The Population Index** - The population table we described before has an index that uniquely + The population state table we described before has an index that uniquely identifies each simulant. This index is used in several places in the simulation to look up information, calculate simulant-specific values, and update information about the simulants' state. +With the simulant ages defined, we then create a dataframe of simulant +ages and entrance times. + +.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py + :start-after: # docs-start: population_dataframe + :end-before: # docs-end: population_dataframe + :dedent: 4 We then come back to the question of whether or not we're using common random numbers in our system. In the ``setup`` method, our criteria for @@ -377,31 +400,23 @@ were specified as the randomness ``key_columns`` in the configuration. These ``key_columns`` are what the randomness system uses to uniquely identify simulants across simulations. -.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 113-120 - :dedent: 4 - :linenos: - :lineno-start: 13 - -If we are using CRN, we must generate these columns before any other calls +Note that if we are using CRN, we must generate these columns before any other calls are made to the randomness system with the population index. We then register these simulants with the randomness system using ``self.register``, -a reference to ``register_simulants`` method in the randomness management +a reference to the ``register_simulants`` method in the randomness management system. This is responsible for mapping the attributes of interest (here ``'entrance_time'`` and ``'age'``) to a particular set of random numbers that will be used across simulations with the same random seed. +.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py + :start-after: # docs-start: crn_registration + :end-before: # docs-end: crn_registration + :dedent: 4 + Once registered, we can generate the remaining attributes of our simulants with guarantees around reproducibility. -If we're not using CRN, we can just generate the full set of simulant -attributes straightaway. - -.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 121-130 - :dedent: 4 - :linenos: - :lineno-start: 21 +If we're not using CRN, we simply don't register the simulants and move on. In either case, we are hanging on to a table representing some attributes of our new simulants. However, this table does not matter yet because the @@ -411,11 +426,14 @@ inform the simulation by passing in the ``DataFrame`` to our ``update`` method. This method is the only way to modify the underlying population table. -.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 132 - :dedent: 4 - :linenos: - :lineno-start: 32 +.. note:: + + **Population Views** + + A :class:`~vivarium.framework.population.population_view.PopulationView` is a + read/write interface to the population state table. It provides a number of + convenience methods for getting and setting attributes, private columns, + and other bits of information about the population. .. warning:: @@ -423,6 +441,24 @@ population table. must have the same index that was passed in with the ``pop_data``. You can potentially cause yourself a great deal of headache otherwise. +.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py + :start-after: # docs-start: update_entrance_time_and_age + :end-before: # docs-end: update_entrance_time_and_age + :dedent: 4 + +The ``initialize_sex`` method +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py + :start-after: # docs-start: initialize_sex + :end-before: # docs-end: initialize_sex + :dedent: 4 + +Thankfully, the ``initialize_sex`` method is much simpler than the previous +initializer. Here, we simply call another randomness stream convenience method +``self.sex_randomness.choice`` to randomly assign a sex to each simulant. We then +update the population via population view with these new values just like before. + The ``on_time_step`` method +++++++++++++++++++++++++++ @@ -430,13 +466,13 @@ The last piece of our population component is the ``'time_step'`` listener method ``on_time_step``. .. literalinclude:: ../../../src/vivarium/examples/disease_model/population.py - :lines: 134, 144-146 + :start-after: # docs-start: on_time_step + :end-before: # docs-end: on_time_step :dedent: 4 - :linenos: -This method takes in an :class:`~vivarium.framework.event.Event` argument +This method takes in an :class:`~ ` argument provided by the simulation. This is very similar to the ``SimulantData`` -argument provided to ``on_initialize_simulants``. It carries around +argument provided to initializer methods. It carries around some information about what's happening in the event. .. note:: @@ -456,22 +492,34 @@ some information about what's happening in the event. It also supports some method for generating new events that we don't care about here. -In order to age our simulants, we first acquire a copy of the current -population state from our population view. In addition to the ``update`` -method, population views also support a ``get`` method that takes in -an index and an optional ``query`` used to filter down the returned -population. Here, we only want to increase the age of people still living. -The ``query`` argument needs to be consistent with the -:meth:`pandas.DataFrame.query` method. +In order to age our simulants, we first acquire a copy of the current ages via +the :meth:`vivarium.framework.population.population_view.PopulationView.get_private_columns` +method. This method takes in an index and optional ``private_columns`` and/or ``query`` +arguments used request specific columns and to filter down the returned +population, respectively. Here, we only need the simulant ages and we only want +to increase the age of people still living. Note that the ``query`` argument needs +to be consistent with the :meth:`pandas.DataFrame.query` method. What we get back +is a ``pandas.Series`` of simulant ages containing the filtered rows corresponding +to the index we passed in. + +.. note:: + + **Private Columns vs. Attributes** -What we get back is another ``pandas.DataFrame`` containing the filtered -rows corresponding to the index we passed in. The columns of the returned -``DataFrame`` are precisely the columns this component created (as well as -any additional ``columns_required``, of which this component has none). + Population views provide methods to get both attributes and private columns. + If you want attributes, you can use :meth:`vivarium.framework.population.population_view.PopulationView.get_attributes` + or :meth:`vivarium.framework.population.population_view.PopulationView.get_attribute_frame`. + If you want private columns, use :meth:`vivarium.framework.population.population_view.PopulationView.get_private_columns`. + + Knowing whether you need a private column or an attribute depends on context, + but when you need to update the state table as we're doing here, it's important + to understand that what you are really updating are the appropriate private columns + that act as the source of the attributes you care about. Refer to the + :ref:`population management documentation ` for more details. We next update the age of our simulants by adding on the width of the time step to their current age and passing the updated table to the ``update`` method -of our population view as we did in ``on_initialize_simulants`` +of our population view as we did in our initializer methods. Examining our work ++++++++++++++++++ @@ -487,16 +535,16 @@ Now that we've done all this hard work, let's see what it gives us. sim = InteractiveContext(components=[BasePopulation()], configuration=config) - print(sim.get_population().head()) + print(sim.get_population(['age', 'sex']).head()) :: - tracked sex age entrance_time alive - 0 True Female 13.806775818385496 2005-07-01 alive - 1 True Male 59.17289327893596 2005-07-01 alive - 2 True Female 11.030887339897 2005-07-01 alive - 3 True Female 27.72319127598699 2005-07-01 alive - 4 True Female 51.05218820533359 2005-07-01 alive + age sex + 0 13.806776 Female + 1 59.172893 Male + 2 11.030887 Female + 3 27.723191 Female + 4 51.052188 Female .. testcode:: :hide: @@ -508,11 +556,17 @@ Now that we've done all this hard work, let's see what it gives us. config = {'randomness': {'key_columns': ['entrance_time', 'age']}} sim = InteractiveContext(components=[BasePopulation()], configuration=config) - expected = pd.DataFrame({ - 'age': [13.806775818385496, 59.17289327893596, 11.030887339897, 27.72319127598699, 51.05218820533359], - 'sex': ['Female', 'Male', 'Female', 'Female', 'Female'], - }) - pd.testing.assert_frame_equal(sim.get_population().head()[['age', 'sex']], expected) + + print(sim.get_population(['age', 'sex']).head()) + +.. testoutput:: + + age sex + 0 13.806776 Female + 1 59.172893 Male + 2 11.030887 Female + 3 27.723191 Female + 4 51.052188 Female Great! We generate a population with a non-trivial age and sex distribution. Let's see what happens when our simulation takes a time step. @@ -520,16 +574,16 @@ Let's see what happens when our simulation takes a time step. .. code-block:: python sim.step() - print(sim.get_population().head()) + print(sim.get_population(['age', 'sex']).head()) :: - tracked sex age entrance_time alive - 0 True Female 13.806775818385496 2005-07-01 alive - 1 True Male 59.17289327893596 2005-07-01 alive - 2 True Female 11.030887339897 2005-07-01 alive - 3 True Female 27.72319127598699 2005-07-01 alive - 4 True Female 51.05218820533359 2005-07-01 alive + age sex + 0 13.809516 Female + 1 59.175633 Male + 2 11.033627 Female + 3 27.725931 Female + 4 51.054928 Female .. testcode:: @@ -538,7 +592,17 @@ Let's see what happens when our simulation takes a time step. import numpy as np sim.step() - assert np.isclose((sim.get_population().head()['age'] - expected['age'])*365, 1, 0.000001).all() + + print(sim.get_population(['age', 'sex']).head()) + +.. testoutput:: + + age sex + 0 13.809516 Female + 1 59.175633 Male + 2 11.033627 Female + 3 27.725931 Female + 4 51.054928 Female Everyone gets older by exactly one time step! We could just keep taking steps in our simulation and people would continue getting infinitely older. This, of @@ -547,11 +611,14 @@ course, does not reflect how the world goes. Time to introduce the grim reaper. Mortality --------- -Now that we have population generation and aging working, the next step -is introducing mortality into our simulation. +Now that we have demonstrated that population generation and aging works, let's +investigate the Mortality component. Note that Mortality is a sub-component +of the BasePopulation component and comes for free when we request BasePopulation +via the model specification; there is no need to add Mortality separately. .. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py :caption: **File**: :file:`~/code/vivarium/examples/disease_model/mortality.py` + :linenos: The purpose of this component is to determine who dies every time step based on a mortality rate. You'll see many of the same framework features we used @@ -566,7 +633,8 @@ Since we're building our disease model without data to inform it, we'll expose all the important bits of the model as parameters in the configuration. .. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 16-17, 23-27 + :start-after: # docs-start: configuration_defaults + :end-before: # docs-end: configuration_defaults Here we're specifying the overall mortality rate in our simulation. Rates have units! We'll phrase our model with rates specified in terms of events per @@ -574,17 +642,6 @@ person-year. So here we're specifying a uniform mortality rate of 0.01 deaths per person-year. This is obviously not realistic, but using toy data like this is often extremely useful in validating a model. -Columns Required -++++++++++++++++ - -.. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 29-31 - -While this component does not create any new columns like the ``BasePopulation`` -component, it does require the ``'tracked'`` and ``'alive'`` columns to be -present in the population table. You'll see that these columns are indeed used -in the ``on_time_step`` and ``on_time_step_prepare`` methods. - The ``setup`` method ++++++++++++++++++++ @@ -592,36 +649,65 @@ There is not a whole lot going on in this setup method, but there is one new con we should discuss. .. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 38, 50-55 + :start-after: # docs-start: setup + :end-before: # docs-end: setup -The first two lines are simply adding some useful attributes: the mortality-specific -configuration and the mortality randomness stream (which is used to answer the -question "which simulants died at this time step?"). +The first line simply adds a useful class attribute: the mortality randomness stream +(which is used to answer the question "which simulants died at this time step?"). -The main feature of note is the introduction of the -:class:`values system `. +The next bit is the main feature of note: the introduction of the +:class:`values system `. The values system provides a way of distributing the computation of a value over multiple components. This can be a bit difficult to grasp, but is vital to the way we think about components in Vivarium. The best way to understand this system is by :doc:`example. ` -In our current context we register a named value "pipeline" into the +.. note:: + + **Values vs Attributes** + + A :term:`value ` is generic and is simply something that is computed from + a :term: `pipeline `. An :term:`attribute ` is a specific + kind of value that is simulant-specific, stored in the population state table, + and is computed from a :term:`attribute pipeline `. Most + values in vivarium are attributes. + +In our current context we register a named attribute pipeline into the simulation called ``'mortality_rate'`` via the ``builder.value.register_rate_producer`` -method. The source for a value is always a callable function or method -(``self.base_mortality_rate`` in this case) which typically takes in a -``pandas.Index`` as its only argument. Other things are possible, but not +method. The source for a value is always a callable which typically takes in a +``pandas.Index`` as its only argument. In this case, the source is a LookupTable, +which is callable, so meets this requirement. Other things are possible, but not necessary for our current use case. The ``'mortality_rate'`` source is then responsible for returning a ``pandas.Series`` containing a base mortality rate for each simulant in the index to the values system. Other components may register themselves as modifiers to this base rate. We'll see more of this once we get to the -disease modelling portion of the tutorial. +disease modeling portion of the tutorial. The value system will coordinate how the base value is modified behind the scenes and return the results of all computations whenever the pipeline is called (in the ``on_time_step`` method in this case - see below). +Finally, we register an initializer method which is responsible for creating an +``'is_alive'`` column in the state table. + +The ``initialize_is_alive`` method +++++++++++++++++++++++++++++++++++++++ + +.. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py + :start-after: # docs-start: initialize_is_alive + :end-before: # docs-end: initialize_is_alive + :dedent: 4 + +This very simple initializer method simply creates an ``'is_alive'`` column in the state +table and sets it to True for all simulants being initialized. Note again +that we need to call the population view's ``update`` method to actually modify +the state table. + +Notice also that when registering this method, we did not specify any required resources +(since every simulant is set as alive regardless of anything else). + The ``on_time_step`` method +++++++++++++++++++++++++++ @@ -629,62 +715,38 @@ Similar to how we aged simulants in the population component, we determine which simulants die during ``'time_step'`` events. .. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 61, 71-77 + :start-after: # docs-start: on_time_step + :end-before: # docs-end: on_time_step :dedent: 4 - :linenos: -Line 2 is where we actually call the pipeline we constructed during setup. -It will return the effective mortality rate for each person in the simulation. +The very first thing we do is get the ``'mortality_rate'`` attribute (which is calculated +from the attribute pipeline we constructed during setup); these are the effective +mortality rate for each person in the simulation. Right now this will just be the base mortality rate, but we'll see how this changes once we bring in a disease. Importantly for now though, the pipeline is automatically rescaling the rate down to the size of the time steps we're taking. -In lines 3-5, we determine who died this time step. We turn our mortality rate +We then determine who died this time step. We turn our mortality rate into a probability of death in the given time step by assuming deaths are `exponentially distributed `_ and using the inverse distribution function. We then draw a uniformly distributed random number for each person and determine who died by comparing that number to the computed probability of death for the individual. -Finally, we update the state table ``'alive'`` column with the newly dead simulants. - -Note that when getting a view of the state table to update, we are using the -``subview`` method which returns only the columns requested. - -The ``on_time_step_prepare`` method -+++++++++++++++++++++++++++++++++++ - -This method simply updates any simulants who died during the previous time step -to be marked as untracked (that is, their ``'tracked'`` value is set to ``False``). +Finally, we update the state table ``is_alive`` column with the newly dead simulants. -.. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 79, 92-96 - -Why didn't we update the newly-dead simulants ``'tracked'`` values at the same time -as their ``'alive'`` values in the ``on_time_step`` method? The reason is that the -deaths observer (discussed later) records the number of deaths that occurred during -the previous time step during the ``collect_metrics`` phase. By updating -the ``'alive'`` column during the ``time_step`` phase (which occurs *before* -``collect_metrics``) and the ``'tracked'`` column during the ``time_step_prepare`` -phase (which occurs *after* ``collect_metrics``), we ensure that the observer -can distinguish which simulants died specifically during the previous time step. Supplying a base mortality rate +++++++++++++++++++++++++++++++ -As discussed above, the ``base_mortality_rate`` method is the source for -the ``'mortality_rate'`` value. Here we take in an index and build -a ``pandas.Series`` that assigns each individual the mortality rate -specified in the configuration. - -.. literalinclude:: ../../../src/vivarium/examples/disease_model/mortality.py - :lines: 102, 115 +As discussed above, the source for the ``'mortality_rate'`` value is a LookupTable +defined in component's configuration. In an actual simulation, we'd inform the base mortality rate with data specific to the age, sex, location, year (and potentially other demographic factors) that represent each simulant. We might disaggregate or interpolate -our data here as well. Which is all to say, the source of a data pipeline can +our data here as well. Which is all to say, the source of a pipeline can do some pretty complicated stuff. Did it work? @@ -699,94 +761,108 @@ can see the impact of our mortality component without taking too many steps. from vivarium import InteractiveContext from vivarium.examples.disease_model.population import BasePopulation - from vivarium.examples.disease_model.mortality import Mortality config = { - 'population': { - 'population_size': 100_000 - }, - 'randomness': { - 'key_columns': ['entrance_time', 'age'] - } + 'population': { + 'population_size': 100_000 + }, + 'randomness': { + 'key_columns': ['entrance_time', 'age'] + } } - sim = InteractiveContext(components=[BasePopulation(), Mortality()], configuration=config) - print(sim.get_population().head()) + sim = InteractiveContext(components=[BasePopulation()], configuration=config) + print(sim.get_population(['age', 'sex', 'mortality_rate', 'is_alive']).head()) :: - tracked sex age entrance_time alive - 0 True Female 13.806775818385496 2005-07-01 alive - 1 True Male 59.17289327893596 2005-07-01 alive - 2 True Female 11.030887339897 2005-07-01 alive - 3 True Female 27.72319127598699 2005-07-01 alive - 4 True Female 51.05218820533359 2005-07-01 alive + age sex mortality_rate is_alive + 0 13.806776 Female 0.000027 True + 1 59.172893 Male 0.000027 True + 2 11.030887 Female 0.000027 True + 3 27.723191 Female 0.000027 True + 4 51.052188 Female 0.000027 True .. testcode:: :hide: - from vivarium.examples.disease_model.mortality import Mortality - config = { - 'population': { - 'population_size': 100_000 - }, - 'randomness': { - 'key_columns': ['entrance_time', 'age'] - } + 'population': { + 'population_size': 100_000 + }, + 'randomness': { + 'key_columns': ['entrance_time', 'age'] + } } - sim = InteractiveContext(components=[BasePopulation(), Mortality()], configuration=config) + sim = InteractiveContext(components=[BasePopulation()], configuration=config) + + print(sim.get_population(['age', 'sex', 'mortality_rate', 'is_alive']).head()) - expected = pd.DataFrame({ - 'age': [13.806775818385496, 59.17289327893596, 11.030887339897, 27.72319127598699, 51.05218820533359], - 'sex': ['Female', 'Male', 'Female', 'Female', 'Female'], - }) - pd.testing.assert_frame_equal(sim.get_population().head()[['age', 'sex']], expected) +.. testoutput:: -This looks (exactly!) the same as it did prior to implementing mortality. Good - -we haven't taken a time step yet and so no one should have died. + age sex mortality_rate is_alive + 0 13.806776 Female 0.000027 True + 1 59.172893 Male 0.000027 True + 2 11.030887 Female 0.000027 True + 3 27.723191 Female 0.000027 True + 4 51.052188 Female 0.000027 True + +Note that aside from modifying the population size in the config, we haven't actually +done anything different than before. Indeed, the ages and sexes of the first five +simulants are the same. Here, however, we are not subsetting the dataframe to only +show the ``'age'`` and ``'sex'`` columns, however, and so we see various others +(notably, the ``'mortality_rate'`` and ``'is_alive'`` columns created by the Mortality +component). + +As we haven't taken a time step yet, everyone should still be alive. .. code-block:: python - print(sim.get_population().alive.value_counts()) + print(sim.get_population("is_alive").value_counts()) :: - alive - alive 100000 - Name: count, dtype: int64 + is_alive + True 100000 + Name: count, dtype: int64 .. testcode:: :hide: - assert sim.get_population().alive.value_counts().alive == 100_000 + print(sim.get_population("is_alive").value_counts()) -Just checking that everyone is alive. Let's run our simulation for a while -and see what happens. +.. testoutput:: + + is_alive + True 100000 + Name: count, dtype: int64 + + +Now let's run our simulation for a while and see what happens. .. code-block:: python sim.take_steps(365) # Run for one year with one day time steps - sim.get_population('tracked==True').alive.value_counts() + sim.get_population("is_alive").value_counts() :: - alive - alive 99015 - dead 985 + is_alive + True 99023 + False 977 Name: count, dtype: int64 -We simulated somewhere between 99,015 (if everyone died in the first time step) +We simulated somewhere between 99,023 (if everyone died in the first time step) and 100,000 (if everyone died in the last time step) living person-years and -saw 985 deaths. This means our empirical mortality rate is somewhere close -to 0.0099 deaths per person-year, very close to the 0.01 rate we provided. +saw 977 deaths. This means our empirical mortality rate is somewhere close +to 0.0098 deaths per person-year, very close to the 0.01 rate we provided. .. testcode:: :hide: - - sim = InteractiveContext(components=[BasePopulation(), Mortality()], configuration=config) - sim.take_steps(2) - assert sim.get_population('tracked==True')['alive'].value_counts()['dead'] == 6 + + # It takes too long to run 365 steps in the test, so we just run 10 steps here + sim.take_steps(10) + assert sim.get_population("is_alive").value_counts()[False] == 27 Disease ------- @@ -815,7 +891,7 @@ the simulation itself to record more sophisticated output. Further, we frequentl work in non-interactive (or even distributed) environments where we simply don't have access to the simulation object and so would like to write our output to disk. These recorded outputs (i.e. results) are referred to in vivarium as **observations** -and it is the job of so-called **observers** to register them to the simulation. +and it is the job of **observers** to register them to the simulation. :class:`Observers ` are vivarium :class:`components ` that are created by the user and added to the simulation via the model specification. @@ -824,11 +900,26 @@ This example's observers are shown below. .. literalinclude:: ../../../src/vivarium/examples/disease_model/observer.py :caption: **File**: :file:`~/code/vivarium/examples/disease_model/observer.py` + :linenos: There are two observers that have each registered a single observation to the simulation: deaths and years of life lost (YLLs). It is important to note that neither of those observations are population state table columns; they are -more complex results that require some computation to determine. +more complex results that require some computation to determine. + +Note that the deaths observer actually creates a private column called ``'previous_alive'``. +The purpose of this column is to distinguish newly-dead simulants (for adding purposes) +from those that died in previous time steps. We update this column in the +``on_time_step_prepare`` method of the observer. + +Why didn't we update the ``'previous_alive'`` values at the same time +as the ``'is_alive'`` values in the Mortality component's ``on_time_step`` method? +The reason is that the deaths observer records the number of deaths that occurred during +the previous time step during the ``collect_metrics`` phase. By updating +the ``'is_alive'`` column during the ``time_step`` phase (which occurs *before* +``collect_metrics``) and the ``'previous_alive'`` column during the ``time_step_prepare`` +phase (which occurs *after* ``collect_metrics``), we ensure that the observer +can distinguish which simulants died specifically during the previous time step. In an interactive setting, we can access these observations via the ``sim.get_results()`` command. This will return a dictionary of all @@ -838,7 +929,6 @@ observations up to this point in the simulation. from vivarium import InteractiveContext from vivarium.examples.disease_model.population import BasePopulation - from vivarium.examples.disease_model.mortality import Mortality from vivarium.examples.disease_model.observer import DeathsObserver, YllsObserver config = { @@ -853,27 +943,26 @@ observations up to this point in the simulation. sim = InteractiveContext( components=[ BasePopulation(), - Mortality(), DeathsObserver(), YllsObserver(), ], configuration=config ) sim.take_steps(365) # Run for one year with one day time steps - + print(sim.get_results()["dead"]) print(sim.get_results()["ylls"]) :: - stratification value - 0 all 985.0 + stratification value + 0 all 977.0 - stratification value - 0 all 27966.647762 + stratification value + 0 all 27720.319912 -We see that after 365 days of simulation, 985 simlants have died and there has -been a total of 27,987 years of life lost. +We see that after 365 days of simulation, 977 simlants have died and there has +been a total of 27,720 years of life lost. .. testcode:: :hide: @@ -883,19 +972,20 @@ been a total of 27,987 years of life lost. sim = InteractiveContext( components=[ BasePopulation(), - Mortality(), DeathsObserver(), YllsObserver(), ], configuration=config ) - sim.take_steps(2) + + # It takes too long to run 365 steps in the test, so we just run 10 steps here + sim.take_steps(10) dead = sim.get_results()["dead"] assert len(dead) == 1 - assert dead["value"][0] == 6 + assert dead["value"][0] == 27 ylls = sim.get_results()["ylls"] assert len(ylls) == 1 - assert ylls["value"][0] == 333.9956932528944 + assert ylls["value"][0] == 1030.7382838676458 .. note:: diff --git a/docs/source/tutorials/exploration.rst b/docs/source/tutorials/exploration.rst index ba17d8963..9b2c1d440 100644 --- a/docs/source/tutorials/exploration.rst +++ b/docs/source/tutorials/exploration.rst @@ -14,8 +14,8 @@ In this tutorial we'll focus on exploring simulations in an interactive setting. The only prerequisite is that you've set up your programming environment (See :ref:`the getting started section `). We'll look -at how to examine the :term:`population table `, how to -print and interpret the simulation :term:`configuration `, +at how to examine the :term:`population state table `, how +to print and interpret the simulation :term:`configuration `, and how to get values from the :term:`value pipeline ` system. We'll work through all this with a few case studies using the simulations @@ -90,12 +90,12 @@ configuration by simply printing it. .. testsetup:: configuration - from vivarium.examples.disease_model import get_disease_model_simulation + from vivarium.examples.disease_model import get_disease_model_simulation - sim = get_disease_model_simulation() + sim = get_disease_model_simulation() - del sim.configuration['input_data'] - del sim.configuration['stratification']['excluded_categories'] + del sim.configuration['input_data'] + del sim.configuration['stratification']['excluded_categories'] .. testcode:: configuration @@ -145,6 +145,9 @@ configuration by simply printing it. model_override: 0.0114 life_expectancy: model_override: 88.9 + data_sources: + mortality_rate: + component_configs: 0.01 lower_respiratory_infections: incidence_rate: model_override: 0.871 @@ -176,14 +179,23 @@ configuration by simply printing it. component_configs: [] include: component_configs: [] + ylls: + exclude: + component_configs: [] + include: + component_configs: [] disease_state.susceptible_to_lower_respiratory_infections: data_sources: initialization_weights: component_configs: 1.0 + excess_mortality_rate: + component_configs: 0.0 disease_state.infected_with_lower_respiratory_infections: data_sources: initialization_weights: component_configs: 0.0 + excess_mortality_rate: + component_configs: 0.0 What do we see here? The configuration is *hierarchical*. There are a set of @@ -192,20 +204,20 @@ just those subsets if we like. .. testcode:: - print(sim.configuration.randomness) + print(sim.configuration.randomness) .. testoutput:: - key_columns: - model_override: ['entrance_time', 'age'] - map_size: - component_configs: 1000000 - random_seed: - component_configs: 0 - additional_seed: - component_configs: None - rate_conversion_type: - component_configs: linear + key_columns: + model_override: ['entrance_time', 'age'] + map_size: + component_configs: 1000000 + random_seed: + component_configs: 0 + additional_seed: + component_configs: None + rate_conversion_type: + component_configs: linear This subset of configuration data contains more keys. All of the keys in our example here (key_columns, map_size, random_seed, additional_seed, @@ -214,52 +226,52 @@ as well. .. testcode:: - print(sim.configuration.randomness.key_columns) - print(sim.configuration.randomness.map_size) - print(sim.configuration.randomness.random_seed) - print(sim.configuration.randomness.additional_seed) - print(sim.configuration.randomness.rate_conversion_type) + print(sim.configuration.randomness.key_columns) + print(sim.configuration.randomness.map_size) + print(sim.configuration.randomness.random_seed) + print(sim.configuration.randomness.additional_seed) + print(sim.configuration.randomness.rate_conversion_type) .. testoutput:: - ['entrance_time', 'age'] - 1000000 - 0 - None - linear + ['entrance_time', 'age'] + 1000000 + 0 + None + linear However, we can no longer modify the configuration since the simulation has already been setup. .. testcode:: - from layered_config_tree import ConfigurationError + from layered_config_tree import ConfigurationError - try: - sim.configuration.randomness.update({'random_seed': 5}) - except ConfigurationError: - print("Can't update configuration after setup") + try: + sim.configuration.randomness.update({'random_seed': 5}) + except ConfigurationError: + print("Can't update configuration after setup") .. testoutput:: - Can't update configuration after setup + Can't update configuration after setup If we look again at the randomness configuration, it appears that there should be one more layer of keys. .. code-block:: python - key_columns: - model_override: ['entrance_time', 'age'] - map_size: - component_configs: 1000000 - random_seed: - component_configs: 0 - additional_seed: - component_configs: None + key_columns: + model_override: ['entrance_time', 'age'] + map_size: + component_configs: 1000000 + random_seed: + component_configs: 0 + additional_seed: + component_configs: None rate_conversion_type: - component_configs: linear + component_configs: linear This last layer reflects a priority level in the way simulation configuration is managed. The ``component_configs`` under ``map_size``, ``random_seed``, and @@ -283,35 +295,33 @@ your starting population. .. code-block:: python - pop = sim.get_population() - print(pop.head()) + pop = sim.get_population() + print(pop.head()) :: - tracked age alive sex entrance_time lower_respiratory_infections child_wasting_propensity - - 0 True 4.341734 alive Male 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.612086 - 1 True 1.009906 alive Male 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.395465 - 2 True 1.166290 alive Male 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.670765 - 3 True 4.075051 alive Female 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.289266 - 4 True 2.133430 alive Female 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.700001 + age is_alive entrance_time lower_respiratory_infections child_wasting_propensity + 0 1.707662 True 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.579157 + 1 2.731665 True 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.280783 + 2 0.511246 True 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.332681 + 3 2.898714 True 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.505482 + 4 1.381896 True 2021-12-31 12:00:00 susceptible_to_lower_respiratory_infections 0.017806 This gives you a ``pandas.DataFrame`` representing your starting population. You can use it to check all sorts of characteristics about individuals or the population as a whole. .. testcode:: - :hide: + :hide: - pop = sim.get_population() - pop = pop.reindex(sorted(pop.columns), axis=1) - print(pop.age.describe()) - print(pop.alive.value_counts()) - print(pop.child_wasting_propensity.describe()) - print(pop.lower_respiratory_infections.value_counts()) - print(pop.entrance_time.value_counts()) - print(pop.sex.value_counts()) - print(pop.tracked.value_counts()) + pop = sim.get_population() + pop = pop.reindex(sorted(pop.columns), axis=1) + print(pop.age.describe()) + print(pop.is_alive.value_counts()) + print(pop.child_wasting_propensity.describe()) + print(pop.lower_respiratory_infections.value_counts()) + print(pop.entrance_time.value_counts()) + print(pop.sex.value_counts()) .. testoutput:: @@ -325,8 +335,8 @@ the population as a whole. 75% 3.730555 max 4.999957 Name: age, dtype: float64 - alive - alive 100000 + is_alive + True 100000 Name: count, dtype: int64 count 1.000000e+05 mean 5.007716e-01 @@ -347,10 +357,6 @@ the population as a whole. Female 50011 Male 49989 Name: count, dtype: int64 - tracked - True 100000 - Name: count, dtype: int64 - Understanding the Simulation Data diff --git a/docs/source/tutorials/running_a_simulation/cli.rst b/docs/source/tutorials/running_a_simulation/cli.rst index 9e1afdb0f..9677eb32c 100644 --- a/docs/source/tutorials/running_a_simulation/cli.rst +++ b/docs/source/tutorials/running_a_simulation/cli.rst @@ -29,10 +29,7 @@ to the ``~/vivarium_results`` directory. If you navigate to that directory, you should see a subdirectory with the name of your model specification. Inside the model specification directory, there will be another subdirectory named for the start time of the run. In here, you -should see two hdf files: ``final_state.hdf``, which is the population -:term:`state table ` at the end of the simulation, and -``output.hdf``, which is the results of the :term:`metrics ` generated -by the simulation. +a results directory which contains all of the simulation results files. For example, say we've run a simulation for a model specification called ``potatoes.yaml`` (maybe we're really into gardening). Our directory tree @@ -41,8 +38,9 @@ will look like:: ~/vivarium_results/ potatoes/ 2019_04_20_15_44_20/ - final_state.hdf - output.hdf + results/ + mashed.parquet + eye_counts.parquet If we decide we don't like our results, or want to rerun the simulation with a different set of :term:`configuration parameters `, we'll add @@ -51,11 +49,13 @@ new time stamped sub-directories to our ``potatoes`` model results directory:: ~/vivarium_results/ potatoes/ 2019_04_20_15_44_20/ - final_state.hdf - output.hdf + results/ + mashed.parquet + eye_counts.parquet 2019_04_20_16_34_12/ - final_state.hdf - output.hdf + results/ + mashed.parquet + eye_counts.parquet ``simulate run`` also provides various flags which you can use to configure options for the run. These are: @@ -119,10 +119,6 @@ the following: DEBUG:vivarium.framework.engine:2005-07-07 00:00:00 DEBUG:vivarium.framework.engine:2005-07-10 00:00:00 DEBUG:vivarium.framework.engine:2005-07-13 00:00:00 - DEBUG:vivarium.framework.engine:{'simulation_run_time': 0.7717499732971191, - 'total_population': 10000, - 'total_population_tracked': 10000, - 'total_population_untracked': 0} DEBUG:vivarium.framework.engine:Some configuration keys not used during run: {'input_data.cache_data', 'output_data.results_directory', 'input_data.intermediary_data_cache_path'} The specifics of these messages will depend on your model specification, but diff --git a/docs/source/tutorials/running_a_simulation/interactive.rst b/docs/source/tutorials/running_a_simulation/interactive.rst index 93633b201..ae47f43d4 100644 --- a/docs/source/tutorials/running_a_simulation/interactive.rst +++ b/docs/source/tutorials/running_a_simulation/interactive.rst @@ -11,7 +11,7 @@ Running a simulation in this way is useful for a variety of reasons, foremost for debugging and validation work. It allows for changing simulation :term:`configuration ` programmatically, stepping through a simulation in a controlled fashion, and examining the -:term:`state ` of the simulation itself. +:term:`state ` of the simulation itself. For the following tutorial, we will assume you have set up an environment and @@ -197,7 +197,7 @@ one last way to set up the simulation in an interactive setting. .. testcode:: :hide: - from vivarium.examples.disease_model import (BasePopulation, Mortality, DeathsObserver, + from vivarium.examples.disease_model import (BasePopulation, DeathsObserver, YllsObserver, SISDiseaseModel, Risk, RiskEffect, TreatmentIntervention) from vivarium import InteractiveContext @@ -229,7 +229,6 @@ one last way to set up the simulation in an interactive setting. } components = [BasePopulation(), - Mortality(), SISDiseaseModel('diarrhea'), Risk('child_growth_failure'), RiskEffect('child_growth_failure', 'infected_with_diarrhea.incidence_rate'), diff --git a/pyproject.toml b/pyproject.toml index 5baa3b734..31e78ec11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,13 +30,6 @@ exclude = [ 'build', # Files below here should have their errors fixed and then be removed from this list # You will need to remove the mypy: ignore-errors comment from the file heading as well - 'setup.py', - 'src/vivarium/examples/disease_model/disease.py', - 'src/vivarium/examples/disease_model/intervention.py', - 'src/vivarium/examples/disease_model/mortality.py', - 'src/vivarium/examples/disease_model/observer.py', - 'src/vivarium/examples/disease_model/population.py', - 'src/vivarium/examples/disease_model/risk.py', ] disable_error_code = [] diff --git a/python_versions.json b/python_versions.json index bb92d342c..350c415d8 100644 --- a/python_versions.json +++ b/python_versions.json @@ -1 +1 @@ -["3.10", "3.11"] \ No newline at end of file +["3.10", "3.11", "3.12", "3.13"] \ No newline at end of file diff --git a/setup.py b/setup.py index a1a0e7c88..50df42402 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ base_dir = Path(__file__).parent src_dir = base_dir / "src" - about = {} + about: dict[str, str] = {} with (src_dir / "vivarium" / "__about__.py").open() as f: exec(f.read(), about) diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 5639fc7cf..37333bd56 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -11,23 +11,18 @@ from __future__ import annotations import re -import warnings from abc import ABC -from collections.abc import Callable, Sequence -from datetime import datetime, timedelta +from collections.abc import Sequence from importlib import import_module from inspect import signature -from typing import TYPE_CHECKING, Any, Literal -from typing import SupportsFloat as Numeric -from typing import cast +from typing import TYPE_CHECKING, Any, overload import pandas as pd from layered_config_tree import ConfigurationError, LayeredConfigTree from vivarium.framework.artifact import ArtifactException -from vivarium.framework.lifecycle import lifecycle_states -from vivarium.framework.population import PopulationError -from vivarium.types import ScalarValue +from vivarium.framework.lifecycle import LifeCycleError, lifecycle_states +from vivarium.types import LookupTableData if TYPE_CHECKING: import loguru @@ -35,8 +30,7 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.lookup import LookupTable - from vivarium.framework.population import PopulationView, SimulantData - from vivarium.framework.resource import Resource + from vivarium.framework.population import PopulationView from vivarium.types import DataInput DEFAULT_EVENT_PRIORITY = 5 @@ -66,10 +60,6 @@ class Component(ABC): - :attr:`sub_components` - :attr:`configuration_defaults` - - :attr:`columns_created` - - :attr:`columns_required` - - :attr:`initialization_requirements` - - :attr:`population_view_query` - :attr:`post_setup_priority` - :attr:`time_step_prepare_priority` - :attr:`time_step_priority` @@ -82,7 +72,6 @@ class Component(ABC): - :meth:`setup` - :meth:`on_post_setup` - - :meth:`on_initialize_simulants` - :meth:`on_time_step_prepare` - :meth:`on_time_step` - :meth:`on_time_step_cleanup` @@ -108,28 +97,9 @@ def __init__(self) -> None: self._repr: str = "" self._name: str = "" self._sub_components: Sequence["Component"] = [] - self.logger: loguru.Logger | None = None - """A :class:`loguru.Logger` instance for this component. - - The logger is initialized during :meth:`setup_component` and can be used - to log messages specific to this component. The logger name is set to the - component's :attr:`name`. - """ - self.get_value_columns: ( - Callable[ - [str | pd.DataFrame | dict[str, list[ScalarValue] | list[str]]], list[str] - ] - | None - ) = None - self.configuration: LayeredConfigTree | None = None + self._logger: loguru.Logger | None = None + self.configuration: LayeredConfigTree = LayeredConfigTree() self._population_view: PopulationView | None = None - self.lookup_tables: dict[str, LookupTable] = {} - """A dictionary of lookup tables built for this component, keyed by table name. - - Lookup tables are built automatically from the ``data_sources`` block in - :attr:`configuration_defaults` before the component's :meth:`setup` method - is called. Tables can be accessed by name, e.g., ``self.lookup_tables["my_table"]``. - """ def __repr__(self) -> str: """Returns a string representation of the :meth:`__init__` call made to create @@ -146,8 +116,7 @@ def __repr__(self) -> str: Returns ------- - A string representation of the __init__ call made to create this - object. + A string representation of the __init__ call made to create this object. """ if not self._repr: args = ", ".join( @@ -169,7 +138,7 @@ def __str__(self) -> str: @property def name(self) -> str: - """Returns the name of the component. + """The name of the component. By convention, these are in snake case with arguments of the :meth:`__init__` appended and separated by ``.``. @@ -186,11 +155,6 @@ def name(self) -> str: IMPORTANT: this property must not be accessed within the :meth:`__init__` functions of this component or its subclasses or its value may not be initialized correctly. - - Returns - ------- - str - The name of the component. """ if not self._name: base_name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", type(self).__name__) @@ -205,13 +169,24 @@ def name(self) -> str: return self._name @property - def population_view(self) -> PopulationView: - """Provides the :class:`~vivarium.framework.population.PopulationView` for this component. + def logger(self) -> loguru.Logger: + """The logger for this component. - Returns - ------- - PopulationView - The PopulationView for this component + Raises + ------ + LifeCycleError + If the logger has not been initialized. + """ + if self._logger is None: + raise LifeCycleError( + f"Logger for component '{self.name}' has not been initialized. " + "This is likely due to having called this prior to simulation setup." + ) + return self._logger + + @property + def population_view(self) -> PopulationView: + """The :class:`~vivarium.framework.population.PopulationView` for this component. Raises ------ @@ -219,152 +194,74 @@ def population_view(self) -> PopulationView: If the component does not have access to the state table. """ if self._population_view is None: + from vivarium.framework.population.exceptions import PopulationError + raise PopulationError( - f"Component '{self.name}' does not have access to the state " - "table. This is likely due to a failure to set columns_required " - "or columns_created for this component." + f"Component '{self.name}' does not have access to the state table. " + "This is likely due to having called this prior to simulation setup." ) return self._population_view @property - def sub_components(self) -> Sequence["Component"]: - """Provide components managed by this component. + def private_columns(self) -> list[str]: + """The list of private columns created by this component.""" + return self.population_view.private_columns - Returns - ------- - List[Component] - The sub-components that are managed by this component. - """ + @property + def sub_components(self) -> Sequence["Component"]: + """The components managed by this component.""" return self._sub_components @property def configuration_defaults(self) -> dict[str, Any]: - """Provides a dictionary containing the defaults for any configurations - managed by this component. + """The dictionary containing the defaults for any configurations managed + by this component. These default values will be stored at the ``component_configs`` layer of the simulation's :class:`~layered_config_tree.main.LayeredConfigTree`. - - Returns - ------- - A dictionary containing the defaults for any configurations managed by - this component. """ return self.CONFIGURATION_DEFAULTS @property - def columns_created(self) -> list[str]: - """Provides names of columns created by the component. - - Returns - ------- - Names of the columns created by this component, or an empty list if - none. - """ - return [] - - @property - def columns_required(self) -> list[str] | Literal["all"] | None: - """Provides names of columns required by the component. - - Returns - ------- - Names of required columns not created by this component. A string of - ``"all"`` means all available columns are needed. :obj:`None` means no - additional columns are necessary. - """ - return None - - @property - def initialization_requirements( - self, - ) -> list[str | Resource]: - """A list containing the columns, pipelines, and randomness streams - required by this component's simulant initializer.""" - return [] - - @property - def population_view_query(self) -> str: - """Provides a query to use when filtering the component's :class:`~vivarium.framework.population.PopulationView`. - - Returns - ------- - A pandas query string for filtering the component's :class:`~vivarium.framework.population.PopulationView`. - Returns an empty string if no filtering is required. - """ - return "" + def lookup_table_value_columns(self) -> dict[str, str | list[str]]: + """A mapping of lookup table names to their value columns.""" + return {} @property def post_setup_priority(self) -> int: - """Provides the priority of this component's ``post_setup`` listener. - - Returns - ------- - The priority of this component's ``post_setup`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``post_setup`` listener.""" return DEFAULT_EVENT_PRIORITY @property def time_step_prepare_priority(self) -> int: - """Provides the priority of this component's ``time_step__prepare`` listener. - - Returns - ------- - The priority of this component's ``time_step__prepare`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``time_step__prepare`` listener.""" return DEFAULT_EVENT_PRIORITY @property def time_step_priority(self) -> int: - """Provides the priority of this component's ``time_step`` listener. - - Returns - ------- - The priority of this component's ``time_step`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``time_step`` listener.""" return DEFAULT_EVENT_PRIORITY @property def time_step_cleanup_priority(self) -> int: - """Provides the priority of this component's ``time_step__cleanup`` listener. - - Returns - ------- - The priority of this component's ``time_step__cleanup`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``time_step__cleanup`` listener.""" return DEFAULT_EVENT_PRIORITY @property def collect_metrics_priority(self) -> int: - """Provides the priority of this component's ``collect_metrics`` listener. - - Returns - ------- - The priority of this component's ``collect_metrics`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``collect_metrics`` listener.""" return DEFAULT_EVENT_PRIORITY @property def simulation_end_priority(self) -> int: - """Provides the priority of this component's ``simulation_end`` listener. - - Returns - ------- - The priority of this component's ``simulation_end`` listener. This value - can range from 0 to 9, inclusive. - """ + """The priority of this component's ``simulation_end`` listener.""" return DEFAULT_EVENT_PRIORITY ##################### # Lifecycle methods # ##################### - def setup_component(self, builder: "Builder") -> None: + def setup_component(self, builder: Builder) -> None: """Sets up the component for a Vivarium simulation. This method is run by Vivarium during the setup phase. It performs a series @@ -380,14 +277,11 @@ def setup_component(self, builder: "Builder") -> None: builder The builder object used to set up the component. """ - self.logger = builder.logging.get_logger(self.name) - self.get_value_columns = builder.data.value_columns() + self._logger = builder.logging.get_logger(self.name) self.configuration = self.get_configuration(builder) - self.build_all_lookup_tables(builder) self.setup(builder) self._set_population_view(builder) self._register_post_setup_listener(builder) - self._register_simulant_initializer(builder) self._register_time_step_prepare_listener(builder) self._register_time_step_listener(builder) self._register_time_step_cleanup_listener(builder) @@ -430,25 +324,6 @@ def on_post_setup(self, event: Event) -> None: """ pass - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - """ - Method that vivarium will run during simulant initialization. - - This method is intended to be overridden by subclasses if there are - operations they need to perform specifically during the simulant - initialization phase. - - Parameters - ---------- - pop_data : SimulantData - The data associated with the simulants being initialized. - - Returns - ------- - None - """ - pass - def on_time_step_prepare(self, event: Event) -> None: """Method that vivarium will run during the ``time_step__prepare`` event. @@ -542,7 +417,7 @@ def get_initialization_parameters(self) -> dict[str, Any]: if hasattr(self, parameter_name) } - def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None: + def get_configuration(self, builder: Builder) -> LayeredConfigTree: """Retrieves the configuration for this component from the builder. This method retrieves the configuration for this component from the @@ -556,61 +431,61 @@ def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None: Returns ------- - The configuration for this component, or :obj:`None` if the component has - no configuration. + The configuration for this component, or a default empty configuration. """ if self.name in builder.configuration: return builder.configuration.get_tree(self.name) - return None - - def build_all_lookup_tables(self, builder: "Builder") -> None: - """Builds all lookup tables for this component. + return LayeredConfigTree({"data_sources": {}}) - This method builds lookup tables for this component based on the data - sources specified in the configuration. If no data sources are specified, - no lookup tables are built. - - The created lookup tables are stored in the :attr:`lookup_tables` attribute of - the component, with the table name as the key. + @overload + def build_lookup_table( + self, + builder: Builder, + name: str, + data_source: DataInput | None = None, + value_columns: str | None = None, + ) -> LookupTable[pd.Series[Any]]: + ... - Parameters - ---------- - builder - The builder object used to set up the component. - """ - if self.configuration and "data_sources" in self.configuration: - for table_name in self.configuration.data_sources.keys(): - try: - self.lookup_tables[table_name] = self.build_lookup_table( - builder, self.configuration.data_sources[table_name] - ) - except ConfigurationError as e: - raise ConfigurationError( - f"Error building lookup table '{table_name}': {e}" - ) + @overload + def build_lookup_table( + self, + builder: Builder, + name: str, + data_source: DataInput | None = None, + value_columns: list[str] | tuple[str, ...] = ..., + ) -> LookupTable[pd.DataFrame]: + ... def build_lookup_table( self, builder: Builder, - data_source: DataInput, - value_columns: Sequence[str] | None = None, - ) -> LookupTable: - """Builds a :class:`~vivarium.framework.lookup.table.LookupTable` from a data source. - - Uses :meth:`get_data` to parse the data source and retrieve the lookup table - data. The :class:`~vivarium.framework.lookup.table.LookupTable` is built from the - data source, with the value columns specified in the ``value_columns`` parameter. - If ``value_columns`` is :obj:`None` and the data is a :class:`~pandas.DataFrame`, - the :class:`~vivarium.framework.artifact.manager.ArtifactManager` will determine the - value columns. + name: str, + data_source: DataInput | None = None, + value_columns: list[str] | tuple[str, ...] | str | None = None, + ) -> LookupTable[pd.Series[Any]] | LookupTable[pd.DataFrame]: + """Builds a LookupTable. + + If a data_source is not provided, the method will look for a data source + in the component's configuration under the key "data_sources" with the + provided name. + + If value_columns provided is a list or tuple, a LookupTable returning a + DataFrame will be built. If it is a string or None, a LookupTable + returning a Series will be built. If value_columns is None, the name of the + returned Series will be "value". Parameters ---------- builder The builder object used to set up the component. data_source - The data source to build the LookupTable from. + The data source to build the LookupTable from. If None, the data source + will be retrieved from the component's configuration. + name + The name of the lookup table, used to retrieve the data source from + the configuration if data_source is None. value_columns The columns to include in the LookupTable. @@ -623,77 +498,24 @@ def build_lookup_table( layered_config_tree.exceptions.ConfigurationError If the data source is invalid. """ - data = self.get_data(builder, data_source) - # TODO update this to use vivarium.types.LookupTableData once we drop - # support for Python 3.9 - if not isinstance( - data, (Numeric, timedelta, datetime, pd.DataFrame, list, tuple, dict) - ): - raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") - - if isinstance(data, list): - return builder.lookup.build_table( - data, value_columns=list(value_columns) if value_columns else () - ) - if isinstance(data, pd.DataFrame): - duplicated_columns = set(data.columns[data.columns.duplicated()]) - if duplicated_columns: - raise ConfigurationError( - f"Dataframe contains duplicate columns {duplicated_columns}." - ) - value_columns, parameter_columns, key_columns = self._get_columns( - value_columns, data + if data_source is None: + data_source = self.configuration.get(["data_sources", name]) + + if data_source is None: + raise ConfigurationError( + f"No data source provided for lookup table '{name}', " + "and no data source found in configuration." ) + try: + data = self.get_data(builder, data_source) return builder.lookup.build_table( - data=data, - key_columns=key_columns, - parameter_columns=parameter_columns, - value_columns=value_columns, + data=data, name=name, value_columns=value_columns ) + except ConfigurationError as e: + raise ConfigurationError(f"Error building lookup table '{name}': {e}") - return builder.lookup.build_table(data) - - def _get_columns( - self, - value_columns: Sequence[str] | None, - data: pd.DataFrame | dict[str, list[ScalarValue] | list[str]], - ) -> tuple[Sequence[str], list[str], list[str]]: - if isinstance(data, pd.DataFrame): - all_columns = list(data.columns) - else: - all_columns = list(data.keys()) - if value_columns is None: - # NOTE: self.get_value_columns cannot be None at this point of the call stack - value_column_getter = cast( - Callable[ - [str | pd.DataFrame | dict[str, list[ScalarValue] | list[str]]], list[str] - ], - self.get_value_columns, - ) - value_columns = value_column_getter(data) - - potential_parameter_columns = [ - str(col).removesuffix("_start") - for col in all_columns - if str(col).endswith("_start") - ] - parameter_columns = [] - bin_edge_columns = [] - for column in potential_parameter_columns: - if f"{column}_end" in all_columns: - parameter_columns.append(column) - bin_edge_columns += [f"{column}_start", f"{column}_end"] - - key_columns = [ - col - for col in all_columns - if col not in value_columns and col not in bin_edge_columns - ] - - return value_columns, parameter_columns, key_columns - - def get_data(self, builder: Builder, data_source: DataInput) -> Any: + def get_data(self, builder: Builder, data_source: DataInput) -> LookupTableData: """Retrieves data from a data source. If the data source is a float or a DataFrame, it is treated as the data @@ -737,7 +559,7 @@ def get_data(self, builder: Builder, data_source: DataInput) -> Any: raise ConfigurationError( f"There is no method '{method}' for the {module_string}." ) - data = data_source_callable(builder) + data: LookupTableData = data_source_callable(builder) else: try: data = builder.data.load(data_source) @@ -752,50 +574,17 @@ def get_data(self, builder: Builder, data_source: DataInput) -> Any: return data - def _set_population_view(self, builder: "Builder") -> None: - """Creates the PopulationView for this component if it needs access to - the state table. - - The method determines the necessary columns for the PopulationView - based on the columns required and created by this component. If no - columns are required or created, no PopulationView is set. + def _set_population_view(self, builder: Builder) -> None: + """Creates the PopulationView for this component. Parameters ---------- builder The builder object used to set up the component. """ - requires_all_columns = False - if self.columns_required == []: - warnings.warn( - "The empty list [] format for requiring all columns is deprecated. Please " - "use the string 'all' instead.", - DeprecationWarning, - stacklevel=2, - ) - if self.columns_required and self.columns_required != "all": - # Get all columns created and required - population_view_columns = self.columns_created + self.columns_required - elif self.columns_required == "all" or self.columns_required == []: - # Empty list means population view needs all available columns - requires_all_columns = True - if self.columns_created: - population_view_columns = self.columns_created - else: - population_view_columns = [] - elif self.columns_required is None and self.columns_created: - # No additional columns required, so just get columns created - population_view_columns = self.columns_created - else: - # no need for a population view if no columns created or required - population_view_columns = None - - if population_view_columns is not None: - self._population_view = builder.population.get_view( - population_view_columns, self.population_view_query, requires_all_columns - ) + self._population_view = builder.population.get_view(self) - def _register_post_setup_listener(self, builder: "Builder") -> None: + def _register_post_setup_listener(self, builder: Builder) -> None: """Registers a ``post_setup`` listener if this component has defined one. This method allows the component to respond to ``post_setup`` events if it @@ -815,39 +604,7 @@ def _register_post_setup_listener(self, builder: "Builder") -> None: self.post_setup_priority, ) - def _register_simulant_initializer(self, builder: Builder) -> None: - """Registers a simulant initializer if this component has defined one. - - This method allows the component to initialize simulants if it has its - own :meth:`on_initialize_simulants` method. It registers this method with the - builder's :class:`~vivarium.framework.population.PopulationManager`. It also - specifies the columns that the component creates and any additional - requirements for initialization. - - Parameters - ---------- - builder - The builder with which to register the initializer. - """ - if isinstance(self.initialization_requirements, list): - initialization_requirements = { - "required_resources": self.initialization_requirements - } - else: - initialization_requirements = self.initialization_requirements - warnings.warn( - "The dict format for initialization_requirements is deprecated." - " You should use provide a list of the required resources.", - DeprecationWarning, - stacklevel=2, - ) - - if type(self).on_initialize_simulants != Component.on_initialize_simulants: - builder.population.initializes_simulants( - self, creates_columns=self.columns_created, **initialization_requirements # type: ignore[arg-type] - ) - - def _register_time_step_prepare_listener(self, builder: "Builder") -> None: + def _register_time_step_prepare_listener(self, builder: Builder) -> None: """Registers a ``time_step__prepare`` listener if this component has defined one. This method allows the component to respond to ``time_step__prepare`` events @@ -866,7 +623,7 @@ def _register_time_step_prepare_listener(self, builder: "Builder") -> None: self.time_step_prepare_priority, ) - def _register_time_step_listener(self, builder: "Builder") -> None: + def _register_time_step_listener(self, builder: Builder) -> None: """Registers a ``time_step`` listener if this component has defined one. This method allows the component to respond to ``time_step`` events @@ -885,7 +642,7 @@ def _register_time_step_listener(self, builder: "Builder") -> None: self.time_step_priority, ) - def _register_time_step_cleanup_listener(self, builder: "Builder") -> None: + def _register_time_step_cleanup_listener(self, builder: Builder) -> None: """Registers a ``time_step__cleanup`` listener if this component has defined one. This method allows the component to respond to ``time_step__cleanup`` events @@ -904,7 +661,7 @@ def _register_time_step_cleanup_listener(self, builder: "Builder") -> None: self.time_step_cleanup_priority, ) - def _register_collect_metrics_listener(self, builder: "Builder") -> None: + def _register_collect_metrics_listener(self, builder: Builder) -> None: """Registers a ``collect_metrics`` listener if this component has defined one. This method allows the component to respond to ``collect_metrics`` events @@ -923,7 +680,7 @@ def _register_collect_metrics_listener(self, builder: "Builder") -> None: self.collect_metrics_priority, ) - def _register_simulation_end_listener(self, builder: "Builder") -> None: + def _register_simulation_end_listener(self, builder: Builder) -> None: """Registers a ``simulation_end`` listener if this component has defined one. This method allows the component to respond to ``simulation_end`` events diff --git a/src/vivarium/examples/boids/forces.py b/src/vivarium/examples/boids/forces.py index e89674a41..da46525be 100644 --- a/src/vivarium/examples/boids/forces.py +++ b/src/vivarium/examples/boids/forces.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any import numpy as np import pandas as pd @@ -10,20 +9,19 @@ from vivarium.framework.engine import Builder +# docs-start: force_base_class class Force(Component, ABC): ############## # Properties # ############## @property - def configuration_defaults(self) -> dict[str, Any]: + def configuration_defaults(self) -> dict[str, dict[str, float]]: return { self.__class__.__name__.lower(): { "max_force": 0.03, }, } - columns_required = [] - ##################### # Lifecycle methods # ##################### @@ -32,21 +30,23 @@ def setup(self, builder: Builder) -> None: self.config = builder.configuration[self.__class__.__name__.lower()] self.max_speed = builder.configuration.movement.max_speed - self.neighbors = builder.value.get_value("neighbors") - - builder.value.register_value_modifier( + # docs-start: register_acceleration_modifier + builder.value.register_attribute_modifier( "acceleration", modifier=self.apply_force, - required_resources=self.columns_required + [self.neighbors], + required_resources=["x", "y", "vx", "vy", "neighbors"], ) + # docs-end: register_acceleration_modifier ################################## # Pipeline sources and modifiers # ################################## def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.DataFrame: - neighbors = self.neighbors(index) - pop = self.population_view.get(index) + neighbors = self.population_view.get_attributes(index, "neighbors") + pop = self.population_view.get_attributes(index, ["x", "y", "vx", "vy"]) + if not (isinstance(neighbors, pd.Series) and isinstance(pop, pd.DataFrame)): + raise ValueError("Neighbors must be a pd.Series of ints and population a pd.DataFrame") pairs = self._get_pairs(neighbors, pop) raw_force = self.calculate_force(pairs) @@ -57,7 +57,8 @@ def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.Da max_speed=self.max_speed, ) - acceleration.loc[force.index, ["x", "y"]] += force[["x", "y"]] + acceleration.loc[force.index, "acc_x"] += force["x"] + acceleration.loc[force.index, "acc_y"] += force["y"] return acceleration ################## @@ -68,7 +69,7 @@ def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.Da def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame: pass - def _get_pairs(self, neighbors: pd.Series[int], pop: pd.DataFrame) -> pd.DataFrame: + def _get_pairs(self, neighbors: pd.Series[int | float], pop: pd.DataFrame) -> pd.DataFrame: pairs = ( pop.join(neighbors.rename("neighbors")) .reset_index() @@ -114,8 +115,10 @@ def _normalize_and_limit_force( def _magnitude(self, df: pd.DataFrame) -> pd.Series[float]: return pd.Series(np.sqrt(np.square(df.x) + np.square(df.y)), dtype=float) +# docs-end: force_base_class +# docs-start: concrete_force_classes class Separation(Force): """Push boids apart when they get too close.""" @@ -170,3 +173,4 @@ def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame: .sum() .rename(columns=lambda c: c.replace("v", "").replace("_other", "")) ) +# docs-end: concrete_force_classes diff --git a/src/vivarium/examples/boids/movement.py b/src/vivarium/examples/boids/movement.py index 1245dea08..f334b5af7 100644 --- a/src/vivarium/examples/boids/movement.py +++ b/src/vivarium/examples/boids/movement.py @@ -6,7 +6,6 @@ from vivarium import Component from vivarium.framework.engine import Builder from vivarium.framework.population import SimulantData -from vivarium.framework.resource.resource import Resource class Movement(Component): @@ -23,15 +22,6 @@ class Movement(Component): }, } - @property - def columns_created(self) -> list[str]: - return ["x", "y", "vx", "vy"] - - @property - def initialization_requirements(self) -> list[str | Resource]: - return [self.randomness] - - ##################### # Lifecycle methods # ##################### @@ -39,23 +29,32 @@ def initialization_requirements(self) -> list[str | Resource]: def setup(self, builder: Builder) -> None: self.config = builder.configuration - self.acceleration = builder.value.register_value_producer( + # docs-start: register_attribute_producer + builder.value.register_attribute_producer( "acceleration", source=self.base_acceleration ) + # docs-end: register_attribute_producer self.randomness = builder.randomness.get_stream(self.name) + builder.population.register_initializer( + initializer=self.initialize_movement, + columns=["x", "y", "vx", "vy"], + required_resources=[self.randomness] + ) ################################## # Pipeline sources and modifiers # ################################## + # docs-start: base_acceleration def base_acceleration(self, index: pd.Index[int]) -> pd.DataFrame: - return pd.DataFrame(0.0, columns=["x", "y"], index=index) + return pd.DataFrame(0.0, columns=["acc_x", "acc_y"], index=index) + # docs-end: base_acceleration ######################## # Event-driven methods # ######################## - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_movement(self, pop_data: SimulantData) -> None: # Start randomly distributed, with random velocities new_population = pd.DataFrame( { @@ -68,13 +67,15 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: ) self.population_view.update(new_population) + # docs-start: on_time_step def on_time_step(self, event: Event) -> None: - pop = self.population_view.get(event.index) - - acceleration = self.acceleration(event.index) + pop = self.population_view.get_private_columns(event.index) + acceleration = self.population_view.get_attribute_frame(event.index, "acceleration") # Accelerate and limit velocity - pop[["vx", "vy"]] += acceleration.rename(columns=lambda c: f"v{c}") + if not isinstance(acceleration, pd.DataFrame): + raise ValueError("Acceleration must be a pd.DataFrame") + pop[["vx", "vy"]] += acceleration.rename(columns=lambda c: c.replace("acc_", "v")) speed = np.sqrt(np.square(pop.vx) + np.square(pop.vy)) velocity_scaling_factor = np.where( speed > self.config.movement.max_speed, @@ -93,3 +94,4 @@ def on_time_step(self, event: Event) -> None: pop["y"] = pop.y % self.config.field.height self.population_view.update(pop) + # docs-end: on_time_step diff --git a/src/vivarium/examples/boids/neighbors.py b/src/vivarium/examples/boids/neighbors.py index 3cac08ec6..28acf63c8 100644 --- a/src/vivarium/examples/boids/neighbors.py +++ b/src/vivarium/examples/boids/neighbors.py @@ -14,8 +14,6 @@ class Neighbors(Component): ############## configuration_defaults = {"neighbors": {"radius": 60}} - columns_required = ["x", "y"] - ##################### # Lifecycle methods # ##################### @@ -25,15 +23,20 @@ def setup(self, builder: Builder) -> None: self.neighbors_calculated = False self._neighbors = pd.Series() - self.neighbors = builder.value.register_value_producer( - "neighbors", source=self.get_neighbors, required_resources=self.columns_required + builder.value.register_attribute_producer( + "neighbors", source=self.get_neighbors, required_resources=["x", "y"] + ) + builder.population.register_initializer( + initializer=self.initialize_neighbors, + columns=None, + required_resources=[], ) ######################## # Event-driven methods # ######################## - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_neighbors(self, pop_data: SimulantData) -> None: self._neighbors = pd.Series([[]] * len(pop_data.index), index=pop_data.index) def on_time_step(self, event: Event) -> None: @@ -54,7 +57,7 @@ def get_neighbors(self, index: pd.Index[int]) -> pd.Series[list[int]]: # type: def _calculate_neighbors(self) -> None: # Reset our list of neighbors - pop = self.population_view.get(self._neighbors.index) + pop = self.population_view.get_attributes(self._neighbors.index, ["x", "y"]) self._neighbors = pd.Series([[] for _ in range(len(pop))], index=pop.index) tree = spatial.KDTree(pop[["x", "y"]]) diff --git a/src/vivarium/examples/boids/population.py b/src/vivarium/examples/boids/population.py index dcf4f3711..0f860bf00 100644 --- a/src/vivarium/examples/boids/population.py +++ b/src/vivarium/examples/boids/population.py @@ -1,43 +1,45 @@ -import numpy as np +# docs-start: imports import pandas as pd from vivarium import Component from vivarium.framework.engine import Builder from vivarium.framework.population import SimulantData -from vivarium.framework.resource.resource import Resource +# docs-end: imports class Population(Component): ############## # Properties # ############## + + # docs-start: configuration_defaults CONFIGURATION_DEFAULTS = { "population": { "colors": ["red", "blue"], } } - - @property - def columns_created(self) -> list[str]: - return ["color", "entrance_time"] - - @property - def initialization_requirements(self) -> list[str | Resource]: - return [self.randomness] + # docs-end: configuration_defaults ##################### # Lifecycle methods # ##################### + # docs-start: setup def setup(self, builder: Builder) -> None: self.colors = builder.configuration.population.colors self.randomness = builder.randomness.get_stream(self.name) + builder.population.register_initializer( + initializer=self.initialize_population, + columns=["color", "entrance_time"], + required_resources=[self.randomness] + ) + # docs-end: setup ######################## # Event-driven methods # ######################## - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_population(self, pop_data: SimulantData) -> None: new_population = pd.DataFrame( { "color": self.randomness.choice(pop_data.index, self.colors), diff --git a/src/vivarium/examples/boids/visualization.py b/src/vivarium/examples/boids/visualization.py index 2a41a5a4f..5be55e869 100644 --- a/src/vivarium/examples/boids/visualization.py +++ b/src/vivarium/examples/boids/visualization.py @@ -4,6 +4,7 @@ from vivarium import InteractiveContext +# docs-start: plot_boids def plot_boids(simulation: InteractiveContext, plot_velocity: bool=False) -> None: width = simulation.configuration.field.width height = simulation.configuration.field.height @@ -17,8 +18,9 @@ def plot_boids(simulation: InteractiveContext, plot_velocity: bool=False) -> Non plt.ylabel("y") plt.axis((0, width, 0, height)) plt.show() +# docs-end: plot_boids - +# docs-start: plot_boids_animated def plot_boids_animated(simulation: InteractiveContext) -> FuncAnimation: width = simulation.configuration.field.width height = simulation.configuration.field.height @@ -35,9 +37,10 @@ def plot_boids_animated(simulation: InteractiveContext) -> FuncAnimation: frame_pops = [] for _ in frames: simulation.step() - frame_pops.append(simulation.get_population()[["x", "y"]]) + frame_pops.append(simulation.get_population(["x", "y"])) def animate(i: int) -> None: s.set_offsets(frame_pops[i]) return FuncAnimation(fig, animate, frames=frames, interval=10) # type: ignore[arg-type] +# docs-end: plot_boids_animated diff --git a/src/vivarium/examples/disease_model/disease.py b/src/vivarium/examples/disease_model/disease.py index 619c67fe6..f66c5ce26 100644 --- a/src/vivarium/examples/disease_model/disease.py +++ b/src/vivarium/examples/disease_model/disease.py @@ -1,12 +1,15 @@ -# mypy: ignore-errors +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any import pandas as pd from vivarium import Component from vivarium.framework.engine import Builder -from vivarium.framework.state_machine import Machine, State, Transition +from vivarium.framework.state_machine import Machine, State, Transition, Trigger from vivarium.framework.utilities import rate_to_probability from vivarium.framework.values import list_combiner, union_post_processor - +from collections.abc import Iterable class DiseaseTransition(Transition): ##################### @@ -19,75 +22,75 @@ def __init__( cause_key: str, measure: str, rate_name: str, - **kwargs, + triggered: Trigger = Trigger.NOT_TRIGGERED, ): super().__init__( - input_state, output_state, probability_func=self._probability, **kwargs + input_state, output_state, probability_func=self._probability, triggered=triggered ) self.cause_key = cause_key self.measure = measure self.rate_name = rate_name + self.joint_paf_pipeline = f"{self.rate_name}.population_attributable_fraction" + self.base_rate: float | int # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: super().setup(builder) - rate = builder.configuration[self.cause_key][self.measure] - self.base_rate = lambda index: pd.Series(rate, index=index) - self.joint_population_attributable_fraction = builder.value.register_value_producer( - f"{self.rate_name}.population_attributable_fraction", + self.base_rate = builder.configuration[self.cause_key][self.measure] + builder.value.register_attribute_producer( + self.joint_paf_pipeline, source=lambda index: [pd.Series(0.0, index=index)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, ) - self.transition_rate = builder.value.register_rate_producer( + builder.value.register_rate_producer( self.rate_name, source=self._risk_deleted_rate, - required_resources=[self.joint_population_attributable_fraction], + required_resources=[self.joint_paf_pipeline], ) ################################## # Pipeline sources and modifiers # ################################## - def _risk_deleted_rate(self, index: pd.Index) -> pd.Series: - return self.base_rate(index) * ( - 1 - self.joint_population_attributable_fraction(index) - ) + def _risk_deleted_rate(self, index: pd.Index[int]) -> pd.Series[float]: + joint_paf = self.population_view.get_attributes(index, self.joint_paf_pipeline) + return self.base_rate * (1 - joint_paf) ################## # Helper methods # ################## - def _probability(self, index: pd.Index) -> pd.Series: - effective_rate = self.transition_rate(index) + def _probability(self, index: pd.Index[int]) -> pd.Series[float]: + effective_rate = self.population_view.get_attributes(index, self.rate_name) return pd.Series(rate_to_probability(effective_rate)) class DiseaseState(State): + ############## # Properties # ############## @property - def columns_required(self) -> list[str] | None: - return [self.model, "alive"] - - @property - def population_view_query(self) -> str | None: - return f"alive == 'alive' and {self.model} == '{self.state_id}'" + def configuration_defaults(self) -> dict[str, Any]: + configuration_defaults = super().configuration_defaults + configuration_defaults[self.name]["data_sources"]["excess_mortality_rate"] = 0.0 + return configuration_defaults ##################### # Lifecycle methods # ##################### - def __init__(self, state_id: str, cause_key: str, with_excess_mortality: bool = False): + def __init__(self, state_id: str, cause_key: str): super().__init__(state_id) self._cause_key = cause_key - self._with_excess_mortality = with_excess_mortality + self.emr_paf_pipeline = f"{self.state_id}.excess_mortality_rate.population_attributable_fraction" + self.emr_pipeline = f"{self.state_id}.excess_mortality_rate" # noinspection PyAttributeOutsideInit - def setup(self, builder: Builder): + def setup(self, builder: Builder) -> None: """Performs this component's simulation setup. Parameters @@ -96,31 +99,26 @@ def setup(self, builder: Builder): Interface to several simulation tools. """ super().setup(builder) - if self._with_excess_mortality: - self._excess_mortality_rate = builder.configuration[ - self._cause_key - ].excess_mortality_rate - else: - self._excess_mortality_rate = 0 self.clock = builder.time.clock() - self.excess_mortality_rate_paf = builder.value.register_value_producer( - f"{self.state_id}.excess_mortality_rate.population_attributable_fraction", + self.emr_table = self.build_lookup_table(builder, "excess_mortality_rate") + builder.value.register_attribute_producer( + self.emr_paf_pipeline, source=lambda index: [pd.Series(0.0, index=index)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, ) - self.excess_mortality_rate = builder.value.register_rate_producer( - f"{self.state_id}.excess_mortality_rate", + builder.value.register_rate_producer( + self.emr_pipeline, source=self.risk_deleted_excess_mortality_rate, - required_resources=[self.excess_mortality_rate_paf], + required_resources=[self.emr_paf_pipeline], ) - builder.value.register_value_modifier( + builder.value.register_attribute_modifier( "mortality_rate", - self.add_in_excess_mortality, - required_resources=[self.excess_mortality_rate] + modifier=self.add_in_excess_mortality, + required_resources=[self.emr_pipeline] ) ################## @@ -128,25 +126,27 @@ def setup(self, builder: Builder): ################## def add_disease_transition( - self, output: "DiseaseState", measure: str, rate_name: str, **kwargs + self, output: "DiseaseState", measure: str, rate_name: str, triggered: Trigger = Trigger.NOT_TRIGGERED ) -> DiseaseTransition: - t = DiseaseTransition(self, output, self._cause_key, measure, rate_name, **kwargs) + t = DiseaseTransition(self, output, self._cause_key, measure, rate_name, triggered=triggered) self.add_transition(t) return t ################################## # Pipeline sources and modifiers # ################################## - - def risk_deleted_excess_mortality_rate(self, index: pd.Index) -> pd.Series: - return pd.Series(self._excess_mortality_rate, index=index) * ( - 1 - self.excess_mortality_rate_paf(index) - ) + + def risk_deleted_excess_mortality_rate(self, index: pd.Index[int]) -> pd.Series[float]: + base_emr = self.emr_table(index) + emr_paf = self.population_view.get_attributes(index, self.emr_paf_pipeline) + return base_emr * (1 - emr_paf) def add_in_excess_mortality( - self, index: pd.Index, mortality_rates: pd.Series - ) -> pd.Series: - mortality_rates.loc[index] += self.excess_mortality_rate(index) + self, index: pd.Index[int], mortality_rates: pd.Series[float] + ) -> pd.Series[float]: + mortality_rates.loc[index] += self.population_view.get_attributes( + index, self.emr_pipeline + ) return mortality_rates @@ -156,6 +156,15 @@ class DiseaseModel(Machine): # Lifecycle methods # ##################### + def __init__( + self, + state_column: str, + states: Iterable[State] = (), + initial_state: State | None = None, + ) -> None: + super().__init__(state_column, states, initial_state) + self.csmr_pipeline = f"{self.state_column}.cause_specific_mortality_rate" + # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: super().setup(builder) @@ -166,22 +175,23 @@ def setup(self, builder: Builder) -> None: ) cause_specific_mortality_rate = config.incidence_rate * case_fatality_rate - self.cause_specific_mortality_rate = builder.value.register_rate_producer( - f"{self.state_column}.cause_specific_mortality_rate", + builder.value.register_rate_producer( + self.csmr_pipeline, source=lambda index: pd.Series(cause_specific_mortality_rate, index=index), ) - builder.value.register_value_modifier( + builder.value.register_attribute_modifier( "mortality_rate", modifier=self.delete_cause_specific_mortality, - required_resources=[self.cause_specific_mortality_rate], + required_resources=[self.csmr_pipeline], ) ################################## # Pipeline sources and modifiers # ################################## - def delete_cause_specific_mortality(self, index: pd.Index, rates: pd.Series) -> pd.Series: - return rates - self.cause_specific_mortality_rate(index) + def delete_cause_specific_mortality(self, index: pd.Index[int], rates: pd.Series[float]) -> pd.Series[float]: + csmr = self.population_view.get_attributes(index, self.csmr_pipeline) + return rates - csmr class SISDiseaseModel(Component): @@ -201,17 +211,13 @@ def __init__(self, disease_name: str): } susceptible_state = DiseaseState(f"susceptible_to_{self._name}", self._name) - infected_state = DiseaseState( - f"infected_with_{self._name}", self._name, with_excess_mortality=True - ) + infected_state = DiseaseState(f"infected_with_{self._name}", self._name) - susceptible_state.allow_self_transitions() susceptible_state.add_disease_transition( infected_state, measure="incidence_rate", rate_name=f"{infected_state.state_id}.incidence_rate", ) - infected_state.allow_self_transitions() infected_state.add_disease_transition( susceptible_state, measure="remission_rate", diff --git a/src/vivarium/examples/disease_model/disease_model.yaml b/src/vivarium/examples/disease_model/disease_model.yaml index c2eb4e0bf..4be929949 100644 --- a/src/vivarium/examples/disease_model/disease_model.yaml +++ b/src/vivarium/examples/disease_model/disease_model.yaml @@ -2,8 +2,6 @@ components: vivarium.examples.disease_model: population: - BasePopulation() - mortality: - - Mortality() disease: - SISDiseaseModel('lower_respiratory_infections') risk: diff --git a/src/vivarium/examples/disease_model/intervention.py b/src/vivarium/examples/disease_model/intervention.py index 0395c3239..ed90dc6e0 100644 --- a/src/vivarium/examples/disease_model/intervention.py +++ b/src/vivarium/examples/disease_model/intervention.py @@ -1,4 +1,5 @@ -# mypy: ignore-errors +from __future__ import annotations + from typing import Any import pandas as pd @@ -30,23 +31,25 @@ def __init__(self, intervention: str, affected_value: str): super().__init__() self.intervention = intervention self.affected_value = affected_value + self.effect_size_pipeline = f"{self.intervention}.effect_size" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: effect_size = builder.configuration[self.intervention].effect_size - self.effect_size = builder.value.register_value_producer( - f"{self.intervention}.effect_size", + builder.value.register_attribute_producer( + self.effect_size_pipeline, source=lambda index: pd.Series(effect_size, index=index), ) - builder.value.register_value_modifier( + builder.value.register_attribute_modifier( self.affected_value, modifier=self.intervention_effect, - required_resources=[self.effect_size], + required_resources=[self.effect_size_pipeline], ) ################################## # Pipeline sources and modifiers # ################################## - def intervention_effect(self, index: pd.Index, value: pd.Series) -> pd.Series: - return value * (1 - self.effect_size(index)) + def intervention_effect(self, index: pd.Index[int], value: pd.Series[float]) -> pd.Series[float]: + effect_size = self.population_view.get_attributes(index, self.effect_size_pipeline) + return value * (1 - effect_size) diff --git a/src/vivarium/examples/disease_model/mortality.py b/src/vivarium/examples/disease_model/mortality.py index 806d09dd7..933809533 100644 --- a/src/vivarium/examples/disease_model/mortality.py +++ b/src/vivarium/examples/disease_model/mortality.py @@ -1,4 +1,5 @@ -# mypy: ignore-errors +from __future__ import annotations + from typing import Any import numpy as np @@ -7,13 +8,16 @@ from vivarium import Component from vivarium.framework.engine import Builder from vivarium.framework.event import Event +from vivarium.framework.population import SimulantData class Mortality(Component): + ############## # Properties # ############## + # docs-start: configuration_defaults @property def configuration_defaults(self) -> dict[str, Any]: """A set of default configuration values for this component. @@ -21,21 +25,15 @@ def configuration_defaults(self) -> dict[str, Any]: These can be overwritten in the simulation model specification or by providing override values when constructing an interactive simulation. """ - return { - "mortality": { - "mortality_rate": 0.01, - } - } - - @property - def columns_required(self) -> list[str] | None: - return ["tracked", "alive"] + return {self.name: {"data_sources": {"mortality_rate": 0.01}}} + # docs-end: configuration_defaults ##################### # Lifecycle methods # ##################### # noinspection PyAttributeOutsideInit + # docs-start: setup def setup(self, builder: Builder) -> None: """Performs this component's simulation setup. @@ -48,40 +46,43 @@ def setup(self, builder: Builder) -> None: builder Access to simulation tools and subsystems. """ - self.config = builder.configuration.mortality self.randomness = builder.randomness.get_stream("mortality") - - self.mortality_rate = builder.value.register_rate_producer( - "mortality_rate", source=self.base_mortality_rate + builder.value.register_rate_producer( + "mortality_rate", source=self.build_lookup_table(builder, "mortality_rate") + ) + builder.population.register_initializer( + initializer=self.initialize_is_alive, + columns="is_alive", + required_resources=[] ) + # docs-end: setup ######################## # Event-driven methods # ######################## - def on_time_step(self, event: Event) -> None: - """Determines who dies each time step. + # docs-start: initialize_is_alive + def initialize_is_alive(self, pop_data: SimulantData) -> None: + """Called by the simulation whenever new simulants are added. + + This component is responsible for creating and filling the 'is_alive' column + in the population state table. Parameters ---------- - event - An event object emitted by the simulation containing an index - representing the simulants affected by the event and timing - information. + pop_data + A record containing the index of the new simulants, the + start of the time step the simulants are added on, the width + of the time step, and the age boundaries for the simulants to + generate. """ - effective_rate = self.mortality_rate(event.index) - effective_probability = 1 - np.exp(-effective_rate) - draw = self.randomness.get_draw(event.index) - affected_simulants = draw < effective_probability - self.population_view.subview("alive").update( - pd.Series("dead", index=event.index[affected_simulants]) - ) + self.population_view.update(pd.Series(True, index=pop_data.index, name="is_alive")) - def on_time_step_prepare(self, event: Event) -> None: - """Untrack any simulants who died during the previous time step. + # docs-end: initialize_is_alive - We do this after the previous time step because the mortality - observer needs to collect observations before updating. + # docs-start: on_time_step + def on_time_step(self, event: Event) -> None: + """Determines who dies each time step. Parameters ---------- @@ -90,27 +91,11 @@ def on_time_step_prepare(self, event: Event) -> None: representing the simulants affected by the event and timing information. """ - population = self.population_view.get(event.index) - population.loc[ - (population["alive"] == "dead") & population["tracked"] == True, "tracked" - ] = False - self.population_view.update(population) - - ################################## - # Pipeline sources and modifiers # - ################################## - - def base_mortality_rate(self, index: pd.Index) -> pd.Series: - """Computes the base mortality rate for every individual. - - Parameters - ---------- - index - A representation of the simulants to compute the base mortality - rate for. - - Returns - ------- - The base mortality rate for all simulants in the index. - """ - return pd.Series(self.config.mortality_rate, index=index) + effective_rate = self.population_view.get_attributes(event.index, "mortality_rate") + effective_probability = 1 - np.exp(-effective_rate) + draw = self.randomness.get_draw(event.index) + affected_simulants = draw < effective_probability + self.population_view.update( + pd.Series(False, index=event.index[affected_simulants], name="is_alive") + ) + # docs-end: on_time_step diff --git a/src/vivarium/examples/disease_model/observer.py b/src/vivarium/examples/disease_model/observer.py index a6d47bf36..75328b883 100644 --- a/src/vivarium/examples/disease_model/observer.py +++ b/src/vivarium/examples/disease_model/observer.py @@ -1,37 +1,47 @@ -# mypy: ignore-errors from typing import Any -from layered_config_tree.main import LayeredConfigTree import pandas as pd from vivarium.framework.engine import Builder from vivarium.framework.results import Observer - +from vivarium.framework.population import SimulantData +from vivarium.framework.event import Event class DeathsObserver(Observer): """Observes the number of deaths.""" - ############## - # Properties # - ############## - - @property - def columns_required(self) -> list[str] | None: - return ["alive"] - ################# # Setup methods # ################# + def setup(self, builder: Builder) -> None: + builder.population.register_initializer( + initializer=self.initialize_previous_alive, + columns="previous_alive", + required_resources=["is_alive"] + ) + def register_observations(self, builder: Builder) -> None: - """We define a newly-dead simulant as one who is 'dead' but who has not - yet become untracked.""" builder.results.register_adding_observation( name="dead", - requires_columns=["alive"], - pop_filter='tracked == True and alive == "dead"', + requires_attributes=["is_alive", "previous_alive"], + pop_filter='previous_alive == True and is_alive == False', ) + ######################## + # Event-driven methods # + ######################## + + def initialize_previous_alive(self, pop_data: SimulantData) -> None: + """Initialize simulants as alive""" + self.population_view.update(pd.Series(True, index=pop_data.index, name="previous_alive")) + + def on_time_step_prepare(self, event: Event) -> None: + """Update the previous deaths column to the current deaths.""" + previous_alive = self.population_view.get_attributes(event.index, "is_alive") + previous_alive.name = "previous_alive" + self.population_view.update(previous_alive) + class YllsObserver(Observer): """Observes the years of lives lost.""" @@ -40,42 +50,29 @@ class YllsObserver(Observer): # Properties # ############## - @property - def columns_required(self) -> list[str] | None: - return ["age", "alive"] - @property def configuration_defaults(self) -> dict[str, Any]: - return { - "mortality": { - "life_expectancy": 80, - } - } - + config = super().configuration_defaults + config["mortality"] = {"life_expectancy": 80.0} + return config ##################### # Lifecycle methods # ##################### - # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: - self.life_expectancy = builder.configuration.mortality.life_expectancy + self.life_expectancy = float(builder.configuration.mortality.life_expectancy) ################# # Setup methods # ################# - def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None: - # Use component configuration - if self.name in builder.configuration: - return builder.configuration.get_tree(self.name) - return None - def register_observations(self, builder: Builder) -> None: builder.results.register_adding_observation( name="ylls", - requires_columns=["age", "alive"], + requires_attributes=["age", "is_alive", "previous_alive"], + pop_filter='previous_alive == True and is_alive == False', aggregator=self.calculate_ylls, ) def calculate_ylls(self, df: pd.DataFrame) -> float: - return (self.life_expectancy - df.loc[df["alive"] == "dead", "age"]).sum() + return float((self.life_expectancy - df["age"]).sum()) \ No newline at end of file diff --git a/src/vivarium/examples/disease_model/population.py b/src/vivarium/examples/disease_model/population.py index 9edc42221..a2bc1eac3 100644 --- a/src/vivarium/examples/disease_model/population.py +++ b/src/vivarium/examples/disease_model/population.py @@ -1,4 +1,4 @@ -# mypy: ignore-errors +# docs-start: imports from typing import Any import pandas as pd @@ -7,6 +7,8 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import SimulantData +from vivarium.examples.disease_model import Mortality +# docs-end: imports class BasePopulation(Component): @@ -16,6 +18,7 @@ class BasePopulation(Component): # Properties # ############## + # docs-start: configuration_defaults @property def configuration_defaults(self) -> dict[str, Any]: """A set of default configuration values for this component. @@ -31,16 +34,20 @@ def configuration_defaults(self) -> dict[str, Any]: # Note: There is also a 'population_size' key. }, } - + # docs-end: configuration_defaults + + # docs-start: sub_components @property - def columns_created(self) -> list[str]: - return ["age", "sex", "alive", "entrance_time"] + def sub_components(self) -> list[Component]: + return [Mortality()] + # docs-end: sub_components ##################### # Lifecycle methods # ##################### # noinspection PyAttributeOutsideInit + # docs-start: setup def setup(self, builder: Builder) -> None: """Performs this component's simulation setup. @@ -55,6 +62,7 @@ def setup(self, builder: Builder) -> None: """ self.config = builder.configuration + # docs-start: crn self.with_common_random_numbers = bool(self.config.randomness.key_columns) self.register = builder.randomness.register_simulants if ( @@ -65,73 +73,73 @@ def setup(self, builder: Builder) -> None: "If running with CRN, you must specify ['entrance_time', 'age'] as" "the randomness key columns." ) + # docs-end: crn + # docs-start: randomness self.age_randomness = builder.randomness.get_stream( - "age_initialization", initializes_crn_attributes=self.with_common_random_numbers + "age_initialization", initializes_crn_attributes=self.with_common_random_numbers, ) self.sex_randomness = builder.randomness.get_stream("sex_initialization") + # docs-end: randomness + + # docs-start: initializers + builder.population.register_initializer( + initializer=self.initialize_entrance_time_and_age, + columns=["entrance_time", "age"], + required_resources=[self.age_randomness] + ) + builder.population.register_initializer( + initializer=self.initialize_sex, + columns="sex", + required_resources=[self.sex_randomness] + ) + # docs-end: initializers + # docs-end: setup ######################## # Event-driven methods # ######################## - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - """Called by the simulation whenever new simulants are added. - - This component is responsible for creating and filling four columns - in the population state table: - - 'age' - The age of the simulant in fractional years. - 'sex' - The sex of the simulant. One of {'Male', 'Female'} - 'alive' - Whether or not the simulant is alive. One of {'alive', 'dead'} - 'entrance_time' - The time that the simulant entered the simulation. The 'birthday' - for simulants that enter as newborns. A `pandas.Timestamp`. - - Parameters - ---------- - pop_data - A record containing the index of the new simulants, the - start of the time step the simulants are added on, the width - of the time step, and the age boundaries for the simulants to - generate. - """ - + # docs-start: initialize_entrance_time_and_age + def initialize_entrance_time_and_age(self, pop_data: SimulantData) -> None: + # docs-start: ages age_start = pop_data.user_data.get("age_start", self.config.population.age_start) age_end = pop_data.user_data.get("age_end", self.config.population.age_end) if age_start == age_end: - age_window = pop_data.creation_window / pd.Timedelta(days=365) + age_window = pop_data.creation_window / pd.Timedelta(days=365) # type: ignore[operator] else: age_window = age_end - age_start age_draw = self.age_randomness.get_draw(pop_data.index) age = age_start + age_draw * age_window + # docs-end: ages + # docs-start: population_dataframe + population = pd.DataFrame( + { + "entrance_time": pop_data.creation_time, + "age": age.values, + }, + index=pop_data.index, + ) + # docs-end: population_dataframe + # docs-start: crn_registration if self.with_common_random_numbers: - population = pd.DataFrame( - {"entrance_time": pop_data.creation_time, "age": age.values}, - index=pop_data.index, - ) self.register(population) - population["sex"] = self.sex_randomness.choice(pop_data.index, ["Male", "Female"]) - population["alive"] = "alive" - else: - population = pd.DataFrame( - { - "age": age.values, - "sex": self.sex_randomness.choice(pop_data.index, ["Male", "Female"]), - "alive": pd.Series("alive", index=pop_data.index), - "entrance_time": pop_data.creation_time, - }, - index=pop_data.index, - ) + # docs-end: crn_registration + # docs-start: update_entrance_time_and_age self.population_view.update(population) + # docs-end: update_entrance_time_and_age + # docs-end: initialize_entrance_time_and_age + + # docs-start: initialize_sex + def initialize_sex(self, pop_data: SimulantData) -> None: + self.population_view.update(pd.Series(self.sex_randomness.choice(pop_data.index, ["Male", "Female"]), name="sex")) + # docs-end: initialize_sex + # docs-start: on_time_step def on_time_step(self, event: Event) -> None: """Updates simulant age on every time step. @@ -142,6 +150,7 @@ def on_time_step(self, event: Event) -> None: representing the simulants affected by the event and timing information. """ - population = self.population_view.get(event.index, query="alive == 'alive'") - population["age"] += event.step_size / pd.Timedelta(days=365) + population = self.population_view.get_private_columns(event.index, "age", query="is_alive == True") + population += event.step_size / pd.Timedelta(days=365) # type: ignore[operator] self.population_view.update(population) + # docs-end: on_time_step diff --git a/src/vivarium/examples/disease_model/risk.py b/src/vivarium/examples/disease_model/risk.py index 931c33164..098d87228 100644 --- a/src/vivarium/examples/disease_model/risk.py +++ b/src/vivarium/examples/disease_model/risk.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors from __future__ import annotations from typing import TYPE_CHECKING, Any @@ -9,7 +8,7 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.resource import Resource + from vivarium.framework.population import SimulantData class Risk(Component): @@ -27,14 +26,6 @@ class Risk(Component): def configuration_defaults(self) -> dict[str, Any]: return {self.risk: self.CONFIGURATION_DEFAULTS["risk"]} - @property - def columns_created(self) -> list[str]: - return [self.propensity_column] - - @property - def initialization_requirements(self) -> list[str | Resource]: - return [self.randomness] - ##################### # Lifecycle methods # ##################### @@ -43,30 +34,38 @@ def __init__(self, risk: str): super().__init__() self.risk = risk self.propensity_column = f"{risk}_propensity" + self.base_proportion_exposed_pipeline = f"{risk}.base_proportion_exposed" + self.exposure_threshold_pipeline = f"{self.risk}.proportion_exposed" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: proportion_exposed = builder.configuration[self.risk].proportion_exposed - self.base_exposure_threshold = builder.value.register_value_producer( - f"{self.risk}.base_proportion_exposed", + builder.value.register_attribute_producer( + self.base_proportion_exposed_pipeline, source=lambda index: pd.Series(proportion_exposed, index=index), ) - self.exposure_threshold = builder.value.register_value_producer( - f"{self.risk}.proportion_exposed", source=self.base_exposure_threshold + builder.value.register_attribute_producer( + self.exposure_threshold_pipeline, + source=[self.base_proportion_exposed_pipeline], ) - self.exposure = builder.value.register_value_producer( + builder.value.register_attribute_producer( f"{self.risk}.exposure", source=self._exposure, - required_resources=[self.propensity_column, self.exposure_threshold], + required_resources=[self.propensity_column, self.exposure_threshold_pipeline], ) self.randomness = builder.randomness.get_stream(self.risk) + builder.population.register_initializer( + initializer=self.initialize_propensity, + columns=self.propensity_column, + required_resources=[self.randomness] + ) ######################## # Event-driven methods # ######################## - def on_initialize_simulants(self, pop_data): + def initialize_propensity(self, pop_data: SimulantData) -> None: draw = self.randomness.get_draw(pop_data.index) self.population_view.update(pd.Series(draw, name=self.propensity_column)) @@ -74,9 +73,10 @@ def on_initialize_simulants(self, pop_data): # Pipeline sources and modifiers # ################################## - def _exposure(self, index): - propensity = self.population_view.get(index)[self.propensity_column] - return self.exposure_threshold(index) > propensity + def _exposure(self, index: pd.Index[int]) -> pd.Series[bool]: + propensity = self.population_view.get_attributes(index, self.propensity_column) + exposure_threshold = self.population_view.get_attributes(index, self.exposure_threshold_pipeline) + return exposure_threshold > propensity class RiskEffect(Component): @@ -103,42 +103,43 @@ def __init__(self, risk_name: str, disease_rate: str): self.risk_name = risk_name self.disease_rate = disease_rate self.risk = f"effect_of_{risk_name}_on_{disease_rate}" + self.base_exposure_pipeline = f"{self.risk_name}.base_proportion_exposed" + self.exposure_pipeline = f"{self.risk_name}.exposure" + self.relative_risk_pipeline = f"{self.risk}.relative_risk" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: - self.base_risk_exposure = builder.value.get_value( - f"{self.risk_name}.base_proportion_exposed" - ) - self.actual_risk_exposure = builder.value.get_value(f"{self.risk_name}.exposure") - relative_risk = builder.configuration[self.risk].relative_risk - self.relative_risk = builder.value.register_value_producer( - f"{self.risk}.relative_risk", + builder.value.register_attribute_producer( + self.relative_risk_pipeline, source=lambda index: pd.Series(relative_risk, index=index), ) - builder.value.register_value_modifier( + builder.value.register_attribute_modifier( f"{self.disease_rate}.population_attributable_fraction", - self.population_attributable_fraction, - required_resources=[self.base_risk_exposure, self.relative_risk], + modifier=self.population_attributable_fraction, + required_resources=[self.base_exposure_pipeline, self.relative_risk_pipeline], ) - builder.value.register_value_modifier( + builder.value.register_attribute_modifier( f"{self.disease_rate}", - self.rate_adjustment, - required_resources=[self.actual_risk_exposure, self.relative_risk], + modifier=self.rate_adjustment, + required_resources=[self.exposure_pipeline, self.relative_risk_pipeline], ) ################################## # Pipeline sources and modifiers # ################################## - def population_attributable_fraction(self, index): - exposure = self.base_risk_exposure(index) - relative_risk = self.relative_risk(index) + def population_attributable_fraction(self, index: pd.Index[int]) -> pd.Series[float]: + pop = self.population_view.get_attributes( + index, [self.base_exposure_pipeline, self.relative_risk_pipeline] + ) + exposure = pop[self.base_exposure_pipeline] + relative_risk = pop[self.relative_risk_pipeline] return exposure * (relative_risk - 1) / (exposure * (relative_risk - 1) + 1) - def rate_adjustment(self, index, rates): - exposed = self.actual_risk_exposure(index) - rr = self.relative_risk(index) + def rate_adjustment(self, index: pd.Index[int], rates: pd.Series[float]) -> pd.Series[float]: + exposed = self.population_view.get_attributes(index, self.exposure_pipeline) + rr = self.population_view.get_attributes(index, self.relative_risk_pipeline) rates[exposed] *= rr[exposed] return rates diff --git a/src/vivarium/framework/artifact/interface.py b/src/vivarium/framework/artifact/interface.py index e5d46741f..bc0f4eac0 100644 --- a/src/vivarium/framework/artifact/interface.py +++ b/src/vivarium/framework/artifact/interface.py @@ -10,17 +10,12 @@ from __future__ import annotations -from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any - -import pandas as pd +from collections.abc import Sequence +from typing import Any from vivarium.framework.artifact.manager import ArtifactManager from vivarium.manager import Interface -if TYPE_CHECKING: - from vivarium.types import ScalarValue - class ArtifactInterface(Interface): """The builder interface for accessing a data artifact.""" @@ -60,20 +55,5 @@ def load(self, entity_key: str, **column_filters: int | str | Sequence[int | str """ return self._manager.load(entity_key, **column_filters) - def value_columns( - self, - ) -> Callable[[str | pd.DataFrame | dict[str, list[ScalarValue] | list[str]]], list[str]]: - """Returns a function that returns the value columns for the given input. - - The function can be called with either a string or a pandas DataFrame. - If a string is provided, it is interpreted as an artifact key, and the - value columns for the data stored at that key are returned. - - Returns - ------- - A function that returns the value columns for the given input. - """ - return self._manager.value_columns() - def __repr__(self) -> str: return "ArtifactManagerInterface()" diff --git a/src/vivarium/framework/artifact/manager.py b/src/vivarium/framework/artifact/manager.py index 78554a6e9..1fe48aa70 100644 --- a/src/vivarium/framework/artifact/manager.py +++ b/src/vivarium/framework/artifact/manager.py @@ -1,7 +1,7 @@ """ -==================== -The Artifact Manager -==================== +================ +Artifact Manager +================ This module contains the :class:`ArtifactManager`, a ``vivarium`` plugin for handling complex data bound up in a data artifact. @@ -21,6 +21,7 @@ from vivarium.framework.artifact.artifact import Artifact from vivarium.framework.lifecycle import lifecycle_states from vivarium.manager import Interface, Manager +from vivarium.types import LookupTableData if TYPE_CHECKING: from vivarium.framework.engine import Builder @@ -112,23 +113,6 @@ def load(self, entity_key: str, **column_filters: int | str | Sequence[int | str return data - def value_columns( - self, - ) -> Callable[[str | pd.DataFrame | dict[str, list[ScalarValue] | list[str]]], list[str]]: - """Returns a function that returns the value columns for the given input. - - The function can be called with either a string or a pandas DataFrame. - If a string is provided, it is interpreted as an artifact key, and the - value columns for the data stored at that key are returned. - - Currently, the returned function will always return ["value"]. - - Returns - ------- - A function that returns the value columns for the given input. - """ - return lambda _: [self._default_value_column] - def __repr__(self) -> str: return "ArtifactManager()" diff --git a/src/vivarium/framework/components/interface.py b/src/vivarium/framework/components/interface.py index eec7bcbab..621a4b3e1 100644 --- a/src/vivarium/framework/components/interface.py +++ b/src/vivarium/framework/components/interface.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Union from vivarium import Component -from vivarium.framework.components.manager import ComponentManager +from vivarium.framework.components.manager import C, ComponentManager from vivarium.manager import Interface, Manager if TYPE_CHECKING: @@ -21,34 +21,27 @@ class ComponentInterface(Interface): - """The builder interface for the component manager system. - - This class defines component manager methods a ``vivarium`` component can - access from the builder. It provides methods for querying and adding components - to the :class:`ComponentManager `. - """ + """The builder interface for the component manager system.""" def __init__(self, manager: ComponentManager): self._manager = manager def get_component(self, name: str) -> Component | Manager: - """Get the component that has ``name`` if presently held by the component - manager. Names are guaranteed to be unique. + """Get the component or manager that has ``name`` if presently held by the + component manager. Names are guaranteed to be unique. Parameters ---------- name - A component name. + A component or manager name. Returns ------- - A component that has name ``name``. + A component or manager that has name ``name``. """ return self._manager.get_component(name) - def get_components_by_type( - self, component_type: type[Component | Manager] | Sequence[type[Component | Manager]] - ) -> list[Component | Manager]: + def get_components_by_type(self, component_type: type[C] | Sequence[type[C]]) -> list[C]: """Get all components that are an instance of ``component_type``. Parameters @@ -64,10 +57,44 @@ def get_components_by_type( return self._manager.get_components_by_type(component_type) def list_components(self) -> dict[str, Component | Manager]: - """Get a mapping of component names to components held by the manager. + """Get a mapping of names to components or managers held by the manager. Returns ------- - A dictionary mapping component names to components. + A dictionary mapping names to components or managers. """ return self._manager.list_components() + + def get_current_component(self) -> Component: + """Get the component currently being set up, if any. + + This method is primarily used internally by the framework to support + automatic component injection in interface methods. + + Returns + ------- + The component currently being set up. + + Raises + ------ + LifeCycleError + If there is no component currently being set up. + """ + return self._manager.get_current_component() + + def get_current_component_or_manager(self) -> Component | Manager: + """Get the component or manager currently being set up, if any. + + This method exists to allow for cases where a manager is needed during + setup, such as if a manager creates a Value Pipeline. + + Returns + ------- + The component or manager currently being set up. + + Raises + ------ + LifeCycleError + If there is no component or manager currently being set up. + """ + return self._manager.get_current_component_or_manager() diff --git a/src/vivarium/framework/components/manager.py b/src/vivarium/framework/components/manager.py index 2cb4edfec..7cf0696fa 100644 --- a/src/vivarium/framework/components/manager.py +++ b/src/vivarium/framework/components/manager.py @@ -1,7 +1,7 @@ """ -============================ -The Component Manager System -============================ +================= +Component Manager +================= The :mod:`vivarium` component manager system is responsible for maintaining a reference to all of the managers and components in a simulation, providing an @@ -19,7 +19,7 @@ import inspect from collections.abc import Iterator, Sequence -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Generic, TypeVar from layered_config_tree import ( ConfigurationError, @@ -29,13 +29,14 @@ from vivarium import Component from vivarium.exceptions import VivariumError -from vivarium.framework.lifecycle import LifeCycleManager, lifecycle_states +from vivarium.framework.lifecycle import LifeCycleError, LifeCycleManager, lifecycle_states from vivarium.manager import Manager if TYPE_CHECKING: from vivarium.framework.engine import Builder - _ComponentsType = Sequence[Union[Component, Manager, "_ComponentsType"]] +C = TypeVar("C", bound=Component) +T = TypeVar("T", Component, Manager) class ComponentConfigError(VivariumError): @@ -44,7 +45,7 @@ class ComponentConfigError(VivariumError): pass -class OrderedComponentSet: +class OrderedComponentSet(Generic[T]): """A container for Vivarium components. It preserves ordering, enforces uniqueness by name, and provides a @@ -52,35 +53,34 @@ class OrderedComponentSet: """ - def __init__(self, *args: Component | Manager): - self.components: list[Component | Manager] = [] + def __init__(self, *args: T) -> None: + self.components: list[T] = [] if args: self.update(args) - def add(self, component: Component | Manager) -> None: + def add(self, component: T) -> None: if component in self: raise ComponentConfigError( f"Attempting to add a component with duplicate name: {component}" ) self.components.append(component) - def update( - self, - components: Sequence[Component | Manager], - ) -> None: + def update(self, components: Sequence[T]) -> None: for c in components: self.add(c) - def pop(self) -> Component | Manager: + def pop(self) -> T: component = self.components.pop(0) return component - def __contains__(self, component: Component | Manager) -> bool: + def __contains__(self, component: T) -> bool: if not hasattr(component, "name"): - raise ComponentConfigError(f"Component {component} has no name attribute") + raise ComponentConfigError( + f"{type(component).__name__} {component} has no name attribute" + ) return component.name in [c.name for c in self.components] - def __iter__(self) -> Iterator[Component | Manager]: + def __iter__(self) -> Iterator[T]: return iter(self.components) def __len__(self) -> int: @@ -89,7 +89,7 @@ def __len__(self) -> int: def __bool__(self) -> bool: return bool(self.components) - def __add__(self, other: "OrderedComponentSet") -> "OrderedComponentSet": + def __add__(self, other: OrderedComponentSet[T]) -> OrderedComponentSet[T]: return OrderedComponentSet(*(self.components + other.components)) def __eq__(self, other: object) -> bool: @@ -102,7 +102,7 @@ def __eq__(self, other: object) -> bool: except TypeError: return False - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: int) -> T: return self.components[index] def __repr__(self) -> str: @@ -126,9 +126,10 @@ class ComponentManager(Manager): """ def __init__(self) -> None: - self._managers = OrderedComponentSet() - self._components = OrderedComponentSet() + self._managers: OrderedComponentSet[Manager] = OrderedComponentSet() + self._components: OrderedComponentSet[Component] = OrderedComponentSet() self._configuration: LayeredConfigTree | None = None + self._current_component: Component | Manager | None = None @property def configuration(self) -> LayeredConfigTree: @@ -162,7 +163,7 @@ def setup_manager( self.list_components, restrict_during=[lifecycle_states.INITIALIZATION] ) - def add_managers(self, managers: list[Manager] | tuple[Manager]) -> None: + def add_managers(self, managers: Sequence[Manager]) -> None: """Registers new managers with the component manager. Managers are configured and setup before components. @@ -172,11 +173,11 @@ def add_managers(self, managers: list[Manager] | tuple[Manager]) -> None: managers Instantiated managers to register. """ - for m in self._flatten(list(managers)): - self.apply_configuration_defaults(m) - self._managers.add(m) + for manager in managers: + self.apply_configuration_defaults(manager) + self._managers.add(manager) - def add_components(self, components: list[Component] | tuple[Component]) -> None: + def add_components(self, components: Sequence[Component]) -> None: """Register new components with the component manager. Components are configured and setup after managers. @@ -186,13 +187,11 @@ def add_components(self, components: list[Component] | tuple[Component]) -> None components Instantiated components to register. """ - for c in self._flatten(list(components)): + for c in self._flatten_subcomponents(list(components)): self.apply_configuration_defaults(c) self._components.add(c) - def get_components_by_type( - self, component_type: type[Component | Manager] | Sequence[type[Component | Manager]] - ) -> list[Component | Manager]: + def get_components_by_type(self, component_type: type[C] | Sequence[type[C]]) -> list[C]: """Get all components that are an instance of ``component_type``. Parameters @@ -208,7 +207,7 @@ def get_components_by_type( component_type = ( component_type if isinstance(component_type, type) else tuple(component_type) ) - return [c for c in self._components if isinstance(c, component_type)] + return [c for c in self._components if isinstance(c, component_type)] # type: ignore[misc] def get_component(self, name: str) -> Component | Manager: """Get the component with name ``name``. @@ -244,6 +243,38 @@ def list_components(self) -> dict[str, Component | Manager]: """ return {c.name: c for c in self._components} + def get_current_component(self) -> Component: + """Get the component currently being set up, if any. + + Returns + ------- + The component currently being set up. + + Raises + ------ + LifeCycleError + No component is currently being set up. + """ + if not isinstance(self._current_component, Component): + raise LifeCycleError("No component is currently being set up.") + return self._current_component + + def get_current_component_or_manager(self) -> Component | Manager: + """Get the component or manager currently being set up, if any. + + Returns + ------- + The component or manager currently being set up. + + Raises + ------ + LifeCycleError + No component or manager is currently being set up. + """ + if self._current_component is None: + raise LifeCycleError("No component or manager is currently being set up.") + return self._current_component + def setup_components(self, builder: Builder) -> None: """Separately configure and set up the managers and components held by the component manager, in that order. @@ -258,7 +289,15 @@ def setup_components(self, builder: Builder) -> None: builder Interface to several simulation tools. """ - self._setup_components(builder, self._managers + self._components) + for manager in self._managers: + self._current_component = manager + manager.setup(builder) + + for component in self._components: + self._current_component = component + component.setup_component(builder) + + self._current_component = None def apply_configuration_defaults(self, component: Component | Manager) -> None: try: @@ -299,34 +338,12 @@ def _get_file(component: Component | Manager) -> str: else: return inspect.getfile(component.__class__) - @staticmethod - def _flatten(components: _ComponentsType) -> list[Component | Manager]: - out: list[Component | Manager] = [] - # Reverse the order of components so we can pop appropriately - components = list(components)[::-1] - while components: - current = components.pop() - if isinstance(current, (list, tuple)): - components.extend(current[::-1]) - elif isinstance(current, Component): - components.extend(current.sub_components[::-1]) - out.append(current) - elif isinstance(current, Manager): - out.append(current) - else: - raise TypeError( - "Expected Component, Manager, List, or Tuple. " - f"Got {type(current)}: {current}" - ) - return out - - @staticmethod - def _setup_components(builder: Builder, components: OrderedComponentSet) -> None: + def _flatten_subcomponents(self, components: Sequence[Component]) -> list[Component]: + out: list[Component] = [] for component in components: - if isinstance(component, Component): - component.setup_component(builder) - elif isinstance(component, Manager): - component.setup(builder) + out.append(component) + out.extend(self._flatten_subcomponents(component.sub_components)) + return out def __repr__(self) -> str: return "ComponentManager()" diff --git a/src/vivarium/framework/configuration.py b/src/vivarium/framework/configuration.py index da3d6246e..daadeb4ef 100644 --- a/src/vivarium/framework/configuration.py +++ b/src/vivarium/framework/configuration.py @@ -47,7 +47,7 @@ def build_model_specification( def validate_model_specification_file(file_path: str | Path) -> None: - """Ensures the provided file is a yaml file""" + """Ensures the provided file is a yaml file.""" file_path = Path(file_path) if not file_path.exists(): raise ConfigurationError( diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index 77ec66edc..5e3294611 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -19,6 +19,8 @@ """ +from __future__ import annotations + from pathlib import Path from pprint import pformat from time import time @@ -62,7 +64,7 @@ class SimulationContext: @staticmethod def _get_context_name(sim_name: str | None) -> str: - """Get a unique name for a simulation context. + """Gets a unique name for a simulation context. Parameters ---------- @@ -96,7 +98,7 @@ def _get_context_name(sim_name: str | None) -> str: @staticmethod def _clear_context_cache() -> None: - """Clear the cache of simulation context names. + """Clears the cache of simulation context names. Notes ----- @@ -199,7 +201,8 @@ def __init__( self._tables, self._data, self._results, - ] + list(self._plugin_manager.get_optional_controllers().values()) + *self._plugin_manager.get_optional_controllers().values(), + ] self._component_manager.add_managers(managers) component_config_parser = self._plugin_manager.get_component_config_parser() @@ -242,11 +245,11 @@ def current_time(self) -> ClockTime: return self._clock.time def get_results(self) -> dict[str, pd.DataFrame]: - """Return the formatted results.""" + """Returns a dictionary of formatted results.""" return self._results.get_results() def run_simulation(self) -> None: - """A wrapper method to run all steps of a simulation""" + """Runs all steps of a simulation.""" self.setup() self.initialize_simulants() self.run() @@ -278,7 +281,7 @@ def initialize_simulants(self) -> None: self._clock.step_backward() population_size = pop_params.population_size self.simulant_creator(population_size, {"sim_state": lifecycle_states.SETUP}) - self._clock.step_forward(self.get_population().index) + self._clock.step_forward(self.get_population_index()) def step(self) -> None: self._logger.info(self.current_time) @@ -286,12 +289,12 @@ def step(self) -> None: self._logger.debug(f"Event: {event}") self._lifecycle.set_state(event) pop_to_update = self._clock.get_active_simulants( - self.get_population().index, + self.get_population_index(), self._clock.event_time, ) self._logger.debug(f"Updating: {len(pop_to_update)}") self.time_step_emitters[event](pop_to_update, None) - self._clock.step_forward(self.get_population().index) + self._clock.step_forward(self.get_population_index()) def run( self, @@ -312,7 +315,7 @@ def run( def finalize(self) -> None: self._lifecycle.set_state(lifecycle_states.SIMULATION_END) - self.end_emitter(self.get_population().index, None) + self.end_emitter(self.get_population_index(), None) unused_config_keys = self.configuration.unused_keys() if unused_config_keys: self._logger.warning( @@ -321,7 +324,7 @@ def finalize(self) -> None: def report(self, print_results: bool = True) -> None: self._lifecycle.set_state(lifecycle_states.REPORT) - self.report_emitter(self.get_population().index, None) + self.report_emitter(self.get_population_index(), None) results = self.get_results() if print_results: for measure, df in results.items(): @@ -335,7 +338,7 @@ def report(self, print_results: bool = True) -> None: self._write_results(results) def _write_results(self, results: dict[str, pd.DataFrame]) -> None: - """Iterate through the measures and write out the formatted results""" + """Iterates through the measures and writes out the formatted results.""" try: results_dir = self.configuration.output_data.results_directory for measure, df in results.items(): @@ -370,8 +373,11 @@ def add_components(self, component_list: list[Component]) -> None: """Adds new components to the simulation.""" self._component_manager.add_components(component_list) - def get_population(self, untracked: bool = True) -> pd.DataFrame: - return self._population.get_population(untracked) + def get_population(self) -> pd.Series[Any] | pd.DataFrame: + return self._population.get_population("all") + + def get_population_index(self) -> pd.Index[int]: + return self._population.get_population_index() def __repr__(self) -> str: return f"SimulationContext({self.name})" @@ -381,7 +387,7 @@ def get_number_of_steps_remaining(self) -> int: @classmethod def load_from_backup(cls, backup_path: Path) -> "SimulationContext": - """Load a simulation context from a backup file.""" + """Loads a simulation context from a backup file.""" with open(backup_path, "rb") as f: backup: SimulationContext = dill.load(f) return backup @@ -422,7 +428,7 @@ def __init__( :ref:`event` system.""" self.population = plugin_manager.get_plugin_interface(PopulationInterface) - """Provides access to simulant state table via the + """Provides access to population state table via the :ref:`population` system.""" self.resources = plugin_manager.get_plugin_interface(ResourceInterface) diff --git a/src/vivarium/framework/event/__init__.py b/src/vivarium/framework/event/__init__.py new file mode 100644 index 000000000..379ff78ac --- /dev/null +++ b/src/vivarium/framework/event/__init__.py @@ -0,0 +1,2 @@ +from vivarium.framework.event.interface import EventInterface +from vivarium.framework.event.manager import Event, EventChannel, EventManager diff --git a/src/vivarium/framework/event/interface.py b/src/vivarium/framework/event/interface.py new file mode 100644 index 000000000..4958ae594 --- /dev/null +++ b/src/vivarium/framework/event/interface.py @@ -0,0 +1,92 @@ +""" +=============== +Event Interface +=============== + +The :class:`EventInterface` is exposed off the :ref:`builder ` +and provides two methods: :func:`get_emitter `, +which returns a callable emitter for the given event type and +:func:`register_listener `, which adds the +given listener to the event channel for the given event. This is the only part +of the event framework with which client code should interact. + +For more information, see the associated event :ref:`concept note `. + +""" +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import pandas as pd + +from vivarium.framework.event.manager import Event, EventManager +from vivarium.manager import Interface + + +class EventInterface(Interface): + """The public interface for the :class:`~ ` system.""" + + def __init__(self, manager: EventManager): + self._manager = manager + + def get_emitter( + self, event_name: str + ) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]: + """Gets an emitter for a named ``Event``. + + Parameters + ---------- + event_name + The name of the ``Event`` the requested emitter will emit. + Users may provide their own named ``Events`` by requesting an emitter + with this function, but should do so with caution as it makes time + much more difficult to think about. + + Returns + ------- + An emitter for the named ``Event``. The emitter should be called by + the requesting component at the appropriate point in the simulation + lifecycle. + """ + return self._manager.get_emitter(event_name) + + def register_listener( + self, event_name: str, listener: Callable[[Event], None], priority: int = 5 + ) -> None: + """Registers a callable as a listener to an ``Event`` with the given name. + + The listening callable will be called with a named ``Event`` as its + only argument any time the ``Event`` emitter is invoked from somewhere in + the simulation. + + The framework creates the following ``Events`` and emits them at different + points in the simulation: + + - At the end of the setup phase: ``post_setup`` + - Every time step: + - ``time_step__prepare`` + - ``time_step`` + - ``time_step__cleanup`` + - ``collect_metrics`` + - At simulation end: ``simulation_end`` + + Parameters + ---------- + event_name + The name of the ``Event`` to listen for. + listener + The callable to be invoked any time an ``Event`` with the given + name is emitted. + priority + One of {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}. + An indication of the order in which ``Event`` listeners should be + called. Listeners with smaller priority values will be called + earlier. Listeners with the same priority have no guaranteed + ordering. This feature should be avoided if possible. Components + should strive to obey the Markov property as they transform the + state table (the state of the simulation at the beginning of the + next time step should only depend on the current state of the + system). + """ + self._manager.register_listener(event_name, listener, priority) diff --git a/src/vivarium/framework/event.py b/src/vivarium/framework/event/manager.py similarity index 68% rename from src/vivarium/framework/event.py rename to src/vivarium/framework/event/manager.py index 04dfe8c63..1d01ad79a 100644 --- a/src/vivarium/framework/event.py +++ b/src/vivarium/framework/event/manager.py @@ -1,13 +1,11 @@ """ -============================ -The Vivarium Event Framework -============================ +============= +Event Manager +============= ``vivarium`` constructs and manages the flow of :ref:`time ` through the emission of regularly scheduled events. The tools in this module -manage the relationships between event emitters and listeners and provide -an interface for user :ref:`components ` to register -themselves as emitters or listeners to particular events. +manage the relationships between event emitters and listeners. The :class:`EventManager` maintains a mapping between event types and channels. Each event type (and event types must be unique so event type is equivalent to @@ -15,17 +13,11 @@ :class:`EventChannel`, which tracks listeners to that event in prioritized levels and passes on the event to those listeners when emitted. -The :class:`EventInterface` is exposed off the :ref:`builder ` -and provides two methods: :func:`get_emitter `, -which returns a callable emitter for the given event type and -:func:`register_listener `, which adds the -given listener to the event channel for the given event. This is the only part -of the event framework with which client code should interact. - -For more information, see the associated event -:ref:`concept note `. +For more information, see the associated event :ref:`concept note `. """ + + from __future__ import annotations from collections.abc import Callable @@ -36,7 +28,7 @@ import pandas as pd from vivarium.framework.lifecycle import ConstraintError, lifecycle_states -from vivarium.manager import Interface, Manager +from vivarium.manager import Manager from vivarium.types import ClockStepSize, ClockTime if TYPE_CHECKING: @@ -50,6 +42,7 @@ class Event: Events themselves are just a bundle of data. They must be emitted along an :class:`EventChannel` in order for other simulation components to respond to them. + """ name: str @@ -65,11 +58,10 @@ class Event: """The current step size at the time of the event.""" def split(self, new_index: pd.Index[int]) -> "Event": - """Create a copy of this event with a new index. + """Creates a copy of this event with a new index. - This function should be used to emit an event in a new - :class:`EventChannel` in response to an event emitted from a - different channel. + This function should be used to emit an event in a new :class:`EventChannel` + in response to an event emitted from a different channel. Parameters ---------- @@ -204,7 +196,7 @@ def on_post_setup(self, event: Event) -> None: def get_emitter( self, event_name: str ) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]: - """Get an emitter function for the named event. + """Gets an emitter function for the named event. Parameters ---------- @@ -244,7 +236,7 @@ def register_listener( self.get_channel(event_name).listeners[priority].append(listener) def get_listeners(self, event_name: str) -> dict[int, list[Callable[[Event], None]]]: - """Get all listeners registered for the named event. + """Gets all listeners registered for the named event. Parameters ---------- @@ -264,7 +256,7 @@ def get_listeners(self, event_name: str) -> dict[int, list[Callable[[Event], Non } def list_events(self) -> list[str]: - """List all event names known to the event system. + """Lists all event names known to the event system. Returns ------- @@ -282,71 +274,3 @@ def __contains__(self, item: str) -> bool: def __repr__(self) -> str: return "EventManager()" - - -class EventInterface(Interface): - """The public interface for the event system.""" - - def __init__(self, manager: EventManager): - self._manager = manager - - def get_emitter( - self, event_name: str - ) -> Callable[[pd.Index[int], dict[str, Any] | None], Event]: - """Gets an emitter for a named event. - - Parameters - ---------- - event_name - The name of the event the requested emitter will emit. - Users may provide their own named events by requesting an emitter - with this function, but should do so with caution as it makes time - much more difficult to think about. - - Returns - ------- - An emitter for the named event. The emitter should be called by - the requesting component at the appropriate point in the simulation - lifecycle. - """ - return self._manager.get_emitter(event_name) - - def register_listener( - self, event_name: str, listener: Callable[[Event], None], priority: int = 5 - ) -> None: - """Registers a callable as a listener to a events with the given name. - - The listening callable will be called with a named ``Event`` as its - only argument any time the event emitter is invoked from somewhere in - the simulation. - - The framework creates the following events and emits them at different - points in the simulation: - - - At the end of the setup phase: ``post_setup`` - - Every time step: - - ``time_step__prepare`` - - ``time_step`` - - ``time_step__cleanup`` - - ``collect_metrics`` - - At simulation end: ``simulation_end`` - - Parameters - ---------- - event_name - The name of the event to listen for. - listener - The callable to be invoked any time an :class:`Event` with the given - name is emitted. - priority - One of {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}. - An indication of the order in which event listeners should be - called. Listeners with smaller priority values will be called - earlier. Listeners with the same priority have no guaranteed - ordering. This feature should be avoided if possible. Components - should strive to obey the Markov property as they transform the - state table (the state of the simulation at the beginning of the - next time step should only depend on the current state of the - system). - """ - self._manager.register_listener(event_name, listener, priority) diff --git a/src/vivarium/framework/logging/manager.py b/src/vivarium/framework/logging/manager.py index c23f7ba76..6bcc7b74d 100644 --- a/src/vivarium/framework/logging/manager.py +++ b/src/vivarium/framework/logging/manager.py @@ -1,7 +1,7 @@ """ -=================== -The Logging Manager -=================== +=============== +Logging Manager +=============== """ from __future__ import annotations diff --git a/src/vivarium/framework/lookup/__init__.py b/src/vivarium/framework/lookup/__init__.py index 1a2731713..4896cc055 100644 --- a/src/vivarium/framework/lookup/__init__.py +++ b/src/vivarium/framework/lookup/__init__.py @@ -1,7 +1,4 @@ from vivarium.framework.lookup.interface import LookupTableInterface -from vivarium.framework.lookup.manager import ( - LookupTableManager, - validate_build_table_parameters, -) -from vivarium.framework.lookup.table import LookupTable +from vivarium.framework.lookup.manager import LookupTableManager +from vivarium.framework.lookup.table import DEFAULT_VALUE_COLUMN, LookupTable from vivarium.types import LookupTableData, ScalarValue diff --git a/src/vivarium/framework/lookup/interface.py b/src/vivarium/framework/lookup/interface.py index 7b159761e..2f63c892f 100644 --- a/src/vivarium/framework/lookup/interface.py +++ b/src/vivarium/framework/lookup/interface.py @@ -7,8 +7,11 @@ """ +from __future__ import annotations -from collections.abc import Sequence +from typing import Any, overload + +import pandas as pd from vivarium.framework.lookup.manager import LookupTableManager from vivarium.framework.lookup.table import LookupTable @@ -17,7 +20,7 @@ class LookupTableInterface(Interface): - """The lookup table management system. + """The interface to the lookup table management system. Simulations tend to require a large quantity of data to run. ``vivarium`` provides the :class:`Lookup Table ` @@ -30,47 +33,63 @@ class LookupTableInterface(Interface): def __init__(self, manager: LookupTableManager): self._manager = manager + @overload + def build_table( + self, + data: LookupTableData, + name: str = "", + value_columns: str | None = None, + ) -> LookupTable[pd.Series[Any]]: + ... + + @overload + def build_table( + self, + data: LookupTableData, + name: str = "", + value_columns: list[str] | tuple[str, ...] = ..., + ) -> LookupTable[pd.DataFrame]: + ... + def build_table( self, data: LookupTableData, - key_columns: Sequence[str] = (), - parameter_columns: Sequence[str] = (), - value_columns: Sequence[str] = (), - ) -> LookupTable: + name: str = "", + value_columns: list[str] | tuple[str, ...] | str | None = None, + ) -> LookupTable[pd.Series[Any]] | LookupTable[pd.DataFrame]: """Construct a LookupTable from input data. - If data is a :class:`pandas.DataFrame`, an interpolation function of - the order specified in the simulation - :term:`configuration ` will be calculated for each - permutation of the set of key_columns. The columns in parameter_columns - will be used as parameters for the interpolation functions which will - estimate all remaining columns in the table. + If the data is a scalar value, this will return a table that when called + will return that scalar value for each index entry. - If data is a number, time, list, or tuple, a scalar table will be - constructed with the values in data as the values in each column of - the table, named according to value_columns. + If the data is a pandas DataFrame columns with names in value_columns + will be returned directly when the table is called with a population index. + The value to return for each index entry will be looked up based on the values + at those indices of other columns of the DataFrame in the simulation population. + Non-value columns which exist as a pair of the form "some_column_start" and + "some_column_end" will be treated as ranges, and the column "some_column" + will be interpolated using order 0 (step function) interpolation over that range. + Other non-value columns will be treated as exact matches for lookups. + If value_columns is a single string, the returned table will return a + :class:`pandas.Series` when called. If value_columns is a list or tuple + of strings, the returned table will return a pandas DataFrame + when called. If value_columns is None, it will return a :class:`pandas.Series` + with the name "value". Parameters ---------- data The source data which will be used to build the resulting :class:`Lookup Table `. - key_columns - Columns used to select between interpolation functions. These - should be the non-continuous variables in the data. For example - 'sex' in data about a population. - parameter_columns - The columns which contain the parameters to the interpolation - functions. These should be the continuous variables. For example - 'age' in data about a population. + name + The name of the table. If not provided, a generic name will be assigned. value_columns - The data columns that will be in the resulting LookupTable. Columns - to be interpolated over if interpolation or the names of the columns - in the scalar table. + The name(s) of the column(s) in the data to return when + the table is called. Returns ------- LookupTable """ - return self._manager.build_table(data, key_columns, parameter_columns, value_columns) + return self._manager.build_table(data, name, value_columns) diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index 076f5891d..4de5a2b2e 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -20,17 +20,17 @@ class Interpolation: Attributes ---------- - data : + data The data from which to build the interpolation. Contains categorical_parameters and continuous_parameters. - categorical_parameters : + categorical_parameters Column names to be used as categorical parameters in Interpolation to select between interpolation functions. - continuous_parameters : + continuous_parameters Column names to be used as continuous parameters in Interpolation. If bin edges, should be of the form (column name used in call, column name for left bin edge, column name for right bin edge). - order : + order Order of interpolation. """ @@ -95,8 +95,8 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Parameters ---------- - interpolants : - Data frame containing the parameters to interpolate.. + interpolants + Data frame containing the parameters to interpolate. Returns ------- @@ -145,11 +145,6 @@ def validate_parameters( if data.empty: raise ValueError("You must supply non-empty data to create the interpolation.") - if len(continuous_parameters) < 1: - raise ValueError( - "You must supply at least one continuous parameter over which to interpolate." - ) - for p in continuous_parameters: if not isinstance(p, (tuple, list)) or len(p) != 3: raise ValueError( @@ -160,7 +155,6 @@ def validate_parameters( ) # break out the individual columns from binned column name lists - param_cols = [col for p in continuous_parameters for col in p] if not value_columns: raise ValueError( f"No non-parameter data. Available columns: {data.columns}, " @@ -343,12 +337,20 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Parameters ---------- interpolants - Data frame containing the parameters to interpolate.. + Data frame containing the parameters to interpolate. Returns ------- A table with the interpolated values for the given interpolants. """ + if not self.parameter_bins: + # No continuous parameters — just broadcast the data values. + # With only categorical parameters, each sub-table has a single row. + return pd.DataFrame( + {col: self.data[col].iloc[0] for col in self.value_columns}, + index=interpolants.index, + ) + # build a dataframe where we have the start of each parameter bin for each interpolant interpolant_bins = pd.DataFrame(index=interpolants.index) diff --git a/src/vivarium/framework/lookup/manager.py b/src/vivarium/framework/lookup/manager.py index b9c1aab4b..ac4c99859 100644 --- a/src/vivarium/framework/lookup/manager.py +++ b/src/vivarium/framework/lookup/manager.py @@ -1,7 +1,7 @@ """ -============= -Lookup Tables -============= +==================== +Lookup Table Manager +==================== Simulations tend to require a large quantity of data to run. :mod:`vivarium` provides the :class:`Lookup Table ` @@ -12,24 +12,22 @@ """ -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import TYPE_CHECKING -from typing import SupportsFloat as Numeric +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, overload import pandas as pd +from layered_config_tree import LayeredConfigTree +from vivarium.framework.event import Event from vivarium.framework.lifecycle import lifecycle_states -from vivarium.framework.lookup.table import ( - CategoricalTable, - InterpolatedTable, - LookupTable, - ScalarTable, -) +from vivarium.framework.lookup.table import DEFAULT_VALUE_COLUMN, LookupTable from vivarium.manager import Manager from vivarium.types import LookupTableData if TYPE_CHECKING: + from vivarium import Component from vivarium.framework.engine import Builder @@ -51,30 +49,74 @@ class LookupTableManager(Manager): def name(self) -> str: return "lookup_table_manager" - def setup(self, builder: "Builder") -> None: - self.tables: dict[int, LookupTable] = {} - self._pop_view_builder = builder.population.get_view + def __init__(self) -> None: + super().__init__() + self.tables: dict[str, LookupTable[pd.Series[Any]] | LookupTable[pd.DataFrame]] = {} + + def setup(self, builder: Builder) -> None: + self._logger = builder.logging.get_logger(self.name) + self._configuration = builder.configuration + self._get_view = builder.population.get_view self.clock = builder.time.clock() - self._interpolation_order = builder.configuration.interpolation.order - self._extrapolate = builder.configuration.interpolation.extrapolate - self._validate = builder.configuration.interpolation.validate + self.interpolation_order = builder.configuration.interpolation.order + self.extrapolate = builder.configuration.interpolation.extrapolate + self.validate_interpolation = builder.configuration.interpolation.validate + self._add_resources = builder.resources.add_resources self._add_constraint = builder.lifecycle.add_constraint + self._get_current_component = builder.components.get_current_component builder.lifecycle.add_constraint( self.build_table, allow_during=[lifecycle_states.SETUP] ) + builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup) + + def on_post_setup(self, event: Event) -> None: + configured_lookup_tables: dict[str, list[str]] = {} + for config_key, config in self._configuration.items(): + if isinstance(config, LayeredConfigTree) and "data_sources" in config: + configured_lookup_tables[config_key] = list( + config.get_tree("data_sources").keys() + ) + + for component_name, table_names in configured_lookup_tables.items(): + for table_name in table_names: + full_table_name = LookupTable.get_name(component_name, table_name) + if full_table_name not in self.tables: + self._logger.warning( + f"Component '{component_name}' configured, but didn't build lookup" + f" table '{table_name}' during setup." + ) + + @overload + def build_table( + self, + data: LookupTableData, + name: str, + value_columns: str | None, + ) -> LookupTable[pd.Series[Any]]: + ... + + @overload + def build_table( + self, + data: LookupTableData, + name: str, + value_columns: list[str] | tuple[str, ...], + ) -> LookupTable[pd.DataFrame]: + ... def build_table( self, data: LookupTableData, - key_columns: Sequence[str], - parameter_columns: Sequence[str], - value_columns: Sequence[str], - ) -> LookupTable: + name: str, + value_columns: list[str] | tuple[str, ...] | str | None, + ) -> LookupTable[pd.Series[Any]] | LookupTable[pd.DataFrame]: """Construct a lookup table from input data.""" - table = self._build_table(data, key_columns, parameter_columns, value_columns) + component = self._get_current_component() + table = self._build_table(component, data, name, value_columns) + self._add_resources(component, table, table.required_resources) self._add_constraint( - table.call, + table._call, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, @@ -85,130 +127,33 @@ def build_table( def _build_table( self, + component: Component, data: LookupTableData, - key_columns: Sequence[str], - parameter_columns: Sequence[str], - value_columns: Sequence[str], - ) -> LookupTable: + name: str, + value_columns: list[str] | tuple[str, ...] | str | None, + ) -> LookupTable[pd.Series[Any]] | LookupTable[pd.DataFrame]: # We don't want to require explicit names for tables, but giving them # generic names is useful for introspection. - table_number = len(self.tables) + if not name: + name = f"lookup_table_{len(self.tables)}" if isinstance(data, Mapping): data = pd.DataFrame(data) - if self._validate: - validate_build_table_parameters( - data, key_columns, parameter_columns, value_columns - ) - - # Note datetime catches pandas timestamps - if isinstance(data, (Numeric, datetime, timedelta, list, tuple)): - table: LookupTable = ScalarTable( - table_number=table_number, - data=data, - key_columns=key_columns, - parameter_columns=parameter_columns, - value_columns=value_columns, - validate=self._validate, - ) - elif parameter_columns: - table = InterpolatedTable( - table_number=table_number, - data=data, - population_view_builder=self._pop_view_builder, - key_columns=key_columns, - parameter_columns=parameter_columns, - value_columns=value_columns, - interpolation_order=self._interpolation_order, - clock=self.clock, - extrapolate=self._extrapolate, - validate=self._validate, - ) - else: - table = CategoricalTable( - table_number=table_number, - data=data, - population_view_builder=self._pop_view_builder, - key_columns=key_columns, - value_columns=value_columns, - ) - - self.tables[table_number] = table - return table + value_columns_ = value_columns if value_columns else DEFAULT_VALUE_COLUMN - def __repr__(self) -> str: - return "LookupTableManager()" + table = LookupTable( + name=name, + component=component, + data=data, + value_columns=value_columns_, + manager=self, + population_view=self._get_view(), + ) + self.tables[table.name] = table -def validate_build_table_parameters( - data: LookupTableData, - key_columns: Sequence[str], - parameter_columns: Sequence[str], - value_columns: Sequence[str], -) -> None: - """Makes sure the data format agrees with the provided column layout.""" - if ( - data is None - or (isinstance(data, pd.DataFrame) and data.empty) - or (isinstance(data, (list, tuple)) and not data) - ): - raise ValueError("Must supply some data") - - acceptable_types = (Numeric, datetime, timedelta, list, tuple, pd.DataFrame) - if not isinstance(data, acceptable_types): - raise TypeError( - f"The only allowable types for data are {acceptable_types}. " - f"You passed {type(data)}." - ) + return table - if isinstance(data, (list, tuple)): - if not value_columns: - raise ValueError( - "To invoke scalar view with multiple values, you must supply value_columns" - ) - if len(value_columns) != len(data): - raise ValueError( - "The number of value columns must match the number of values." - f"You supplied values: {data} and value_columns: {value_columns}" - ) - if key_columns: - raise ValueError( - f"key_columns are not allowed for scalar view: Provided {key_columns}." - ) - if parameter_columns: - raise ValueError( - "parameter_columns are not allowed for scalar view: " - f"Provided {parameter_columns}." - ) - - if isinstance(data, pd.DataFrame): - if not key_columns and not parameter_columns: - raise ValueError( - "Must supply either key_columns or parameter_columns with a DataFrame." - ) - - bin_edge_columns = [] - for p in parameter_columns: - bin_edge_columns.extend([f"{p}_start", f"{p}_end"]) - all_parameter_columns = set(parameter_columns) | set(bin_edge_columns) - - if set(key_columns).intersection(all_parameter_columns): - raise ValueError( - f"There should be no overlap between key columns: {key_columns} " - f"and parameter columns: {parameter_columns}." - ) - - lookup_columns = set(key_columns) | all_parameter_columns - if set(value_columns).intersection(lookup_columns): - raise ValueError( - f"There should be no overlap between value columns: {value_columns} " - f"and key or parameter columns: {lookup_columns}." - ) - - specified_columns = set(key_columns) | set(bin_edge_columns) | set(value_columns) - if specified_columns.difference(data.columns): - raise ValueError( - f"The columns supplied: {specified_columns} must all be " - f"present in the passed data: {data.columns}" - ) + def __repr__(self) -> str: + return "LookupTableManager()" diff --git a/src/vivarium/framework/lookup/table.py b/src/vivarium/framework/lookup/table.py index a7013aed1..a7dafde05 100644 --- a/src/vivarium/framework/lookup/table.py +++ b/src/vivarium/framework/lookup/table.py @@ -13,20 +13,29 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence -from datetime import datetime -from typing import Any +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Generic +from typing import SupportsFloat as Numeric +from typing import TypeVar -import numpy as np import pandas as pd +from vivarium.component import Component from vivarium.framework.lookup.interpolation import Interpolation from vivarium.framework.population.population_view import PopulationView -from vivarium.types import ClockTime, ScalarValue +from vivarium.framework.resource import Resource +from vivarium.types import LookupTableData +if TYPE_CHECKING: + from vivarium.framework.lookup.manager import LookupTableManager -class LookupTable(ABC): +T = TypeVar("T", pd.Series, pd.DataFrame) # type: ignore [type-arg] + + +DEFAULT_VALUE_COLUMN = "value" + + +class LookupTable(Resource, Generic[T]): """A callable to produces values for a population index. In :mod:`vivarium` simulations, the index is synonymous with the simulated @@ -37,37 +46,80 @@ class LookupTable(ABC): Notes ----- - These should not be created directly. Use the `lookup` method on the builder - during setup. + These should not be created directly. Use the :attr:`~vivarium.framework.engine.Builder.lookup` + attribute on the :class:`~vivarium.framework.engine.Builder` class during setup. """ def __init__( self, - table_number: int, - key_columns: Sequence[str] = (), - parameter_columns: Sequence[str] = (), - value_columns: Sequence[str] = (), - validate: bool = True, + component: Component, + data: LookupTableData, + name: str, + value_columns: list[str] | tuple[str, ...] | str, + manager: LookupTableManager, + population_view: PopulationView, ): - self.table_number = table_number - """Unique identifier of the table.""" - self.key_columns = key_columns + super().__init__("lookup_table", self.get_name(component.name, name), component) + self._validate_data_inputs(data, value_columns) + + self.data: LookupTableData = data + """The data this table will use to produce values.""" + self.return_type: type[T] = ( + pd.Series if isinstance(value_columns, str) else pd.DataFrame + ) + """The type of data returned by the lookup table (pd.Series or pd.DataFrame).""" + self.key_columns: list[str] = [] """Column names to be used as categorical parameters in Interpolation to select between interpolation functions.""" - self.parameter_columns = parameter_columns + self.parameter_columns: list[str] = [] """Column names to be used as continuous parameters in Interpolation.""" - self.value_columns = list(value_columns) - """Names of value columns to be interpolated over.""" - self.validate = validate - """Whether to validate the data before building the LookupTable.""" + self.value_columns: list[str] = ( + list(value_columns) if not isinstance(value_columns, str) else [value_columns] + ) + """Names of value columns that will be returned by the lookup table.""" + self._manager: LookupTableManager = manager + """The manager that created this lookup table.""" + self.population_view: PopulationView = population_view + """PopulationView to use to get attributes for interpolation or categorization.""" + self.interpolation: Interpolation | None = None + """Interpolation object to use when data is a DataFrame. Will be None if data is + a scalar or list of scalars.""" + + if isinstance(data, pd.DataFrame): + self.parameter_columns, self.key_columns = self._get_columns( + self.value_columns, data + ) + parameter_columns_with_edges: list[tuple[str, str, str]] = [ + (p, f"{p}_start", f"{p}_end") for p in self.parameter_columns + ] + required_cols = { + *self.key_columns, + *{col for p in parameter_columns_with_edges for col in p}, + *self.value_columns, + } + if extra_columns := list(data.columns.difference(list(required_cols))): + raise ValueError( + f"Data contains extra columns not in " + f"key_columns, parameter_columns, or value_columns: {extra_columns}" + ) + + self.interpolation = Interpolation( + data, + self.key_columns, + parameter_columns_with_edges, + self.value_columns, + order=self._manager.interpolation_order, + extrapolate=self._manager.extrapolate, + validate=self._manager.validate_interpolation, + ) @property - def name(self) -> str: - """Tables are generically named after the order they were created.""" - return f"lookup_table_{self.table_number}" + def required_resources(self) -> list[str]: + lookup_columns = list(self.key_columns) + list(self.parameter_columns) + return [col for col in lookup_columns if col != "year"] - def __call__(self, index: pd.Index[int]) -> pd.Series[Any] | pd.DataFrame: + def __call__(self, index: pd.Index[int]) -> T: """Get the mapped values for the given index. Parameters @@ -81,241 +133,131 @@ def __call__(self, index: pd.Index[int]) -> pd.Series[Any] | pd.DataFrame: columns """ - mapped_values: pd.Series[Any] | pd.DataFrame = self.call(index).squeeze(axis=1) + mapped_values = self._call(index).squeeze(axis=1) + if not isinstance(mapped_values, self.return_type): + raise TypeError( + f"LookupTable expected to return {self.return_type}, " + f"but got {type(mapped_values)}" + ) return mapped_values - @abstractmethod - def call(self, index: pd.Index[int]) -> pd.DataFrame: + def _call(self, index: pd.Index[int]) -> pd.DataFrame: """Private method to allow LookupManager to add constraints.""" - pass - - def __repr__(self) -> str: - return "LookupTable()" - - -class InterpolatedTable(LookupTable): - """A callable that interpolates data according to a given strategy. - - Notes - ----- - These should not be created directly. Use the `lookup` interface on the - :class:`builder ` during setup. - - """ - - def __init__( - self, - table_number: int, - data: pd.DataFrame, - population_view_builder: Callable[[list[str]], PopulationView], - key_columns: Sequence[str], - parameter_columns: Sequence[str], - value_columns: Sequence[str], - interpolation_order: int, - clock: Callable[[], ClockTime], - extrapolate: bool, - validate: bool, - ): - super().__init__( - table_number=table_number, - key_columns=key_columns, - parameter_columns=parameter_columns, - value_columns=value_columns, - validate=validate, - ) - self.data = data - self.clock = clock - self.interpolation_order = interpolation_order - self.extrapolate = extrapolate - """Callable for current time in simulation.""" - param_cols_with_edges = [] - for p in parameter_columns: - param_cols_with_edges += [(p, f"{p}_start", f"{p}_end")] - view_columns = sorted((set(key_columns) | set(parameter_columns)) - {"year"}) + [ - "tracked" - ] - - self.parameter_columns_with_edges = param_cols_with_edges - - required_cols = ( - set(self.key_columns) - | set([col for p in self.parameter_columns_with_edges for col in p]) - | set(self.value_columns) - ) - extra_columns = self.data.columns.difference(list(required_cols)) - - if not self.value_columns: - self.value_columns = list(extra_columns) - else: - self.data = self.data.drop(columns=extra_columns) - - self.population_view = population_view_builder(view_columns) - self.interpolation = Interpolation( - data, - self.key_columns, - self.parameter_columns_with_edges, - self.value_columns, - order=self.interpolation_order, - extrapolate=self.extrapolate, - validate=self.validate, - ) - - def call(self, index: pd.Index[int]) -> pd.DataFrame: - """Get the interpolated values for the rows in ``index``. - - Parameters - ---------- - index - Index of the population to interpolate for. - - Returns - ------- - A table with the interpolated values for the population requested. - - """ - pop = self.population_view.get(index) - del pop["tracked"] - if "year" in [col for col in self.parameter_columns]: - current_time = self.clock() - # TODO: [MIC-5478] handle Number output from clock - if isinstance(current_time, pd.Timestamp) or isinstance(current_time, datetime): - fractional_year = float(current_time.year) - fractional_year += current_time.timetuple().tm_yday / 365.25 - pop["year"] = fractional_year - else: - raise ValueError( - "You cannot use the column 'year' in a simulation unless your simulation uses a DateTimeClock." + if self.interpolation is None: + # Broadcast scalar or list of scalars to the index. + if not isinstance(self.data, (list, tuple)): + values_series: pd.Series[Any] = pd.Series( + self.data, index=index, name=self.value_columns[0] ) - - return self.interpolation(pop) - - def __repr__(self) -> str: - return "InterpolatedTable()" - - -class CategoricalTable(LookupTable): - """ - A callable that selects values from a table based on categorical parameters - across an index. - - Notes - ----- - These should not be created directly. Use the `lookup` interface on the - :class:`builder ` during setup. - - """ - - def __init__( - self, - table_number: int, - data: pd.DataFrame, - population_view_builder: Callable[[list[str]], PopulationView], - key_columns: Sequence[str], - value_columns: Sequence[str], - ): - super().__init__( - table_number=table_number, - key_columns=key_columns, - value_columns=value_columns, - ) - self.data = data - self.population_view = population_view_builder(list(self.key_columns) + ["tracked"]) - - extra_columns = self.data.columns.difference( - list(set(self.key_columns) | set(self.value_columns)) - ) - - if not self.value_columns: - self.value_columns = list(extra_columns) + return pd.DataFrame(values_series) + else: + values_list: list[pd.Series[Any]] = [ + pd.Series(v, index=index) for v in self.data + ] + return pd.DataFrame(dict(zip(self.value_columns, values_list))) else: - self.data = self.data.drop(columns=extra_columns) - - def call(self, index: pd.Index[int]) -> pd.DataFrame: - """Get the mapped values for the rows in ``index``. - - Parameters - ---------- - index - Index of the population to interpolate for. - - Returns - ------- - A table with the mapped values for the population requested. - """ - pop = self.population_view.get(index) - del pop["tracked"] - - # specify some numeric type for columns, so they won't be objects but - # will be updated with whatever column type it actually is - result = pd.DataFrame(index=pop.index, columns=self.value_columns, dtype=np.float64) - - sub_tables = pop.groupby(list(self.key_columns)) - for key, sub_table in list(sub_tables): - if sub_table.empty: - continue - - category_masks: list[pd.Series[bool]] = [ - self.data[self.key_columns[i]] == category for i, category in enumerate(key) + # Interpolate continuous parameters and categorize categorical parameters based on + # the population attributes. + requested_columns = [ + col + for col in list(self.key_columns) + list(self.parameter_columns) + if col != "year" ] - joint_mask = pd.Series(True, index=self.data.index) - for category_mask in category_masks: - joint_mask = joint_mask & category_mask - values = self.data.loc[joint_mask, self.value_columns].values - result.loc[sub_table.index, self.value_columns] = values - - return result + pop = pd.DataFrame(self.population_view.get_attributes(index, requested_columns)) + if "year" in self.parameter_columns: + current_time = self._manager.clock() + if isinstance(current_time, pd.Timestamp) or isinstance( + current_time, datetime + ): + fractional_year = float(current_time.year) + fractional_year += current_time.timetuple().tm_yday / 365.25 + pop["year"] = fractional_year + else: + raise ValueError( + "You cannot use the column 'year' in a simulation unless " + "your simulation uses a DateTimeClock." + ) + return self.interpolation(pop) def __repr__(self) -> str: - return "CategoricalTable()" - - -class ScalarTable(LookupTable): - """A callable that broadcasts a scalar or list of scalars over an index. - - Notes - ----- - These should not be created directly. Use the `lookup` interface on the - builder during setup. - """ - - def __init__( - self, - table_number: int, - data: ScalarValue | list[ScalarValue] | tuple[ScalarValue, ...], - key_columns: Sequence[str] = (), - parameter_columns: Sequence[str] = (), - value_columns: Sequence[str] = (), - validate: bool = True, - ): - super().__init__( - table_number, key_columns, parameter_columns, value_columns, validate - ) - self.data = data + return "LookupTable()" - def call(self, index: pd.Index[int]) -> pd.DataFrame: - """Broadcast this tables values over the provided index. + @staticmethod + def get_name(component_name: str, table_name: str) -> str: + """Get the fully qualified name for a lookup table. Parameters ---------- - index - Index of the population to construct table for. + component_name + Name of the component the lookup table belongs to. + table_name + Name of the lookup table. Returns ------- - A table with a column for each of the scalar values for the - population requested. + Fully qualified name for the lookup table. """ - if not isinstance(self.data, (list, tuple)): - values_series: pd.Series[Any] = pd.Series( - self.data, - index=index, - name=self.value_columns[0] if self.value_columns else None, + return f"{component_name}.{table_name}" + + @staticmethod + def _get_columns( + value_columns: list[str], data: pd.DataFrame + ) -> tuple[list[str], list[str]]: + all_columns = list(data.columns) + + potential_parameter_columns = [ + str(col).removesuffix("_start") + for col in all_columns + if str(col).endswith("_start") + ] + parameter_columns = [] + bin_edge_columns = [] + for column in potential_parameter_columns: + if f"{column}_end" in all_columns: + parameter_columns.append(column) + bin_edge_columns += [f"{column}_start", f"{column}_end"] + + key_columns = [ + col + for col in all_columns + if col not in value_columns and col not in bin_edge_columns + ] + + return parameter_columns, key_columns + + @staticmethod + def _validate_data_inputs( + data: LookupTableData, + value_columns: list[str] | tuple[str, ...] | str, + ) -> None: + """Makes sure the data format agrees with the provided column layout.""" + if ( + data is None + or (isinstance(data, pd.DataFrame) and data.empty) + or (isinstance(data, (list, tuple)) and not data) + ): + raise ValueError("Must supply some data") + + acceptable_types = (Numeric, datetime, timedelta, list, tuple, pd.DataFrame) + if not isinstance(data, acceptable_types): + raise TypeError( + f"The only allowable types for data are {acceptable_types}. " + f"You passed {type(data)}." ) - return pd.DataFrame(values_series) - else: - values_list: list[pd.Series[Any]] = [pd.Series(v, index=index) for v in self.data] - return pd.DataFrame(dict(zip(self.value_columns, values_list))) - def __repr__(self) -> str: - return "ScalarTable(value(s)={})".format(self.data) + if isinstance(data, (list, tuple)): + if isinstance(value_columns, str): + raise ValueError( + "When supplying multiple values, value_columns must be a list or tuple of strings." + ) + if len(value_columns) != len(data): + raise ValueError( + "The number of value columns must match the number of values." + f"You supplied values: {data} and value_columns: {value_columns}" + ) + elif not isinstance(data, pd.DataFrame): + if not isinstance(value_columns, str): + raise ValueError( + "When supplying a single value, value_columns must be a string if provided." + ) diff --git a/src/vivarium/framework/plugins.py b/src/vivarium/framework/plugins.py index 614d48f8c..fdf72045e 100644 --- a/src/vivarium/framework/plugins.py +++ b/src/vivarium/framework/plugins.py @@ -16,19 +16,22 @@ from vivarium.exceptions import VivariumError from vivarium.framework.artifact import ArtifactInterface, ArtifactManager -from vivarium.framework.components import ComponentInterface, ComponentManager -from vivarium.framework.components.parser import ComponentConfigurationParser +from vivarium.framework.components import ( + ComponentConfigurationParser, + ComponentInterface, + ComponentManager, +) from vivarium.framework.event import EventInterface, EventManager from vivarium.framework.lifecycle import LifeCycleInterface, LifeCycleManager from vivarium.framework.logging import LoggingInterface, LoggingManager from vivarium.framework.lookup import LookupTableInterface, LookupTableManager -from vivarium.framework.population.manager import PopulationInterface, PopulationManager +from vivarium.framework.population import PopulationInterface, PopulationManager from vivarium.framework.randomness import RandomnessInterface, RandomnessManager from vivarium.framework.resource import ResourceInterface, ResourceManager from vivarium.framework.results import ResultsInterface, ResultsManager from vivarium.framework.time import SimulationClock, TimeInterface from vivarium.framework.utilities import import_by_path -from vivarium.framework.values.manager import ValuesInterface, ValuesManager +from vivarium.framework.values import ValuesInterface, ValuesManager from vivarium.manager import Interface, Manager I = TypeVar("I", bound=Interface) diff --git a/src/vivarium/framework/population/__init__.py b/src/vivarium/framework/population/__init__.py index 9a96d14eb..6d23247e3 100644 --- a/src/vivarium/framework/population/__init__.py +++ b/src/vivarium/framework/population/__init__.py @@ -3,7 +3,7 @@ The Population Management System ================================ -This subpackage provides tools for managing the :term:`state table ` +This subpackage provides tools for managing the :term:`population state table ` in a :mod:`vivarium` simulation, which is the record of all simulants in a simulation and their state. Its main tasks are managing the creation of new simulants and providing the ability for components to view and update simulant @@ -12,9 +12,6 @@ """ from vivarium.framework.population.exceptions import PopulationError -from vivarium.framework.population.manager import ( - PopulationInterface, - PopulationManager, - SimulantData, -) +from vivarium.framework.population.interface import PopulationInterface +from vivarium.framework.population.manager import PopulationManager, SimulantData from vivarium.framework.population.population_view import PopulationView diff --git a/src/vivarium/framework/population/exceptions.py b/src/vivarium/framework/population/exceptions.py index db15fea55..f2afe3276 100644 --- a/src/vivarium/framework/population/exceptions.py +++ b/src/vivarium/framework/population/exceptions.py @@ -3,7 +3,7 @@ Population Management Exceptions ================================ -Errors related to the mishandling of the simulation state table. +Errors related to the mishandling of the population state table. """ from vivarium.exceptions import VivariumError diff --git a/src/vivarium/framework/population/interface.py b/src/vivarium/framework/population/interface.py new file mode 100644 index 000000000..909979058 --- /dev/null +++ b/src/vivarium/framework/population/interface.py @@ -0,0 +1,120 @@ +""" +==================== +Population Interface +==================== + +This module provides a :class:`PopulationInterface ` class with +methods to initialize simulants and get a population view. + +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any + +import pandas as pd + +from vivarium.framework.population.population_view import PopulationView +from vivarium.framework.resource import Resource +from vivarium.manager import Interface + +if TYPE_CHECKING: + from vivarium import Component + from vivarium.framework.population import SimulantData + from vivarium.framework.population.manager import PopulationManager + + +class PopulationInterface(Interface): + """Provides access to the system for reading and updating the population. + + The most important aspect of the simulation state is the ``population state table`` + (or simply ``state table``). It is a table with a row for every individual or + cohort (referred to as a simulant) being simulated and a column for each of + the attributes of the simulant being modeled. All access to the state table + is mediated by :class:`population views `, + which may be requested from this system during setup time. + + """ + + def __init__(self, manager: PopulationManager): + self._manager = manager + + def get_view(self, component: Component | None = None) -> PopulationView: + """Gets a time-varying view of the population state table. + + The requested population view can be used to view the current state or + to update the state with new values. + + Parameters + ---------- + component + The component requesting this view. If None, the view will provide + read-only access. + + Returns + ------- + A view of the requested columns of the population state table. + """ + return self._manager.get_view(component) + + def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: + """Gets a function that can generate new simulants. + + The creator function takes the number of simulants to be created as its + first argument and a population configuration dict that will be available + to simulant initializers as its second argument. It generates the new rows + in the population state table and then calls each initializer registered + with the population system with a data object containing the state table + index of the new simulants, the configuration info passed to the creator, + the current simulation time, and the size of the next time step. + + Returns + ------- + The simulant creator function. + """ + return self._manager.get_simulant_creator() + + def register_initializer( + self, + initializer: Callable[[SimulantData], None], + columns: str | Sequence[str] | None, + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Registers a component's initializers and any (private) columns created by them. + + This does three primary things: + 1. Registers each private column's corresponding attribute producer. + 2. Records metadata about which component created which private columns. + 3. Registers the initializer as a resource. + + A `columns` value of None indicates that no private columns are being registered. + This is useful when a component or manager needs to register an initializer + that does not create any private columns. + + Parameters + ---------- + initializer + A function that will be called to initialize the state of new simulants. + columns + The private columns that the given initializer provides the initial state + information for. + required_resources + The resources that the initializer requires to run. Strings are interpreted + as attributes. + """ + self._manager.register_initializer(initializer, columns, required_resources) + + def register_tracked_query(self, query: str) -> None: + """Adds a new query to the population manager's tracked query string. + + Parameters + ---------- + query + The new query to be added to the population manager's tracked query string. + """ + self._manager.register_tracked_query(query) + + def get_tracked_query(self) -> Callable[[], str]: + """Gets a callable that returns the combined tracked query for the population.""" + return self._manager.get_tracked_query diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 23cba8ced..8fe5168b8 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -1,134 +1,55 @@ """ -====================== -The Population Manager -====================== - -The manager and :ref:`builder ` interface for the -:ref:`population management system `. +================== +Population Manager +================== """ from __future__ import annotations -import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from types import MethodType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload import pandas as pd +import vivarium.framework.population.utilities as pop_utils +from vivarium.component import Component +from vivarium.framework.event import Event from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.population.exceptions import PopulationError from vivarium.framework.population.population_view import PopulationView from vivarium.framework.resource import Resource -from vivarium.manager import Interface, Manager +from vivarium.manager import Manager if TYPE_CHECKING: - from vivarium import Component from vivarium.framework.engine import Builder from vivarium.types import ClockStepSize, ClockTime +from collections import defaultdict + @dataclass class SimulantData: """Data to help components initialize simulants. Any time simulants are added to the simulation, each initializer is called - with this structure containing information relevant to their - initialization. + with this structure containing information relevant to their initialization. """ - #: The index representing the new simulants being added to the simulation. index: pd.Index[int] - #: A dictionary of extra data passed in by the component creating the - #: population. + """The index representing the new simulants being added to the simulation.""" user_data: dict[str, Any] - #: The time when the simulants enter the simulation. + """A dictionary of extra data passed in by the component creating the population.""" creation_time: ClockTime - #: The span of time over which the simulants are created. Useful for, - #: e.g., distributing ages over the window. + """The time when the simulants enter the simulation.""" creation_window: ClockStepSize - - -class InitializerComponentSet: - """Set of unique components with population initializers.""" - - def __init__(self) -> None: - self._components: dict[str, list[str]] = {} - self._columns_produced: dict[str, str] = {} - - def add( - self, initializer: Callable[[SimulantData], None], columns_produced: Sequence[str] - ) -> None: - """Adds an initializer and columns to the set, enforcing uniqueness. - - Parameters - ---------- - initializer - The population initializer to add to the set. - columns_produced - The columns the initializer produces. - - Raises - ------ - TypeError - If the initializer is not an object method. - AttributeError - If the object bound to the method does not have a name attribute. - PopulationError - If the component bound to the method already has an initializer - registered or if the columns produced are duplicates of columns - another initializer produces. - """ - if not isinstance(initializer, MethodType): - raise TypeError( - "Population initializers must be methods of vivarium Components " - "or the simulation's PopulationManager. " - f"You provided {initializer} which is of type {type(initializer)}." - ) - component = initializer.__self__ - # TODO: raise error once all active Component implementations have been refactored - # if not (isinstance(component, Component) or isinstance(component, PopulationManager)): - # raise AttributeError( - # "Population initializers must be methods of vivarium Components " - # "or the simulation's PopulationManager. " - # f"You provided {initializer} which is bound to {component} that " - # f"is of type {type(component)} which does not inherit from " - # "Component." - # ) - if not hasattr(component, "name"): - raise AttributeError( - "Population initializers must be methods of named simulation components. " - f"You provided {initializer} which is bound to {component} that has no " - f"name attribute." - ) - - component_name = component.name - if component_name in self._components: - raise PopulationError( - f"Component {component_name} has multiple population initializers. " - "This is not allowed." - ) - for column in columns_produced: - if column in self._columns_produced: - raise PopulationError( - f"Component {component_name} and component " - f"{self._columns_produced[column]} have both registered initializers " - f"for column {column}." - ) - self._columns_produced[column] = component_name - self._components[component_name] = list(columns_produced) - - def __repr__(self) -> str: - return repr(self._components) - - def __str__(self) -> str: - return str(self._components) + """The span of time over which the simulants are created. Useful for, e.g., distributing + ages over the window.""" class PopulationManager(Manager): - """Manages the state of the simulated population.""" + """Manages the population state table.""" # TODO: Move the configuration for initial population creation to # user components. @@ -139,38 +60,52 @@ class PopulationManager(Manager): } @property - def population(self) -> pd.DataFrame: - """The current population state table.""" - if self._population is None: - raise PopulationError("Population has not been initialized.") - return self._population + def name(self) -> str: + """The name of this component.""" + return "population_manager" - def __init__(self) -> None: - self._population: pd.DataFrame | None = None - self._initializer_components = InitializerComponentSet() - self.creating_initial_population = False - self.adding_simulants = False - self._last_id = -1 + @property + def private_columns(self) -> pd.DataFrame: + """The dataframe of all population private columns. + + Notes + ----- + Critically, the private columns dataframe not only contains all private + columns created for the simulation, but also serves as the simulant + index for the entire population. Even if no private columns are created, + this dataframe will exist and all simulants will be represented by its index. + """ + if self._private_columns is None: + raise PopulationError("Population has not been initialized.") + return self._private_columns ############################ # Normal Component Methods # ############################ - @property - def name(self) -> str: - """The name of this component.""" - return "population_manager" - - @property - def columns_created(self) -> list[str]: - return ["tracked"] + def __init__(self) -> None: + self._private_columns: pd.DataFrame | None = None + self._private_column_metadata: defaultdict[str, list[str]] = defaultdict(list) + self._registered_initializers: list[Callable[[SimulantData], None]] = [] + self.creating_initial_population = False + self.adding_simulants = False + self._last_id = -1 + self.tracked_queries: list[str] = [] def setup(self, builder: Builder) -> None: """Registers the population manager with other vivarium systems.""" + super().setup(builder) + self.logger = builder.logging.get_logger(self.name) self.clock = builder.time.clock() self.step_size = builder.time.step_size() self.resources = builder.resources self._add_constraint = builder.lifecycle.add_constraint + self._get_attribute_pipelines = builder.value.get_attribute_pipelines() + self._register_attribute_producer = builder.value.register_attribute_producer + self._get_current_component_or_manager = ( + builder.components.get_current_component_or_manager + ) + self.get_current_state = builder.lifecycle.current_state() builder.lifecycle.add_constraint( self.get_view, @@ -186,16 +121,21 @@ def setup(self, builder: Builder) -> None: self.get_simulant_creator, allow_during=[lifecycle_states.SETUP] ) builder.lifecycle.add_constraint( - self.register_simulant_initializer, allow_during=[lifecycle_states.SETUP] + self.register_initializer, allow_during=[lifecycle_states.SETUP] + ) + self._add_constraint( + self.get_population, + restrict_during=[ + lifecycle_states.SETUP, + lifecycle_states.POST_SETUP, + ], ) - self.register_simulant_initializer(self, creates_columns=self.columns_created) - self._view = self.get_view("tracked") + builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - """Adds a ``tracked`` column to the state table for new simulants.""" - status = pd.Series(True, index=pop_data.index) - self._view.update(status) + def on_post_setup(self, event: Event) -> None: + # All pipelines are registered during setup and so exist at this point. + self._attribute_pipelines = self._get_attribute_pipelines() def __repr__(self) -> str: return "PopulationManager()" @@ -204,56 +144,166 @@ def __repr__(self) -> str: # Builder API and helpers # ########################### - def get_view( + def register_tracked_query(self, query: str) -> None: + """Updates list of registered tracked queries with the provided query. + + Parameters + ---------- + query + The new query to add to the running list of tracked queries. + + Notes + ----- + While we log a warning if the same query is registered multiple times, + we make no attempt to de-duplicate functionally-equivalent queries that + are syntactically different, e.g. "x > 5" and "5 < x". In such cases, + duplicate queries will be applied which is not optimal but will not + affect correctness. + """ + if query in self.tracked_queries: + self.logger.warning( + f"The tracked query '{query}' has already been registered. " + "Duplicate registrations are ignored." + ) + return + self.tracked_queries.append(query) + + def get_private_column_names(self, component_name: str) -> list[str]: + """Gets the names of private columns created by a given component. + + Parameters + ---------- + component_name + The name of the component whose private column names are to be retrieved. + + Returns + ------- + The list of private column names created by the specified component. + If the component has not created any private columns, an empty list is returned. + """ + return self._private_column_metadata[component_name] + + @overload + def get_private_columns( self, - columns: str | Sequence[str], - query: str = "", - requires_all_columns: bool = False, - ) -> PopulationView: - """Get a time-varying view of the population state table. + component: Component | Manager, + index: pd.Index[int] | None = None, + columns: str = ..., + ) -> pd.Series[Any]: + ... - The requested population view can be used to view the current state or - to update the state with new values. + @overload + def get_private_columns( + self, + component: Component | Manager, + index: pd.Index[int] | None = None, + columns: list[str] | tuple[str, ...] = ..., + ) -> pd.DataFrame: + ... - If the column 'tracked' is not specified in the ``columns`` argument, - the query string 'tracked == True' will be added to the provided - query argument. This allows components to ignore untracked simulants - by default. If the columns argument is empty, the population view will - have access to the entire state table. + @overload + def get_private_columns( + self, + component: Component | Manager, + index: pd.Index[int] | None = None, + columns: None = None, + ) -> pd.Series[Any] | pd.DataFrame: + ... + + def get_private_columns( + self, + component: Component | Manager, + index: pd.Index[int] | None = None, + columns: str | list[str] | tuple[str, ...] | None = None, + ) -> pd.DataFrame | pd.Series[Any]: + """Gets the private columns for a given component. + + While the ``private_columns`` property provides a dataframe of all private + columns in population, this method returns only the private columns created + by the specified component. If no component is specified, then no columns + are returned. Parameters ---------- + component + The component whose private columns are to be retrieved. If None, + no columns are returned. + index + The index of simulants to include in the returned dataframe. If None, + all simulants are included. columns - A subset of the state table columns that will be available in the - returned view. If requires_all_columns is True, this should be set to - the columns created by the component containing the population view. - query - A filter on the population state. This filters out particular - simulants (rows in the state table) based on their current state. - The query should be provided in a way that is understood by the - :meth:`pandas.DataFrame.query` method and may reference state - table columns not requested in the ``columns`` argument. - requires_all_columns - If True, all columns in the population state table will be - included in the population view. + The specific column(s) to include. If None, all columns created by the + component are included. + + Raises + ------ + PopulationError + If ``columns`` are requested during initial population creation + (when no columns yet exist) or if the provided ``component`` does not + create one or more of them. + + Returns + ------- + The private column(s) created by the specified component. Will return + a Series if a single column is requested or a Dataframe otherwise. + """ + + if self.creating_initial_population: + if columns: + raise PopulationError( + "Cannot get private columns during initial population " + "creation when no columns yet exist." + ) + returned_cols = [] + squeeze = False # does not really matter (will return an empty df anyway) + else: + all_private_columns = self._private_column_metadata.get(component.name, []) + if columns is None: + returned_cols = all_private_columns + squeeze = True + else: + if isinstance(columns, str): + columns = [columns] + squeeze = True + else: + columns = list(columns) + squeeze = False + missing_cols = set(columns).difference(set(all_private_columns)) + if missing_cols: + raise PopulationError( + f"Component {component.name} is requesting the following " + f"private columns to which it does not have access: {missing_cols}." + ) + returned_cols = columns + private_columns = self.private_columns[returned_cols] + if squeeze: + private_columns = private_columns.squeeze(axis=1) + return private_columns.loc[index] if index is not None else private_columns + + def get_population_index(self) -> pd.Index[int]: + """Gets the index of the current population.""" + return self.private_columns.index + + def get_view(self, component: Component | None = None) -> PopulationView: + """Gets a time-varying view of the population state table. + + The requested population view can be used to view the current state or + to update the state with new values. + + Parameters + ---------- + component + The component requesting this view. If None, the view will provide + read-only access. Returns ------- - A filtered view of the requested columns of the population state - table. + A view of the requested private columns of the population state table. """ - if not columns and not requires_all_columns: - warnings.warn( - "The empty list [] format for requiring all columns is deprecated. Please " - "use the new argument 'requires_all_columns' instead.", - DeprecationWarning, - stacklevel=2, - ) - requires_all_columns = True - view = self._get_view(columns, query, requires_all_columns) + view = self._get_view(component) self._add_constraint( - view.get, + view.get_attributes, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, @@ -272,97 +322,21 @@ def get_view( ) return view - def _get_view( - self, columns: str | Sequence[str], query: str, requires_all_columns: bool = False - ) -> PopulationView: - if isinstance(columns, str): - columns = [columns] - - if columns and "tracked" not in columns: - if not query: - query = "tracked == True" - elif "tracked" not in query: - query += " and tracked == True" + def _get_view(self, component: Component | None) -> PopulationView: self._last_id += 1 - return PopulationView(self, self._last_id, columns, query, requires_all_columns) - - def register_simulant_initializer( - self, - component: Component | Manager, - creates_columns: str | Sequence[str] = (), - requires_columns: str | Sequence[str] = (), - requires_values: str | Sequence[str] = (), - requires_streams: str | Sequence[str] = (), - required_resources: Iterable[str | Resource] = (), - ) -> None: - """Marks a source of initial state information for new simulants. - - Parameters - ---------- - component - The component or manager that will add or update initial state - information about new simulants. - creates_columns - The state table columns that the given initializer provides the - initial state information for. - requires_columns - The state table columns that already need to be present and - populated in the state table before the provided initializer is - called. - requires_values - The value pipelines that need to be properly sourced before the - provided initializer is called. - requires_streams - The randomness streams necessary to initialize the simulant - attributes. - required_resources - The resources that the initializer requires to run. Strings are - interpreted as column names. - """ - if requires_columns or requires_values or requires_streams: - if required_resources: - raise ValueError( - "If requires_columns, requires_values, or requires_streams are provided, " - "requirements must be empty." - ) - - if isinstance(requires_columns, str): - requires_columns = [requires_columns] - if isinstance(requires_values, str): - requires_values = [requires_values] - if isinstance(requires_streams, str): - requires_streams = [requires_streams] - - required_resources = ( - list(requires_columns) - + [Resource("value", name, component) for name in requires_values] - + [Resource("stream", name, component) for name in requires_streams] - ) - - if isinstance(creates_columns, str): - creates_columns = [creates_columns] - - if "tracked" not in creates_columns: - # The population view itself uses the tracked column, so include - # to be safe. - all_dependencies = list(required_resources) + ["tracked"] - else: - all_dependencies = list(required_resources) - - self._initializer_components.add(component.on_initialize_simulants, creates_columns) - self.resources.add_resources(component, creates_columns, all_dependencies) + view = PopulationView(self, component, self._last_id) + return view def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: """Gets a function that can generate new simulants. - The creator function takes the number of simulants to be created as it's - first argument and a dict population configuration that will be available - to simulant initializers as it's second argument. It generates the new rows - in the population state table and then calls each initializer - registered with the population system with a data - object containing the state table index of the new simulants, the - configuration info passed to the creator, the current simulation - time, and the size of the next time step. + The creator function takes the number of simulants to be created as its + first argument and a population configuration dict that will be available + to simulant initializers as its second argument. It generates the new rows + in the population state table and then calls each initializer registered + with the population system with a data object containing the state table + index of the new simulants, the configuration info passed to the creator, + the current simulation time, and the size of the next time step. Returns ------- @@ -376,14 +350,14 @@ def _create_simulants( population_configuration = ( population_configuration if population_configuration else {} ) - if self._population is None: + if self._private_columns is None: self.creating_initial_population = True - self._population = pd.DataFrame() + self._private_columns = pd.DataFrame() - new_index = range(len(self._population) + count) - new_population = self._population.reindex(new_index) - index = new_population.index.difference(self._population.index) - self._population = new_population + new_index = range(len(self._private_columns) + count) + new_population = self._private_columns.reindex(new_index) + index = new_population.index.difference(self._private_columns.index) + self._private_columns = new_population self.adding_simulants = True for initializer in self.resources.get_population_initializers(): initializer( @@ -392,153 +366,390 @@ def _create_simulants( self.creating_initial_population = False self.adding_simulants = False + missing = {} + for component, cols_created in self._private_column_metadata.items(): + missing_cols = [col for col in cols_created if col not in self._private_columns] + if missing_cols: + missing[component] = missing_cols + if missing: + raise PopulationError( + "The following components registered initializers to create columns " + f"that were not actually created: {missing}." + ) + return index - ############### - # Context API # - ############### + def register_initializer( + self, + initializer: Callable[[SimulantData], None], + columns: str | Sequence[str] | None, + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Registers a component's initializers and any (private) columns created by them. - def get_population(self, untracked: bool) -> pd.DataFrame: - """Provides a copy of the full population state table. + This does three primary things: + 1. Registers each private column's corresponding attribute producer. + 2. Records metadata about which component created which private columns. + 3. Registers the initializer as a resource. + + A `columns` value of None indicates that no private columns are being registered. + This is useful when a component or manager needs to register an initializer + that does not create any private columns. Parameters ---------- - untracked - Whether to include untracked simulants in the returned population. + initializer + A function that will be called to initialize the state of new simulants. + columns + The private columns that the given initializer provides the initial state + information for. + required_resources + The resources that the initializer requires to run. Strings are interpreted + as attributes. - Returns - ------- - A copy of the population table. + Raises + ------ + PopulationError + If this initializer has already been registered or if the columns being + created by this initializer overlap with columns created by another initializer. """ - pop = self._population.copy() if self._population is not None else pd.DataFrame() - if not untracked and "tracked" in pop.columns: - pop = pop[pop.tracked] - return pop + if initializer in self._registered_initializers: + raise PopulationError( + f"The initializer '{initializer.__qualname__}' has already been registered. " + "Each initializer may only be registered once." + ) -class PopulationInterface(Interface): - """Provides access to the system for reading and updating the population. + component = self._get_current_component_or_manager() - The most important aspect of the simulation state is the ``population - table`` or ``state table``. It is a table with a row for every - individual or cohort (referred to as a simulant) being simulated and a - column for each of the attributes of the simulant being modeled. All - access to the state table is mediated by - :class:`population views `, - which may be requested from this system during setup time. + if columns is None: + columns = [] + elif isinstance(columns, str): + columns = [columns] + for column_name in columns: + # Check for duplicate registration + for component_name, columns_list in self._private_column_metadata.items(): + if column_name in columns_list: + raise PopulationError( + f"Component '{component.name}' is attempting to register " + f"private column '{column_name}' but it is already registered " + f"by component '{component_name}'." + ) + # Register each private column's attribute producer + self._register_attribute_producer( + column_name, + source=[column_name], + source_is_private_column=True, + ) - The population system itself manages a single attribute of simulants - called ``tracked``. This attribute allows global control of which - simulants are available to read and update in the state table by - default. + # Register private column metadata + self._private_column_metadata[component.name].extend(columns) - For example, in a simulation of childhood illness, we might not - need information about individuals or cohorts once they reach five years - of age, and so we can have them "age out" of the simulation at five years - old by setting the ``tracked`` attribute to ``False``. + # Track the initializer to prevent duplicate registration + self._registered_initializers.append(initializer) - """ + # Register the initializer as a resource + self.resources.add_private_columns( + initializer=initializer, + columns=columns, + required_resources=required_resources, + ) + + ############### + # Context API # + ############### - def __init__(self, manager: PopulationManager): - self._manager = manager + def get_all_attribute_names(self) -> list[str]: + """Gets the names of all attributes in the population. - def get_view( + Returns + ------- + A list of all attribute names in the population. + """ + return list(self._attribute_pipelines.keys()) + + @overload + def get_population( self, - columns: str | Sequence[str], + attributes: list[str] | tuple[str, ...] | Literal["all"], + index: pd.Index[int] | None = None, query: str = "", - requires_all_columns: bool = False, - ) -> PopulationView: - """Get a time-varying view of the population state table. + squeeze: Literal[True] = True, + skip_post_processor: Literal[False] = False, + ) -> pd.Series[Any] | pd.DataFrame: + ... - The requested population view can be used to view the current state or - to update the state with new values. + @overload + def get_population( + self, + attributes: list[str] | tuple[str, ...] | Literal["all"], + index: pd.Index[int] | None = None, + query: str = "", + squeeze: Literal[False] = ..., + skip_post_processor: Literal[False] = False, + ) -> pd.DataFrame: + ... - If the column 'tracked' is not specified in the ``columns`` argument, - the query string 'tracked == True' will be added to the provided - query argument. This allows components to ignore untracked simulants - by default. If the columns argument is empty, the population view will - have access to the entire state table. + @overload + def get_population( + self, + attributes: list[str] | tuple[str, ...] | Literal["all"], + index: pd.Index[int] | None = None, + query: str = "", + squeeze: Literal[True, False] = True, + skip_post_processor: Literal[True] = ..., + ) -> Any: + ... + + def get_population( + self, + attributes: list[str] | tuple[str, ...] | Literal["all"], + index: pd.Index[int] | None = None, + query: str = "", + squeeze: Literal[True, False] = True, + skip_post_processor: Literal[True, False] = False, + ) -> Any: + """Provides a copy of the population state table. Parameters ---------- - columns - A subset of the state table columns that will be available in the - returned view. If requires_all_columns is True, this should be set to - the columns created by the component containing the population view. + attributes + The attributes to include as the state table. If "all", all attributes are included. + index + The index of simulants to include in the returned population. If None, + all simulants are included. query - A filter on the population state. This filters out particular - simulants (rows in the state table) based on their current state. - The query should be provided in a way that is understood by the - :meth:`pandas.DataFrame.query` method and may reference state - table columns not requested in the ``columns`` argument. - requires_all_columns - If True, all columns in the population state table will be - included in the population view. + Additional conditions used to filter the index. + squeeze + Whether or not to attempt to squeeze a multi-level column into a single-level + column and/or a single-column dataframe into a series. + skip_post_processor + Whether we should invoke the post-processor on the combined + source and mutator output or return without post-processing. + This is useful when the post-processor acts as some sort of final + unit conversion (e.g. the rescale post processor). + + Notes + ----- + If ``skip_post_processor`` is True, the returned data will not be squeezed + regardless of the ``squeeze`` argument passed. Returns ------- - A filtered view of the requested columns of the population state - table. + A copy of the population state table. + + Raises + ------ + TypeError + If ``attributes`` is not a list or tuple of strings or "all". + PopulationError + - If any of the requested attributes do not exist in the state table. + - If a required column for querying is missing from the state table. + - If the population has not yet been initialized. + ValueError + If multiple attributes are requested when ``skip_post_processor`` is True. """ - return self._manager.get_view(columns, query, requires_all_columns) - def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: - """Gets a function that can generate new simulants. + if self._private_columns is None: + return pd.DataFrame() - The creator function takes the number of simulants to be created as it's - first argument and a dict population configuration that will be available - to simulant initializers as it's second argument. It generates the new rows - in the population state table and then calls each initializer - registered with the population system with a data - object containing the state table index of the new simulants, the - configuration info passed to the creator, the current simulation - time, and the size of the next time step. + if isinstance(attributes, str) and attributes != "all": + raise TypeError( + f"Attributes must be a list of strings or 'all'; got '{attributes}'." + ) + + if attributes == "all": + requested_attributes = self.get_all_attribute_names() + else: + attributes = list(attributes) + # check for duplicate request + if len(attributes) != len(set(attributes)): + # deduplicate while preserving order + requested_attributes = list(dict.fromkeys(attributes)) + self.logger.warning( + f"Duplicate attributes requested: {set(attributes) - set(requested_attributes)}\n" + "Only returning one instance of each of these duplicate requests." + ) + else: + requested_attributes = attributes + + non_existent_attributes = set(requested_attributes) - set( + self._attribute_pipelines.keys() + ) + if non_existent_attributes: + raise PopulationError( + f"Requested attribute(s) {non_existent_attributes} not in population state table. " + "This is likely due to a failure to require some columns, randomness " + "streams, or pipelines when registering a simulant initializer, an attribute " + "producer, or an attribute modifier. NOTE: It is possible for a run to " + "succeed even if resource requirements were not properly specified in " + "the simulant initializers or pipeline creation/modification calls. This " + "success depends on component initialization order which may change in " + "different run settings." + ) + + idx = index if index is not None else self._private_columns.index + + # Filter the index based on the query + columns_to_get = set(requested_attributes) + if query: + query_columns = pop_utils.extract_columns_from_query(query) + # We can remove these query columns from requested columns (and will fetch later) + columns_to_get = columns_to_get.difference(query_columns) + missing_query_columns = query_columns.difference(set(self._attribute_pipelines)) + if missing_query_columns: + raise PopulationError( + "Columns used for querying missing from population state table:\n" + f"Missing columns: {missing_query_columns}\n" + f"Query: {query}" + ) + query_df = self._get_attributes(idx, list(query_columns)) + query_df = query_df.query(query) + idx = query_df.index + + data = self._get_attributes( + idx, + requested_attributes if skip_post_processor else list(columns_to_get), + skip_post_processor, + ) + if skip_post_processor: + return data + + # Add on any query columns that are actually requested to be returned + requested_query_columns = ( + query_columns.intersection(set(requested_attributes)) if query else set() + ) + if requested_query_columns: + requested_query_df = query_df[list(requested_query_columns)] + if isinstance(data.columns, pd.MultiIndex): + # Make the query df multi-index to prevent converting columns from + # multi-index to single index w/ tuples for column names + requested_query_df.columns = pd.MultiIndex.from_product( + [requested_query_df.columns, [""]] + ) + data = pd.concat([data, requested_query_df], axis=1) + + # Maintain column ordering + data = data[requested_attributes] + + if squeeze: + if ( + isinstance(data.columns, pd.MultiIndex) + and len(set(data.columns.get_level_values(0))) == 1 + ): + # If multi-index columns with a single outer level, drop the outer level + data = data.droplevel(0, axis=1) + if len(data.columns) == 1: + # If single column df, squeeze to series + data = data.squeeze(axis=1) + + return data + + def get_tracked_query(self) -> str: + """Gets the combined tracked query for the population. Returns ------- - The simulant creator function. + A query string combining all registered tracked queries with "and" operators. """ - return self._manager.get_simulant_creator() + return " and ".join(self.tracked_queries) - def initializes_simulants( + @overload + def _get_attributes( self, - component: Component | Manager, - creates_columns: str | Sequence[str] = (), - requires_columns: str | Sequence[str] = (), - requires_values: str | Sequence[str] = (), - requires_streams: str | Sequence[str] = (), - required_resources: Sequence[str | Resource] = (), - ) -> None: - """Marks a source of initial state information for new simulants. + idx: pd.Index[int], + requested_attributes: Sequence[str], + skip_post_processor: Literal[False] = ..., + ) -> pd.DataFrame: + ... + + @overload + def _get_attributes( + self, + idx: pd.Index[int], + requested_attributes: Sequence[str], + skip_post_processor: Literal[True] = ..., + ) -> Any: + ... - Parameters - ---------- - component - The component or manager that will add or update initial state - information about new simulants. - creates_columns - The state table columns that the given initializer - provides the initial state information for. - requires_columns - The state table columns that already need to be present - and populated in the state table before the provided initializer - is called. - requires_values - The value pipelines that need to be properly sourced - before the provided initializer is called. - requires_streams - The randomness streams necessary to initialize the - simulant attributes. - required_resources - The resources that the initializer requires to run. Strings are - interpreted as column names, and Pipelines and RandomnessStreams - are interpreted as value pipelines and randomness streams, - """ - self._manager.register_simulant_initializer( - component, - creates_columns, - requires_columns, - requires_values, - requires_streams, - required_resources, + def _get_attributes( + self, + idx: pd.Index[int], + requested_attributes: Sequence[str], + skip_post_processor: Literal[True, False] = False, + ) -> Any: + """Gets the population for a given index and requested attributes.""" + + if skip_post_processor: + if len(requested_attributes) != 1: + raise ValueError( + "When skip_post_processor is True, a single attribute must " + f"be requested. You requested {requested_attributes}." + ) + return self._attribute_pipelines[requested_attributes[0]]( + idx, skip_post_processor=skip_post_processor + ) + + attributes_list: list[pd.Series[Any] | pd.DataFrame] = [] + + # batch simple attributes and directly leverage private column backing dataframe + simple_attributes = [ + name + for name, pipeline in self._attribute_pipelines.items() + if name in requested_attributes and pipeline.is_simple + ] + if simple_attributes: + if self._private_columns is None: + raise PopulationError("Population has not been initialized.") + attributes_list.append(self._private_columns.loc[idx, simple_attributes]) + + # handle remaining non-simple attributes one by one + remaining_attributes = [ + attribute + for attribute in requested_attributes + if attribute not in simple_attributes + ] + contains_column_multi_index = False + for name in remaining_attributes: + values = self._attribute_pipelines[name](idx) + + # Handle column names + if isinstance(values, pd.Series): + if values.name is not None and values.name != name: + self.logger.warning( + f"The '{name}' attribute pipeline returned a pd.Series with a " + f"different name '{values.name}'. For the column being added to the " + f"population state table, we will use '{name}'." + ) + values.name = name + else: + # Must be a dataframe. Coerce the columns to multi-index and set the + # attribute name as the outer level. + if isinstance(values.columns, pd.MultiIndex): + # FIXME [MIC-6645] + raise NotImplementedError( + f"The '{name}' attribute pipeline returned a DataFrame with multi-level " + f"columns (nlevels={values.columns.nlevels}). Multi-level columns in " + "attribute pipeline outputs are not supported." + ) + values.columns = pd.MultiIndex.from_product([[name], values.columns]) + contains_column_multi_index = True + attributes_list.append(values) + + # Make sure all items of the list have consistent column levels + if contains_column_multi_index: + for i, item in enumerate(attributes_list): + if isinstance(item, pd.Series): + item_df = item.to_frame() + item_df.columns = pd.MultiIndex.from_tuples([(item.name, "")]) + attributes_list[i] = item_df + if isinstance(item, pd.DataFrame) and item.columns.nlevels == 1: + item.columns = pd.MultiIndex.from_product([item.columns, [""]]) + df = ( + pd.concat(attributes_list, axis=1) if attributes_list else pd.DataFrame(index=idx) ) + + return df + + def update(self, update: pd.DataFrame) -> None: + self.private_columns[update.columns] = update diff --git a/src/vivarium/framework/population/population_view.py b/src/vivarium/framework/population/population_view.py index 92f0d62c8..41a046ba0 100644 --- a/src/vivarium/framework/population/population_view.py +++ b/src/vivarium/framework/population/population_view.py @@ -3,49 +3,46 @@ The Population View =================== -The :class:`PopulationView` is a user-facing abstraction that manages read and write access -to the underlying simulation :term:`State Table`. It has two primary responsibilities: +The :class:`PopulationView` is a user-facing abstraction that manages read and write +access to the underlying :term:`population state table `. +It has two primary responsibilities: - 1. To provide user access to subsets of the simulation state table - when it is safe to do so. - 2. To allow the user to update the simulation state in a controlled way. + 1. To provide user access to subsets of the state table when it is safe to do so. + 2. To allow the user to update private data in a controlled way. """ from __future__ import annotations -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload import pandas as pd +import vivarium.framework.population.utilities as pop_utils +from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.population.exceptions import PopulationError if TYPE_CHECKING: + from vivarium.component import Component from vivarium.framework.population.manager import PopulationManager class PopulationView: - """A read/write manager for the simulation state table. + """A read/write manager for the population state table. + + It can be used to both read and update the state of the population. While a + PopulationView can read any column, it can only write those columns that the + component it is attached to created (i.e. that component's private columns). - It can be used to both read and update the state of the population. A - PopulationView can only read and write columns for which it is configured. Attempts to update non-existent columns are ignored except during simulant creation when new columns are allowed to be created. - Notes - ----- - By default, this view will filter out ``untracked`` simulants unless - the ``tracked`` column is specified in the initialization arguments. - """ def __init__( self, manager: PopulationManager, + component: Component | None, view_id: int, - columns: Sequence[str] = (), - query: str = "", - requires_all_columns: bool = False, ): """ @@ -53,204 +50,391 @@ def __init__( ---------- manager The population manager for the simulation. + component + The component requesting this view. If None, the view will provide + read-only access. view_id The unique identifier for this view. - columns - The set of columns this view should have access too. If - requies_all_columns is True, this should be set to - the columns created by the component containing the population view. - query - A :mod:`pandas`-style filter that will be applied any time this - view is read from. - requires_all_columns - If True, all columns from the population state table will be - included in the population view. """ self._manager = manager + self._component = component self._id = view_id - self._columns = list(columns) - self.query = query - self.requires_all_columns = requires_all_columns + + ############## + # Properties # + ############## @property def name(self) -> str: return f"population_view_{self._id}" @property - def columns(self) -> list[str]: - """The columns that the view can read and update. + def private_columns(self) -> list[str]: + """The names of private columns managed by this PopulationView. - Set of columns the view will have access to. This is a subset of the - population state table columns. If self.requires_all_columns is True, - then this will be the columns created by the component containing the - population view. + These private columns are those that were created by the component + that created this view. """ - if self.requires_all_columns: - all_columns = list(self._manager.get_population(True).columns) + self._columns - return list(set(all_columns)) - return self._columns + if self._component is None: + raise PopulationError( + "This PopulationView is read-only, so it doesn't have access to private_columns." + ) + return self._manager.get_private_column_names(self._component.name) + + ########### + # Methods # + ########### + + @overload + def get_attributes( + self, + index: pd.Index[int], + attributes: str, + query: str = "", + include_untracked: bool = False, + skip_post_processor: bool = False, + ) -> pd.Series[Any]: + ... + + @overload + def get_attributes( + self, + index: pd.Index[int], + attributes: list[str] | tuple[str, ...], + query: str = "", + include_untracked: bool = False, + skip_post_processor: bool = False, + ) -> pd.DataFrame: + ... + + @overload + def get_attributes( + self, + index: pd.Index[int], + attributes: str | list[str] | tuple[str, ...], + query: str = "", + include_untracked: bool = False, + skip_post_processor: bool = True, + ) -> Any: + ... + + def get_attributes( + self, + index: pd.Index[int], + attributes: str | list[str] | tuple[str, ...], + query: str = "", + include_untracked: bool = False, + skip_post_processor: Literal[True, False] = False, + ) -> Any: + """Gets a specific subset of the population state table. - def subview(self, columns: str | Sequence[str]) -> PopulationView: - """Retrieves a new view with a subset of this view's columns. + For the rows in ``index``, return the ``attributes`` (i.e. columns) from the + state table. The resulting rows may be further filtered by the call's ``query`` + and whether or not to include untracked simulants. Parameters ---------- - columns - The set of columns to provide access to in the subview. Must be - a proper subset of this view's columns. + index + Index of the population to get. This may be further filtered by various + query conditions. + attributes + The attributes to retrieve. If a single attribute is passed in via a + string, the result will be squeezed to a Series if possible. + query + Additional conditions used to filter the index. + include_untracked + Whether to include untracked simulants. + skip_post_processor + Whether we should invoke the post-processor on the combined + source and mutator output or return without post-processing. + This is useful when the post-processor acts as some sort of final + unit conversion (e.g. the rescale post processor). + + Notes + ----- + If ``skip_post_processor`` is True, the returned data will not be squeezed. Returns ------- - A new view with access to the requested columns. + The attribute(s) requested subset to the ``index`` and filtered using + the various optional queries. If ``skip_post_processor`` is False, will + return a Series if a single attribute is requested or a Dataframe otherwise. Raises ------ - PopulationError - If the requested columns are not a proper subset of this view's - columns or no columns are requested. + ValueError + If the result is expected to be a Series but is not. + """ + + squeeze: Literal[True, False] = isinstance(attributes, str) + attributes = [attributes] if isinstance(attributes, str) else list(attributes) + + population = self._manager.get_population( + attributes=attributes, + index=index, + query=self._build_query(query, include_untracked), + squeeze=squeeze, + skip_post_processor=skip_post_processor, + ) + if not skip_post_processor and squeeze and not isinstance(population, pd.Series): + raise ValueError( + "Expected a pandas Series to be returned when requesting a single " + "attribute, but got a DataFrame instead. If you expect this attribute " + "to be a DataFrame, you should call `get_attribute_frame()` instead." + ) + return population + + def get_attribute_frame( + self, + index: pd.Index[int], + attribute: str, + query: str = "", + include_untracked: bool = False, + ) -> pd.DataFrame: + """Gets a single attribute as a DataFrame. + + For the rows in ``index``, return the ``attributes`` (i.e. columns) from the + state table. The resulting rows may be further filtered by the call's ``query`` + and whether or not to include untracked simulants. + + Parameters + ---------- + index + Index of the population to get. + attribute + The attribute to retrieve. This attribute may contain one or more columns. + query + Additional conditions used to filter the index. + include_untracked + Whether to include untracked simulants. Notes ----- - Subviews are useful during population initialization. The original - view may contain both columns that a component needs to create and - update as well as columns that the component needs to read. By - requesting a subview, a component can read the sections it needs - without running the risk of trying to access uncreated columns - because the component itself has not created them. - """ - if isinstance(columns, str): - columns = [columns] + The difference between this method and ``get_attributes`` is subtle. This + method always returns a dataframe even if the requested attribute contains + a single column. Further, in the event the attribute has multi-level columns, + it will be squeezed to only return the inner columns. - if not columns or set(columns) - set(self.columns): - raise PopulationError( - f"Invalid subview requested. Requested columns must be a non-empty " - f"subset of this view's columns. Requested columns: {columns}, " - f"Available columns: {self.columns}" + Calling ``get_attributes`` to request a list of a single attribute seems + identical to this, but in that case the underlying data would not be squeezed + at all, i.e. a dataframe with multi-level columns would also return the + outer columns. + + Returns + ------- + The attribute requested subset to the ``index`` and filtered using + the various optional queries. Will always return a DataFrame. + + """ + return pd.DataFrame( + self._manager.get_population( + index=index, + attributes=[attribute], + query=self._build_query(query, include_untracked), ) - # Skip constraints for requesting subviews. - return self._manager._get_view(columns, self.query) + ) + + @overload + def get_private_columns( + self, + index: pd.Index[int], + private_columns: str = ..., + query: str = "", + include_untracked: bool = False, + ) -> pd.Series[Any]: + ... + + @overload + def get_private_columns( + self, + index: pd.Index[int], + private_columns: list[str] | tuple[str, ...] = ..., + query: str = "", + include_untracked: bool = False, + ) -> pd.DataFrame: + ... + + @overload + def get_private_columns( + self, + index: pd.Index[int], + private_columns: None = None, + query: str = "", + include_untracked: bool = False, + ) -> pd.Series[Any] | pd.DataFrame: + ... - def get(self, index: pd.Index[int], query: str = "") -> pd.DataFrame: - """Select the rows represented by the given index from this view. + def get_private_columns( + self, + index: pd.Index[int], + private_columns: str | list[str] | tuple[str, ...] | None = None, + query: str = "", + include_untracked: bool = False, + ) -> pd.Series[Any] | pd.DataFrame: + """Gets a specific subset of this ``PopulationView's`` private columns. - For the rows in ``index`` get the columns from the simulation's - state table to which this view has access. The resulting rows may be - further filtered by the view's query and only return a subset - of the population represented by the index. + For the rows in ``index``, return the requested ``private_columns``. The + resulting rows may be further filtered by the call's ``query`` and whether + or not to include untracked simulants. Parameters ---------- index Index of the population to get. + private_columns + The private columns to retrieve. If None, all columns created by the + component that created this view are included. query - Additional conditions used to filter the index. These conditions - will be unioned with the default query of this view. The query - provided may use columns that this view does not have access to. + Additional conditions used to filter the index. + include_untracked + Whether to include untracked simulants. Returns ------- - A table with the subset of the population requested. + The private column(s) requested subset to the ``index`` and filtered + using the various optional queries. Will return a Series if a single + column is requested or a Dataframe otherwise. Raises ------ PopulationError - If this view has access to columns that have not yet been created - and this method is called. If you see this error, you should - request a subview with the columns you need read access to. - - See Also - -------- - :meth:`subview ` + If there is no component attached to this view (indicating that this + view is to be read-only and thus cannot access private columns). """ - pop = self._manager.get_population(True).loc[index] - - if not index.empty: - if self.query: - pop = pop.query(self.query) - if query: - pop = pop.query(query) - non_existent_columns = set(self.columns) - set(pop.columns) - if non_existent_columns: + if self._component is None: raise PopulationError( - f"Requested column(s) {non_existent_columns} not in population table. " - "This is likely due to a failure to require some columns, randomness " - "streams, or pipelines when registering a simulant initializer, a value " - "producer, or a value modifier. NOTE: It is possible for a run to " - "succeed even if resource requirements were not properly specified in " - "the simulant initializers or pipeline creation/modification calls. This " - "success depends on component initialization order which may change in " - "different run settings." + "This PopulationView is read-only, so it doesn't have access to get_private_columns()." ) - return pop.loc[:, self.columns] - def update(self, population_update: pd.Series[Any] | pd.DataFrame) -> None: - """Updates the state table with the provided data. + index = self.get_filtered_index( + index, + query=self._build_query(query, include_untracked), + include_untracked=True, + ) + + return self._manager.get_private_columns(self._component, index, private_columns) + + def get_filtered_index( + self, + index: pd.Index[int], + query: str = "", + include_untracked: bool = False, + ) -> pd.Index[int]: + """Gets a specific index of the population. + + The requested index may be further filtered by the call's ``query`` and + whether or not to include untracked simulants. + + Parameters + ---------- + index + Index of the population to get. + query + Additional conditions used to filter the index. + include_untracked + Whether to include untracked simulants. + + Returns + ------- + The requested and filtered population index. + """ + + return self.get_attributes( + index, + attributes=[], + query=query, + include_untracked=include_untracked, + ).index + + def update(self, update: pd.Series[Any] | pd.DataFrame) -> None: + """Updates private columns with the provided data. Parameters ---------- - population_update - The data which should be copied into the simulation's state. If - the update is a :class:`pandas.DataFrame`, it can contain a subset - of the view's columns but no extra columns. If ``pop`` is a - :class:`pandas.Series` it must have a name that matches one of - this view's columns unless the view only has one column in which - case the Series will be assumed to refer to that regardless of its - name. + update + The data which should be copied into the simulation's private columns. + If the ``update`` is a :class:`pandas.DataFrame`, it can contain a subset + of the view's columns but no extra columns. If it's a :class:`pandas.Series` + it must have a name that matches one of this view's columns unless the + view only has one column in which case the Series will be assumed to + refer to that regardless of its name. Raises ------ PopulationError - If the provided data name or columns do not match columns that - this view manages or if the view is being updated with a data - type inconsistent with the original population data. + - If there is no component attached to this view (indicating that this + view is to be read-only and thus cannot be updated). + - If the update has simulants that are not in the existing private data. + - If the update is missing simulants during initial population creation. + - If this view manages multiple private columns but the update is an + unnamed :class:`pandas.Series`. + - If the update contains columns not managed by this view. + - If the update is empty. + - If the update includes different dtypes than the existing data (unless + new simulants are being added). + TypeError + If the update is not a :class:`pandas.Series` or a :class:`pandas.DataFrame`. """ - state_table = self._manager.get_population(True) - population_update = self._format_update_and_check_preconditions( - population_update, - state_table, - self.columns, + + if self._component is None: + raise PopulationError( + "This PopulationView is read-only, so it doesn't have access to update()." + ) + + existing = pd.DataFrame(self._manager.get_private_columns(self._component)) + update_df: pd.DataFrame = self._format_update_and_check_preconditions( + self._component.name, + update, + existing, + self.private_columns, self._manager.creating_initial_population, self._manager.adding_simulants, ) if self._manager.creating_initial_population: - new_columns = list(set(population_update).difference(state_table)) - self._manager.population[new_columns] = population_update[new_columns] - elif not population_update.empty: - update_columns = list(set(population_update).intersection(state_table)) + new_columns = list(set(update_df.columns).difference(existing.columns)) + self._manager.update(update_df[new_columns]) + elif not update_df.empty: + update_columns = list(set(update_df.columns).intersection(existing.columns)) + updated_cols_list = [] for column in update_columns: column_update = self._update_column_and_ensure_dtype( - population_update[column], - state_table[column], + update_df[column], + existing[column], self._manager.adding_simulants, ) - self._manager.population[column] = column_update + updated_cols_list.append(column_update) + self._manager.update(pd.concat(updated_cols_list, axis=1)) def __repr__(self) -> str: - return f"PopulationView(_id={self._id}, _columns={self.columns}, query={self.query})" + name = self._component.name if self._component else "None" + private_columns = self.private_columns if self._component else "N/A" + return f"PopulationView(_id={self._id}, _component={name}, private_columns={private_columns})" ################## # Helper methods # ################## + # FIXME: make this not a static method @staticmethod def _format_update_and_check_preconditions( - population_update: pd.Series[Any] | pd.DataFrame, - state_table: pd.DataFrame, - view_columns: list[str], + component_name: str, + update: pd.Series[Any] | pd.DataFrame, + existing: pd.DataFrame, + private_columns: list[str], creating_initial_population: bool, adding_simulants: bool, ) -> pd.DataFrame: """Standardizes the population update format and checks preconditions. - Managing how values get written to the underlying population state table is critical - to rule out several categories of error in client simulation code. The state table + Managing how values get written to the underlying population private data is critical + to rule out several categories of error in client simulation code. The private data is modified at three different times. In the first, the initial population table - is being created and new columns are being added to the state table with their + is being created and new columns are being added to the private data with their initial values. In the second, the population manager has added new rows with - appropriate null values to the state table in response to population creation + appropriate null values to the private data in response to population creation dictated by client code, and population updates are being provided to fill in - initial values for those new rows. In the final case, state table values for + initial values for those new rows. In the final case, private data values for existing simulants are being overridden as part of a time step. All of these modification scenarios require that certain preconditions are met. @@ -258,27 +442,26 @@ def _format_update_and_check_preconditions( 1. The update is a DataFrame or a Series. 2. If it is a series, it is nameless and this view manages a single column - or it is named and it's name matches a column in this PopulationView. + or it is named and its name matches a column in this PopulationView. 3. The update matches at least one column in this PopulationView. 4. The update columns are a subset of the columns managed by this PopulationView. - 5. The update index is a subset of the existing state table index. + 5. The update index is a subset of the existing private data index. PopulationViews don't make rows, they just fill them in. - For initial population creation additional preconditions are documented in - :meth:`PopulationView._ensure_coherent_initialization`. Outside population - initialization, we require that all columns in the update to be present in - the existing state table. When new simulants are added in the middle of the - simulation, we require that only one component provide updates to a column. + Note that except during population initialization, we require that all columns + in the update to be present in the existing private data. Parameters ---------- - population_update - The update to the simulation state table. - state_table - The existing simulation state table. - view_columns - The columns managed by this PopulationView. + component_name + The name of the component requesting the update. + update + The update to the private data owned by the component that created this view. + existing + The existing private data owned by the component that created this view. + private_columns + The private columns managed by this PopulationView. creating_initial_population Whether the initial population is being created. adding_simulants @@ -288,166 +471,80 @@ def _format_update_and_check_preconditions( ------- The input data formatted as a DataFrame. - Raises - ------ - TypeError - If the population update is not a :class:`pandas.Series` or a - :class:`pandas.DataFrame`. - PopulationError - If the update violates any preconditions relevant to the context in which - the update is provided (initial population creation, population creation on - time steps, or population state changes on time steps). - """ assert not creating_initial_population or adding_simulants - population_update = PopulationView._coerce_to_dataframe( - population_update, - view_columns, - ) + update = PopulationView._coerce_to_dataframe(update, private_columns) - unknown_simulants = len(population_update.index.difference(state_table.index)) + unknown_simulants = len(update.index.difference(existing.index)) if unknown_simulants: raise PopulationError( "Population updates must have an index that is a subset of the current " - f"population state table. {unknown_simulants} simulants were provided " - f"in an update with no matching index in the existing table." + f"private data. {unknown_simulants} simulants were provided " + "in an update with no matching index in the existing table." ) if creating_initial_population: - PopulationView._ensure_coherent_initialization(population_update, state_table) - else: - new_columns = list(set(population_update).difference(state_table)) - if new_columns: + missing_pops = len(existing.index.difference(update.index)) + if missing_pops: raise PopulationError( - f"Attempting to add new columns {new_columns} to the state table " - f"outside the initial population creation phase." + "Components must initialize all simulants during population initialization. " + f"Component '{component_name}' is missing updates for {missing_pops} simulants." ) - if adding_simulants: - state_table_new_simulants = state_table.loc[population_update.index, :] - conflicting_columns = [ - column - for column in population_update - if state_table_new_simulants[column].notnull().any() - and not population_update[column].equals( - state_table_new_simulants[column] - ) - ] - if conflicting_columns: - raise PopulationError( - "Two components are providing conflicting initialization data " - f"for the state table columns: {conflicting_columns}." - ) - - return population_update + return update @staticmethod def _coerce_to_dataframe( - population_update: pd.Series[Any] | pd.DataFrame, - view_columns: list[str], + update: pd.Series[Any] | pd.DataFrame, + private_columns: list[str], ) -> pd.DataFrame: - """Coerce all population updates to a :class:`pandas.DataFrame` format. + """Coerces all population updates to a :class:`pandas.DataFrame` format. Parameters ---------- - population_update - The update to the simulation state table. + update + The update to the private data owned by the component that created this view. + private_columns + The private column names owned by the component that created this view. Returns ------- The input data formatted as a DataFrame. - - Raises - ------ - TypeError - If the population update is not a :class:`pandas.Series` or a - :class:`pandas.DataFrame`. - PopulationError - If the input data is a :class:`pandas.Series` and this :class:`PopulationView` - manages multiple columns or if the population update contains columns not - managed by this view. """ - if not isinstance(population_update, (pd.Series, pd.DataFrame)): + if not isinstance(update, (pd.Series, pd.DataFrame)): raise TypeError( "The population update must be a pandas Series or DataFrame. " - f"A {type(population_update)} was provided." + f"A {type(update)} was provided." ) - if isinstance(population_update, pd.Series): - if population_update.name is None: - if len(view_columns) == 1: - population_update.name = view_columns[0] + if isinstance(update, pd.Series): + if update.name is None: + if len(private_columns) == 1: + update.name = private_columns[0] else: raise PopulationError( "Cannot update with an unnamed pandas series unless there " "is only a single column in the view." ) - population_update = pd.DataFrame(population_update) + update = pd.DataFrame(update) - if not set(population_update.columns).issubset(view_columns): + if not set(update.columns).issubset(private_columns): raise PopulationError( f"Cannot update with a DataFrame or Series that contains columns " f"the view does not. Dataframe contains the following extra columns: " - f"{set(population_update.columns).difference(view_columns)}." + f"{set(update.columns).difference(private_columns)}." ) - update_columns = list(population_update) + update_columns = list(update) if not update_columns: raise PopulationError( - "The update method of population view is being called " - "on a DataFrame with no columns." + "The update method of population view is being called on a DataFrame " + "with no columns." ) - return population_update - - @staticmethod - def _ensure_coherent_initialization( - population_update: pd.DataFrame, state_table: pd.DataFrame - ) -> None: - """Ensure that overlapping population updates have the same information. - - During population initialization, each state table column should be updated by - exactly one component and each component with an initializer should create at - least one column. Sometimes components are a little sloppy and provide - duplicate column information, which we should continue to allow. We want to ensure - that a column is only getting one set of unique values though. - - Parameters - ---------- - population_update - The update to the simulation state table. - state_table - The existing simulation state table. When this method is called, the table - should be in a partially complete state. That is the provided population - update should carry some new attributes we need to assign. - - Raises - ----- - PopulationError - If the population update contains no new information or if it contains - information in conflict with the existing state table. - """ - missing_pops = len(state_table.index.difference(population_update.index)) - if missing_pops: - raise PopulationError( - f"Components should initialize the same population at the simulation start. " - f"A component is missing updates for {missing_pops} simulants." - ) - new_columns = set(population_update).difference(state_table) - overlapping_columns = set(population_update).intersection(state_table) - if not new_columns: - raise PopulationError( - f"A component is providing a population update for {list(population_update)} " - "but all provided columns are initialized by other components." - ) - for column in overlapping_columns: - if not population_update[column].equals(state_table[column]): - raise PopulationError( - "Two components are providing conflicting initialization data for the " - f"{column} state table column." - ) + return update @staticmethod def _update_column_and_ensure_dtype( @@ -455,14 +552,19 @@ def _update_column_and_ensure_dtype( existing: pd.Series[Any], adding_simulants: bool, ) -> pd.Series[Any]: - """Build the updated state table column with an appropriate dtype. + """Builds the updated private column with an appropriate dtype. + + This method updates any existing private column values with their corresponding + new values from the update; existing values not in the update are preserved. + It also ensures that the resulting column has a dtype consistent with the + original column (unless new simulants are being added). Parameters ---------- update The new column values for a subset of the existing index. existing - The existing column values for all simulants in the state table. + The existing column values for all simulants. adding_simulants Whether new simulants are currently being initialized. @@ -475,7 +577,7 @@ def _update_column_and_ensure_dtype( # I've also seen this error, though I don't have a reproducible and useful example. # I'm reasonably sure what's really being accounted for here is non-nullable columns # that temporarily have null values introduced in the space between rows being - # added to the state table and initializers filling them with their first values. + # added to the private data and initializers filling them with their first values. # That means the space of dtype casting issues is actually quite small. What should # actually happen in the long term is to separate the population creation entirely # from the mutation of existing state. I.e. there's not an actual reason we need @@ -483,13 +585,13 @@ def _update_column_and_ensure_dtype( # the creation of new simulants besides the fact that it's the existing # implementation. update_values = update.array.copy() - new_state_table_values = existing.array.copy() + new_values = existing.array.copy() update_index_positional = existing.index.get_indexer(update.index) # type: ignore [no-untyped-call] # Assumes the update index labels can be interpreted as an array position. - new_state_table_values[update_index_positional] = update_values + new_values[update_index_positional] = update_values - unmatched_dtypes = new_state_table_values.dtype != update_values.dtype + unmatched_dtypes = new_values.dtype != update_values.dtype if unmatched_dtypes and not adding_simulants: # This happens when the population is being grown because extending # the index forces columns that don't have a natural null type @@ -498,8 +600,28 @@ def _update_column_and_ensure_dtype( "A component is corrupting the population table by modifying the dtype of " f"the {update.name} column from {existing.dtype} to {update.dtype}." ) - new_state_table_values = new_state_table_values.astype(update_values.dtype) - new_state_table: pd.Series[Any] = pd.Series( - new_state_table_values, index=existing.index, name=existing.name + new_values = new_values.astype(update_values.dtype) + new_data: pd.Series[Any] = pd.Series( + new_values, index=existing.index, name=existing.name + ) + return new_data + + def _build_query(self, query: str, include_untracked: bool) -> str: + """Builds the full query for this PopulationView. + + This combines the provided query with the population manager's tracked query + as appropriate. + + Notes + ----- + We explicitly set 'include_untracked' to True during initialization and + population creation lifecycle phases. + """ + include_untracked = include_untracked or self._manager.get_current_state() in [ + lifecycle_states.INITIALIZATION, + lifecycle_states.POPULATION_CREATION, + ] + return pop_utils.combine_queries( + query, + self._manager.get_tracked_query() if not include_untracked else "", ) - return new_state_table diff --git a/src/vivarium/framework/population/utilities.py b/src/vivarium/framework/population/utilities.py new file mode 100644 index 000000000..2e6778d44 --- /dev/null +++ b/src/vivarium/framework/population/utilities.py @@ -0,0 +1,45 @@ +""" +============================ +Population Utility Functions +============================ + +""" +import re + + +def extract_columns_from_query(query: str) -> set[str]: + """Extracts the column names required by a query string.""" + + # Extract columns with backticks + columns = re.findall(r"`([^`]*)`", query) + + # Begin dropping known non-columns from query + # Remove backticked content + query = re.sub(r"`[^`]*`", "", query) + # Remove keywords including "in" and "not in" + query = re.sub(r"\b(and|if|or|True|False|in|not\s+in)\b", "", query, flags=re.IGNORECASE) + # Remove quoted strings + query = re.sub(r"'[^']*'|\"[^\"]*\"", "", query) + # Remove standalone numbers (not part of identifiers) + query = re.sub(r"\b\d+\b", "", query) + # Remove @ references + query = re.sub(r"@\S+", "", query) + # Remove list/array syntax + query = re.sub(r"\[[^\]]*\]", "", query) + # Remove operators and punctuation but preserve column names + query = re.sub(r"[!=<>]+|[()&|~\-+*/,.]", " ", query) + + # Combine query words and columns + query = re.sub(r"\s+", " ", query).strip() + query_words = [word for word in query.split(" ") if word] + return set(query_words + columns) + + +def combine_queries(*queries: str) -> str: + """Combines any number of queries with an 'and' operator. + + Notes + ----- + Empty queries (i.e., '') are ignored. + """ + return " and ".join([f"({query})" for query in filter(None, queries)]) diff --git a/src/vivarium/framework/randomness/interface.py b/src/vivarium/framework/randomness/interface.py index 09aaf1592..8619e8efc 100644 --- a/src/vivarium/framework/randomness/interface.py +++ b/src/vivarium/framework/randomness/interface.py @@ -1,7 +1,7 @@ """ -=================== -Component Interface -=================== +==================== +Randomness Interface +==================== This module provides an interface to the :class:`RandomnessManager `. @@ -9,17 +9,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pandas as pd from vivarium.framework.randomness.manager import RandomnessManager from vivarium.framework.randomness.stream import RandomnessStream from vivarium.manager import Interface -if TYPE_CHECKING: - from vivarium import Component - class RandomnessInterface(Interface): def __init__(self, manager: RandomnessManager): @@ -28,13 +23,11 @@ def __init__(self, manager: RandomnessManager): def get_stream( self, decision_point: str, - # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, initializes_crn_attributes: bool = False, ) -> RandomnessStream: """Provides a new source of random numbers for the given decision point. - ``vivarium`` provides a framework for Common Random Numbers which + Vivarium provides a Common Random Number framework which allows for variance reduction when modeling counter-factual scenarios. Users interested in causal analysis and comparisons between simulation scenarios should be careful to use randomness streams provided by the @@ -46,8 +39,6 @@ def get_stream( A unique identifier for a stream of random numbers. Typically, this represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. - component - The component that is requesting the randomness stream. initializes_crn_attributes A flag indicating whether this stream is used to generate key initialization information that will be used to identify simulants @@ -57,33 +48,31 @@ def get_stream( Returns ------- - An entry point into the Common Random Number generation framework. + An entry point into the Common Random Number framework. The stream provides vectorized access to random numbers and a few other utilities. """ - return self._manager.get_randomness_stream( - decision_point, component, initializes_crn_attributes - ) + return self._manager.get_randomness_stream(decision_point, initializes_crn_attributes) def get_seed(self, decision_point: str) -> int: - """Get a randomly generated seed for use with external randomness tools. + """Gets a randomly generated seed for use with external randomness tools. Parameters ---------- - decision_point : - A unique identifier for a stream of random numbers. Typically + decision_point + A unique identifier for a stream of random numbers. Typically represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. Returns ------- A seed for a random number generation that is linked to Vivarium's - common random number framework. + Common Random Number framework. """ return self._manager.get_seed(decision_point) def register_simulants(self, simulants: pd.DataFrame) -> None: - """Registers simulants with the Common Random Number Framework. + """Registers simulants with the Common Random Number framework. Parameters ---------- diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index 5dbe28f74..60d8c16a3 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -1,7 +1,7 @@ """ -========================= -Randomness System Manager -========================= +================== +Randomness Manager +================== """ @@ -20,7 +20,7 @@ from vivarium.types import ClockTime if TYPE_CHECKING: - from vivarium import Component + from vivarium.component import Component from vivarium.framework.engine import Builder @@ -76,9 +76,12 @@ def setup(self, builder: Builder) -> None: pop_size = builder.configuration.population.population_size map_size = max(map_size, 10 * pop_size) self._key_mapping_ = IndexMap(self._key_columns, map_size) + + self._get_current_component = builder.components.get_current_component self._rate_conversion_type = builder.configuration.randomness.rate_conversion_type - self.resources = builder.resources self._add_constraint = builder.lifecycle.add_constraint + self._add_resources = builder.resources.add_resources + self._add_constraint(self.get_seed, restrict_during=[lifecycle_states.INITIALIZATION]) self._add_constraint( self.get_randomness_stream, allow_during=[lifecycle_states.SETUP] @@ -97,7 +100,6 @@ def setup(self, builder: Builder) -> None: def get_randomness_stream( self, decision_point: str, - component: Component | None, initializes_crn_attributes: bool = False, rate_conversion_type: Literal["linear", "exponential"] = "linear", ) -> RandomnessStream: @@ -109,8 +111,6 @@ def get_randomness_stream( A unique identifier for a stream of random numbers. Typically represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. - component - The component that is requesting the randomness stream. initializes_crn_attributes A flag indicating whether this stream is used to generate key initialization information that will be used to identify simulants @@ -124,7 +124,7 @@ def get_randomness_stream( Returns ------- - An entry point into the Common Random Number generation framework. + An entry point into the Common Random Number framework. The stream provides vectorized access to random numbers and a few other utilities. @@ -134,12 +134,21 @@ def get_randomness_stream( If another location in the simulation has already created a randomness stream with the same identifier. """ + component = self._get_current_component() stream = self._get_randomness_stream( - decision_point, component, initializes_crn_attributes, rate_conversion_type + decision_point, + component, + initializes_crn_attributes, + rate_conversion_type, ) - if not initializes_crn_attributes: - # We need the key columns to be created before this stream can be called. - self.resources.add_resources(component, [stream], self._key_columns) + + # We need the key columns to be created before this stream can be called. + self._add_resources( + component=component, + resources=stream, + required_resources=self._key_columns if not initializes_crn_attributes else [], + ) + self._add_constraint( stream.get_draw, restrict_during=[ @@ -178,7 +187,7 @@ def get_randomness_stream( def _get_randomness_stream( self, decision_point: str, - component: Component | None, + component: Component, initializes_crn_attributes: bool = False, rate_conversion_type: Literal["linear", "exponential"] = "linear", ) -> RandomnessStream: @@ -192,7 +201,7 @@ def _get_randomness_stream( clock=self._clock, seed=self._seed, index_map=self._key_mapping, - component=component, + component=self._get_current_component(), initializes_crn_attributes=initializes_crn_attributes, rate_conversion_type=rate_conversion_type, ) @@ -200,7 +209,7 @@ def _get_randomness_stream( return stream def get_seed(self, decision_point: str) -> int: - """Get a randomly generated seed for use with external randomness tools. + """Gets a randomly generated seed for use with external randomness tools. Parameters ---------- @@ -212,7 +221,7 @@ def get_seed(self, decision_point: str) -> int: Returns ------- A seed for a random number generation that is linked to Vivarium's - common random number framework. + Common Random Number framework. """ return get_hash("_".join([decision_point, str(self._clock()), str(self._seed)])) @@ -222,7 +231,7 @@ def register_simulants(self, simulants: pd.DataFrame) -> None: Parameters ---------- simulants - A table with state data representing the new simulants. Each + A table with state data representing the new simulants. Each simulant should pass through this function exactly once. Raises diff --git a/src/vivarium/framework/randomness/stream.py b/src/vivarium/framework/randomness/stream.py index ce8b0f457..e31d44a67 100644 --- a/src/vivarium/framework/randomness/stream.py +++ b/src/vivarium/framework/randomness/stream.py @@ -3,8 +3,8 @@ Randomness Streams ================== -This module provides a wrapper around numpy's randomness system with the intent of coupling -it to vivarium's tools for Common Random Number genereration. +This module provides a wrapper around numpy's randomness system with the intent +of coupling it to vivarium's tools for common random number genereration. Attributes @@ -16,12 +16,6 @@ [0.2, 0.2, RESIDUAL_CHOICE] => [0.2, 0.2, 0.6] - -Notes ------ -Currently this object is only used in the `choice` function of this -module. - """ from __future__ import annotations @@ -43,6 +37,7 @@ if TYPE_CHECKING: from vivarium import Component + from vivarium.manager import Manager RESIDUAL_CHOICE = object() @@ -60,7 +55,7 @@ def get_hash(key: str) -> int: Parameters ---------- - key : + key A string used to create a seed for the random number generator. Returns @@ -72,10 +67,10 @@ def get_hash(key: str) -> int: class RandomnessStream(Resource): - """A stream for producing common random numbers. + """A stream for producing Common Random Numbers (CRN). `RandomnessStream` objects provide an interface to Vivarium's - common random number generation. They provide a number of methods + common random number generation. They provide a number of methods for doing common simulation tasks that require random numbers like making decisions among a number of choices. @@ -99,8 +94,7 @@ def __init__( clock: Callable[[], ClockTime], seed: Any, index_map: IndexMap, - # TODO [MIC-5452]: all resources should have a component - component: Component | None = None, + component: Component | Manager, initializes_crn_attributes: bool = False, rate_conversion_type: Literal["linear", "exponential"] = "linear", ): @@ -126,7 +120,7 @@ def __init__( ) def _key(self, additional_key: Any = None) -> str: - """Construct a hashable key from this object's state. + """Constructs a hashable key from this object's state. Parameters ---------- @@ -140,7 +134,7 @@ def _key(self, additional_key: Any = None) -> str: return "_".join([self.key, str(self.clock()), str(additional_key), str(self.seed)]) def get_draw(self, index: pd.Index[int], additional_key: Any = None) -> pd.Series[float]: - """Get an indexed set of numbers uniformly drawn from the unit interval. + """Gets an indexed set of numbers uniformly drawn from the unit interval. Parameters ---------- @@ -207,7 +201,7 @@ def filter_for_rate( rate: float | list[float] | tuple[float] | NumericArray | pd.Series[float], additional_key: Any = None, ) -> PandasObject: - """Decide an event outcome for each individual from rates. + """Decides an event outcome for each individual from rates. Given a population or its index and an array of associated rates for some event to happen, we create and return the subpopulation for whom @@ -242,10 +236,14 @@ def filter_for_rate( def filter_for_probability( self, population: PandasObject, - probability: float | list[float] | tuple[float] | NumericArray | pd.Series[float], + probability: float + | list[float] + | tuple[float, ...] + | NumericArray + | pd.Series[float], additional_key: Any = None, ) -> PandasObject: - """Decide an outcome for each individual from probabilities. + """Decides an outcome for each individual from probabilities. Given a population or its index and an array of associated probabilities for some event to happen, we create and return the subpopulation for @@ -282,8 +280,16 @@ def filter_for_probability( else: index = population.index - draws = self.get_draw(index, additional_key) - mask = np.array(draws < probability) + probabilities = pd.Series(probability, index=index) + # We skip draws for simulants who have a zero or one probability + zeros_idx = probabilities[probabilities == 0].index + ones_idx = probabilities[probabilities == 1].index + get_draws_idx = probabilities.index.difference(zeros_idx).difference(ones_idx) + draws = self.get_draw(get_draws_idx, additional_key) + # instantiate mask as False and fill in True where appropriate + mask = np.zeros(len(index), dtype=bool) + mask[index.get_indexer(ones_idx)] = True # type: ignore [no-untyped-call] + mask[index.get_indexer(get_draws_idx)] = draws < probabilities[get_draws_idx] # type: ignore [no-untyped-call] return population[mask] def choice( @@ -313,9 +319,9 @@ def choice( choices A set of options to choose from. p - The relative weights of the choices. Can be either a 1-d array of + The relative weights of the choices. Can be either a 1-d array of the same length as `choices` or a 2-d array with `len(index)` rows - and `len(choices)` columns. In the 1-d case, the same set of + and `len(choices)` columns. In the 1-d case, the same set of weights are used to decide among the choices for every item in the `index`. In the 2-d case, each row in `p` contains a separate set of weights for every item in the `index`. @@ -344,8 +350,7 @@ def sample_from_distribution( additional_key: Any = None, **distribution_kwargs: Any, ) -> pd.Series[Any]: - """Given a distribution, returns an indexed set of samples from that - distribution. + """Returns an indexed set of samples from a given distribution. Parameters ---------- @@ -409,9 +414,9 @@ def _choice( A set of options to choose from. Choices must be the same for every simulant. p - The relative weights of the choices. Can be either a 1-d array of + The relative weights of the choices. Can be either a 1-d array of the same length as `choices` or a 2-d array with `len(draws)` rows - and `len(choices)` columns. In the 1-d case, the same set of weights + and `len(choices)` columns. In the 1-d case, the same set of weights are used to decide among the choices for every item in the `index`. In the 2-d case, each row in `p` contains a separate set of weights for every item in the `index`. diff --git a/src/vivarium/framework/resource/__init__.py b/src/vivarium/framework/resource/__init__.py index 614842d26..764ded94e 100644 --- a/src/vivarium/framework/resource/__init__.py +++ b/src/vivarium/framework/resource/__init__.py @@ -6,7 +6,7 @@ This module provides a tool to manage dependencies on resources within a :mod:`vivarium` simulation. These resources take the form of things that can be created and utilized by components, for example columns in the -:mod:`state table ` +:mod:`population state table ` or :mod:`named value pipelines `. Because these resources need to be created before they can be used, they are @@ -22,4 +22,4 @@ from vivarium.framework.resource.interface import ResourceInterface from vivarium.framework.resource.manager import ResourceManager -from vivarium.framework.resource.resource import Resource +from vivarium.framework.resource.resource import Column, Resource diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index 8fa21fe56..aa7d92928 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -1,37 +1,41 @@ from __future__ import annotations -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator from typing import TYPE_CHECKING from vivarium.framework.resource.exceptions import ResourceError -from vivarium.framework.resource.resource import Resource +from vivarium.framework.resource.resource import Column, Resource if TYPE_CHECKING: from vivarium.framework.population import SimulantData + from vivarium.framework.values import AttributePipeline class ResourceGroup: """Resource groups are the nodes in the resource dependency graph. A resource group represents the pool of resources produced by a single - callable and all the dependencies necessary to produce that resource. + callable and all the required resources necessary to produce them. When thinking of the dependency graph, this represents a vertex and - all in-edges. This is a local-information representation that can be + all in-edges. This is a local-information representation that can be used to construct the entire dependency graph once all resources are specified. """ def __init__( - self, initialized_resources: Sequence[Resource], dependencies: Sequence[Resource] + self, + initialized_resources: Iterable[Column] | Resource, + required_resources: Iterable[str | Resource], + initializer: Callable[[SimulantData], None] | None, ): - """Create a new resource group. + """Creates a new resource group. Parameters ---------- initialized_resources The resources initialized by this resource group's initializer. - dependencies + required_resources The resources this resource group's initializer depends on. Raises @@ -42,21 +46,27 @@ def __init__( if not initialized_resources: raise ResourceError("Resource groups must have at least one resource.") - if len(set(r.component for r in initialized_resources)) != 1: - raise ResourceError("All initialized resources must have the same component.") + initialized_resources_ = ( + [initialized_resources] + if isinstance(initialized_resources, Resource) + else list(initialized_resources) + ) - if len(set(r.resource_type for r in initialized_resources)) != 1: + if len(set(res.component for res in initialized_resources_)) != 1: + raise ResourceError("All initialized resources must have the same component.") + if len(set(res.resource_type for res in initialized_resources_)) != 1: raise ResourceError("All initialized resources must be of the same type.") - self.component = initialized_resources[0].component + self.component = initialized_resources_[0].component """The component or manager that produces the resources in this group.""" - self.type = initialized_resources[0].resource_type + self.type = initialized_resources_[0].resource_type """The type of resource in this group.""" - self.is_initialized = initialized_resources[0].is_initialized - """Whether this resource group contains initialized resources.""" - self._dependencies = dependencies - self.resources = {r.resource_id: r for r in initialized_resources} + self._required_resources = required_resources + self.resources = {res.resource_id: res for res in initialized_resources_} """A dictionary of resources produced by this group, keyed by resource_id.""" + self.initializer = initializer + self.is_initialized = initializer is not None + """Whether this resource group contains initialized resources.""" @property def names(self) -> list[str]: @@ -64,24 +74,32 @@ def names(self) -> list[str]: return list(self.resources) @property - def initializer(self) -> Callable[[SimulantData], None]: - """The method that initializes this group of resources.""" - # TODO [MIC-5452]: all resource groups should have a component - if not self.component: - raise ResourceError(f"Resource group {self} does not have an initializer.") - return self.component.on_initialize_simulants - - @property - def dependencies(self) -> list[str]: - """The long names (including type) of dependencies for this group.""" - return [dependency.resource_id for dependency in self._dependencies] + def required_resources(self) -> list[str]: + """The long names (including type) of required resources for this group.""" + dependency_strings = [dep for dep in self._required_resources if isinstance(dep, str)] + if dependency_strings: + raise ResourceError( + "Resource group has not been finalized; required_resources are still strings.\n" + f"Resource group: {self}\n" + f"String required_resources: {dependency_strings}" + ) + return [dep.resource_id for dep in self._required_resources] # type: ignore[union-attr] + + def set_required_resources( + self, attribute_pipelines: dict[str, AttributePipeline] + ) -> None: + """Converts any required resources specified as strings to :class:`AttributePipelines `.""" + self._required_resources = [ + attribute_pipelines[dep] if isinstance(dep, str) else dep + for dep in self._required_resources + ] def __iter__(self) -> Iterator[str]: return iter(self.names) def __repr__(self) -> str: resources = ", ".join(self) - return f"ResourceProducer({resources})" + return f"ResourceGroup({resources})" def __str__(self) -> str: resources = ", ".join(self) diff --git a/src/vivarium/framework/resource/interface.py b/src/vivarium/framework/resource/interface.py index 3d1b1b04d..e51cb4760 100644 --- a/src/vivarium/framework/resource/interface.py +++ b/src/vivarium/framework/resource/interface.py @@ -9,7 +9,6 @@ from __future__ import annotations -from collections.abc import Iterable from typing import TYPE_CHECKING, Any from vivarium.framework.resource.manager import ResourceManager @@ -17,24 +16,27 @@ from vivarium.manager import Interface, Manager if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from vivarium import Component + from vivarium.framework.population.manager import SimulantData class ResourceInterface(Interface): """The resource management system. - A resource in :mod:`vivarium` is something like a state table column - or a randomness stream. These resources are used to initialize or alter - the state of the simulation. Many of these resources might depend on each - other and therefore need to be created or updated in a particular order. + A "resource" in :mod:`vivarium` is something like a state table private column, + a lookup table, or a randomness stream. These resources are used to initialize + or alter the state of the simulation. Many of these resources might depend on + each other and therefore need to be created or updated in a particular order. These dependency chains can be quite long and complex. Placing the ordering responsibility on end users makes simulations very fragile and difficult to understand. Instead, the resource management - system allows users to only specify local dependencies. The system then - uses the local dependency information to construct a full dependency - graph, validate that there are no cyclic dependencies, and return - resources and their producers in an order that makes sense. + system allows users to only specify local dependencies (referred to throughout + as "required resources"). The system then uses the local dependency information + to construct a full dependency graph, validate that there are no cyclic dependencies, + and return resources and their producers in an order that makes sense. """ @@ -43,10 +45,9 @@ def __init__(self, manager: ResourceManager): def add_resources( self, - # TODO [MIC-5452]: all resource groups should have a component - component: Component | Manager | None, - resources: Iterable[str | Resource], - dependencies: Iterable[str | Resource], + component: Component | Manager, + resources: Resource, + required_resources: Iterable[str | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -55,19 +56,45 @@ def add_resources( component The component or manager adding the resources. resources - The resources being added. A string represents a column resource. - dependencies + The resources being added. A string represents an attribute pipeline. + required_resources A list of resources that the producer requires. A string represents - a column resource. + a population attribute. Raises ------ ResourceError - If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, or - there are multiple producers of the same resource. + If there are multiple producers of the same resource. + """ + self._manager.add_resources( + component, + initializer=None, + resources=resources, + required_resources=required_resources, + ) + + def add_private_columns( + self, + initializer: Callable[[SimulantData], None], + columns: Iterable[str] | str, + required_resources: Iterable[str | Resource], + ) -> None: + """Adds private column resources to the resource pool. + + Parameters + ---------- + initializer + A function that will be called to initialize the state of new simulants. + columns + The population state table private columns that the given initializer + provides initial state information for. + required_resources + The resources that the initializer requires to run. Strings are interpreted + as attributes. """ - self._manager.add_resources(component, resources, dependencies) + self._manager.add_private_columns( + initializer=initializer, columns=columns, required_resources=required_resources + ) def get_population_initializers(self) -> list[Any]: """Returns a dependency-sorted list of population initializers. diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index b83efdb8d..2e7039073 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -7,11 +7,12 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any import networkx as nx +from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup from vivarium.framework.resource.resource import Column, NullResource, Resource @@ -20,6 +21,8 @@ if TYPE_CHECKING: from vivarium import Component from vivarium.framework.engine import Builder + from vivarium.framework.event import Event + from vivarium.framework.population.manager import SimulantData class ResourceManager(Manager): @@ -42,7 +45,6 @@ def __init__(self) -> None: @property def name(self) -> str: - """The name of this manager.""" return "resource_manager" @property @@ -66,20 +68,31 @@ def sorted_nodes(self) -> list[ResourceGroup]: self._sorted_nodes = list(nx.algorithms.topological_sort(self.graph)) # type: ignore[func-returns-value] except nx.NetworkXUnfeasible: raise ResourceError( - "The resource pool contains at least one cycle: " + "The resource pool contains at least one cycle:\n" f"{nx.find_cycle(self.graph)}." ) return self._sorted_nodes def setup(self, builder: Builder) -> None: self.logger = builder.logging.get_logger(self.name) + self._get_attribute_pipelines = builder.value.get_attribute_pipelines() + self._get_current_component_or_manager = ( + builder.components.get_current_component_or_manager + ) + builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup) + + def on_post_setup(self, _: Event) -> None: + # Finalize the resource group dependencies + attribute_pipelines = self._get_attribute_pipelines() + for group in self._resource_group_map.values(): + group.set_required_resources(attribute_pipelines) def add_resources( self, - # TODO [MIC-5452]: all resource groups should have a component - component: Component | Manager | None, - resources: Iterable[str | Resource], - dependencies: Iterable[str | Resource], + component: Component | Manager, + initializer: Callable[[SimulantData], None] | None, + resources: Iterable[Column] | Resource, + required_resources: Iterable[str | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -87,40 +100,68 @@ def add_resources( ---------- component The component or manager adding the resources. + initializer + A method that will be called to initialize the state of new simulants. resources - The resources being added. A string represents a column resource. - dependencies + The resources being added. A string represents a population attribute. + required_resources A list of resources that the producer requires. A string represents - a column resource. + a population attribute. Raises ------ ResourceError - If a component has multiple resource producers for the ``column`` - resource type or there are multiple producers of the same resource. + If there are multiple producers of the same resource. """ - resource_group = self._get_resource_group(component, resources, dependencies) - + resource_group = self._get_resource_group( + component, initializer, resources, required_resources + ) for resource_id, resource in resource_group.resources.items(): if resource_id in self._resource_group_map: other_resource = self._resource_group_map[resource_id] - # TODO [MIC-5452]: all resource groups should have a component - resource_component = resource.component.name if resource.component else None - other_resource_component = ( - other_resource.component.name if other_resource.component else None - ) raise ResourceError( - f"Component '{resource_component}' is attempting to register" + f"Component '{resource.component.name}' is attempting to register" f" resource '{resource_id}' but it is already registered by" - f" '{other_resource_component}'." + f" '{other_resource.component.name}'." ) self._resource_group_map[resource_id] = resource_group + def add_private_columns( + self, + initializer: Callable[[SimulantData], None], + columns: Iterable[str] | str, + required_resources: Iterable[str | Resource], + ) -> None: + """Adds private column resources to the resource pool. + + Parameters + ---------- + initializer + A method that will be called to initialize the state of new simulants. + columns + The population state table private columns that the given initializer + provides initial state information for. + required_resources + The resources that the initializer requires to run. Strings are interpreted + as attributes. + """ + if isinstance(columns, str): + columns = [columns] + component = self._get_current_component_or_manager() + columns_ = [Column(col, component) for col in columns] + self.add_resources( + component=component, + initializer=initializer, + resources=columns_, + required_resources=required_resources, + ) + def _get_resource_group( self, - component: Component | Manager | None, - resources: Iterable[str | Resource], - dependencies: Iterable[str | Resource], + component: Component | Manager, + initializer: Callable[[SimulantData], None] | None, + resources: Iterable[Column] | Resource, + required_resources: Iterable[str | Resource], ) -> ResourceGroup: """Packages resource information into a resource group. @@ -128,26 +169,21 @@ def _get_resource_group( -------- :class:`ResourceGroup` """ - resources_ = [Column(r, component) if isinstance(r, str) else r for r in resources] - dependencies_ = [Column(d, None) if isinstance(d, str) else d for d in dependencies] - - if not resources_: + if not resources: # We have a "producer" that doesn't produce anything, but # does have dependencies. This is necessary for components that # want to track private state information. - resources_ = [NullResource(self._null_producer_count, component)] + resources = NullResource(self._null_producer_count, component) self._null_producer_count += 1 - # TODO [MIC-5452]: all resource groups should have a component - if component and ( - have_other_component := [r for r in resources_ if r.component != component] - ): + if isinstance(resources, Resource) and resources.component != component: raise ResourceError( - f"All initialized resources must have the component '{component.name}'." - f" The following resources have a different component: {have_other_component}" + "All initialized resources in this resource group must have the" + f" component '{component.name}'. The following resource has a different" + f" component: {resources.name}" ) - return ResourceGroup(resources_, dependencies_) + return ResourceGroup(resources, required_resources, initializer) def _to_graph(self) -> nx.DiGraph: """Constructs the full resource graph from information in the groups. @@ -169,7 +205,7 @@ def _to_graph(self) -> nx.DiGraph: resource_graph.add_nodes_from(self._resource_group_map.values()) for resource_group in resource_graph.nodes: - for dependency in resource_group.dependencies: + for dependency in resource_group.required_resources: if dependency not in self._resource_group_map: # Warn here because this sometimes happens naturally # if observer components are missing from a simulation. @@ -196,5 +232,5 @@ def __repr__(self) -> str: out = {} for resource_group in set(self._resource_group_map.values()): produced = ", ".join(resource_group) - out[produced] = ", ".join(resource_group.dependencies) + out[produced] = ", ".join(resource_group.required_resources) return "\n".join([f"{produced} : {depends}" for produced, depends in out.items()]) diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py index 343368d48..175c817a2 100644 --- a/src/vivarium/framework/resource/resource.py +++ b/src/vivarium/framework/resource/resource.py @@ -1,57 +1,59 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING +from vivarium.framework.lifecycle import LifeCycleError + if TYPE_CHECKING: from vivarium import Component from vivarium.manager import Manager -@dataclass class Resource: """A generic resource representing a node in the dependency graph.""" - resource_type: str - """The type of the resource.""" - name: str - """The name of the resource.""" - # TODO [MIC-5452]: all resources should have a component - component: Component | Manager | None - """The component that creates the resource.""" + def __init__( + self, resource_type: str, name: str, component: Component | Manager | None + ) -> None: + """Create a new resource.""" + self.resource_type = resource_type + """The type of the resource.""" + self.name = name + """The name of the resource.""" + self._component = component + """The component that creates the resource. Can be None if not yet set.""" + + @property + def component(self) -> Component | Manager: + """The component that creates the resource.""" + if self._component is None: + raise LifeCycleError( + f"The component for the resource '{self.resource_id}' has not been set yet." + ) + return self._component @property def resource_id(self) -> str: """The long name of the resource, including the type.""" return f"{self.resource_type}.{self.name}" - @property - def is_initialized(self) -> bool: - """Return True if the resource needs to be initialized.""" - return False - class NullResource(Resource): """A node in the dependency graph that does not produce any resources.""" - # TODO [MIC-5452]: all resources should have a component - def __init__(self, index: int, component: Component | Manager | None): + def __init__(self, index: int, component: Component | Manager) -> None: super().__init__("null", f"{index}", component) - @property - def is_initialized(self) -> bool: - """Return True if the resource needs to be initialized.""" - return True - class Column(Resource): - """A resource representing a column in the state table.""" + """A resource representing a column in the population private data.""" - # TODO [MIC-5452]: all resources should have a component - def __init__(self, name: str, component: Component | Manager | None): + def __init__(self, name: str, component: Component | Manager) -> None: super().__init__("column", name, component) - @property - def is_initialized(self) -> bool: - """Return True if the resource needs to be initialized.""" - return True + def __eq__(self, value: object) -> bool: + return ( + isinstance(value, Column) + and self.resource_id == value.resource_id + and self.component.name == value.component.name + ) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index c0939dcd5..3b277a6c9 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -15,18 +15,15 @@ from pandas.core.groupby.generic import DataFrameGroupBy from vivarium.framework.event import Event +from vivarium.framework.population import utilities as pop_utils from vivarium.framework.results.exceptions import ResultsConfigurationError from vivarium.framework.results.observation import Observation -from vivarium.framework.results.stratification import ( - Stratification, - get_mapped_col_name, - get_original_col_name, -) -from vivarium.framework.values import Pipeline +from vivarium.framework.results.stratification import Stratification, get_mapped_col_name from vivarium.types import ScalarMapper, VectorMapper if TYPE_CHECKING: from vivarium.framework.engine import Builder + from vivarium.framework.results.interface import PopulationFilter class ResultsContext: @@ -52,7 +49,7 @@ class ResultsContext: objects to be produced keyed by the observation name. grouped_observations Dictionary of observation details. It is of the format - {lifecycle_state: {pop_filter: {stratifications: list[Observation]}}}. + {lifecycle_state: {PopulationFilter: {stratifications: list[Observation]}}}. Allowable lifecycle_states are "time_step__prepare", "time_step", "time_step__cleanup", and "collect_metrics". logger @@ -65,7 +62,11 @@ def __init__(self) -> None: self.excluded_categories: dict[str, list[str]] = {} self.observations: dict[str, Observation] = {} self.grouped_observations: defaultdict[ - str, defaultdict[str, defaultdict[tuple[str, ...] | None, list[Observation]]] + str, + defaultdict[ + PopulationFilter, + defaultdict[tuple[str, ...] | None, list[Observation]], + ], ] = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) @property @@ -73,7 +74,7 @@ def name(self) -> str: return "results_context" def setup(self, builder: Builder) -> None: - """Set up the results context. + """Sets up the results context. This method is called by the :class:`ResultsManager ` during the setup phase of that object. @@ -82,10 +83,11 @@ def setup(self, builder: Builder) -> None: self.excluded_categories = ( builder.configuration.stratification.excluded_categories.to_dict() ) + self.get_tracked_query = builder.population.get_tracked_query() # noinspection PyAttributeOutsideInit def set_default_stratifications(self, default_grouping_columns: list[str]) -> None: - """Set the default stratifications to be used by stratified observations. + """Sets the default stratifications to be used by stratified observations. Parameters ---------- @@ -105,7 +107,7 @@ def set_default_stratifications(self, default_grouping_columns: list[str]) -> No self.default_stratifications = default_grouping_columns def set_stratifications(self) -> None: - """Set stratifications on all Observers. + """Sets stratifications on all Observers. Emits a warning if any registered stratifications are not being used by any observation. @@ -140,29 +142,28 @@ def set_stratifications(self) -> None: def add_stratification( self, name: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], categories: list[str], excluded_categories: list[str] | None, mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> None: - """Add a stratification to the results context. + """Adds a stratification to the results context. Parameters ---------- name Name of the stratification. - sources - A list of the columns and values needed as input for the `mapper`. + requires_attributes + The population attributes needed as input for the `mapper`. categories Exhaustive list of all possible stratification values. excluded_categories List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that maps the columns and value pipelines specified by - `sources` to the stratification categories. It can either map the entire + A callable that maps the population attributes specified by + `requires_attributes` to the stratification categories. It can either map the entire population or an individual simulant. A simulation will fail if the `mapper` ever produces an invalid value. is_vectorized @@ -172,11 +173,9 @@ def add_stratification( Raises ------ ValueError - If the stratification `name` is already used. - ValueError - If there are duplicate `categories`. - ValueError - If any `excluded_categories` are not in `categories`. + - If the stratification `name` is already used. + - If there are duplicate `categories`. + - If any `excluded_categories` are not in `categories`. """ if name in self.stratifications: raise ValueError(f"Stratification name '{name}' is already used.") @@ -210,8 +209,7 @@ def add_stratification( self.stratifications[name] = Stratification( name=name, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, categories=categories, excluded_categories=to_exclude, mapper=mapper, @@ -222,14 +220,13 @@ def register_observation( self, observation_type: type[Observation], name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], stratifications: tuple[str, ...] | None, **kwargs: Any, ) -> Observation: - """Add an observation to the results context. + """Adds an observation to the results context. Parameters ---------- @@ -238,9 +235,11 @@ def register_observation( name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle state the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". @@ -261,18 +260,15 @@ def register_observation( f"Observation name '{name}' is already used: {self.observations[name]}." ) - # Instantiate the observation and add it and its (pop_filter, stratifications) - # tuple as a key-value pair to the self.observations[when] dictionary. observation = observation_type( name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, **kwargs, ) self.observations[name] = observation - self.grouped_observations[observation.when][observation.pop_filter][ + self.grouped_observations[observation.when][observation.population_filter][ stratifications ].append(observation) return observation @@ -287,12 +283,10 @@ def gather_results( None, None, ]: - """Generate and yield current results for all observations at this lifecycle - state and event. + """Generates and yields current results for all observations at this lifecycle state and event. - Each set of results are stratified and grouped by - all registered stratifications as well as filtered by their respective - observation's pop_filter. + Each set of results are stratified and grouped by all registered stratifications + as well as filtered by their respective observation's pop_filter. Parameters ---------- @@ -316,11 +310,12 @@ def gather_results( If a stratification's temporary column name already exists in the population DataFrame. """ - # Optimization: We store all the producers by pop_filter and stratifications + # Optimization: We store all the producers by population_filter and stratifications # so that we only have to apply them once each time we compute results. - for pop_filter, stratification_observations in self.grouped_observations[ - lifecycle_state - ].items(): + for ( + population_filter, + stratification_observations, + ) in self.grouped_observations[lifecycle_state].items(): event_pop_filter_observations = [ observation for observations in stratification_observations.values() @@ -330,7 +325,7 @@ def gather_results( if not event_pop_filter_observations: continue - filtered_population = self._filter_population(population, pop_filter) + filtered_population = self._filter_population(population, population_filter) if filtered_population.empty: continue @@ -353,7 +348,7 @@ def gather_results( yield (results, observation.name, observation.results_updater) def get_observations(self, event: Event) -> list[Observation]: - """Get all observations for a given event. + """Gets all observations for a given event. Parameters ---------- @@ -374,7 +369,7 @@ def get_observations(self, event: Event) -> list[Observation]: ] def get_stratifications(self, observations: list[Observation]) -> list[Stratification]: - """Get all stratifications for a given set of observations. + """Gets all stratifications for a given set of observations. Parameters ---------- @@ -395,10 +390,10 @@ def get_stratifications(self, observations: list[Observation]) -> list[Stratific }.values() ) - def get_required_columns( + def get_required_attributes( self, observations: list[Observation], stratifications: list[Stratification] ) -> list[str]: - """Get all columns required for producing results for a given Event. + """Gets all population attributes required for producing results for a given Event. Parameters ---------- @@ -412,44 +407,32 @@ def get_required_columns( Returns ------- - A list of all columns required for producing results for the given Event. + All population attributes required for producing results for the given Event. """ - required_columns = {"tracked"} + required_attributes = set() for observation in observations: - required_columns.update(observation.requires_columns) - for stratification in stratifications: - required_columns.update(stratification.requires_columns) - return list(required_columns) - - def get_required_values( - self, observations: list[Observation], stratifications: list[Stratification] - ) -> list[Pipeline]: - """Get all values required for producing results for a given Event. - - Parameters - ---------- - observations - List of observations to be gathered for this specific event. Note that this - excludes all observations whose `to_observe` method returns False. - stratifications - List of stratifications to be gathered for this specific event. This only - includes stratifications which are needed by the observations which will be - made during this `Event`. - - Returns - ------- - A list of all values required for producing results for the given Event. - """ - required_values = set() - for observation in observations: - required_values.update(observation.requires_values) + required_attributes.update(set(observation.requires_attributes)) + required_attributes.update( + pop_utils.extract_columns_from_query(self.get_tracked_query()) + if not observation.population_filter.include_untracked + else set() + ) + required_attributes.update( + pop_utils.extract_columns_from_query(observation.population_filter.query) + ) for stratification in stratifications: - required_values.update(stratification.requires_values) - return list(required_values) + required_attributes.update(stratification.requires_attributes) + return list(required_attributes) - def _filter_population(self, population: pd.DataFrame, pop_filter: str) -> pd.DataFrame: + def _filter_population( + self, population: pd.DataFrame, population_filter: PopulationFilter + ) -> pd.DataFrame: """Filter out simulants not to observe.""" - return population.query(pop_filter) if pop_filter else population.copy() + query = population_filter.query + if not population_filter.include_untracked: + # combine the tracking query with the population filter query + query = pop_utils.combine_queries(query, self.get_tracked_query()) + return population.query(query) if query else population.copy() def _drop_na_stratifications( self, population: pd.DataFrame, stratification_names: tuple[str, ...] | None @@ -470,7 +453,7 @@ def _drop_na_stratifications( def _get_groups( stratifications: tuple[str, ...], filtered_pop: pd.DataFrame ) -> DataFrameGroupBy[tuple[str, ...] | str, bool]: - """Group the population by stratification. + """Groups the population by stratification. Notes ----- diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index cc5a9277d..2b1c8bc71 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -9,7 +9,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Sequence, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -30,14 +30,16 @@ ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] -"""This is a Callable that takes existing results and new observations and returns updated results.""" +"""A Callable that takes existing results and new observations and returns updated results.""" ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame] -"""This is a Callable that takes a measure as a string and a DataFrame of observation results and returns formatted results.""" +"""A Callable that takes a measure as a string and a DataFrame of observation results +and returns formatted results.""" ResultsGathererInput = Union[ pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None # type: ignore [type-arg] ] ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame] -"""This is a Callable that optionally takes a possibly stratified population and returns new observation results.""" +"""A Callable that optionally takes a possibly stratified population and returns +new observation results.""" def _required_function_placeholder( @@ -46,35 +48,41 @@ def _required_function_placeholder( | tuple[str, pd.DataFrame], **kwargs: Any, ) -> pd.DataFrame: - """Placeholder function to indicate that a required function is missing.""" + """Returns and empty dataframe. + + Placeholder function to indicate that a required function is missing. + """ return pd.DataFrame() def _default_stratified_observation_formatter( measure: str, results: pd.DataFrame ) -> pd.DataFrame: - """Default formatter for stratified observations.""" + """Resets the results index.""" return results.reset_index() def _default_unstratified_observation_formatter( measure: str, results: pd.DataFrame ) -> pd.DataFrame: - """Default formatter for unstratified observations.""" + """Returns the results unchanged.""" return results +class PopulationFilter(NamedTuple): + """Container class for population query string and include_untracked flag.""" + + query: str = "" + include_untracked: bool = False + + class ResultsInterface(Interface): """Builder interface for the results management system. The results management system allows users to delegate results production to the simulation framework. This process attempts to roughly mimic the groupby-apply logic commonly done when manipulating :mod:`pandas` - DataFrames. The representation of state in the simulation is complex, - however, as it includes information both in the population state table - and dynamically generated information available from the - :class:`value pipelines `. - Additionally, good encapsulation of simulation logic typically has + DataFrames. Good encapsulation of simulation logic typically has results production separated from the modeling code into specialized `Observer` components. This often highlights the need for transformations of the simulation state into representations that aren't needed for @@ -105,8 +113,7 @@ def register_stratification( excluded_categories: list[str] | None = None, mapper: VectorMapper | ScalarMapper | None = None, is_vectorized: bool = False, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], ) -> None: """Registers a stratification that can be used by stratified observations. @@ -120,20 +127,16 @@ def register_stratification( List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that maps the columns and value pipelines specified by the - `requires_columns` and `requires_values` arguments to the stratification - categories. It can either map the entire population or an individual - simulant. A simulation will fail if the `mapper` ever produces an invalid - value. + A callable that maps the population attributes specified by the + `requires_attributes` argumnt to the stratification categories. It can + either map the entire population or an individual simulant. A simulation + will fail if the `mapper` ever produces an invalid value. is_vectorized True if the `mapper` function will map the entire population, and False if it will only map a single simulant. - requires_columns - A list of the state table columns that are required by the `mapper` - to produce the stratification. - requires_values - A list of the value pipelines that are required by the `mapper` to - produce the stratification. + requires_attributes + The population attributes that are required by the `mapper` to produce + the stratification. """ self._manager.register_stratification( name, @@ -141,8 +144,7 @@ def register_stratification( excluded_categories, mapper, is_vectorized, - requires_columns, - requires_values, + requires_attributes, ) def register_binned_stratification( @@ -152,7 +154,6 @@ def register_binned_stratification( bin_edges: Sequence[int | float] = [], labels: list[str] = [], excluded_categories: list[str] | None = None, - target_type: str = "column", **cut_kwargs: int | str | bool, ) -> None: """Registers a binned stratification that can be used by stratified observations. @@ -160,7 +161,7 @@ def register_binned_stratification( Parameters ---------- target - Name of the state table column or value pipeline to be binned. + Name of the population attribute to be binned. binned_column Name of the (binned) stratification. bin_edges @@ -172,9 +173,6 @@ def register_binned_stratification( excluded_categories List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. - target_type - Type specification of the `target` to be binned. "column" if it's a - state table column or "value" if it's a value pipeline. **cut_kwargs Keyword arguments for :meth: pandas.cut. """ @@ -184,7 +182,6 @@ def register_binned_stratification( bin_edges, labels, excluded_categories, - target_type, **cut_kwargs, ) @@ -195,10 +192,10 @@ def register_binned_stratification( def register_stratified_observation( self, name: str, - pop_filter: str = "tracked==True", + pop_filter: str = "", + include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], results_updater: ResultsUpdater = _required_function_placeholder, results_formatter: ResultsFormatter = _default_stratified_observation_formatter, additional_stratifications: list[str] = [], @@ -217,13 +214,13 @@ def register_stratified_observation( pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. + include_untracked + Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of the state table columns that are required by either the `pop_filter` or the `aggregator`. - requires_values - List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. + requires_attributes + The population attributes that are required by the `aggregator`. results_updater Function that updates existing raw observation results with newly gathered results. results_formatter @@ -250,10 +247,9 @@ def register_stratified_observation( self._manager.register_observation( observation_type=StratifiedObservation, name=name, - pop_filter=pop_filter, + population_filter=PopulationFilter(pop_filter, include_untracked), when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_updater=results_updater, results_formatter=results_formatter, additional_stratifications=additional_stratifications, @@ -266,10 +262,10 @@ def register_stratified_observation( def register_unstratified_observation( self, name: str, - pop_filter: str = "tracked==True", + pop_filter: str = "", + include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], results_gatherer: ResultsGatherer = _required_function_placeholder, results_updater: ResultsUpdater = _required_function_placeholder, results_formatter: ResultsFormatter = _default_unstratified_observation_formatter, @@ -285,15 +281,13 @@ def register_unstratified_observation( pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. + include_untracked + Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of the state table columns that are required by either the `pop_filter` or the - `results_gatherer`. - requires_values - List of the value pipelines that are required by either the `pop_filter` or the - `results_gatherer`. + requires_attributes + The population attributes that are required by the `results_gatherer`. results_gatherer Function that gathers the latest observation results. results_updater @@ -316,10 +310,9 @@ def register_unstratified_observation( self._manager.register_observation( observation_type=UnstratifiedObservation, name=name, - pop_filter=pop_filter, + population_filter=PopulationFilter(pop_filter, include_untracked), when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_updater=results_updater, results_gatherer=results_gatherer, results_formatter=results_formatter, @@ -329,10 +322,10 @@ def register_unstratified_observation( def register_adding_observation( self, name: str, - pop_filter: str = "tracked==True", + pop_filter: str = "", + include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], results_formatter: ResultsFormatter = _default_stratified_observation_formatter, additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], @@ -357,13 +350,13 @@ def register_adding_observation( pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. + include_untracked + Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of the state table columns that are required by either the `pop_filter` or the `aggregator`. - requires_values - List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. + requires_attributes + The population attributes that are required by the `aggregator`. results_formatter Function that formats the raw observation results. additional_stratifications @@ -382,10 +375,9 @@ def register_adding_observation( self._manager.register_observation( observation_type=AddingObservation, name=name, - pop_filter=pop_filter, + population_filter=PopulationFilter(pop_filter, include_untracked), when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_formatter=results_formatter, additional_stratifications=additional_stratifications, excluded_stratifications=excluded_stratifications, @@ -397,10 +389,10 @@ def register_adding_observation( def register_concatenating_observation( self, name: str, - pop_filter: str = "tracked==True", + pop_filter: str = "", + include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], results_formatter: ResultsFormatter = _default_unstratified_observation_formatter, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: @@ -421,13 +413,13 @@ def register_concatenating_observation( pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. + include_untracked + Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of the state table columns that are required by either the `pop_filter` or the `aggregator`. - requires_values - List of the value pipelines that are required by either the `pop_filter` or the `aggregator`. + requires_attributes + The population attributes that are required by the `aggregator`. results_formatter Function that formats the raw observation results. to_observe @@ -436,10 +428,9 @@ def register_concatenating_observation( self._manager.register_observation( observation_type=ConcatenatingObservation, name=name, - pop_filter=pop_filter, + population_filter=PopulationFilter(pop_filter, include_untracked), when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_formatter=results_formatter, to_observe=to_observe, ) diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 6127706c4..3ec61c6ee 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -1,15 +1,14 @@ """ -====================== -Results System Manager -====================== +=============== +Results Manager +=============== """ from __future__ import annotations from collections import defaultdict -from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Sequence import pandas as pd @@ -18,17 +17,12 @@ from vivarium.framework.results.context import ResultsContext from vivarium.framework.results.observation import Observation from vivarium.framework.results.stratification import Stratification, get_mapped_col_name -from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarMapper, VectorMapper if TYPE_CHECKING: from vivarium.framework.engine import Builder - - -class SourceType(Enum): - COLUMN = 0 - VALUE = 1 + from vivarium.framework.results.interface import PopulationFilter class ResultsManager(Manager): @@ -57,15 +51,7 @@ def name(self) -> str: return self._name def get_results(self) -> dict[str, pd.DataFrame]: - """Return the measure-specific formatted results in a dictionary. - - Notes - ----- - self._results_context.observations is a list where each item is a dictionary - of the form {lifecycle_phase: {(pop_filter, stratification_names): List[Observation]}}. - We use a triple-nested for loop to iterate over only the list of Observations - (i.e. we do not need the lifecycle_phase, pop_filter, or stratification_names - for this method). + """Gets the measure-specific formatted results in a dictionary. Returns ------- @@ -80,11 +66,11 @@ def get_results(self) -> dict[str, pd.DataFrame]: # noinspection PyAttributeOutsideInit def setup(self, builder: "Builder") -> None: - """Set up the results manager.""" + """Sets up the results manager.""" self._results_context.setup(builder) self.logger = builder.logging.get_logger(self.name) - self.population_view = builder.population.get_view([]) + self.population_view = builder.population.get_view() self.clock = builder.time.clock() self.step_size = builder.time.step_size() @@ -100,8 +86,6 @@ def setup(self, builder: "Builder") -> None: lifecycle_states.COLLECT_METRICS, self.on_collect_metrics ) - self.get_value = builder.value.get_value - self.set_default_stratifications(builder) def on_post_setup(self, _: Event) -> None: @@ -111,30 +95,29 @@ def on_post_setup(self, _: Event) -> None: self._raw_results[name] = observation.results_initializer() def on_time_step_prepare(self, event: Event) -> None: - """Define the listener callable for the time_step__prepare phase.""" + """Defines the listener callable for the time_step__prepare phase.""" self.gather_results(event) def on_time_step(self, event: Event) -> None: - """Define the listener callable for the time_step phase.""" + """Defines the listener callable for the time_step phase.""" self.gather_results(event) def on_time_step_cleanup(self, event: Event) -> None: - """Define the listener callable for the time_step__cleanup phase.""" + """Defines the listener callable for the time_step__cleanup phase.""" self.gather_results(event) def on_collect_metrics(self, event: Event) -> None: - """Define the listener callable for the collect_metrics phase.""" + """Defines the listener callable for the collect_metrics phase.""" self.gather_results(event) def gather_results(self, event: Event) -> None: - """Update existing results with any new results.""" + """Updates existing results with any new results.""" observations = self._results_context.get_observations(event) stratifications = self._results_context.get_stratifications(observations) if not observations or event.index.empty: return population = self._prepare_population(event, observations, stratifications) - for results_group, measure, updater in self._results_context.gather_results( population, event.name, observations ): @@ -145,7 +128,7 @@ def gather_results(self, event: Event) -> None: ########################## def set_default_stratifications(self, builder: "Builder") -> None: - """Set the default stratifications for the results context. + """Sets the default stratifications for the results context. This passes the default stratifications from the configuration to the :class:`ResultsContext ` @@ -166,13 +149,11 @@ def register_stratification( excluded_categories: list[str] | None, mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, - requires_columns: list[str] = [], - requires_values: list[str] = [], + requires_attributes: list[str] = [], ) -> None: - """Manager-level stratification registration. + """Registers a stratification that can be used by stratified observations. - Adds a stratification to the - :class:`ResultsContext ` + Adds a stratification to the :class:`ResultsContext ` as well as the stratification's required resources to this manager. Parameters @@ -185,26 +166,21 @@ def register_stratification( List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper - A callable that maps the columns and value pipelines specified by the - `requires_columns` and `requires_values` arguments to the stratification - categories. It can either map the entire population or an individual - simulant. A simulation will fail if the `mapper` ever produces an invalid - value. + A callable that maps population attributes specified by the + `requires_attributes` argument to the stratification categories. It can + either map the entire population or an individual simulant. A simulation + will fail if the `mapper` ever produces an invalid value. is_vectorized True if the `mapper` function will map the entire population, and False if it will only map a single simulant. - requires_columns + requires_attributes A list of the state table columns that are required by the `mapper` to produce the stratification. - requires_values - A list of the value pipelines that are required by the `mapper` to - produce the stratification. """ self.logger.debug(f"Registering stratification {name}") self._results_context.add_stratification( name=name, - requires_columns=requires_columns, - requires_values=[self.get_value(value) for value in requires_values], + requires_attributes=requires_attributes, categories=categories, excluded_categories=excluded_categories, mapper=mapper, @@ -218,16 +194,14 @@ def register_binned_stratification( bin_edges: Sequence[int | float], labels: list[str], excluded_categories: list[str] | None, - target_type: str, **cut_kwargs: int | str | bool, ) -> None: - """Manager-level registration of a continuous `target` quantity to observe - into bins in a `binned_column`. + """Registers a continuous `target` quantity to observe into bins in a `binned_column`. Parameters ---------- target - Name of the state table column or value pipeline to be binned. + Name of population attribute to be binned. binned_column Name of the (binned) stratification. bin_edges @@ -239,9 +213,6 @@ def register_binned_stratification( excluded_categories List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. - target_type - Type specification of the `target` to be binned. "column" if it's a - state table column or "value" if it's a value pipeline. **cut_kwargs Keyword arguments for :meth: pandas.cut. """ @@ -268,32 +239,25 @@ def _bin_data(data: pd.DataFrame) -> pd.Series[Any]: f"match the number of labels ({len(labels)})" ) - target_arg = "requires_columns" if target_type == "column" else "requires_values" - target_kwargs = {target_arg: [target]} - self.register_stratification( name=binned_column, categories=labels, excluded_categories=excluded_categories, mapper=_bin_data, is_vectorized=True, - **target_kwargs, + requires_attributes=[target], ) def register_observation( self, observation_type: type[Observation], name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[str], + requires_attributes: list[str], **kwargs: Any, ) -> None: - """Manager-level observation registration. - - Adds an observation to the - :class:`ResultsContext `. + """Registers an observation to the results system. Parameters ---------- @@ -302,28 +266,24 @@ def register_observation( name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of the state table columns that are required to compute the observation. - requires_values - List of the value pipelines that are required to compute the observation. + requires_attributes + The population attributes that are required to compute the observation. **kwargs Additional keyword arguments to be passed to the observation's constructor. """ self.logger.debug(f"Registering observation {name}") - if any(not isinstance(column, str) for column in requires_columns): - raise TypeError( - f"All required columns must be strings, but got {requires_columns} when registering observation {name}." - ) - if any(not isinstance(value, str) for value in requires_values): + if any(not isinstance(attribute, str) for attribute in requires_attributes): raise TypeError( - f"All required values must be strings, but got {requires_values} when registering observation {name}." + f"All required attributes must be strings, but got {requires_attributes} when registering observation {name}." ) if observation_type.is_stratified(): @@ -341,10 +301,9 @@ def register_observation( self._results_context.register_observation( observation_type=observation_type, name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=[self.get_value(value) for value in requires_values], + requires_attributes=requires_attributes, stratifications=stratifications, **kwargs, ) @@ -359,7 +318,7 @@ def _get_stratifications( additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], ) -> tuple[str, ...]: - """Resolve the stratifications required for the observation.""" + """Resolves the stratifications required for the observation.""" self._warn_check_stratifications(additional_stratifications, excluded_stratifications) stratifications = list( @@ -379,39 +338,43 @@ def _prepare_population( observations: list[Observation], stratifications: list[Stratification], ) -> pd.DataFrame: - """Prepare the population for results gathering.""" - required_columns = self._results_context.get_required_columns( - observations, stratifications - ) - required_values = self._results_context.get_required_values( + """Prepares the population for results gathering.""" + required_attributes = self._results_context.get_required_attributes( observations, stratifications ) - required_columns = required_columns.copy() - population = pd.DataFrame(index=event.index) - if "current_time" in required_columns: + attributes_to_get = [ + attribute + for attribute in required_attributes + if attribute + not in ["current_time", "event_step_size", "event_time"] + + list(event.user_data.keys()) + ] + if attributes_to_get: + # FIXME: (Inefficiency) In the event every single observation has some identical + # query string (e.g. 'is_alive == True'), we still calculate all attributes for + # the entire population and then apply the query downstream. + population = self.population_view.get_attributes( + event.index, + attributes_to_get, + include_untracked=any( + obs.population_filter.include_untracked for obs in observations + ), + ) + else: + population = pd.DataFrame(index=event.index) + + if "current_time" in required_attributes: population["current_time"] = self.clock() - required_columns.remove("current_time") - if "event_step_size" in required_columns: + if "event_step_size" in required_attributes: population["event_step_size"] = event.step_size - required_columns.remove("event_step_size") - if "event_time" in required_columns: + if "event_time" in required_attributes: population["event_time"] = self.clock() + event.step_size # type: ignore [operator] - required_columns.remove("event_time") - for k, v in event.user_data.items(): - if k in required_columns: - population[k] = v - required_columns.remove(k) + for key, val in event.user_data.items(): + if key in required_attributes: + population[key] = val - for pipeline in required_values: - population[pipeline.name] = pipeline(event.index) - - if required_columns: - population = pd.concat( - [self.population_view.subview(required_columns).get(event.index), population], - axis=1, - ) for stratification in stratifications: new_column = get_mapped_col_name(stratification.name) if new_column in population.columns: @@ -426,7 +389,7 @@ def _prepare_population( def _warn_check_stratifications( self, additional_stratifications: list[str], excluded_stratifications: list[str] ) -> None: - """Check additional and excluded stratifications if they'd not affect + """Checks additional and excluded stratifications if they'd not affect stratifications (i.e., would be NOP), and emit warning.""" nop_additional = [ s diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f5b477864..0d308cf87 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -24,6 +24,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass +from typing import TYPE_CHECKING import pandas as pd from pandas.api.types import CategoricalDtype @@ -32,7 +33,9 @@ from vivarium.exceptions import VivariumError from vivarium.framework.event import Event from vivarium.framework.results.stratification import Stratification, get_original_col_name -from vivarium.framework.values import Pipeline + +if TYPE_CHECKING: + from vivarium.framework.results.interface import PopulationFilter VALUE_COLUMN = "value" @@ -49,16 +52,16 @@ class Observation(ABC): name: str """Name of the observation. It will also be the name of the output results file for this particular observation.""" - pop_filter: str - """A Pandas query filter string to filter the population down to the simulants - who should be considered for the observation.""" + population_filter: PopulationFilter + """A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation.""" when: str """Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".""" - requires_columns: list[str] - """List of columns required for this observation.""" - requires_values: list[Pipeline] - """List of values required for this observation.""" + requires_attributes: list[str] + """The population attributes required for this observation.""" results_initializer: Callable[[], pd.DataFrame] """Method or function that initializes the raw observation results prior to starting the simulation. This could return, for example, an empty @@ -87,7 +90,7 @@ def observe( df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame: - """Gather the results of the observation. + """Gathers the results of the observation. Parameters ---------- @@ -119,16 +122,16 @@ class UnstratifiedObservation(Observation): name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of columns required for this observation. - requires_values - List of values required for this observation. + requires_attributes + The population attributes required for this observation. results_gatherer Method or function that gathers the new observation results. results_updater @@ -143,10 +146,9 @@ class UnstratifiedObservation(Observation): def __init__( self, name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], results_gatherer: Callable[[pd.DataFrame], pd.DataFrame], results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], @@ -165,10 +167,9 @@ def _wrap_results_gatherer( super().__init__( name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_initializer=self.create_empty_df, results_gatherer=_wrap_results_gatherer, results_updater=results_updater, @@ -182,7 +183,7 @@ def is_stratified(cls) -> bool: @staticmethod def create_empty_df() -> pd.DataFrame: - """Initialize an empty dataframe. + """Initializes an empty dataframe. Returns ------- @@ -203,16 +204,16 @@ class StratifiedObservation(Observation): name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of columns required for this observation. - requires_values - List of values required for this observation. + requires_attributes + The population attributes required for this observation. results_updater Method or function that updates existing raw observation results with newly gathered results. results_formatter @@ -229,10 +230,9 @@ class StratifiedObservation(Observation): def __init__( self, name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], aggregator_sources: list[str] | None, @@ -241,10 +241,9 @@ def __init__( ): super().__init__( name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_initializer=self.create_expanded_df, results_gatherer=self.get_complete_stratified_results, # type: ignore [arg-type] results_updater=results_updater, @@ -263,7 +262,7 @@ def observe( df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame: - """Gather the results of the observation. + """Gathers the results of the observation. Parameters ---------- @@ -291,7 +290,7 @@ def _rename_stratification_columns(self, results: pd.DataFrame) -> None: results.index.rename(get_original_col_name(idx_name), inplace=True) def create_expanded_df(self) -> pd.DataFrame: - """Initialize a dataframe of 0s with complete set of stratifications as the index. + """Initializes a dataframe of 0s with complete set of stratifications as the index. Returns ------- @@ -337,7 +336,7 @@ def get_complete_stratified_results( pop_groups: DataFrameGroupBy[str, bool], stratifications: tuple[str, ...], ) -> pd.DataFrame: - """Gather results for this observation. + """Gathers results for this observation. Parameters ---------- @@ -363,9 +362,7 @@ def _aggregate( aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], ) -> pd.Series[float] | pd.DataFrame: - """Apply the `aggregator` to the population groups and their - `aggregator_sources` columns. - """ + """Applies the provided aggregator to the popoulation groups.""" aggregates = ( pop_groups[aggregator_sources].apply(aggregator).fillna(0.0) # type: ignore [arg-type] if aggregator_sources @@ -375,9 +372,7 @@ def _aggregate( @staticmethod def _format(aggregates: pd.Series[float] | pd.DataFrame) -> pd.DataFrame: - """Convert the results to a pandas DataFrame if necessary and ensure the - results column name is 'value'. - """ + """Converts the results to a dataframe and ensures the results column name is 'value'.""" df = pd.DataFrame(aggregates) if isinstance(aggregates, pd.Series) else aggregates if df.shape[1] == 1: df.rename(columns={df.columns[0]: "value"}, inplace=True) @@ -385,7 +380,7 @@ def _format(aggregates: pd.Series[float] | pd.DataFrame) -> pd.DataFrame: @staticmethod def _expand_index(aggregates: pd.DataFrame) -> pd.DataFrame: - """Include all stratifications in the results by filling missing values with 0.""" + """Includes all stratifications in the results by filling missing values with 0.""" full_idx = ( pd.MultiIndex.from_product(aggregates.index.levels) if isinstance(aggregates.index, pd.MultiIndex) @@ -406,16 +401,16 @@ class AddingObservation(StratifiedObservation): name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of columns required for this observation. - requires_values - List of values required for this observation. + requires_attributes + The population attributes required for this observation. results_formatter Method or function that formats the raw observation results. stratifications @@ -433,10 +428,9 @@ class AddingObservation(StratifiedObservation): def __init__( self, name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], @@ -444,10 +438,9 @@ def __init__( ): super().__init__( name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_updater=self.add_results, results_formatter=results_formatter, aggregator_sources=aggregator_sources, @@ -459,7 +452,7 @@ def __init__( def add_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: - """Add newly-observed results to the existing results. + """Adds newly-observed results to the existing results. Parameters ---------- @@ -493,25 +486,23 @@ class ConcatenatingObservation(UnstratifiedObservation): """Concrete class for observing concatenating (and by extension, unstratified) results. The parent class `results_gatherer` and `results_updater` methods are explicitly - defined and attribute `included_columns` is added. + defined. Attributes ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. - pop_filter - A Pandas query filter string to filter the population down to the simulants who should - be considered for the observation. + population_filter + A named tuple of population filtering details. The first item is a Pandas + query string to filter the population down to the simulants who should be + considered for the observation. The second item is a boolean indicating whether + to include untracked simulants from the observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - requires_columns - List of columns required for this observation. - requires_values - List of values required for this observation. - included_columns - Columns to include in the observation + requires_attributes + The population attributes required for this observation. results_formatter Method or function that formats the raw observation results. to_observe @@ -522,38 +513,33 @@ class ConcatenatingObservation(UnstratifiedObservation): def __init__( self, name: str, - pop_filter: str, + population_filter: PopulationFilter, when: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame], to_observe: Callable[[Event], bool] = lambda event: True, ): - requires_columns = ["event_time"] + requires_columns + requires_attributes = ["event_time"] + requires_attributes super().__init__( name=name, - pop_filter=pop_filter, + population_filter=population_filter, when=when, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, results_gatherer=self.get_results_of_interest, results_updater=self.concatenate_results, results_formatter=results_formatter, to_observe=to_observe, ) - self.included_columns = self.requires_columns + [ - value.name for value in self.requires_values - ] def get_results_of_interest(self, pop: pd.DataFrame) -> pd.DataFrame: """Return the population with only the `included_columns`.""" - return pop[self.included_columns] + return pop[self.requires_attributes] @staticmethod def concatenate_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: - """Concatenate the existing results with the new observations. + """Concatenates the existing results with the new observations. Parameters ---------- diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 6b659ed33..734217a6b 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -50,7 +50,7 @@ def configuration_defaults(self) -> dict[str, Any]: } def get_configuration_name(self) -> str: - """Return the name of a concrete observer for use in the configuration""" + """Returns the name of a concrete observer for use in the configuration""" return self.name.split("_observer")[0] def get_configuration(self, builder: Builder) -> LayeredConfigTree: @@ -60,17 +60,17 @@ def get_configuration(self, builder: Builder) -> LayeredConfigTree: @abstractmethod def register_observations(self, builder: Builder) -> None: - """(Required). Register observations with within each observer.""" + """Registers observations with within each observer.""" pass def setup_component(self, builder: Builder) -> None: - """Set up the observer component.""" + """Sets up the observer component.""" super().setup_component(builder) self.register_observations(builder) self.set_results_dir(builder) def set_results_dir(self, builder: Builder) -> None: - """Define the results directory from the configuration.""" + """Defines the results directory from the configuration.""" self.results_dir = ( builder.configuration.to_dict() .get("output_data", {}) diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 1db0d3174..3c0a970f4 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -12,7 +12,6 @@ import pandas as pd from pandas.api.types import CategoricalDtype -from vivarium.framework.values import Pipeline from vivarium.types import ScalarMapper, VectorMapper STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" @@ -23,42 +22,42 @@ @dataclass class Stratification: """Class for stratifying observed quantities by specified characteristics. + Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. This class includes a :meth:`stratify ` method that produces an output column by calling the mapper on the source columns. + """ name: str """Name of the stratification.""" - requires_columns: list[str] - """A list of the columns needed as input for the `mapper`.""" - requires_values: list[Pipeline] - """A list of the values needed as input for the `mapper`.""" + requires_attributes: list[str] + """The population attributes needed as input for the `mapper`.""" categories: list[str] """Exhaustive list of all possible stratification values.""" excluded_categories: list[str] """List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration.""" mapper: VectorMapper | ScalarMapper | None - """A callable that maps the columns and value pipelines specified by the - `requires_columns` and `requires_values` arguments to the stratification - categories. It can either map the entire population or an individual - simulant. A simulation will fail if the `mapper` ever produces an invalid - value.""" + """A callable that maps the population attributes specified by the + `requires_attributes` argument to the stratification categories. It can either + map the entire population or an individual simulant. A simulation will fail if + the `mapper` ever produces an invalid value.""" is_vectorized: bool = False """True if the `mapper` function will map the entire population, and False if it will only map a single simulant.""" def __str__(self) -> str: return ( - f"Stratification '{self.name}' with sources {self._sources}, " + f"Stratification '{self.name}' with required attributes {self.requires_attributes}, " f"categories {self.categories}, and mapper {getattr(self.mapper, '__name__', repr(self.mapper))}" ) def __post_init__(self) -> None: - """Assign a default `mapper` if none was provided and check for non-empty - `categories` and `sources` otherwise. + """Assigns a default `mapper` if none was provided and check for non-empty + `categories` and `requires_attributes` otherwise. + Raises ------ ValueError @@ -66,24 +65,21 @@ def __post_init__(self) -> None: ValueError If the categories argument is empty. ValueError - If the sources argument is empty. + If the requires_attributes argument is empty. """ - self._sources = self.requires_columns + [ - pipeline.name for pipeline in self.requires_values - ] + self.vectorized_mapper = self._get_vectorized_mapper(self.mapper, self.is_vectorized) if not self.categories: raise ValueError("The categories argument must be non-empty.") - if not self._sources: - raise ValueError("The sources argument must be non-empty.") + if not self.requires_attributes: + raise ValueError("The requires_attributes argument must be non-empty.") def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]: - """Apply the `mapper` to the population `sources` columns to create a new - Series to be added to the population. + """Applies the `mapper` to the population `sources` columns. - Any `excluded_categories` (which have already been removed from `categories`) - will be converted to NaNs in the new column and dropped later at the - observation level. + This creates a new Series to be added to the population. Any `excluded_categories` + (which have already been removed from `categories`) will be converted to + NaNs in the new column and dropped later at the observation level. Parameters ---------- @@ -99,7 +95,7 @@ def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]: ValueError If the mapper returns any values not in `categories` or `excluded_categories`. """ - mapped_column = self.vectorized_mapper(population[self._sources]) + mapped_column = self.vectorized_mapper(population[self.requires_attributes]) unknown_categories = set(mapped_column) - set( self.categories + self.excluded_categories ) @@ -121,14 +117,12 @@ def _get_vectorized_mapper( user_provided_mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> VectorMapper: - """ - Choose a VectorMapper based on the inputted callable mapper. - """ + """Chooses a VectorMapper based on the provided callable mapper.""" if user_provided_mapper is None: - if len(self._sources) != 1: + if len(self.requires_attributes) != 1: raise ValueError( - f"No mapper but {len(self._sources)} stratification sources are " - f"provided for stratification {self.name}. The list of sources " + f"No mapper but {len(self.requires_attributes)} required attributes are " + f"provided for stratification {self.name}. The list of required attributes " "must be of length 1 if no mapper is provided." ) return self._default_mapper @@ -139,7 +133,7 @@ def _get_vectorized_mapper( @staticmethod def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]: - """Default stratification mapper that squeezes a DataFrame to a Series. + """Squeezes a DataFrame to a Series. Parameters ---------- @@ -159,12 +153,12 @@ def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]: def get_mapped_col_name(col_name: str) -> str: - """Return a new column name to be used for mapped values""" + """Returns a new column name to be used for mapped values""" return f"{col_name}_{STRATIFICATION_COLUMN_SUFFIX}" def get_original_col_name(col_name: str) -> str: - """Return the original column name given a modified mapped column name.""" + """Returns the original column name given a modified mapped column name.""" return ( col_name[: -(len(STRATIFICATION_COLUMN_SUFFIX)) - 1] if col_name.endswith(f"_{STRATIFICATION_COLUMN_SUFFIX}") diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index fdc736c64..1079dafb5 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -22,12 +22,14 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import PopulationView, SimulantData - from vivarium.framework.resource import Resource from vivarium.types import ClockTime, DataInput, NumericArray def default_probability_function(index: pd.Index[int]) -> pd.Series[float]: - """Transition decision function that always triggers this transition.""" + """Returns a series of ones for the provided index. + + This is the default transition decision function (always triggers this transition). + """ return pd.Series(1.0, index=index) @@ -130,7 +132,7 @@ def __init__( [pd.Index[int]], pd.Series[float] ] = lambda index: pd.Series(1.0, index=index), triggered: Trigger = Trigger.NOT_TRIGGERED, - ): + ) -> None: """Initializes a transition between two states. Parameters @@ -221,7 +223,7 @@ def model(self) -> str | None: def __init__( self, state_id: str, - allow_self_transition: bool = False, + allow_self_transition: bool = True, initialization_weights: DataInput = 0.0, ) -> None: super().__init__() @@ -232,6 +234,15 @@ def __init__( self.initialization_weights = initialization_weights self._model: str | None = None self._sub_components = [self.transition_set] + self.initialization_weights_pipeline = f"{self.state_id}.initialization_weights" + + def setup(self, builder: Builder) -> None: + self.initialization_weights_table = self.build_lookup_table( + builder, "initialization_weights" + ) + builder.value.register_attribute_producer( + self.initialization_weights_pipeline, self.initialization_weights_table + ) ################## # Public methods # @@ -239,9 +250,9 @@ def __init__( def has_initialization_weights(self) -> bool: """Determines if state has explicitly defined initialization weights.""" - return ( - not isinstance(self.initialization_weights, (float, int)) - or self.initialization_weights != 0.0 + return not ( + not isinstance(self.initialization_weights_table.data, pd.DataFrame) + and self.initialization_weights_table.data == 0.0 ) def set_model(self, model_name: str) -> None: @@ -330,9 +341,6 @@ def add_transition( self.transition_set.append(transition) return transition - def allow_self_transitions(self) -> None: - self.transition_set.allow_null_transition = True - ################## # Helper methods # ################## @@ -381,11 +389,11 @@ def name(self) -> str: ##################### def __init__( - self, state_id: str, *transitions: Transition, allow_self_transition: bool = False + self, state_id: str, *transitions: Transition, allow_self_transition: bool = True ): super().__init__() self.state_id = state_id - self.allow_null_transition = allow_self_transition + self.allow_self_transition = allow_self_transition self.transitions: list[Transition] = [] self._sub_components = self.transitions @@ -451,7 +459,7 @@ def extend(self, transitions: Iterable[Transition]) -> None: def _normalize_probabilities( self, outputs: list[State | str], probabilities: NumericArray ) -> tuple[list[State | str], NumericArray]: - """Normalize probabilities to sum to 1 and add a null transition. + """Normalizes probabilities to sum to 1 and add a null transition. Parameters ---------- @@ -485,7 +493,7 @@ def _normalize_probabilities( probabilities[has_default] /= total[has_default, np.newaxis] total = np.sum(probabilities, axis=1) # All totals should be ~<= 1 at this point. - if self.allow_null_transition: + if self.allow_self_transition: if np.any(total > 1 + 1e-08): # Accommodate rounding errors raise ValueError( f"Null transition requested with un-normalized " @@ -534,16 +542,6 @@ class Machine(Component): def sub_components(self) -> Sequence[Component]: return self.states - @property - def columns_created(self) -> list[str]: - return [self.state_column] - - @property - def initialization_requirements( - self, - ) -> list[str | Resource]: - return [self.randomness] - ##################### # Lifecycle methods # ##################### @@ -557,53 +555,54 @@ def __init__( super().__init__() self.states: list[State] = [] self.state_column = state_column + self._initial_state = initial_state + self.initialization_weights_pipelines: list[str] = [] + if states: self.add_states(states) - states_with_initialization_weights = [ - state for state in self.states if state.has_initialization_weights() - ] - if initial_state is not None: if initial_state not in self.states: raise ValueError( f"Initial state '{initial_state}' must be one of the" f" states: {self.states}." ) - if states_with_initialization_weights: - raise ValueError( - "Cannot specify both an initial state and provide" - " initialization weights to states." - ) initial_state.initialization_weights = 1.0 - # TODO: [MIC-5403] remove this on_initialize_simulants check once - # VPH's DiseaseModel has a compatible initialization strategy - elif ( - type(self).on_initialize_simulants == Machine.on_initialize_simulants - and not states_with_initialization_weights - ): + def setup(self, builder: Builder) -> None: + self.randomness = builder.randomness.get_stream(self.name) + builder.population.register_initializer( + initializer=self.initialize_state, + columns=self.state_column, + required_resources=[self.randomness, *self.initialization_weights_pipelines], + ) + + def on_post_setup(self, event: Event) -> None: + states_with_initialization_weights = [ + state for state in self.states if state.has_initialization_weights() + ] + if self._initial_state is not None and states_with_initialization_weights != [ + self._initial_state + ]: + raise ValueError( + "Cannot specify both an initial state and provide initialization" + " weights to states." + ) + elif self._initial_state is None and not states_with_initialization_weights: raise ValueError( "Must specify either an initial state or provide" " initialization weights to states." ) - def setup(self, builder: Builder) -> None: - self.randomness = builder.randomness.get_stream(self.name) - - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_state(self, pop_data: SimulantData) -> None: state_ids = [s.state_id for s in self.states] - state_weights = pd.concat( - [ - state.lookup_tables["initialization_weights"](pop_data.index) - for state in self.states - ], - axis=1, - ).to_numpy() + state_weights = self.population_view.get_attributes( + pop_data.index, self.initialization_weights_pipelines + ) initial_states = self.randomness.choice( - pop_data.index, state_ids, state_weights, "initialization" + pop_data.index, state_ids, state_weights.to_numpy(), "initialization" ).rename(self.state_column) self.population_view.update(initial_states) @@ -620,6 +619,9 @@ def on_time_step_cleanup(self, event: Event) -> None: def add_states(self, states: Iterable[State]) -> None: for state in states: self.states.append(state) + self.initialization_weights_pipelines.append( + state.initialization_weights_pipeline + ) state.set_model(self.state_column) def transition(self, index: pd.Index[int], event_time: ClockTime) -> None: @@ -637,7 +639,7 @@ def transition(self, index: pd.Index[int], event_time: ClockTime) -> None: state.next_state( affected.index, event_time, - self.population_view.subview(self.state_column), + self.population_view, ) def cleanup(self, index: pd.Index[int], event_time: ClockTime) -> None: @@ -645,12 +647,14 @@ def cleanup(self, index: pd.Index[int], event_time: ClockTime) -> None: if not affected.empty: state.cleanup_effect(affected.index, event_time) - def _get_state_pops(self, index: pd.Index[int]) -> list[tuple[State, pd.DataFrame]]: - population = self.population_view.get(index) - return [ - (state, population[population[self.state_column] == state.state_id]) - for state in self.states - ] + def _get_state_pops(self, index: pd.Index[int]) -> list[tuple[State, pd.Series[Any]]]: + population = self.population_view.get_attributes(index, self.state_column) + if not isinstance(population, pd.Series): + raise TypeError( + "Expected population view to return a pandas Series for" + f" state column '{self.state_column}', but got: {type(population)}" + ) + return [(state, population[population == state.state_id]) for state in self.states] ################## # Helper methods # diff --git a/src/vivarium/framework/time/__init__.py b/src/vivarium/framework/time/__init__.py new file mode 100644 index 000000000..52a296fc6 --- /dev/null +++ b/src/vivarium/framework/time/__init__.py @@ -0,0 +1,7 @@ +from vivarium.framework.time.interface import TimeInterface +from vivarium.framework.time.manager import ( + DateTimeClock, + SimpleClock, + SimulationClock, + get_time_stamp, +) diff --git a/src/vivarium/framework/time/interface.py b/src/vivarium/framework/time/interface.py new file mode 100644 index 000000000..cf88a4dc0 --- /dev/null +++ b/src/vivarium/framework/time/interface.py @@ -0,0 +1,75 @@ +""" +============== +Time Interface +============== + +This module provides an interface to the various types of +:class:`simulation clocks ` for +use in ``vivarium``. + +For more information about time in the simulation, see the associated +:ref:`concept note `. + +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING + +import pandas as pd + +from vivarium.types import ClockStepSize, ClockTime + +if TYPE_CHECKING: + from vivarium.framework.resource import Resource + from vivarium.framework.time.manager import SimulationClock + +from vivarium.manager import Interface + + +class TimeInterface(Interface): + """Public interface for the simulation time management system.""" + + def __init__(self, manager: SimulationClock) -> None: + self._manager = manager + + def clock(self) -> Callable[[], ClockTime]: + """Gets a callable that returns the current simulation time.""" + return lambda: self._manager.time + + def step_size(self) -> Callable[[], ClockStepSize]: + """Gets a callable that returns the current simulation step size.""" + return lambda: self._manager.step_size + + def simulant_next_event_times(self) -> Callable[[pd.Index[int]], pd.Series[ClockTime]]: + """Gets a callable that returns the next event times for simulants.""" + return self._manager.simulant_next_event_times + + def simulant_step_sizes(self) -> Callable[[pd.Index[int]], pd.Series[ClockStepSize]]: + """Gets a callable that returns the simulant step sizes.""" + return self._manager.simulant_step_sizes + + def move_simulants_to_end(self) -> Callable[[pd.Index[int]], None]: + """Gets a callable that moves simulants to the end of the simulation""" + return self._manager.move_simulants_to_end + + def register_step_size_modifier( + self, + modifier: Callable[[pd.Index[int]], pd.Series[ClockStepSize]], + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Registers a step size modifier. + + Parameters + ---------- + modifier + Modifier of the step size pipeline. Modifiers can take an index + and should return a series of step sizes. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + """ + return self._manager.register_step_modifier( + modifier=modifier, required_resources=required_resources + ) diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time/manager.py similarity index 63% rename from src/vivarium/framework/time.py rename to src/vivarium/framework/time/manager.py index 28c206523..ee267dbd0 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time/manager.py @@ -1,7 +1,7 @@ """ -==================== -The Simulation Clock -==================== +============ +Time Manager +============ The components here provide implementations of different kinds of simulation clocks for use in ``vivarium``. @@ -14,7 +14,6 @@ from __future__ import annotations import math -from collections.abc import Callable from functools import partial from typing import TYPE_CHECKING, Any @@ -26,30 +25,21 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.population.population_view import PopulationView from vivarium.framework.event import Event from vivarium.framework.population import SimulantData from vivarium.framework.values import ValuesManager from vivarium.framework.values import list_combiner -from vivarium.manager import Interface, Manager +from vivarium.manager import Manager class SimulationClock(Manager): - """A base clock that includes global clock and a pandas series of clocks for each simulant""" + """A time manager that includes a global clock and simulant-specific clocks.""" @property def name(self) -> str: return "simulation_clock" - @property - def columns_created(self) -> list[str]: - return ["next_event_time", "step_size"] - - @property - def columns_required(self) -> list[str]: - return ["tracked"] - @property def time(self) -> ClockTime: """The current simulation time.""" @@ -107,38 +97,38 @@ def __init__(self) -> None: self._minimum_step_size: ClockStepSize | None = None self._standard_step_size: ClockStepSize | None = None self._clock_step_size: ClockStepSize | None = None - self._individual_clocks: PopulationView | None = None - self._pipeline_name = "simulant_step_size" - # TODO: Delegate this functionality to "tracked" or similar when appropriate + self._individual_clocks: pd.DataFrame | None = None + self._simulant_step_size_pipeline = "simulant_step_size" + # TODO: Delegate this functionality a better place when appropriate self._simulants_to_snooze = pd.Index([]) - def setup(self, builder: "Builder") -> None: + def setup(self, builder: Builder) -> None: + super().setup(builder) self._step_size_pipeline = builder.value.register_value_producer( - self._pipeline_name, + self._simulant_step_size_pipeline, source=lambda idx: [pd.Series(np.nan, index=idx).astype("timedelta64[ns]")], preferred_combiner=list_combiner, preferred_post_processor=self.step_size_post_processor, ) self.register_step_modifier = partial( builder.value.register_value_modifier, - self._pipeline_name, - component=self, + self._simulant_step_size_pipeline, ) - builder.population.initializes_simulants(self, creates_columns=self.columns_created) - builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup) - self._individual_clocks = builder.population.get_view( - columns=self.columns_created + self.columns_required + builder.population.register_initializer( + initializer=self.initialize_individual_clock, columns=None ) + builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup) + self._individual_clocks = pd.DataFrame() - def on_post_setup(self, event: "Event") -> None: + def on_post_setup(self, event: Event) -> None: if not self._step_size_pipeline.mutators: - ## No components modify the step size, so we use the default - ## and remove the population view + # No components modify the step size, so we use the default + # and remove the dataframe self._individual_clocks = None - def on_initialize_simulants(self, pop_data: "SimulantData") -> None: + def initialize_individual_clock(self, pop_data: SimulantData) -> None: """Sets the next_event_time and step_size columns for each simulant""" - if self._individual_clocks: + if self._individual_clocks is not None: clocks_to_initialize = pd.DataFrame( { "next_event_time": [self.event_time] * len(pop_data.index), @@ -146,23 +136,21 @@ def on_initialize_simulants(self, pop_data: "SimulantData") -> None: }, index=pop_data.index, ) - self._individual_clocks.update(clocks_to_initialize) + self._individual_clocks = pd.concat( + [self._individual_clocks, clocks_to_initialize] + ) def simulant_next_event_times(self, index: pd.Index[int]) -> pd.Series[ClockTime]: """The next time each simulant will be updated.""" - if not self._individual_clocks: + if self._individual_clocks is None: return pd.Series(self.event_time, index=index) - return self._individual_clocks.subview(["next_event_time", "tracked"]).get(index)[ - "next_event_time" - ] + return self._individual_clocks.loc[index, "next_event_time"] def simulant_step_sizes(self, index: pd.Index[int]) -> pd.Series[ClockStepSize]: """The step size for each simulant.""" - if not self._individual_clocks: + if self._individual_clocks is None: return pd.Series(self.step_size, index=index) - return self._individual_clocks.subview(["step_size", "tracked"]).get(index)[ - "step_size" - ] + return self._individual_clocks.loc[index, "step_size"] def step_backward(self) -> None: """Rewinds the clock by the current step size.""" @@ -173,33 +161,32 @@ def step_backward(self) -> None: def step_forward(self, index: pd.Index[int]) -> None: """Advances the clock by the current step size, and updates aligned simulant clocks.""" self._clock_time += self.step_size # type: ignore [assignment, operator] - if self._individual_clocks and not index.empty: + if self._individual_clocks is not None and not index.empty: update_index = self.get_active_simulants(index, self.time) - clocks_to_update = self._individual_clocks.get(update_index) - if not clocks_to_update.empty: - clocks_to_update["step_size"] = self._step_size_pipeline(update_index) - # Simulants that were flagged to get moved to the end should have a next event time - # of stop time + 1 minimum timestep - clocks_to_update.loc[self._simulants_to_snooze, "step_size"] = ( + if not update_index.empty: + self._individual_clocks.loc[ + update_index, "step_size" + ] = self._step_size_pipeline(update_index) + self._individual_clocks.loc[self._simulants_to_snooze, "step_size"] = ( self.stop_time + self.minimum_step_size - self.time # type: ignore [operator] ) - # TODO: Delegate this functionality to "tracked" or similar when appropriate + # TODO: Delegate this functionality to a better place when appropriate self._simulants_to_snooze = pd.Index([]) - clocks_to_update["next_event_time"] = ( - self.time + clocks_to_update["step_size"] + self._individual_clocks.loc[update_index, "next_event_time"] = ( + self.time + self._individual_clocks.loc[update_index, "step_size"] ) - self._individual_clocks.update(clocks_to_update) + self._clock_step_size = self.simulant_next_event_times(index).min() - self.time # type: ignore [operator] def get_active_simulants(self, index: pd.Index[int], time: ClockTime) -> pd.Index[int]: """Gets population that is aligned with global clock""" - if index.empty or not self._individual_clocks: + if index.empty or self._individual_clocks is None: return index next_event_times = self.simulant_next_event_times(index) return next_event_times[next_event_times <= time].index def move_simulants_to_end(self, index: pd.Index[int]) -> None: - if self._individual_clocks and not index.empty: + if self._individual_clocks is not None and not index.empty: self._simulants_to_snooze = self._simulants_to_snooze.union(index) def step_size_post_processor(self, value: Any, manager: ValuesManager) -> Any: @@ -211,21 +198,29 @@ def step_size_post_processor(self, value: Any, manager: ValuesManager) -> Any: Parameters ---------- - values + index + The index of the population for which the attribute is being produced + (not used by this post processor but is required to be used by + AttributePipelines). + value A list of step sizes + manager + The ValuesManager for this simulation (not used by this post processor + but is required to be used by AttributePipelines). Returns ------- - The largest feasible step size for each simulant + The largest feasible step size for each simulant (not used by this + post processor but is required to be used by AttributePipelines). """ min_modified = pd.DataFrame(value).min(axis=0).fillna(self.standard_step_size) - ## Rescale pipeline values to global minimum step size + # Rescale pipeline values to global minimum step size discretized_step_sizes = ( np.floor(min_modified / self.minimum_step_size).replace(0, 1) # type: ignore [attr-defined, operator] * self.minimum_step_size ) - ## Make sure we don't get zero + # Make sure we don't get zero return discretized_step_sizes @@ -265,7 +260,7 @@ def get_time_stamp(time: dict[str, int]) -> pd.Timestamp: class DateTimeClock(SimulationClock): - """A date-time based simulation clock.""" + """A time manager that uses a date-time based simulation clock.""" CONFIGURATION_DEFAULTS = { "time": { @@ -303,59 +298,3 @@ def setup(self, builder: Builder) -> None: def __repr__(self) -> str: return "DateTimeClock()" - - -class TimeInterface(Interface): - def __init__(self, manager: SimulationClock) -> None: - self._manager = manager - - def clock(self) -> Callable[[], ClockTime]: - """Gets a callable that returns the current simulation time.""" - return lambda: self._manager.time - - def step_size(self) -> Callable[[], ClockStepSize]: - """Gets a callable that returns the current simulation step size.""" - return lambda: self._manager.step_size - - def simulant_next_event_times(self) -> Callable[[pd.Index[int]], pd.Series[ClockTime]]: - """Gets a callable that returns the next event times for simulants.""" - return self._manager.simulant_next_event_times - - def simulant_step_sizes(self) -> Callable[[pd.Index[int]], pd.Series[ClockStepSize]]: - """Gets a callable that returns the simulant step sizes.""" - return self._manager.simulant_step_sizes - - def move_simulants_to_end(self) -> Callable[[pd.Index[int]], None]: - """Gets a callable that moves simulants to the end of the simulation""" - return self._manager.move_simulants_to_end - - def register_step_size_modifier( - self, - modifier: Callable[[pd.Index[int]], pd.Series[ClockStepSize]], - requires_columns: list[str] = [], - requires_values: list[str] = [], - requires_streams: list[str] = [], - ) -> None: - """Registers a step size modifier. - - Parameters - ---------- - modifier - Modifier of the step size pipeline. Modifiers can take an index - and should return a series of step sizes. - requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the modifier - is called. - requires_values - A list of the value pipelines that need to be properly sourced - before the modifier is called. - requires_streams - A list of the randomness streams that need to be properly sourced - before the modifier is called.""" - return self._manager.register_step_modifier( - modifier=modifier, - requires_columns=requires_columns, - requires_values=requires_values, - requires_streams=requires_streams, - ) diff --git a/src/vivarium/framework/utilities.py b/src/vivarium/framework/utilities.py index 6284c2638..58e9a949b 100644 --- a/src/vivarium/framework/utilities.py +++ b/src/vivarium/framework/utilities.py @@ -6,25 +6,84 @@ Collection of utility functions shared by the ``vivarium`` framework. """ +from __future__ import annotations + import functools from bdb import BdbQuit from collections.abc import Callable, Sequence from importlib import import_module -from typing import Any, Literal +from typing import Any, Literal, TypeVar, overload import numpy as np +import pandas as pd from loguru import logger from vivarium.types import NumberLike, NumericArray, Timedelta +TimeValue = TypeVar("TimeValue", bound=NumberLike) + + +@overload +def from_yearly(value: int, time_step: Timedelta) -> float: + ... + + +@overload +def from_yearly(value: float, time_step: Timedelta) -> float: + ... + + +@overload +def from_yearly(value: NumericArray, time_step: Timedelta) -> NumericArray: + ... + + +@overload +def from_yearly( + value: pd.Series[int] | pd.Series[float], time_step: Timedelta +) -> pd.Series[float]: + ... + + +@overload +def from_yearly(value: pd.DataFrame, time_step: Timedelta) -> pd.DataFrame: + ... + def from_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike: - """Rescale a yearly rate to the size of a time step.""" + """Rescales a yearly rate to the size of a time step.""" return value * (time_step.total_seconds() / (60 * 60 * 24 * 365.0)) +@overload +def to_yearly(value: int, time_step: Timedelta) -> float: + ... + + +@overload +def to_yearly(value: float, time_step: Timedelta) -> float: + ... + + +@overload +def to_yearly(value: NumericArray, time_step: Timedelta) -> NumericArray: + ... + + +@overload +def to_yearly( + value: pd.Series[int] | pd.Series[float], time_step: Timedelta +) -> pd.Series[float]: + ... + + +@overload +def to_yearly(value: pd.DataFrame, time_step: Timedelta) -> pd.DataFrame: + ... + + def to_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike: - """Convert a time-step-scaled rate back to a yearly rate.""" + """Converts a time-step-scaled rate back to a yearly rate.""" return value / (time_step.total_seconds() / (60 * 60 * 24 * 365.0)) @@ -48,7 +107,19 @@ def rate_to_probability( Returns ------- - An array of floats representing the probability of the converted rates + An array of floats representing the probability of the converted rates. + + Raises + ------ + ValueError + If an unsupported rate conversion type is provided. + + Notes + ----- + Beware machine-specific floating point issues. We have encountered underflow + when using the exponential conversion for rates greater than ~30,000. To avoid + this, we cap the rate at 250 when using the exponential conversion since + exp(-250) is effectively zero for practical purposes. """ if rate_conversion_type not in ["linear", "exponential"]: raise ValueError( @@ -70,9 +141,7 @@ def rate_to_probability( "The probability has been clipped to 1.0 and indicates the rate is too high. " ) else: - # encountered underflow from rate > 30k - # for rates greater than 250, exp(-rate) evaluates to 1e-109 - # beware machine-specific floating point issues + # NOTE: Cap the rate at 250 to avoid floating point underflow issues rate = np.asarray(rate) rate[rate > 250] = 250.0 probability = 1 - np.exp(-rate * time_scaling_factor) @@ -85,7 +154,7 @@ def probability_to_rate( time_scaling_factor: float | int = 1.0, rate_conversion_type: Literal["linear", "exponential"] = "linear", ) -> NumericArray: - """Function to convert a probability to a rate. + """Converts a probability to a rate. Parameters ---------- @@ -100,7 +169,12 @@ def probability_to_rate( Returns ------- - An array of floats representing the rate of the converted probabilities + An array of floats representing the rate of the converted probabilities. + + Raises + ------ + ValueError + If an unsupported rate conversion type is provided. """ # NOTE: The default behavior for randomness streams is to use a rate that is already # scaled to the time step which is why the default time scaling factor is 1.0. @@ -116,6 +190,7 @@ def probability_to_rate( else: probability = np.asarray(probability) rate = -np.log(1 - probability) + return rate @@ -133,16 +208,16 @@ def collapse_nested_dict( def import_by_path(path: str) -> Callable[..., Any]: - """Import a class or function given its absolute path. + """Imports a class or function given its absolute path. Parameters ---------- path - Path to object to import + Fully qualified dotted path to the object (e.g. "module.submodule.ClassName") Returns ------- - The imported class or function + The imported class or function. """ module_path, _, class_name = path.rpartition(".") diff --git a/src/vivarium/framework/values/__init__.py b/src/vivarium/framework/values/__init__.py index 27dd71bf5..b9bb8fc44 100644 --- a/src/vivarium/framework/values/__init__.py +++ b/src/vivarium/framework/values/__init__.py @@ -14,9 +14,16 @@ """ from vivarium.framework.values.combiners import ValueCombiner, list_combiner, replace_combiner from vivarium.framework.values.exceptions import DynamicValueError -from vivarium.framework.values.manager import ValuesInterface, ValuesManager -from vivarium.framework.values.pipeline import Pipeline, ValueModifier, ValueSource +from vivarium.framework.values.interface import ValuesInterface +from vivarium.framework.values.manager import ValuesManager +from vivarium.framework.values.pipeline import ( + AttributePipeline, + Pipeline, + ValueModifier, + ValueSource, +) from vivarium.framework.values.post_processors import ( + AttributePostProcessor, PostProcessor, rescale_post_processor, union_post_processor, diff --git a/src/vivarium/framework/values/combiners.py b/src/vivarium/framework/values/combiners.py index ba2ad5292..c37f90ca3 100644 --- a/src/vivarium/framework/values/combiners.py +++ b/src/vivarium/framework/values/combiners.py @@ -12,7 +12,7 @@ def __call__( def replace_combiner( value: Any, mutator: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: - """Replace the previous pipeline output with the output of the mutator. + """Replaces the previous pipeline output with the output of the mutator. This is the default combiner. diff --git a/src/vivarium/framework/values/interface.py b/src/vivarium/framework/values/interface.py new file mode 100644 index 000000000..17d1c54ba --- /dev/null +++ b/src/vivarium/framework/values/interface.py @@ -0,0 +1,266 @@ +""" +================ +Values Interface +================ + +This module provides a :class:`ValuesInterface ` class with +methods to register different types of value and attribute producers and modifiers. + +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any + +from vivarium.framework.resource import Resource +from vivarium.framework.values.combiners import ValueCombiner, replace_combiner +from vivarium.framework.values.pipeline import AttributePipeline, Pipeline +from vivarium.framework.values.post_processors import ( + AttributePostProcessor, + PostProcessor, + rescale_post_processor, +) +from vivarium.manager import Interface + +if TYPE_CHECKING: + import pandas as pd + + from vivarium.framework.values import ValuesManager + + +class ValuesInterface(Interface): + """Public interface for the simulation values management system. + + The values system provides tools to build up values across many components, + allowing users to build components that focus on small groups of simulation + variables. + + Notes + ----- + This is the only public interface for the values system; different methods + exist for working with generic value :class:`Pipelines ` + and :class:`AttributePipelines `. + + """ + + def __init__(self, manager: ValuesManager) -> None: + self._manager = manager + + def register_value_producer( + self, + value_name: str, + source: Callable[..., Any], + required_resources: Sequence[str | Resource] = (), + preferred_combiner: ValueCombiner = replace_combiner, + preferred_post_processor: PostProcessor | None = None, + ) -> Pipeline: + """Registers a ``Pipeline`` as the producer of a named value. + + Parameters + ---------- + value_name + The name of the new dynamic value pipeline. + source + A callable source for the dynamic value pipeline. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + preferred_combiner + A strategy for combining the source and the results of any calls + to mutators in the pipeline. ``vivarium`` provides the strategies + ``replace_combiner`` (the default) and ``list_combiner``, which + are importable from ``vivarium.framework.values``. Client code + may define additional strategies as necessary. + preferred_post_processor + A strategy for processing the final output of the pipeline. + ``vivarium`` provides the strategies ``rescale_post_processor`` + and ``union_post_processor`` which are importable from + ``vivarium.framework.values``. Client code may define additional + strategies as necessary. + + Returns + ------- + The ``Pipeline`` that is registered as the producer of the named value. + """ + return self._manager.register_value_producer( + value_name, + source, + required_resources, + preferred_combiner, + preferred_post_processor, + ) + + def register_attribute_producer( + self, + value_name: str, + source: Callable[[pd.Index[int]], Any] | list[str], + required_resources: Sequence[str | Resource] = (), + preferred_combiner: ValueCombiner = replace_combiner, + preferred_post_processor: AttributePostProcessor | None = None, + source_is_private_column: bool = False, + ) -> None: + """Registers an ``AttributePipeline`` as the producer of a named attribute. + + Parameters + ---------- + value_name + The name of the new dynamic attribute pipeline. + source + The source for the dynamic attribute pipeline. This can be a callable, + a list containing a single name of a private column created by this + component, or a list of population attributes. If a private column name + is passed, `source_is_private_column` must also be set to True. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + preferred_combiner + A strategy for combining the source and the results of any calls + to mutators in the pipeline. ``vivarium`` provides the strategies + ``replace_combiner`` (the default) and ``list_combiner``, which + are importable from ``vivarium.framework.values``. Client code + may define additional strategies as necessary. + preferred_post_processor + A strategy for processing the final output of the pipeline. + ``vivarium`` provides the strategies ``rescale_post_processor`` + and ``union_post_processor`` which are importable from + ``vivarium.framework.values``. Client code may define additional + strategies as necessary. + source_is_private_column + Whether or not the source is the name of a private column created by + this component. + """ + self._manager.register_attribute_producer( + value_name, + source, + required_resources, + preferred_combiner, + preferred_post_processor, + source_is_private_column, + ) + + def register_rate_producer( + self, + rate_name: str, + source: Callable[[pd.Index[int]], Any] | list[str], + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Registers an ``AttributePipeline`` as the producer of a named rate. + + This is a convenience wrapper around ``register_attribute_producer`` that + makes sure rate data is appropriately scaled to the size of the simulation + time step. It is equivalent to calling ``register_attribute_producer()`` + with the ``rescale_post_processor`` as the preferred post processor. + + Parameters + ---------- + rate_name + The name of the new dynamic rate pipeline. + source + The source for the dynamic rate pipeline. This can be a callable + or a list of column names. If a list of column names is provided, + the component that is registering this attribute producer must be the + one that creates those columns. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + """ + self.register_attribute_producer( + rate_name, + source, + required_resources, + preferred_post_processor=rescale_post_processor, + ) + + def register_value_modifier( + self, + value_name: str, + modifier: Callable[..., Any], + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Marks a ``Callable`` as the modifier of a named value. + + Parameters + ---------- + value_name + The name of the dynamic value ``Pipeline`` to be modified. + modifier + A function that modifies the source of the dynamic value ``Pipeline`` + when called. If the pipeline has a ``replace_combiner``, the + modifier must have the same arguments as the pipeline source + with an additional last positional argument for the results of the + previous stage in the pipeline. For the ``list_combiner`` strategy, + the pipeline modifiers should have the same signature as the pipeline + source. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + """ + self._manager.register_value_modifier(value_name, modifier, required_resources) + + def register_attribute_modifier( + self, + value_name: str, + modifier: Callable[..., Any] | str, + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Marks a ``Callable`` as the modifier of a named attribute. + + Parameters + ---------- + value_name + The name of the dynamic ``AttributePipeline`` to be modified. + modifier + A function that modifies the source of the dynamic ``AttributePipeline`` + when called. If a string is passed, it refers to the name of an ``AttributePipeline``. + If the pipeline has a ``replace_combiner``, the modifier should accept + the same arguments as the pipeline source with an additional last positional + argument for the results of the previous stage in the pipeline. For + the ``list_combiner`` strategy, the pipeline modifiers should have the + same signature as the pipeline source. + required_resources + A list of resources that the producer requires. A string represents + a population attribute. + """ + self._manager.register_attribute_modifier( + value_name, + modifier, + required_resources=required_resources, + ) + + def get_value(self, name: str) -> Pipeline: + """Retrieves the ``Pipeline`` representing the named value. + + Parameters + ---------- + name + Name of the ``Pipeline`` to return. + + Returns + ------- + The requested ``Pipeline``. + + Notes + ----- + This will create a new ``Pipeline`` if one does not already exist. + """ + return self._manager.get_value(name) + + def get_attribute_pipelines(self) -> Callable[[], dict[str, AttributePipeline]]: + """Returns a ``Callable`` that retrieves a dictionary of ``AttributePipelines``. + + Returns + ------- + A ``Callable`` that returns a dictionary mapping all registered attribute + names to their corresponding ``AttributePipelines``. + + Notes + ----- + This is not the preferred access method to getting population attributes + since it does not implement various features (e.g. querying, simulant + tracking, etc); it exists for other managers to use if needed. Use + :meth:`vivarium.framework.population.population_view.PopulationView.get_attributes` + or :meth:`vivarium.framework.population.population_view.PopulationView.get_attribute_frame` + instead. + """ + return self._manager.get_attribute_pipelines diff --git a/src/vivarium/framework/values/manager.py b/src/vivarium/framework/values/manager.py index f2fd3c72e..0c9d25d10 100644 --- a/src/vivarium/framework/values/manager.py +++ b/src/vivarium/framework/values/manager.py @@ -1,6 +1,12 @@ +""" +============== +Values Manager +============== + +""" + from __future__ import annotations -import warnings from collections.abc import Callable, Iterable, Sequence from typing import TYPE_CHECKING, Any, TypeVar @@ -8,36 +14,58 @@ from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.resource import Resource from vivarium.framework.values.combiners import ValueCombiner, replace_combiner -from vivarium.framework.values.pipeline import Pipeline -from vivarium.framework.values.post_processors import PostProcessor, rescale_post_processor -from vivarium.manager import Interface, Manager +from vivarium.framework.values.pipeline import ( + AttributePipeline, + AttributesValueSource, + DynamicValueError, + Pipeline, + PrivateColumnValueSource, + ValueSource, +) +from vivarium.framework.values.post_processors import AttributePostProcessor, PostProcessor +from vivarium.manager import Manager if TYPE_CHECKING: - from vivarium import Component + import pandas as pd + from vivarium.framework.engine import Builder T = TypeVar("T") class ValuesManager(Manager): - """Manager for the dynamic value system.""" + """Manager for the dynamic value system. + + Notes + ----- + This is the only manager for the values system; different methods exist for + working with generic value :class:`Pipelines ` + and :class:`AttributePipelines `. + """ def __init__(self) -> None: # Pipelines are lazily initialized by _register_value_producer - self._pipelines: dict[str, Pipeline] = {} + self._value_pipelines: dict[str, Pipeline] = {} + self._attribute_pipelines: dict[str, AttributePipeline] = {} @property def name(self) -> str: return "values_manager" + @property + def _all_pipelines(self) -> dict[str, Pipeline]: + return {**self._value_pipelines, **self._attribute_pipelines} + def setup(self, builder: Builder) -> None: + self._population_mgr = builder.population._manager self.logger = builder.logging.get_logger(self.name) self.step_size = builder.time.step_size() self.simulant_step_sizes = builder.time.simulant_step_sizes() builder.event.register_listener("post_setup", self.on_post_setup) - self.resources = builder.resources - self.add_constraint = builder.lifecycle.add_constraint + self._add_resources = builder.resources.add_resources + self._get_current_component = builder.components.get_current_component_or_manager + self._add_constraint = builder.lifecycle.add_constraint builder.lifecycle.add_constraint( self.register_value_producer, allow_during=[lifecycle_states.SETUP] @@ -45,227 +73,48 @@ def setup(self, builder: Builder) -> None: builder.lifecycle.add_constraint( self.register_value_modifier, allow_during=[lifecycle_states.SETUP] ) - - def on_post_setup(self, _event: Event) -> None: - """Finalizes dependency structure for the pipelines.""" - # Unsourced pipelines might occur when generic components register - # modifiers to values that aren't required in a simulation. - unsourced_pipelines = [p for p, v in self._pipelines.items() if not v.source] - if unsourced_pipelines: - self.logger.warning(f"Unsourced pipelines: {unsourced_pipelines}") - - # register_value_producer and register_value_modifier record the - # dependency structure for the pipeline source and pipeline modifiers, - # respectively. We don't have enough information to record the - # dependency structure for the pipeline itself until now, where - # we say the pipeline value depends on its source and all its - # modifiers. - for name, pipe in self._pipelines.items(): - self.resources.add_resources( - pipe.component, [pipe], [pipe.source] + list(pipe.mutators) - ) - - def register_value_producer( - self, - value_name: str, - source: Callable[..., Any], - # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, - requires_columns: Iterable[str] = (), - requires_values: Iterable[str] = (), - requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Resource] = (), - preferred_combiner: ValueCombiner = replace_combiner, - preferred_post_processor: PostProcessor | None = None, - ) -> Pipeline: - """Marks a ``Callable`` as the producer of a named value. - - See Also - -------- - :meth:`ValuesInterface.register_value_producer` - """ - self.logger.debug(f"Registering value pipeline {value_name}") - pipeline = self.get_value(value_name) - pipeline.set_attributes( - component, - source, - preferred_combiner, - preferred_post_processor, - self, - ) - - # The resource we add here is just the pipeline source. - # The value will depend on the source and its modifiers, and we'll - # declare that resource at post-setup once all sources and modifiers - # are registered. - dependencies = self._convert_dependencies( - source, requires_columns, requires_values, requires_streams, required_resources + builder.lifecycle.add_constraint( + self.register_attribute_producer, allow_during=[lifecycle_states.SETUP] ) - self.resources.add_resources(pipeline.component, [pipeline.source], dependencies) - - self.add_constraint( - pipeline._call, - restrict_during=[ - lifecycle_states.INITIALIZATION, - lifecycle_states.SETUP, - lifecycle_states.POST_SETUP, - ], + builder.lifecycle.add_constraint( + self.register_attribute_modifier, allow_during=[lifecycle_states.SETUP] ) - - return pipeline - - def register_value_modifier( - self, - value_name: str, - modifier: Callable[..., Any], - # TODO [MIC-5452]: all calls should have a component - component: Component | Manager | None = None, - requires_columns: Iterable[str] = (), - requires_values: Iterable[str] = (), - requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Resource] = (), - ) -> None: - """Marks a ``Callable`` as the modifier of a named value. - - Parameters - ---------- - value_name : - The name of the dynamic value pipeline to be modified. - modifier : - A function that modifies the source of the dynamic value pipeline - when called. If the pipeline has a ``replace_combiner``, the - modifier should accept the same arguments as the pipeline source - with an additional last positional argument for the results of the - previous stage in the pipeline. For the ``list_combiner`` strategy, - the pipeline modifiers should have the same signature as the pipeline - source. - component - The component that is registering the value modifier. - requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline modifier - is called. - requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline modifier is called. - requires_streams - A list of the randomness streams that need to be properly sourced - before the pipeline modifier is called. - required_resources - A list of resources that need to be properly sourced before the - pipeline modifier is called. This is a list of strings, pipeline - names, or randomness streams. - """ - pipeline = self.get_value(value_name) - value_modifier = pipeline.get_value_modifier(modifier, component) - self.logger.debug(f"Registering {value_modifier.name} as modifier to {value_name}") - - dependencies = self._convert_dependencies( - modifier, requires_columns, requires_values, requires_streams, required_resources + builder.lifecycle.add_constraint( + self.get_attribute_pipelines, restrict_during=[lifecycle_states.SETUP] ) - self.resources.add_resources(component, [value_modifier], dependencies) - - def get_value(self, name: str) -> Pipeline: - """Retrieve the pipeline representing the named value. - - Parameters - ---------- - name - Name of the pipeline to return. - - Returns - ------- - A callable reference to the named pipeline. The pipeline arguments - should be identical to the arguments to the pipeline source - (frequently just a :class:`pandas.Index` representing the - simulants). - """ - pipeline = self._pipelines.get(name) or Pipeline(name) - self._pipelines[name] = pipeline - return pipeline - @staticmethod - def _convert_dependencies( - func: Callable[..., Any], - requires_columns: Iterable[str], - requires_values: Iterable[str], - requires_streams: Iterable[str], - required_resources: Iterable[str | Resource], - ) -> Iterable[str | Resource]: - if isinstance(func, Pipeline): - # The dependencies of the pipeline itself will have been declared - # when the pipeline was registered. - return [func] - - if requires_columns or requires_values or requires_streams: - warnings.warn( - "Specifying requirements individually is deprecated. You should " - "specify them using the 'required_resources' argument instead.", - DeprecationWarning, - stacklevel=2, - ) - if required_resources: - raise ValueError( - "If requires_columns, requires_values, or requires_streams" - " are provided, requirements must be empty." + def on_post_setup(self, _event: Event) -> None: + """Finalizes dependency structure for the pipelines.""" + for pipeline in self._all_pipelines.values(): + # Unsourced pipelines might occur when generic components register + # modifiers to values that aren't required in a simulation. + if not pipeline.source: + self.logger.warning( + f"Pipeline {pipeline.name} has no source. It will not be usable." ) - - return ( - list(requires_columns) - + [Resource("value", name, None) for name in requires_values] - + [Resource("stream", name, None) for name in requires_streams] + continue + + # register_value_producer and register_value_modifier record the + # dependency structure for the pipeline source and pipeline modifiers, + # respectively. We don't have enough information to record the + # dependency structure for the pipeline itself until now, where + # we say the pipeline value depends on its source and all its + # modifiers. + self._add_resources( + component=pipeline.component, + resources=pipeline, + required_resources=[pipeline.source] + list(pipeline.mutators), ) - else: - return required_resources - - def keys(self) -> Iterable[str]: - """Get an iterable of pipeline names.""" - return self._pipelines.keys() - - def items(self) -> Iterable[tuple[str, Pipeline]]: - """Get an iterable of name, pipeline tuples.""" - return self._pipelines.items() - - def values(self) -> Iterable[Pipeline]: - """Get an iterable of all pipelines.""" - return self._pipelines.values() - - def __contains__(self, item: str) -> bool: - return item in self._pipelines - - def __iter__(self) -> Iterable[str]: - return iter(self._pipelines) - - def __repr__(self) -> str: - return "ValuesManager()" - - -class ValuesInterface(Interface): - """Public interface for the simulation values management system. - - The values system provides tools to build up a value across many - components, allowing users to build components that focus on small groups - of simulant attributes. - - """ - - def __init__(self, manager: ValuesManager) -> None: - self._manager = manager def register_value_producer( self, value_name: str, source: Callable[..., Any], - # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, - requires_columns: Iterable[str] = (), - requires_values: Iterable[str] = (), - requires_streams: Iterable[str] = (), required_resources: Sequence[str | Resource] = (), preferred_combiner: ValueCombiner = replace_combiner, preferred_post_processor: PostProcessor | None = None, ) -> Pipeline: - """Marks a ``Callable`` as the producer of a named value. + """Registers a ``Pipeline`` as the producer of a named value. Parameters ---------- @@ -273,175 +122,304 @@ def register_value_producer( The name of the new dynamic value pipeline. source A callable source for the dynamic value pipeline. - component - The component that is registering the value producer. - requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline source - is called. - requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline source is called. - requires_streams - A list of the randomness streams that need to be properly sourced - before the pipeline source is called. required_resources - A list of resources that need to be properly sourced before the - pipeline source is called. This is a list of strings, pipeline - names, or randomness streams. + A list of resources that the producer requires. A string represents + a population attribute. preferred_combiner A strategy for combining the source and the results of any calls to mutators in the pipeline. ``vivarium`` provides the strategies ``replace_combiner`` (the default) and ``list_combiner``, which - are importable from ``vivarium.framework.values``. Client code + are importable from ``vivarium.framework.values``. Client code may define additional strategies as necessary. preferred_post_processor A strategy for processing the final output of the pipeline. ``vivarium`` provides the strategies ``rescale_post_processor`` and ``union_post_processor`` which are importable from - ``vivarium.framework.values``. Client code may define additional + ``vivarium.framework.values``. Client code may define additional strategies as necessary. Returns ------- - A callable reference to the named dynamic value pipeline. + The ``Pipeline`` that is registered as the producer of the named value. """ - return self._manager.register_value_producer( - value_name, + self.logger.debug(f"Registering value pipeline {value_name}") + pipeline = self.get_value(value_name) + self._configure_pipeline( + pipeline, source, - component, - requires_columns, - requires_values, - requires_streams, required_resources, preferred_combiner, preferred_post_processor, ) + return pipeline - def register_rate_producer( + def register_attribute_producer( self, - rate_name: str, - source: Callable[..., Any], - # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, - requires_columns: Iterable[str] = (), - requires_values: Iterable[str] = (), - requires_streams: Iterable[str] = (), + value_name: str, + source: Callable[[pd.Index[int]], Any] | list[str], required_resources: Sequence[str | Resource] = (), - ) -> Pipeline: - """Marks a ``Callable`` as the producer of a named rate. - - This is a convenience wrapper around ``register_value_producer`` that - makes sure rate data is appropriately scaled to the size of the - simulation time step. It is equivalent to - ``register_value_producer(value_name, source, - preferred_combiner=replace_combiner, - preferred_post_processor=rescale_post_processor)`` + preferred_combiner: ValueCombiner = replace_combiner, + preferred_post_processor: AttributePostProcessor | None = None, + source_is_private_column: bool = False, + ) -> None: + """Registers an ``AttributePipeline`` as the producer of a named attribute. Parameters ---------- - rate_name - The name of the new dynamic rate pipeline. + value_name + The name of the new dynamic attribute pipeline. source - A callable source for the dynamic rate pipeline. - component - The component that is registering the rate producer. - requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline source - is called. - requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline source is called. - requires_streams - A list of the randomness streams that need to be properly sourced - before the pipeline source is called. + The source for the dynamic attribute pipeline. This can be a callable, + a list containing a single name of a private column created by this + component, or a list of population attributes. If a private column name + is passed, `source_is_private_column` must also be set to True. required_resources - A list of resources that need to be properly sourced before the - pipeline source is called. This is a list of strings, pipeline - names, or randomness streams. - - Returns - ------- - A callable reference to the named dynamic rate pipeline. + A list of resources that the producer requires. A string represents + a population attribute. + preferred_combiner + A strategy for combining the source and the results of any calls + to mutators in the pipeline. ``vivarium`` provides the strategies + ``replace_combiner`` (the default) and ``list_combiner``, which + are importable from ``vivarium.framework.values``. Client code + may define additional strategies as necessary. + preferred_post_processor + A strategy for processing the final output of the pipeline. + ``vivarium`` provides the strategies ``rescale_post_processor`` + and ``union_post_processor`` which are importable from + ``vivarium.framework.values``. Client code may define additional + strategies as necessary. + source_is_private_column + Whether or not the source is the name of a private column created by + this component. """ - return self.register_value_producer( - rate_name, + self.logger.debug(f"Registering attribute pipeline {value_name}") + pipeline = self.get_attribute(value_name) + self._configure_pipeline( + pipeline, source, - component, - requires_columns, - requires_values, - requires_streams, - required_resources, - preferred_post_processor=rescale_post_processor, + required_resources=required_resources, + preferred_combiner=preferred_combiner, + preferred_post_processor=preferred_post_processor, + source_is_private_column=source_is_private_column, ) def register_value_modifier( self, value_name: str, modifier: Callable[..., Any], - # TODO [MIC-5452]: all calls should have a component - component: Component | Manager | None = None, - requires_columns: Iterable[str] = (), - requires_values: Iterable[str] = (), - requires_streams: Iterable[str] = (), required_resources: Sequence[str | Resource] = (), ) -> None: """Marks a ``Callable`` as the modifier of a named value. Parameters ---------- - value_name : - The name of the dynamic value pipeline to be modified. - modifier : - A function that modifies the source of the dynamic value pipeline + value_name + The name of the dynamic value ``Pipeline`` to be modified. + modifier + A function that modifies the source of the dynamic value ``Pipeline`` when called. If the pipeline has a ``replace_combiner``, the - modifier should accept the same arguments as the pipeline source + modifier must have the same arguments as the pipeline source with an additional last positional argument for the results of the previous stage in the pipeline. For the ``list_combiner`` strategy, the pipeline modifiers should have the same signature as the pipeline source. - component - The component that is registering the value modifier. - requires_columns - A list of the state table columns that already need to be present - and populated in the state table before the pipeline modifier - is called. - requires_values - A list of the value pipelines that need to be properly sourced - before the pipeline modifier is called. - requires_streams - A list of the randomness streams that need to be properly sourced - before the pipeline modifier is called. required_resources - A list of resources that need to be properly sourced before the - pipeline modifier is called. This is a list of strings, pipeline - names, or randomness streams. + A list of resources that the producer requires. A string represents + a population attribute. """ - self._manager.register_value_modifier( - value_name, + self._configure_modifier( + self.get_value(value_name), modifier, - component, - requires_columns, - requires_values, - requires_streams, required_resources, ) + def register_attribute_modifier( + self, + value_name: str, + modifier: Callable[..., Any] | str, + required_resources: Sequence[str | Resource] = (), + ) -> None: + """Marks a ``Callable`` as the modifier of a named attribute. + + Parameters + ---------- + value_name + The name of the dynamic ``AttributePipeline`` to be modified. + modifier + A function that modifies the source of the dynamic ``AttributePipeline`` + when called. If a string is passed, it refers to the name of an ``AttributePipeline``. + If the pipeline has a ``replace_combiner``, the modifier should accept + the same arguments as the pipeline source with an additional last positional + argument for the results of the previous stage in the pipeline. For + the ``list_combiner`` strategy, the pipeline modifiers should have the + same signature as the pipeline source. + required_resources + A list of resources that need to be properly sourced before the + pipeline modifier is called. This is a list of attribute names, pipelines, + or randomness streams. + """ + modifier = self.get_attribute(modifier) if isinstance(modifier, str) else modifier + self._configure_modifier( + self.get_attribute(value_name), + modifier, + required_resources=required_resources, + ) + def get_value(self, name: str) -> Pipeline: - """Retrieve the pipeline representing the named value. + """Retrieves the ``Pipeline`` representing the named value. Parameters ---------- name - Name of the pipeline to return. + Name of the ``Pipeline`` to return. Returns ------- - A callable reference to the named pipeline. The pipeline arguments - should be identical to the arguments to the pipeline source - (frequently just a :class:`pandas.Index` representing the - simulants). + The requested ``Pipeline``. + Notes + ----- + This will create a new ``Pipeline`` if one does not already exist. """ - return self._manager.get_value(name) + if name in self._attribute_pipelines: + raise DynamicValueError( + f"'{name}' is already registered as an attribute pipeline." + ) + pipeline = self._value_pipelines.get(name, Pipeline(name)) + self._value_pipelines[name] = pipeline + return pipeline + + def get_value_pipelines(self) -> dict[str, Pipeline]: + """Retrieves a dictionary of all registered value ``Pipelines``. + + To get all ``AttributePipelines``, use :meth:`get_attribute_pipelines`. + + Returns + ------- + A dictionary mapping value names to their corresponding ``Pipelines``. + """ + return self._value_pipelines + + def get_attribute(self, name: str) -> AttributePipeline: + """Retrieves the ``AttributePipeline`` representing the named attribute. + + To get a value ``Pipeline``, use :meth:`get_value`. + + Parameters + ---------- + name + Name of the ``AttributePipeline`` to return. + + Returns + ------- + The requested ``AttributePipeline``. + + Notes + ----- + This will create a new ``AttributePipeline`` if one does not already exist. + """ + if name in self._value_pipelines: + raise DynamicValueError(f"'{name}' is already registered as a value pipeline.") + pipeline = self._attribute_pipelines.get(name, AttributePipeline(name)) + self._attribute_pipelines[name] = pipeline + return pipeline + + def get_attribute_pipelines(self) -> dict[str, AttributePipeline]: + """Returns a dictionary of ``AttributePipelines``. + + Returns + ------- + A dictionary mapping all registered attribute names to their corresponding + ``AttributePipelines``. + + Notes + ----- + This is not the preferred access method to getting population attributes + since it does not implement various features (e.g. querying, simulant + tracking, etc); it exists for other managers to use if needed. Use + :meth:`vivarium.framework.population.population_view.PopulationView.get_attributes` + or :meth:`vivarium.framework.population.population_view.PopulationView.get_attribute_frame` + instead. + """ + return self._attribute_pipelines + + ################## + # Helper methods # + ################## + + def _configure_pipeline( + self, + pipeline: Pipeline | AttributePipeline, + source: Callable[..., Any] | list[str], + required_resources: Sequence[str | Resource] = (), + preferred_combiner: ValueCombiner = replace_combiner, + preferred_post_processor: PostProcessor | AttributePostProcessor | None = None, + source_is_private_column: bool = False, + ) -> None: + component = self._get_current_component() + value_source: ValueSource + if source_is_private_column: + value_source = PrivateColumnValueSource( + pipeline, source, component, required_resources + ) + elif isinstance(source, list): + value_source = AttributesValueSource( + pipeline, source, component, required_resources + ) + else: + value_source = ValueSource(pipeline, source, component, required_resources) + + pipeline.set_attributes( + component=component, + source=value_source, + combiner=preferred_combiner, + post_processor=preferred_post_processor, # type: ignore[arg-type] + manager=self, + ) + + # The resource we add here is just the pipeline source. + self._add_resources( + component=pipeline.component, + resources=pipeline.source, + required_resources=pipeline.source.required_resources, + ) + + self._add_constraint( + pipeline._call, + restrict_during=[ + lifecycle_states.INITIALIZATION, + lifecycle_states.SETUP, + lifecycle_states.POST_SETUP, + ], + ) + + def _configure_modifier( + self, + pipeline: Pipeline | AttributePipeline, + modifier: Callable[..., Any], + required_resources: Sequence[str | Resource] = (), + ) -> None: + component = self._get_current_component() + value_modifier = pipeline.get_value_modifier(modifier, component) + self.logger.debug(f"Registering {value_modifier.name} as modifier to {pipeline.name}") + if isinstance(modifier, Resource) and required_resources: + self.logger.warning( + f"Conflicting information for {pipeline.name}. Ignoring 'required_resources' " + f"since the `modifier` is of type {type(modifier)} and we can infer " + "the required resources directly." + ) + required_resources = [modifier] + self._add_resources( + component=component, + resources=value_modifier, + required_resources=required_resources, + ) + + def __contains__(self, item: str) -> bool: + return item in self._all_pipelines + + def __iter__(self) -> Iterable[str]: + return iter(self._all_pipelines) + + def __repr__(self) -> str: + return "ValuesManager()" diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index b7afe8ebb..5050f90cd 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -1,51 +1,131 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, TypeVar import pandas as pd from vivarium import Component -from vivarium.framework.resource import Resource +from vivarium.framework.resource import Column, Resource from vivarium.framework.values.exceptions import DynamicValueError from vivarium.manager import Manager if TYPE_CHECKING: - from vivarium.framework.values.combiners import ValueCombiner - from vivarium.framework.values.manager import ValuesManager - from vivarium.framework.values.post_processors import PostProcessor + from vivarium.framework.population import PopulationManager + from vivarium.framework.values import ( + AttributePostProcessor, + PostProcessor, + ValueCombiner, + ValuesManager, + ) T = TypeVar("T") class ValueSource(Resource): - """A resource representing the source of a value pipeline.""" + """A resource representing the source of a value pipeline. + + If the source is a private column, use :class:`PrivateColumnValueSource` instead. + """ def __init__( self, pipeline: Pipeline, source: Callable[..., Any] | None, - component: Component | None, + component: Component | Manager | None, + required_resources: Sequence[str | Resource] | None = None, ) -> None: + self._pipeline_type = ( + "attribute" if isinstance(pipeline, AttributePipeline) else "value" + ) super().__init__( - "value_source" if source else "missing_value_source", pipeline.name, component + f"{self._pipeline_type}_source" + if source + else f"missing_{self._pipeline_type}_source", + pipeline.name, + component, ) self._pipeline = pipeline self._source = source + self.required_resources = required_resources or [] + if isinstance(source, Resource): + self.required_resources.append(source) def __bool__(self) -> bool: return self._source is not None - def __call__(self, *args: Any, **kwargs: Any) -> Any: - if not self._source: + def __call__(self, population_mgr: PopulationManager, *args: Any, **kwargs: Any) -> Any: + if self._source is None: raise DynamicValueError( - f"The dynamic value pipeline for {self.name} has no source." + f"The dynamic {self._pipeline_type} pipeline for {self.name} has no source." " This likely means you are attempting to modify a value that" " hasn't been created." ) + return self._source(*args, **kwargs) +class PrivateColumnValueSource(ValueSource): + """A resource representing private column source of a value pipeline.""" + + def __init__( + self, + pipeline: Pipeline, + source: Callable[..., Any] | list[str], + component: Component | Manager, + required_resources: Sequence[str | Resource], + ) -> None: + generic_error_msg = ( + f"Invalid source for {pipeline.name}. `source` must be list containing a single" + " private column name." + ) + if not isinstance(source, list): + raise ValueError(generic_error_msg + f"Got `source` type {type(source)} instead.") + if len(source) != 1: + raise ValueError(generic_error_msg + f"Got {len(source)} names instead.") + + self.column = Column(source[0], component) + required_resources = [self.column, *required_resources] + super().__init__( + pipeline, source=None, component=component, required_resources=required_resources + ) + + def __bool__(self) -> bool: + return True + + def __call__( + self, population_mgr: PopulationManager, index: pd.Index[int] + ) -> pd.Series[Any]: + return population_mgr.get_private_columns( + component=self.component, columns=self.column.name, index=index + ) + + +class AttributesValueSource(ValueSource): + """A resource representing the list of attributes source of an attribute pipeline.""" + + def __init__( + self, + pipeline: Pipeline, + source: list[str], + component: Component | Manager, + required_resources: Sequence[str | Resource], + ) -> None: + self.attributes = source + required_resources = [*self.attributes, *required_resources] + super().__init__( + pipeline, source=None, component=component, required_resources=required_resources + ) + + def __bool__(self) -> bool: + return True + + def __call__( + self, population_mgr: PopulationManager, index: pd.Index[int] + ) -> pd.Series[Any] | pd.DataFrame: + return population_mgr.get_population(attributes=self.attributes, index=index) + + class ValueModifier(Resource): """A resource representing a modifier of a value pipeline.""" @@ -53,7 +133,7 @@ def __init__( self, pipeline: Pipeline, modifier: Callable[..., Any], - component: Component | Manager | None, + component: Component | Manager, ) -> None: mutator_name = self._get_modifier_name(modifier) mutator_index = len(pipeline.mutators) + 1 @@ -98,10 +178,20 @@ class Pipeline(Resource): need a source or to be configured. This might occur when writing generic components that create a set of pipeline modifiers for values that won't be used in the particular simulation. + + Notes + ----- + Pipelines are highy generic and can be used to calculate values of any type + through a simulation. *Most* pipelines are intended to calculate simulant + attributes; for those, use :class:`~vivarium.framework.values.pipeline.AttributePipeline`. """ def __init__(self, name: str, component: Component | None = None) -> None: - super().__init__("value", name, component=component) + super().__init__( + "attribute" if isinstance(self, AttributePipeline) else "value", + name, + component=component, + ) self.source: ValueSource = ValueSource(self, source=None, component=None) """The callable source of the value represented by the pipeline.""" @@ -154,7 +244,7 @@ def __call__(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) This is useful when the post-processor acts as some sort of final unit conversion (e.g. the rescale post processor). args, kwargs - Pipeline arguments. These should be the arguments to the + Pipeline arguments. These should be the arguments to the callable source of the pipeline. Returns @@ -174,7 +264,7 @@ def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> f"The dynamic value pipeline for {self.name} has no source. This likely means " f"you are attempting to modify a value that hasn't been created." ) - value = self.source(*args, **kwargs) + value = self.source(self.manager._population_mgr, *args, **kwargs) for mutator in self.mutators: value = self.combiner(value, mutator, *args, **kwargs) if self.post_processor and not skip_post_processor: @@ -191,9 +281,9 @@ def __hash__(self) -> int: return hash(self.name) def get_value_modifier( - self, modifier: Callable[..., Any], component: Component | Manager | None + self, modifier: Callable[..., Any], component: Component | Manager ) -> ValueModifier: - """Add a value modifier to the pipeline and return it. + """Adds a value modifier to the pipeline and returns it. Parameters ---------- @@ -208,21 +298,24 @@ def get_value_modifier( def set_attributes( self, - component: Component | None, - source: Callable[..., Any], + component: Component | Manager, + source: ValueSource, combiner: ValueCombiner, post_processor: PostProcessor | None, manager: ValuesManager, ) -> None: """ - Add a source, combiner, post-processor, and manager to a pipeline. + Adds a source, combiner, post-processor, and manager to a pipeline. Parameters ---------- component The component that creates the pipeline. source - The callable source of the value represented by the pipeline. + The source for the dynamic attribute pipeline. This can be a callable + or a list of column names. If a list of column names is provided, + the component that is registering this attribute producer must be the + one that creates those columns. combiner A strategy for combining the source and mutator values into the final value represented by the pipeline. @@ -232,8 +325,79 @@ def set_attributes( manager The simulation values manager. """ - self.component = component - self.source = ValueSource(self, source, component) + self._component = component + self.source = source self._combiner = combiner self.post_processor = post_processor self._manager = manager + + +class AttributePipeline(Pipeline): + """A type of value pipeline for calculating simulant attributes. + + An attribute pipeline is a specific type of :class:`~vivarium.framework.values.pipeline.Pipeline` + where the source and callable must take a pd.Index of integers and return a pd.Series + or pd.DataFrame that has that same index. + + """ + + @property + def is_simple(self) -> bool: + """Whether or not this ``AttributePipeline`` is simple, i.e. it has a list + of columns as its source and no modifiers or postprocessors.""" + return ( + isinstance(self.source, PrivateColumnValueSource) + and not self.mutators + and not self.post_processor + ) + + def __init__(self, name: str, component: Component | None = None) -> None: + super().__init__(name, component=component) + # Re-define the post-processor type to be more specific + self.post_processor: AttributePostProcessor | None = None # type: ignore[assignment] + """An optional final transformation to perform on the combined output of + the source and mutators.""" + + def __call__( # type: ignore[override] + self, index: pd.Index[int], skip_post_processor: bool = False + ) -> pd.Series[Any] | pd.DataFrame: + """Generates the attributes represented by this pipeline. + + Arguments + --------- + index + A pd.Index of integers representing the simulants for which we + want to calculate the attribute. + skip_post_processor + Whether we should invoke the post-processor on the combined + source and mutator output or return without post-processing. + + Returns + ------- + A pd.Series or pd.DataFrame of attributes for the simulants in `index`. + + Raises + ------ + DynamicValueError + If the pipeline is invoked without a source set. + """ + # NOTE: must pass index in as arg (NOT kwarg!) to match signature of parent Pipeline._call() + attribute = self._call(index, skip_post_processor=True) + if self.post_processor and not skip_post_processor: + attribute = self.post_processor(index, attribute, self.manager) + if not isinstance(attribute, (pd.Series, pd.DataFrame)): + raise DynamicValueError( + f"The dynamic attribute pipeline for {self.name} returned a {type(attribute)} " + "but pd.Series' or pd.DataFrames are expected for attribute pipelines." + ) + if not attribute.index.equals(index): + raise DynamicValueError( + f"The dynamic attribute pipeline for {self.name} returned a series " + "or dataframe with a different index than was passed in. " + f"\nReturned index: {attribute.index}" + f"\nExpected index: {index}" + ) + return attribute + + def __repr__(self) -> str: + return f"_AttributePipeline({self.name})" diff --git a/src/vivarium/framework/values/post_processors.py b/src/vivarium/framework/values/post_processors.py index 811372fea..7c3f27614 100644 --- a/src/vivarium/framework/values/post_processors.py +++ b/src/vivarium/framework/values/post_processors.py @@ -3,11 +3,12 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Protocol +import numpy as np import pandas as pd from vivarium.framework.utilities import from_yearly from vivarium.framework.values.exceptions import DynamicValueError -from vivarium.types import NumberLike +from vivarium.types import NumberLike, NumericArray if TYPE_CHECKING: from vivarium.framework.values.manager import ValuesManager @@ -18,19 +19,29 @@ def __call__(self, value: Any, manager: ValuesManager) -> Any: ... -def rescale_post_processor(value: NumberLike, manager: ValuesManager) -> NumberLike: +class AttributePostProcessor(Protocol): + """An attribute pipeline post-processor must return a pd.Series or pd.DataFrame.""" + + def __call__( + self, index: pd.Index[int], value: Any, manager: ValuesManager + ) -> pd.Series[Any] | pd.DataFrame: + ... + + +def rescale_post_processor( + index: pd.Index[int], value: NumberLike, manager: ValuesManager +) -> pd.Series[float] | pd.DataFrame: """Rescales annual rates to time-step appropriate rates. - This should only be used with a simulation using a - :class:`~vivarium.framework.time.DateTimeClock` or another implementation - of a clock that traffics in pandas date-time objects. + This should only be used with a simulation using a :class:`~vivarium.framework.time.manager.DateTimeClock` + or another implementation of a clock that traffics in pandas date-time objects. Parameters ---------- + index + The index of the population for which the attribute is being produced. value - Annual rates, either as a number or something we can broadcast - multiplication over like a :mod:`numpy` array or :mod:`pandas` - data frame. + Annual rates. manager The ValuesManager for this simulation. @@ -46,22 +57,35 @@ def rescale_post_processor(value: NumberLike, manager: ValuesManager) -> NumberL / (60 * 60 * 24 * 365.0), axis=0, ) - else: - time_step = manager.step_size() - if not isinstance(time_step, (pd.Timedelta, timedelta)): + time_step = manager.step_size() + if not isinstance(time_step, (pd.Timedelta, timedelta)): + raise DynamicValueError( + "The rescale post processor requires a time step size that is a " + "datetime timedelta or pandas Timedelta object." + ) + if isinstance(value, (int, float)): + return pd.Series(from_yearly(value, time_step), index=index) + elif isinstance(value, np.ndarray): + if value.ndim == 1: + return pd.Series(from_yearly(value, time_step), index=index) + elif value.ndim == 2: + return pd.DataFrame(from_yearly(value, time_step), index=index) + else: raise DynamicValueError( - "The rescale post processor requires a time step size that is a " - "datetime timedelta or pandas Timedelta object." + f"Numpy arrays with {value.ndim} dimensions are not supported. " + "Only 1D and 2D arrays are allowed." ) - return from_yearly(value, time_step) + else: + raise NotImplementedError -def union_post_processor(values: list[NumberLike], _: Any) -> NumberLike: +def union_post_processor( + index: pd.Index[int], value: list[NumberLike], manager: ValuesManager +) -> pd.Series[Any] | pd.DataFrame: """Computes a probability on the union of the sample spaces in the values. Given a list of values where each value is a probability of an independent - event, this post processor computes the probability of the union of the - events. + event, this post processor computes the probability of the union of the events. .. list-table:: :width: 100% @@ -79,23 +103,48 @@ def union_post_processor(values: list[NumberLike], _: Any) -> NumberLike: Parameters ---------- values - A list of independent proportions or probabilities, either - as numbers or as a something we can broadcast addition and - multiplication over. + A list of independent proportions or probabilities, either as numbers or + as a something we can broadcast addition and multiplication over. Returns ------- The probability over the union of the sample spaces represented by the original probabilities. """ - # if there is only one value, return the value - if len(values) == 1: - return values[0] - - # if there are multiple values, calculate the joint value - product: NumberLike = 1 - for v in values: - new_value = 1 - v - product = product * new_value - joint_value = 1 - product - return joint_value + if not isinstance(value, list): + raise DynamicValueError("The union post processor requires a list of values.") + + for v in value: + if not isinstance(v, (np.ndarray, pd.Series, pd.DataFrame, float, int)): + raise DynamicValueError( + "The union post processor only supports numeric types, " + f"pandas Series/DataFrames, and numpy ndarrays. " + f"You provided a value of type {type(v)}." + ) + + joint_value: NumericArray | pd.Series[float] | pd.DataFrame | float | int + if len(value) == 1: + # if there is only one value, return the value + joint_value = value[0] + else: + # if there are multiple values, calculate the joint value + product: NumberLike = 1 + for v in value: + new_value = 1 - v + product = product * new_value + joint_value = 1 - product + + if isinstance(joint_value, np.ndarray): + if joint_value.ndim == 1: + return pd.Series(joint_value, index=index) + elif joint_value.ndim == 2: + return pd.DataFrame(joint_value, index=index) + else: + raise DynamicValueError( + f"Numpy arrays with {joint_value.ndim} dimensions are not supported. " + "Only 1D and 2D arrays are allowed." + ) + elif isinstance(joint_value, (float, int)): + return pd.Series(joint_value, index=index) + else: + return joint_value diff --git a/src/vivarium/interface/interactive.py b/src/vivarium/interface/interactive.py index 7fd95574a..b3ebe9ebd 100644 --- a/src/vivarium/interface/interactive.py +++ b/src/vivarium/interface/interactive.py @@ -16,7 +16,7 @@ from collections.abc import Callable from math import ceil -from typing import Any +from typing import Any, Literal, overload import pandas as pd @@ -164,29 +164,55 @@ def take_steps( for _ in range(number_of_steps): self.step(step_size) - def get_population(self, untracked: bool = False) -> pd.DataFrame: + @overload + def get_population(self, attributes: str | None = None) -> pd.Series[Any] | pd.DataFrame: + ... + + @overload + def get_population(self, attributes: list[str] | tuple[str, ...] = ...) -> pd.DataFrame: + ... + + def get_population( + self, attributes: str | list[str] | tuple[str, ...] | None = None + ) -> pd.Series[Any] | pd.DataFrame: """Get a copy of the population state table. Parameters ---------- - untracked - Whether or not to return simulants who are no longer being tracked - by the simulation. + attributes + The attribute pipelines to include in the returned table. If None, all + attributes are included. Returns ------- - The population state table. + The current state of requested population attributes. """ - return self._population.get_population(untracked) + returned_attributes: list[str] | tuple[str, ...] | Literal["all"] = "all" + squeeze: Literal[True, False] = True + if isinstance(attributes, str): + returned_attributes = [attributes] + elif attributes is not None: + squeeze = False + returned_attributes = list(attributes) + return self._population.get_population( + attributes=returned_attributes, squeeze=squeeze + ) + + def get_attribute_names(self) -> list[str]: + """List all attributes in the population state table.""" + return self._population.get_all_attribute_names() def list_values(self) -> list[str]: - """List the names of all pipelines in the simulation.""" - return list(self._values.keys()) + """List the names of all value pipelines in the simulation.""" + return list(self._values.get_value_pipelines().keys()) def get_value(self, value_pipeline_name: str) -> Pipeline: """Get the value pipeline associated with the given name.""" if value_pipeline_name not in self.list_values(): - raise ValueError(f"No value pipeline '{value_pipeline_name}' registered.") + raise ValueError( + f"No value pipeline '{value_pipeline_name}' registered. " + "Are you looking for an attribute pipeline?" + ) return self._values.get_value(value_pipeline_name) def list_events(self) -> list[str]: diff --git a/src/vivarium/manager.py b/src/vivarium/manager.py index 17b073aa0..2b7e925c1 100644 --- a/src/vivarium/manager.py +++ b/src/vivarium/manager.py @@ -14,10 +14,10 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.population import SimulantData class Manager(ABC): + CONFIGURATION_DEFAULTS: dict[str, Any] = {} """A dictionary containing the defaults for any configurations managed by this manager. An empty dictionary indicates no managed configurations. @@ -43,11 +43,6 @@ def configuration_defaults(self) -> dict[str, Any]: """ return self.CONFIGURATION_DEFAULTS - @property - def columns_created(self) -> list[str]: - """Provides names of columns created by the manager.""" - return [] - ##################### # Lifecycle methods # ##################### @@ -57,8 +52,7 @@ def setup(self, builder: Builder) -> None: lifecycle phase. This method is intended to be overridden by subclasses to perform any - necessary setup operations specific to the manager. By default, it - does nothing. + necessary setup operations specific to the manager. Parameters ---------- @@ -67,21 +61,6 @@ def setup(self, builder: Builder) -> None: """ pass - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - """ - Method that vivarium will run during simulant initialization. - - This method is intended to be overridden by subclasses if there are - operations they need to perform specifically during the simulant - initialization phase. - - Parameters - ---------- - pop_data : SimulantData - The data associated with the simulants being initialized. - """ - pass - class Interface: """An interface class to be used to manage different systems for a simulation in ``vivarium``""" diff --git a/src/vivarium/testing_utilities.py b/src/vivarium/testing_utilities.py index 5cf50743f..cb4a566aa 100644 --- a/src/vivarium/testing_utilities.py +++ b/src/vivarium/testing_utilities.py @@ -35,17 +35,18 @@ class NonCRNTestPopulation(Component): }, } - @property - def columns_created(self) -> list[str]: - return ["age", "sex", "location", "alive", "entrance_time", "exit_time"] - def setup(self, builder: Builder) -> None: self.config = builder.configuration self.randomness = builder.randomness.get_stream( "population_age_fuzz", initializes_crn_attributes=True ) + builder.population.register_initializer( + initializer=self.initialize_population, + columns=["age", "sex", "location", "is_alive", "entrance_time", "exit_time"], + required_resources=[self.randomness], + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_population(self, pop_data: SimulantData) -> None: age_start = pop_data.user_data.get( "age_start", self.config.population.initialization_age_min ) @@ -66,7 +67,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: self.population_view.update(population) def on_time_step(self, event: Event) -> None: - population = self.population_view.get(event.index, query="alive == 'alive'") + population = self.population_view.get_attributes( + event.index, ["is_alive", "age"], query="is_alive == True" + ) # This component won't work if event.step_size is an int if not isinstance(event.step_size, int): population["age"] += event.step_size / pd.Timedelta(days=365) @@ -75,13 +78,21 @@ def on_time_step(self, event: Event) -> None: class TestPopulation(NonCRNTestPopulation): def setup(self, builder: Builder) -> None: - super().setup(builder) + self.config = builder.configuration + self.randomness = builder.randomness.get_stream( + "population_age_fuzz", initializes_crn_attributes=True + ) self.age_randomness = builder.randomness.get_stream( "age_initialization", initializes_crn_attributes=True ) self.register = builder.randomness.register_simulants + builder.population.register_initializer( + initializer=self.initialize_population, + columns=["age", "sex", "location", "is_alive", "entrance_time", "exit_time"], + required_resources=[self.randomness, self.age_randomness], + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_population(self, pop_data: SimulantData) -> None: age_start = pop_data.user_data.get( "age_start", self.config.population.initialization_age_min ) @@ -125,7 +136,7 @@ def _build_population( "sex": randomness_stream.choice( index, ["Male", "Female"], additional_key="sex_choice" ), - "alive": pd.Series("alive", index=index), + "is_alive": pd.Series(True, index=index), "location": location, "exit_time": pd.NaT, }, @@ -158,7 +169,7 @@ def _non_crn_build_population( "sex": randomness_stream.choice( index, ["Male", "Female"], additional_key="sex_choice" ), - "alive": pd.Series("alive", index=index), + "is_alive": pd.Series(True, index=index), "location": location, "entrance_time": creation_time, "exit_time": pd.NaT, @@ -253,12 +264,22 @@ def get_randomness( clock: Callable[[], pd.Timestamp | datetime | int] = lambda: pd.Timestamp(1990, 7, 2), seed: int = 12345, initializes_crn_attributes: bool = False, + component: Component | None = None, ) -> RandomnessStream: + if component is None: + # Create a simple mock component for testing + class _MockComponent(Component): + @property + def name(self) -> str: + return "mock_component" + + component = _MockComponent() return RandomnessStream( key, clock, seed=seed, index_map=IndexMap(), + component=component, initializes_crn_attributes=initializes_crn_attributes, ) diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 9f5ac83f5..27bdd3e89 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -19,12 +19,16 @@ ClockStepSize = Timedelta | int ScalarValue = Numeric | Timedelta | Time +DataFrameMapping = Mapping[str, list[ScalarValue] | list[str]] LookupTableData = ( + # FIXME: this is not correct - a LookupTable can return a str, so it should be + # possible to build a LookupTable from str data as well, but adding a string here + # will break some assumptions of Component.get_data() ScalarValue | pd.DataFrame | list[ScalarValue] | tuple[ScalarValue, ...] - | Mapping[str, list[ScalarValue] | list[str]] + | DataFrameMapping ) DataInput = LookupTableData | str | Callable[["Builder"], LookupTableData] diff --git a/tests/examples/test_disease_model.py b/tests/examples/test_disease_model.py index 5899af4cf..db7127ebe 100644 --- a/tests/examples/test_disease_model.py +++ b/tests/examples/test_disease_model.py @@ -14,7 +14,9 @@ def test_disease_model(fuzzy_checker: FuzzyChecker, disease_model_spec: Path) -> { "configuration": { "mortality": { - "mortality_rate": 20.0, + "data_sources": { + "mortality_rate": 20.0, + }, }, "lower_respiratory_infections": { "incidence_rate": 25.0, @@ -29,18 +31,34 @@ def test_disease_model(fuzzy_checker: FuzzyChecker, disease_model_spec: Path) -> pop = simulation.get_population() expected_columns = { - "tracked", - "alive", + "is_alive", + "previous_alive", "age", "sex", "entrance_time", + "mortality_rate", "lower_respiratory_infections", + "lower_respiratory_infections.cause_specific_mortality_rate", "child_wasting_propensity", + "child_wasting.proportion_exposed", + "child_wasting.base_proportion_exposed", + "child_wasting.exposure", + "susceptible_to_lower_respiratory_infections.initialization_weights", + "susceptible_to_lower_respiratory_infections.excess_mortality_rate", + "susceptible_to_lower_respiratory_infections.excess_mortality_rate.population_attributable_fraction", + "infected_with_lower_respiratory_infections.initialization_weights", + "infected_with_lower_respiratory_infections.incidence_rate.population_attributable_fraction", + "infected_with_lower_respiratory_infections.remission_rate", + "infected_with_lower_respiratory_infections.excess_mortality_rate.population_attributable_fraction", + "infected_with_lower_respiratory_infections.excess_mortality_rate", + "infected_with_lower_respiratory_infections.incidence_rate", + "infected_with_lower_respiratory_infections.remission_rate.population_attributable_fraction", + "effect_of_child_wasting_on_infected_with_lower_respiratory_infections.incidence_rate.relative_risk", + "sqlns.effect_size", } assert set(pop.columns) == expected_columns assert len(pop) == 100_000 - assert np.all(pop["tracked"] == True) - assert np.all(pop["alive"] == "alive") + assert np.all(pop["is_alive"] == True) assert np.all((pop["age"] >= 0) & (pop["age"] <= 5)) assert np.all(pop["entrance_time"] == datetime(2021, 12, 31, 12)) @@ -57,8 +75,8 @@ def test_disease_model(fuzzy_checker: FuzzyChecker, disease_model_spec: Path) -> assert np.all((pop["child_wasting_propensity"] >= 0) & (pop["child_wasting_propensity"] <= 1)) simulation.step() - pop = simulation.get_population() - is_alive = pop["alive"] == "alive" + pop = simulation.get_population(["is_alive", "lower_respiratory_infections"]) + is_alive = pop["is_alive"] == True alive_target = from_yearly(20, timedelta(days=0.5)) assert isinstance(alive_target, float) diff --git a/tests/framework/artifact/test_artifact.py b/tests/framework/artifact/test_artifact.py index 1bb95ce31..daad723ef 100644 --- a/tests/framework/artifact/test_artifact.py +++ b/tests/framework/artifact/test_artifact.py @@ -102,7 +102,7 @@ def test_artifact_creation( assert a.filter_terms is None assert a._cache == {} assert a.keys == keys_mock - hdf_mock.load.called_once_with("metadata.keyspace") + hdf_mock.load.assert_called_once_with(artifact_path, "metadata.keyspace", None, None) a = Artifact(artifact_path, filter_terms) @@ -110,7 +110,7 @@ def test_artifact_creation( assert a.filter_terms == filter_terms assert a._cache == {} assert a.keys == keys_mock - hdf_mock.load.called_once_with("metadata.keyspace") + hdf_mock.load.assert_called_with(artifact_path, "metadata.keyspace", None, None) def test_artifact_load_missing_key(hdf_mock: MagicMock, artifact_path: Path) -> None: @@ -118,7 +118,7 @@ def test_artifact_load_missing_key(hdf_mock: MagicMock, artifact_path: Path) -> key = "not.a_real.key" a = Artifact(artifact_path, filter_terms) - hdf_mock.load.called_once_with("metadata.keyspace") + hdf_mock.load.assert_called_once_with(artifact_path, "metadata.keyspace", None, None) hdf_mock.load.reset_mock() with pytest.raises(ArtifactException) as err_info: a.load(key) @@ -140,7 +140,7 @@ def test_artifact_load_key_has_no_data(hdf_mock: MagicMock, artifact_path: Path) assert f"Data for {key} is not available. Check your model specification." == str( err_info.value ) - assert hdf_mock.load.called_once_with(artifact_path, key, filter_terms) + hdf_mock.load.assert_called_with(artifact_path, key, filter_terms, ["draw_10", "value"]) assert a._cache == {} @@ -163,7 +163,9 @@ def test_artifact_load( result = a.load(key) - assert hdf_mock.load.called_once_with(artifact_path, key, filter_terms) + hdf_mock.load.assert_called_with( + artifact_path, key, filter_terms, ["draw_10", "value"] + ) assert key in a._cache assert a._cache[key] == "data" assert result == "data" @@ -205,7 +207,6 @@ def test_artifact_write_duplicate_key( assert f"{key} already in artifact." == str(err_info.value) assert key in art assert key not in art._cache - hdf_mock.write.called_once_with(artifact_path, "metadata.keyspace", ["metadata.keyspace"]) hdf_mock.remove.assert_not_called() assert art.keys == initial_keys @@ -224,7 +225,6 @@ def test_artifact_write_no_data(hdf_mock: MagicMock, artifact_path: Path) -> Non assert key not in a assert key not in a._cache - hdf_mock.write.called_once_with(artifact_path, "metadata.keyspace", ["metadata.keyspace"]) hdf_mock.remove.assert_not_called() assert a.keys == initial_keys @@ -272,7 +272,9 @@ def test_artifact_write_and_load_with_different_key_types( a.load(load_key) - assert hdf_mock.load.called_once_with(artifact_path, load_key, filter_terms) + hdf_mock.load.assert_called_with( + artifact_path, load_key, filter_terms, ["draw_10", "value"] + ) assert load_key in a._cache hdf_mock.load.reset_mock() @@ -298,7 +300,7 @@ def test_artifact_write_and_reopen_then_load_with_entity_key( a_again.load(key) - assert hdf_mock.load.called_once_with(artifact_path, key, filter_terms) + hdf_mock.load.assert_called_with(artifact_path, key, filter_terms, ["draw_10", "value"]) assert key in a_again._cache hdf_mock.load.reset_mock() @@ -321,7 +323,7 @@ def test_remove_bad_key(hdf_mock: MagicMock, artifact_path: Path) -> None: assert key not in a assert key not in a._cache hdf_mock.remove.assert_not_called() - hdf_mock.write.called_once_with(artifact_path, "metadata.keyspace", ["metadata.keyspace"]) + hdf_mock.write.assert_not_called() assert a.keys == initial_keys @@ -445,9 +447,8 @@ def test_replace_nonexistent_key(hdf_mock: MagicMock, artifact_path: Path) -> No key = "new.key" a = Artifact(artifact_path, filter_terms=filter_terms) - hdf_mock.called_once_with(key) + hdf_mock.load.assert_called_once_with(artifact_path, "metadata.keyspace", None, None) assert key not in a - hdf_mock.reset_mock() with pytest.raises(ArtifactException): a.replace(key, "new_data") diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index ece6c517f..df563c6a2 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -7,14 +7,10 @@ from layered_config_tree.exceptions import ConfigurationError from tests.helpers import ( - AllColumnsRequirer, ColumnCreator, - ColumnCreatorAndAllRequirer, ColumnCreatorAndRequirer, - ColumnRequirer, CustomPriorities, DefaultPriorities, - FilteredPopulationView, LookupCreator, NoPopulationView, OrderedColumnsLookupCreator, @@ -25,7 +21,7 @@ from vivarium import Artifact, InteractiveContext from vivarium.framework.engine import Builder from vivarium.framework.lifecycle import lifecycle_states -from vivarium.framework.lookup.table import CategoricalTable, InterpolatedTable, ScalarTable +from vivarium.framework.lookup.table import LookupTable from vivarium.framework.population import PopulationError @@ -86,39 +82,9 @@ def test_component_that_creates_columns_population_view() -> None: # Assert population view is set and has the correct columns assert component.population_view is not None - assert set(component.population_view.columns) == set(component.columns_created) - - -def test_component_that_requires_columns_population_view() -> None: - component = ColumnRequirer() - InteractiveContext(components=[ColumnCreator(), component]) - - # Assert population view is set and has the correct columns - assert component.population_view is not None - assert set(component.population_view.columns) == set(component.columns_required) - - -def test_component_that_creates_and_requires_columns_population_view() -> None: - component = ColumnCreatorAndRequirer() - InteractiveContext(components=[ColumnCreator(), component]) - - # Assert population view is set and has the correct columns - expected_columns = component.columns_required + component.columns_created - - assert component.population_view is not None - assert set(component.population_view.columns) == set(expected_columns) - - -def test_component_that_creates_column_and_requires_all_columns_population_view() -> None: - component = ColumnCreatorAndAllRequirer() - simulation = InteractiveContext(components=[ColumnCreator(), component]) - population = simulation.get_population() - - # Assert population view is set and has the correct columns - expected_columns = population.columns - - assert component.population_view is not None - assert set(component.population_view.columns) == set(expected_columns) + assert set(component.population_view.private_columns) == set( + ["test_column_1", "test_column_2", "test_column_3"] + ) def test_component_with_initialization_requirements() -> None: @@ -127,50 +93,30 @@ def test_component_with_initialization_requirements() -> None: ) # Assert required resources have been recorded by the ResourceManager - component_dependencies_list = [ - r.dependencies + component_required_resources_list = [ + r.required_resources # get all resources in the dependency graph for r in simulation._resource.sorted_nodes # if the resource is an initializer if r.is_initialized # its initializer is an instance method and hasattr(r.initializer, "__self__") + # and is not None + and r.initializer is not None # and is a method of ColumnCreatorAndRequirer and isinstance(r.initializer.__self__, ColumnCreatorAndRequirer) ] - assert len(component_dependencies_list) == 1 - component_dependencies = component_dependencies_list[0] - - assert "value.pipeline_1" in component_dependencies - assert "column.test_column_2" in component_dependencies - assert "stream.stream_1" in component_dependencies - - -def test_component_that_requires_all_columns_population_view() -> None: - component = AllColumnsRequirer() - simulation = InteractiveContext( - components=[ColumnCreator(), ColumnCreatorAndRequirer(), component] - ) - population = simulation.get_population() - - # Assert population view is set and has the correct columns - expected_columns = population.columns - - assert component.population_view is not None - assert set(component.population_view.columns) == set(expected_columns) + assert len(component_required_resources_list) == 1 + component_required_resources = component_required_resources_list[0] + assert "value.pipeline_1" in component_required_resources + assert "attribute.test_column_2" in component_required_resources + assert "stream.stream_1" in component_required_resources -def test_component_with_filtered_population_view() -> None: - component = FilteredPopulationView() - InteractiveContext(components=[ColumnCreator(), component]) - # Assert population view is being filtered using the desired query - assert component.population_view.query == "test_column_1 == 5 and tracked == True" - - -def test_component_with_no_population_view() -> None: +def test_component_population_view_raises_before_setup() -> None: component = NoPopulationView() - InteractiveContext(components=[ColumnCreator(), component]) + sim = InteractiveContext(components=[ColumnCreator(), component], setup=False) # Assert population view is not set assert component._population_view is None @@ -179,16 +125,20 @@ def test_component_with_no_population_view() -> None: with pytest.raises(PopulationError, match=f"'{component.name}' does not have access"): _ = component.population_view + sim.setup() + assert component._population_view is not None + def test_component_initializer_is_not_registered_if_not_defined() -> None: component = NoPopulationView() simulation = InteractiveContext(components=[component]) - # Assert that simulant initializer has been registered - assert ( - component.on_initialize_simulants - not in simulation._resource.get_population_initializers() - ) + # Assert that component did not register an initializer + initializers = [ + initializer.__repr__() + for initializer in simulation._resource.get_population_initializers() + ] + assert "NoPopulationView" not in ";".join(initializers) def test_component_initializer_is_registered_and_called_if_defined() -> None: @@ -198,15 +148,15 @@ def test_component_initializer_is_registered_and_called_if_defined() -> None: config = {"population": {"population_size": pop_size}} simulation = InteractiveContext(components=[component], configuration=config) - population = simulation.get_population() - + population = simulation.get_population(component.private_columns) + assert isinstance(population, pd.DataFrame) # Assert that simulant initializer has been registered assert ( - component.on_initialize_simulants + component.initialize_test_columns in simulation._resource.get_population_initializers() ) # and that created columns are correctly initialized - pd.testing.assert_frame_equal(population[component.columns_created], expected_pop_view) + pd.testing.assert_frame_equal(population[component.private_columns], expected_pop_view) def test_listeners_are_not_registered_if_not_defined() -> None: @@ -276,23 +226,22 @@ def test_listeners_are_registered_at_custom_priorities() -> None: def test_component_configuration_gets_set() -> None: without_config = ColumnCreator() - with_config = ColumnRequirer() + with_config = ColumnCreatorAndRequirer() column_requirer_config = { - "column_requirer": {"test_configuration": "some_config_value"}, + "column_creator_and_requirer": {"test_configuration": "some_config_value"}, } sim = InteractiveContext(components=[with_config, without_config], setup=False) sim.configuration.update(column_requirer_config) - - assert without_config.configuration is None - assert with_config.configuration is None - sim.setup() - assert without_config.configuration is None + assert without_config.configuration.to_dict() == {"data_sources": {}} assert with_config.configuration is not None - assert with_config.configuration.to_dict() == column_requirer_config["column_requirer"] + assert ( + with_config.configuration.to_dict() + == column_requirer_config["column_creator_and_requirer"] + ) def test_component_lookup_table_configuration(hdf_file_path: Path) -> None: @@ -338,41 +287,34 @@ def test_component_lookup_table_configuration(hdf_file_path: Path) -> None: ) sim.setup() - # Assertions for specific lookup tables - expected_tables = { - "favorite_team", - "favorite_color", - "favorite_number", - "favorite_scalar", - "favorite_list", - "baking_time", - "cooling_time", - } - assert expected_tables == set(component.lookup_tables.keys()) - # check that tables have correct type - assert isinstance(component.lookup_tables["favorite_team"], CategoricalTable) - assert isinstance(component.lookup_tables["favorite_color"], InterpolatedTable) - assert isinstance(component.lookup_tables["favorite_scalar"], ScalarTable) - assert isinstance(component.lookup_tables["favorite_list"], ScalarTable) - assert isinstance(component.lookup_tables["baking_time"], ScalarTable) - assert isinstance(component.lookup_tables["cooling_time"], CategoricalTable) + # Check that lookup table backing data is of the correct type + assert isinstance(component.favorite_team_table.data, pd.DataFrame) + assert isinstance(component.favorite_color_table.data, pd.DataFrame) + assert isinstance(component.favorite_number_table.data, pd.DataFrame) + assert isinstance(component.favorite_scalar_table.data, float) + assert isinstance(component.favorite_list_table.data, list) + assert isinstance(component.cooling_time_table.data, pd.DataFrame) # Check for correct columns in lookup tables - assert component.lookup_tables["favorite_team"].key_columns == ["test_column_1"] - assert not component.lookup_tables["favorite_team"].parameter_columns - assert component.lookup_tables["favorite_color"].key_columns == ["test_column_2"] - assert component.lookup_tables["favorite_color"].parameter_columns == ["test_column_3"] - assert component.lookup_tables["favorite_list"].value_columns == ["column_1", "column_2"] - assert component.lookup_tables["cooling_time"].key_columns == ["test_column_1"] - assert not component.lookup_tables["cooling_time"].parameter_columns + assert component.favorite_team_table.key_columns == ["test_column_1"] + assert not component.favorite_team_table.parameter_columns + assert component.favorite_color_table.key_columns == ["test_column_2"] + assert component.favorite_color_table.parameter_columns == ["test_column_3"] + assert component.favorite_number_table.key_columns == [] + assert component.favorite_number_table.parameter_columns == ["test_column_3"] + assert component.favorite_scalar_table.value_columns == ["scalar"] + assert component.favorite_list_table.value_columns == ["column_1", "column_2"] + assert component.cooling_time_table.key_columns == ["test_column_1"] + assert not component.cooling_time_table.parameter_columns # Check for correct data in lookup tables - assert component.lookup_tables["favorite_team"].data.equals(favorite_team.reset_index()) - assert component.lookup_tables["favorite_color"].data.equals(favorite_color.reset_index()) - assert component.lookup_tables["favorite_scalar"].data == 0.4 - assert component.lookup_tables["favorite_list"].data == [9, 4] - assert component.lookup_tables["baking_time"].data == 0.5 - assert component.lookup_tables["cooling_time"].data.equals(cooling_time.reset_index()) + assert component.favorite_team_table.data.equals(favorite_team.reset_index()) + assert component.favorite_color_table.data.equals(favorite_color.reset_index()) + assert component.favorite_number_table.data.equals(favorite_number.reset_index()) + assert component.favorite_scalar_table.data == 0.4 + assert component.favorite_list_table.data == [9, 4] + assert component.baking_time_table.data == 0.5 + assert component.cooling_time_table.data.equals(cooling_time.reset_index()) @pytest.mark.parametrize( @@ -422,10 +364,7 @@ def test_failing_component_lookup_table_configurations( sim.setup() -@pytest.mark.parametrize( - "table", ["ordered_columns_categorical", "ordered_columns_interpolated"] -) -def test_value_column_order_is_maintained(table: str) -> None: +def test_value_column_order_is_maintained() -> None: """Tests that the order of value columns is maintained when creating a LookupTable. Notes @@ -438,9 +377,20 @@ def test_value_column_order_is_maintained(table: str) -> None: """ component = OrderedColumnsLookupCreator() sim = InteractiveContext(components=[component]) - lookup_table = component.lookup_tables[table] - assert isinstance( - lookup_table, CategoricalTable if "categorical" in table else InterpolatedTable - ) - data = lookup_table(sim.get_population().index) - assert list(data.columns) == ["one", "two", "three", "four", "five", "six", "seven"] + columns = list(component.categorical_table(sim.get_population_index()).columns) + assert columns == OrderedColumnsLookupCreator.VALUE_COLUMNS + columns = list(component.interpolated_table(sim.get_population_index()).columns) + assert columns == OrderedColumnsLookupCreator.VALUE_COLUMNS + + +def test_attribute_pipelines_from_private_columns() -> None: + idx = pd.Index([4, 8, 15, 16, 23, 42]) + component = ColumnCreator() + sim = InteractiveContext(components=[component]) + for column in component.private_columns: + pipeline = sim._builder.value.get_attribute_pipelines()()[column] + assert pipeline.name == column + assert pipeline.mutators == [] + attributes = pipeline(idx) + assert attributes.equals(pd.Series([i % 3 for i in idx], index=idx)) + assert attributes.name == column diff --git a/tests/framework/components/test_manager.py b/tests/framework/components/test_manager.py index 84c8c3e00..3e668208e 100644 --- a/tests/framework/components/test_manager.py +++ b/tests/framework/components/test_manager.py @@ -1,7 +1,9 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import pytest -from pytest_mock import MockerFixture +from layered_config_tree import LayeredConfigTree from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent, MockManager from vivarium import Component @@ -12,8 +14,13 @@ OrderedComponentSet, ) from vivarium.framework.configuration import build_simulation_configuration +from vivarium.framework.engine import Builder +from vivarium.interface import InteractiveContext from vivarium.manager import Manager +if TYPE_CHECKING: + from pytest_mock import MockerFixture + def test_component_set_add() -> None: component_list = OrderedComponentSet() @@ -136,7 +143,7 @@ def test_manager_get_file() -> None: def test_flatten_simple() -> None: components = [MockComponentA(name=str(i)) for i in range(10)] - assert ComponentManager._flatten(components) == components + assert ComponentManager()._flatten_subcomponents(components) == components def test_flatten_with_lists() -> None: @@ -144,7 +151,7 @@ def test_flatten_with_lists() -> None: for i in range(5): for j in range(5): components.append(MockComponentA(name=str(5 * i + j))) - out = ComponentManager._flatten(components) + out = ComponentManager()._flatten_subcomponents(components) expected = [MockComponentA(name=str(i)) for i in range(25)] assert out == expected @@ -154,7 +161,7 @@ def test_flatten_with_sub_components() -> None: for i in range(5): name, *args = [str(5 * i + j) for j in range(5)] components.append(MockComponentB(*args, name=name)) - out = ComponentManager._flatten(components) + out = ComponentManager()._flatten_subcomponents(components) expected = [MockComponentB(name=str(i)) for i in range(25)] assert out == expected @@ -170,14 +177,10 @@ def nest(start: int, depth: int) -> Component: components: list[Component] = [] for i in range(5): components.append(nest(5 * i, 5)) - out = ComponentManager._flatten(components) + out = ComponentManager()._flatten_subcomponents(components) expected = [MockComponentA(name=str(i)) for i in range(25)] assert out == expected - # Lists with nested subcomponents - out = ComponentManager._flatten([components, components]) - assert out == 2 * expected - def test_setup_components(mocker: MockerFixture) -> None: builder = mocker.Mock() @@ -186,8 +189,9 @@ def test_setup_components(mocker: MockerFixture) -> None: mocker.patch("vivarium.framework.results.observer.Observer.get_configuration") mock_a = MockComponentA("test_a") mock_b = MockComponentB("test_b") - components = OrderedComponentSet(mock_a, mock_b) - ComponentManager._setup_components(builder, components) + manager = ComponentManager() + manager._components.update([mock_a, mock_b]) + manager.setup_components(builder) assert mock_a.builder_used_for_setup is None # class has no setup method assert mock_b.builder_used_for_setup is builder @@ -289,7 +293,7 @@ def test_component_manager_add_components(components: list[Component]) -> None: cm = ComponentManager() cm._configuration = config cm.add_components(components) - assert cm._components == OrderedComponentSet(*ComponentManager._flatten(components)) + assert cm._components == OrderedComponentSet(*cm._flatten_subcomponents(components)) @pytest.mark.parametrize( @@ -320,3 +324,27 @@ def test_component_manager_add_components_duplicated(components: list[Component] match="is attempting to set the configuration value mock_component_a, but it has already been set by mock_component_a", ): cm.add_components(components) + + +def test_get_current_component_outside_setup() -> None: + cm = ComponentManager() + with pytest.raises(VivariumError, match="No component is currently being set up"): + _ = cm.get_current_component() + + +def test_setting_current_component() -> None: + class TestComponent(Component): + def __init__(self, some_name: str) -> None: + super().__init__() + self.some_name = some_name + + def setup(self, builder: Builder) -> None: + self.component = builder.components.get_current_component() + + component1 = TestComponent("component1") + component2 = TestComponent("component2") + + InteractiveContext(components=[component1, component2]) + + assert component1.component == component1 + assert component2.component == component2 diff --git a/tests/framework/lookup/test_lookup.py b/tests/framework/lookup/test_lookup.py index e23dc3821..ff4c572b3 100644 --- a/tests/framework/lookup/test_lookup.py +++ b/tests/framework/lookup/test_lookup.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import itertools -from typing import Sequence +from typing import Any import numpy as np import pandas as pd @@ -7,15 +9,60 @@ from layered_config_tree import LayeredConfigTree from pytest_mock import MockerFixture -from vivarium import InteractiveContext -from vivarium.framework.lookup import ( - LookupTableInterface, - LookupTableManager, - validate_build_table_parameters, -) -from vivarium.framework.lookup.table import InterpolatedTable +from tests.helpers import LookupCreator +from vivarium import Component, InteractiveContext +from vivarium.framework.lifecycle import lifecycle_states +from vivarium.framework.lookup.manager import LookupTableManager +from vivarium.framework.lookup.table import LookupTable from vivarium.testing_utilities import TestPopulation, build_table -from vivarium.types import LookupTableData +from vivarium.types import LookupTableData, ScalarValue + + +def test_build_table_calls_methods_correctly(mocker: MockerFixture) -> None: + """Test that build_table orchestrates calls to helper methods correctly.""" + # Setup + manager = LookupTableManager() + test_component = Component() + test_data = pd.DataFrame({"a": [1, 2, 3], "value": [10, 20, 30]}) + test_name = "test_table" + test_value_columns = "value" + + # Set up a mock LookupTable + mock_table = mocker.Mock() + mock_table.required_resources = ["resource1", "resource2"] + mock_table.call = mocker.Mock() + + # Inject mocks into the manager + manager._get_current_component = mocker.Mock(return_value=test_component) + manager._build_table = mocker.Mock(return_value=mock_table) # type: ignore[method-assign] + manager._add_resources = mocker.Mock() + manager._add_constraint = mocker.Mock() + + # Execute + result = manager.build_table(test_data, test_name, test_value_columns) + + # Assert _build_table was called with correct arguments + manager._build_table.assert_called_once_with( # type: ignore[attr-defined] + test_component, test_data, test_name, test_value_columns + ) + + # Assert _add_resources was called with correct arguments + manager._add_resources.assert_called_once_with( # type: ignore[attr-defined] + test_component, mock_table, ["resource1", "resource2"] + ) + + # Assert correct constraint have been set on table.call + manager._add_constraint.assert_called_once() # type: ignore[attr-defined] + call_args = manager._add_constraint.call_args # type: ignore[attr-defined] + assert call_args[0][0] == mock_table._call + assert call_args[1]["restrict_during"] == [ + lifecycle_states.INITIALIZATION, + lifecycle_states.SETUP, + lifecycle_states.POST_SETUP, + ] + + # Assert the table is returned + assert result == mock_table @pytest.mark.skip(reason="only order 0 interpolation currently supported") @@ -41,55 +88,38 @@ def test_interpolated_tables(base_config: LayeredConfigTree) -> None: {"population": {"population_size": 10000}, "interpolation": {"order": 1}} ) # the results we're checking later assume interp order 1 - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - years = manager.build_table( - years_df, - key_columns=("sex",), - parameter_columns=( - "age", - "year", - ), - value_columns=(), - ) - ages = manager.build_table( - ages_df, - key_columns=("sex",), - parameter_columns=( - "age", - "year", - ), - value_columns=(), - ) - one_d_age = manager.build_table( - one_d_age_df, key_columns=("sex",), parameter_columns=("age",), value_columns=() - ) + years = manager.build_table(years_df, "", value_columns=()) + age_table = manager.build_table(ages_df, "", value_columns=()) + one_d_age = manager.build_table(one_d_age_df, "", value_columns=()) - pop = simulation.get_population(untracked=True) - result_years = years(pop.index) - result_ages = ages(pop.index) - result_ages_1d = one_d_age(pop.index) + ages = simulation.get_population("age") + result_years = years(ages.index) + result_ages = age_table(ages.index) + result_ages_1d = one_d_age(ages.index) fractional_year = simulation._clock.time.year # type: ignore [union-attr] fractional_year += simulation._clock.time.timetuple().tm_yday / 365.25 # type: ignore [union-attr] assert np.allclose(result_years, fractional_year) - assert np.allclose(result_ages, pop.age) - assert np.allclose(result_ages_1d, pop.age) + assert np.allclose(result_ages, ages) + assert np.allclose(result_ages_1d, ages) simulation._clock._clock_time += pd.Timedelta(30.5 * 125, unit="D") # type: ignore [operator] - simulation._population._population.age += 125 / 12 # type: ignore [union-attr] + simulation._population._private_columns.age += 125 / 12 # type: ignore [union-attr] - result_years = years(pop.index) - result_ages = ages(pop.index) - result_ages_1d = one_d_age(pop.index) + result_years = years(ages.index) + result_ages = age_table(ages.index) + result_ages_1d = one_d_age(ages.index) fractional_year = simulation._clock.time.year # type: ignore [union-attr] fractional_year += simulation._clock.time.timetuple().tm_yday / 365.25 # type: ignore [union-attr] assert np.allclose(result_years, fractional_year) - assert np.allclose(result_ages, pop.age) - assert np.allclose(result_ages_1d, pop.age) + assert np.allclose(result_ages, ages) + assert np.allclose(result_ages_1d, ages) @pytest.mark.skip(reason="only order 0 interpolation currently supported") @@ -111,19 +141,12 @@ def test_interpolated_tables_without_uninterpolated_columns( {"population": {"population_size": 10000}, "interpolation": {"order": 1}} ) # the results we're checking later assume interp order 1 - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - years = manager.build_table( - years_df, - key_columns=(), - parameter_columns=( - "year", - "age", - ), - value_columns=(), - ) + years = manager.build_table(years_df, "", value_columns=()) - result_years = years(simulation.get_population().index) + result_years = years(simulation.get_population_index()) fractional_year = simulation._clock.time.year # type: ignore [union-attr] fractional_year += simulation._clock.time.timetuple().tm_yday / 365.25 # type: ignore [union-attr] @@ -132,7 +155,7 @@ def test_interpolated_tables_without_uninterpolated_columns( simulation._clock._clock_time += pd.Timedelta(30.5 * 125, unit="D") # type: ignore [operator] - result_years = years(simulation.get_population().index) + result_years = years(simulation.get_population_index()) fractional_year = simulation._clock.time.year # type: ignore [union-attr] fractional_year += simulation._clock.time.timetuple().tm_yday / 365.25 # type: ignore [union-attr] @@ -156,16 +179,15 @@ def test_interpolated_tables__exact_values_at_input_points( input_years = years_df.year_start.unique() base_config.update({"population": {"population_size": 10000}}) - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - years = manager._build_table( - years_df, key_columns=["sex"], parameter_columns=["age", "year"], value_columns=() - ) + years = manager._build_table(component, years_df, "", value_columns="value") for year in input_years: simulation._clock._clock_time = pd.Timestamp(year, 1, 1) assert np.allclose( - years(simulation.get_population().index), simulation._clock.time.year + 1 / 365 # type: ignore [union-attr] + years(simulation.get_population_index()), simulation._clock.time.year + 1 / 365 # type: ignore [union-attr] ) @@ -184,13 +206,12 @@ def test_interpolated_tables__only_categorical_parameters( base_config.update({"population": {"population_size": 10000}}) - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - lookup_table = manager._build_table( - input_data, key_columns=["sex", "location"], parameter_columns=(), value_columns=() - ) + lookup_table = manager._build_table(component, input_data, "", value_columns="some_value") - population = simulation.get_population()[["sex", "location"]] + population = simulation.get_population(["sex", "location"]) output_data = lookup_table(population.index) for i, (sex, location) in combinations: @@ -200,13 +221,14 @@ def test_interpolated_tables__only_categorical_parameters( @pytest.mark.parametrize("data", [(1, 2), [1, 2]]) def test_lookup_table_scalar_from_list( - base_config: LayeredConfigTree, data: Sequence[int] + base_config: LayeredConfigTree, data: list[ScalarValue] | tuple[ScalarValue, ...] ) -> None: - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - table = manager._build_table( - data, key_columns=(), parameter_columns=(), value_columns=["a", "b"] # type: ignore [arg-type] - )(simulation.get_population().index) + table = manager._build_table(component, data, "", value_columns=["a", "b"])( + simulation.get_population_index() + ) assert isinstance(table, pd.DataFrame) assert table.columns.values.tolist() == ["a", "b"] @@ -215,20 +237,22 @@ def test_lookup_table_scalar_from_list( def test_lookup_table_scalar_from_single_value(base_config: LayeredConfigTree) -> None: - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - table = manager._build_table( - 1, key_columns=(), parameter_columns=(), value_columns=["a"] - )(simulation.get_population().index) + table = manager._build_table(component, 1, "", value_columns="a")( + simulation.get_population_index() + ) assert isinstance(table, pd.Series) assert np.all(table == 1) def test_invalid_data_type_build_table(base_config: LayeredConfigTree) -> None: - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables with pytest.raises(TypeError): - manager._build_table("break", key_columns=(), parameter_columns=(), value_columns=()) # type: ignore [arg-type] + manager._build_table(component, "break", "", value_columns=()) # type: ignore [arg-type] def test_lookup_table_interpolated_return_types(base_config: LayeredConfigTree) -> None: @@ -241,120 +265,126 @@ def test_lookup_table_interpolated_return_types(base_config: LayeredConfigTree) "age": (0, 125), }, ) - - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables - table = manager._build_table( - data, key_columns=["sex"], parameter_columns=["age", "year"], value_columns=() - )(simulation.get_population().index) + table = manager._build_table(component, data, "", value_columns="value")( + simulation.get_population_index() + ) # make sure a single value column is returned as a series assert isinstance(table, pd.Series) # now add a second value column to make sure the result is a df data["value2"] = data.value - table = manager._build_table( - data, key_columns=["sex"], parameter_columns=["age", "year"], value_columns=() - )(simulation.get_population().index) + table = manager._build_table(component, data, "", value_columns=["value", "value2"])( + simulation.get_population_index() + ) assert isinstance(table, pd.DataFrame) -@pytest.mark.parametrize( - "data", [None, pd.DataFrame(), pd.DataFrame(columns=["a", "b", "c"]), [], tuple()] -) -def test_validate_parameters_no_data(data: LookupTableData) -> None: - with pytest.raises(ValueError, match="supply some data"): - validate_build_table_parameters(data, [], [], []) - - -@pytest.mark.parametrize( - "key_cols, param_cols, val_cols, match", - [ - ((), (), (), "supply value_columns"), - ((), (), [], "supply value_columns"), - ((), (), ["a", "b"], "match the number of values"), - (("a", "b"), (), ["d", "e", "f"], "key_columns are not allowed"), - ((), ("a", "b"), ["d", "e", "f"], "parameter_columns are not allowed"), - ], -) -def test_validate_parameters_error_scalar_data( - key_cols: Sequence[str], param_cols: Sequence[str], val_cols: Sequence[str], match: str -) -> None: - with pytest.raises(ValueError, match=match): - validate_build_table_parameters([1, 2, 3], key_cols, param_cols, val_cols) - - -@pytest.mark.parametrize("data", ["FAIL", pd.Interval(5, 10), "2019-05-17"]) -def test_validate_parameters_fail_other_data(data: LookupTableData) -> None: - with pytest.raises(TypeError, match="only allowable types"): - validate_build_table_parameters(data, [], [], []) - - -@pytest.mark.parametrize( - "key_cols, param_cols, val_cols, match", - [ - ([], [], ["c"], "either key_columns or parameter_columns"), - (["a", "b"], ["b"], ["c"], "no overlap between key.*and parameter columns"), - (["a"], ["b"], ["a", "c"], "no overlap between value.*and key.*columns"), - (["a"], ["b"], ["b", "c"], "no overlap between value.*and.*parameter columns"), - (["d"], ["b"], ["c"], "columns.*must all be present"), - (["a"], ["d"], ["c"], "columns.*must all be present"), - (["a"], ["b"], ["d"], "columns.*must all be present"), - ], -) -def test_validate_parameters_error_dataframe( - key_cols: Sequence[str], param_cols: Sequence[str], val_cols: Sequence[str], match: str -) -> None: - data = pd.DataFrame({"a": [1, 2], "b_start": [0, 5], "b_end": [5, 10], "c": [100, 150]}) - with pytest.raises(ValueError, match=match): - validate_build_table_parameters(data, key_cols, param_cols, val_cols) - - -def test_validate_parameters_pass_scalar_data() -> None: - validate_build_table_parameters([1, 2, 3], (), (), ["a", "b", "c"]) - - -@pytest.mark.parametrize( - "key_cols, param_cols, val_cols", - [ - (["a"], ["b"], ["c"]), - ([], ["b"], ["c", "a"]), - ([], ["b"], ["a", "c"]), - ([], ["b"], ["c"]), - (["a"], [], ["c"]), - (["a"], ["b"], []), - (["a"], [], []), - ([], ["b"], []), - ], -) -def test_validate_parameters_pass_dataframe( - key_cols: Sequence[str], param_cols: Sequence[str], val_cols: Sequence[str] -) -> None: - data = pd.DataFrame({"a": [1, 2], "b_start": [0, 5], "b_end": [5, 10], "c": [100, 150]}) - validate_build_table_parameters(data, key_cols, param_cols, val_cols) - +class TestLookupTableResource: + @pytest.fixture + def manager(self, mocker: MockerFixture) -> LookupTableManager: + manager = LookupTableManager() + manager.clock = mocker.Mock() + manager._get_view = mocker.Mock() + manager._add_resources = mocker.Mock() + manager._add_constraint = mocker.Mock() + manager._get_current_component = mocker.Mock() + manager.interpolation_order = 0 + manager.extrapolate = True + manager.validate_interpolation = True + return manager + + def test_scalar_table_resource_attributes(self, manager: LookupTableManager) -> None: + table = manager._build_table(LookupCreator(), 5, "test_table", value_columns="value") + assert table.resource_type == "lookup_table" + assert table.name == "lookup_creator.test_table" + assert table.resource_id == "lookup_table.lookup_creator.test_table" + assert table.required_resources == [] + + def test_categorical_table_resource_attributes(self, manager: LookupTableManager) -> None: + table = manager._build_table( + LookupCreator(), + pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6], "baz": [7, 8, 9]}), + "test_table", + value_columns="baz", + ) + assert table.resource_type == "lookup_table" + assert table.name == "lookup_creator.test_table" + assert table.resource_id == "lookup_table.lookup_creator.test_table" + assert table.required_resources == ["foo", "bar"] + + def test_interpolated_table_resource_attributes( + self, + manager: LookupTableManager, + ) -> None: + data = pd.DataFrame( + { + "foo": [1, 2, 3], + "bar_start": [0, 1, 2], + "bar_end": [1, 2, 3], + "year_start": [2000, 2001, 2002], + "year_end": [2001, 2002, 2003], + "baz": [7, 8, 9], + } + ) + table = manager._build_table(LookupCreator(), data, "test_table", value_columns="baz") + assert table.resource_type == "lookup_table" + assert table.name == "lookup_creator.test_table" + assert table.resource_id == "lookup_table.lookup_creator.test_table" + assert table.required_resources == ["foo", "bar"] + + def test_adding_resources(self, manager: LookupTableManager) -> None: + component = LookupCreator() + manager._get_current_component.return_value = component # type: ignore [attr-defined] + table = manager.build_table(5, "test_table", value_columns="value") + manager._add_resources.assert_called_once_with( # type: ignore[attr-defined] + component, table, table.required_resources + ) -@pytest.mark.parametrize("validate", [True, False]) -def test_validate_flag(mocker: MockerFixture, validate: bool) -> None: - manager = LookupTableManager() - manager.setup(mocker.Mock()) - manager._validate = validate - interface = LookupTableInterface(manager) - mock_validator = mocker.patch( - "vivarium.framework.lookup.manager.validate_build_table_parameters" +class TestValidateBuildTableParameters: + @pytest.mark.parametrize( + "data", [None, pd.DataFrame(), pd.DataFrame(columns=["a", "b", "c"]), [], tuple()] ) + def test_no_data(self, data: LookupTableData) -> None: + with pytest.raises(ValueError, match="supply some data"): + LookupTable._validate_data_inputs(data, []) - interface.build_table(0, value_columns=["a"]) - - if validate: - mock_validator.assert_called_once() - else: - mock_validator.assert_not_called() + @pytest.mark.parametrize( + "data, val_cols, match", + [ + ([1, 2, 3], "a", "value_columns must be a list or tuple of strings"), + ([1, 2, 3], ["a", "b"], "match the number of values"), + (5, ["a", "b"], "value_columns must be a string"), + ], + ) + def test_scalar_data_value_columns_mismatch( + self, data: LookupTableData, val_cols: str | list[str], match: str + ) -> None: + with pytest.raises(ValueError, match=match): + LookupTable._validate_data_inputs(data, val_cols) + + @pytest.mark.parametrize("data", ["FAIL", pd.Interval(5, 10), "2019-05-17"]) + def test_validate_parameters_fail_other_data(self, data: LookupTableData) -> None: + with pytest.raises(TypeError, match="only allowable types"): + LookupTable._validate_data_inputs(data, []) + + def test_validate_parameters_pass_scalar_data(self) -> None: + LookupTable._validate_data_inputs([1, 2, 3], ["a", "b", "c"]) + + def test_validate_parameters_pass_dataframe_data(self) -> None: + data = pd.DataFrame( + {"a": [1, 2], "b_start": [0, 5], "b_end": [5, 10], "c": [100, 150]} + ) + LookupTable._validate_data_inputs(data, ["c"]) def test__build_table_from_dict(base_config: LayeredConfigTree) -> None: - simulation = InteractiveContext(components=[TestPopulation()], configuration=base_config) + component = TestPopulation() + simulation = InteractiveContext(components=[component], configuration=base_config) manager = simulation._tables data = { "a_start": [0.0, 0.5, 1.0, 1.5], @@ -365,13 +395,37 @@ def test__build_table_from_dict(base_config: LayeredConfigTree) -> None: # We convert the dict to a dataframe before we call validate_build_table_parameters so # this test is really going to just ensure we don't error out when we pass in a dict and # we get the expected return type from _build_table - table = manager._build_table( - data, # type: ignore [arg-type] - key_columns=["b"], - parameter_columns=["a"], - value_columns=["c"], - ) - assert isinstance(table, InterpolatedTable) + table = manager._build_table(component, data, "", value_columns=["c"]) # type: ignore [arg-type] + assert isinstance(table, LookupTable) assert table.key_columns == ["b"] assert table.parameter_columns == ["a"] assert table.value_columns == ["c"] + + +def test_uncreated_lookup_table_warning( + base_config: LayeredConfigTree, caplog: pytest.LogCaptureFixture +) -> None: + """Test that a warning is logged when a lookup table is configured but not created.""" + + class ComponentWithUnusedLookupTable(Component): + @property + def configuration_defaults(self) -> dict[str, Any]: + return { + "component_with_unused_lookup_table": { + "data_sources": { + "unused_table": 42, + } + } + } + + InteractiveContext( + components=[ComponentWithUnusedLookupTable()], configuration=base_config + ) + + # Check that the warning was logged at WARNING level + warning_records = [record for record in caplog.records if record.levelname == "WARNING"] + assert len(warning_records) == 1 + assert ( + "Component 'component_with_unused_lookup_table' configured, but didn't build " + "lookup table 'unused_table' during setup." in warning_records[0].message + ) diff --git a/tests/framework/population/conftest.py b/tests/framework/population/conftest.py new file mode 100644 index 000000000..d2e19317b --- /dev/null +++ b/tests/framework/population/conftest.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import itertools +import math +from collections import defaultdict +from typing import Any + +import pandas as pd +import pytest +from pytest_mock import MockerFixture + +from vivarium import Component +from vivarium.framework.engine import Builder, SimulationContext +from vivarium.framework.population import PopulationManager, SimulantData +from vivarium.framework.values import ValuesManager + +# FIXME: Streamline with already-existing classes in tests/helpers.py +PIE_COL_NAMES = ["pie", "pi"] +PIES = ["apple", "chocolate", "pecan", "pumpkin", "sweet_potato"] +PIS = [math.pi**i for i in range(1, 11)] +PIE_RECORDS = [(pie, pi) for pie, pi in itertools.product(PIES, PIS)] +PIE_DF = pd.DataFrame(data=PIE_RECORDS, columns=PIE_COL_NAMES) +CUBE_COL_NAMES = ["cube", "cube_string"] +CUBE = [i**3 for i in range(len(PIE_RECORDS))] +CUBE_STRING = [str(i**3) for i in range(len(PIE_RECORDS))] +CUBE_DF = pd.DataFrame( + zip(CUBE, CUBE_STRING), + columns=CUBE_COL_NAMES, + index=PIE_DF.index, +) + + +class PieComponent(Component): + def setup(self, builder: Builder) -> None: + builder.population.register_initializer( + initializer=self.make_pie, columns=PIE_COL_NAMES + ) + + def make_pie(self, pop_data: SimulantData) -> None: + self.population_view.update(self.get_initial_state(pop_data.index)) + + def get_initial_state(self, index: pd.Index[int]) -> pd.DataFrame: + return PIE_DF + + +class CubeComponent(Component): + def setup(self, builder: Builder) -> None: + builder.population.register_initializer( + initializer=self.cubify, columns=CUBE_COL_NAMES + ) + + def cubify(self, pop_data: SimulantData) -> None: + self.population_view.update(self.get_initial_state(pop_data.index)) + + def get_initial_state(self, index: pd.Index[int]) -> pd.DataFrame: + return CUBE_DF + + +@pytest.fixture(scope="function") +def pies_and_cubes_pop_mgr(mocker: MockerFixture) -> PopulationManager: + """A mocked PopulationManager with some private columns set up. + + This fixture is tied directly to the PieComponent and CubeComponent helper classes. + + """ + + class _PopulationManager(PopulationManager): + def __init__(self) -> None: + super().__init__() + self._private_columns: pd.DataFrame = pd.concat([PIE_DF, CUBE_DF], axis=1) + + def _add_constraint(self, *args: Any, **kwargs: Any) -> None: + pass + + mgr = _PopulationManager() + + # Use SimulationContext just for builder and mock as appropriate + sim = SimulationContext() + builder = sim._builder + mocker.patch.object(ValuesManager, "logger", mocker.Mock(), create=True) + mocker.patch.object(ValuesManager, "resources", mocker.Mock(), create=True) + mocker.patch.object(ValuesManager, "add_constraint", mocker.Mock(), create=True) + mocker.patch.object(ValuesManager, "_population_mgr", mgr, create=True) + mocked_attribute_pipelines = {} + sim._lifecycle.set_state("setup") + mgr.setup(builder) + sim._lifecycle.set_state("post_setup") + sim._lifecycle.set_state("population_creation") + + for col in mgr._private_columns.columns: + mocked_attribute_pipelines[col] = mocker.Mock() + mgr._attribute_pipelines = mocked_attribute_pipelines + mgr._private_column_metadata = defaultdict( + list, + { + "pie_component": PIE_COL_NAMES, + "cube_component": CUBE_COL_NAMES, + }, + ) + # Change lifecycle phase to ensure tracked queries are applied appropriately + mocker.patch.object(mgr, "get_current_state", lambda: "on_time_step") + return mgr diff --git a/tests/framework/population/helpers.py b/tests/framework/population/helpers.py new file mode 100644 index 000000000..76d858149 --- /dev/null +++ b/tests/framework/population/helpers.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd + + +def assert_squeezing_multi_level_multi_outer( + unsqueezed: pd.Series[Any] | pd.DataFrame, squeezed: pd.Series[Any] | pd.DataFrame +) -> None: + assert isinstance(squeezed, pd.DataFrame) + assert isinstance(squeezed.columns, pd.MultiIndex) + assert squeezed.equals(unsqueezed) + + +def assert_squeezing_multi_level_single_outer_multi_inner( + unsqueezed: pd.Series[Any] | pd.DataFrame, squeezed: pd.Series[Any] | pd.DataFrame +) -> None: + assert isinstance(unsqueezed, pd.DataFrame) + assert isinstance(unsqueezed.columns, pd.MultiIndex) + assert isinstance(squeezed, pd.DataFrame) + assert not isinstance(squeezed.columns, pd.MultiIndex) + assert squeezed.equals(unsqueezed.droplevel(0, axis=1)) + + +def assert_squeezing_multi_level_single_outer_single_inner( + unsqueezed: pd.Series[Any] | pd.DataFrame, + squeezed: pd.Series[Any] | pd.DataFrame, + column: tuple[str, str] = ("attribute_generating_column_8", "test_column_8"), +) -> None: + assert isinstance(unsqueezed, pd.DataFrame) + assert isinstance(unsqueezed.columns, pd.MultiIndex) + assert isinstance(squeezed, pd.Series) + assert unsqueezed[column].equals(squeezed) + + +def assert_squeezing_single_level_multi_col( + unsqueezed: pd.Series[Any] | pd.DataFrame, squeezed: pd.Series[Any] | pd.DataFrame +) -> None: + assert isinstance(squeezed, pd.DataFrame) + assert not isinstance(squeezed.columns, pd.MultiIndex) + assert squeezed.equals(unsqueezed) + + +def assert_squeezing_single_level_single_col( + unsqueezed: pd.Series[Any] | pd.DataFrame, + squeezed: pd.Series[Any] | pd.DataFrame, + column: str = "test_column_1", +) -> None: + assert isinstance(unsqueezed, pd.DataFrame) + assert not isinstance(unsqueezed.columns, pd.MultiIndex) + assert isinstance(squeezed, pd.Series) + assert unsqueezed[column].equals(squeezed) diff --git a/tests/framework/population/test_manager.py b/tests/framework/population/test_manager.py index 9ddd6e34a..23835dfef 100644 --- a/tests/framework/population/test_manager.py +++ b/tests/framework/population/test_manager.py @@ -1,30 +1,31 @@ -import pytest - -from vivarium import Component -from vivarium.framework.population.exceptions import PopulationError -from vivarium.framework.population.manager import ( - InitializerComponentSet, - PopulationManager, - SimulantData, -) - - -def test_initializer_set_fail_type() -> None: - component_set = InitializerComponentSet() +from __future__ import annotations - with pytest.raises(TypeError): - component_set.add(lambda _: None, ["test_column"]) +from typing import Any, Literal - def initializer(simulant_data: SimulantData) -> None: - pass - - with pytest.raises(TypeError): - component_set.add(initializer, ["test_column"]) - - -class NonComponent: - def initializer(self, simulant_data: SimulantData) -> None: - pass +import pandas as pd +import pytest +from pytest_mock import MockerFixture + +from tests.framework.population.conftest import CUBE_COL_NAMES, PIE_COL_NAMES, PIE_RECORDS +from tests.framework.population.helpers import ( + assert_squeezing_multi_level_multi_outer, + assert_squeezing_multi_level_single_outer_multi_inner, + assert_squeezing_multi_level_single_outer_single_inner, + assert_squeezing_single_level_multi_col, + assert_squeezing_single_level_single_col, +) +from tests.helpers import ( + AttributePipelineCreator, + ColumnCreator, + ColumnCreatorAndRequirer, + MultiLevelMultiColumnCreator, + MultiLevelSingleColumnCreator, + SingleColumnCreator, +) +from vivarium import Component, InteractiveContext +from vivarium.framework.engine import Builder +from vivarium.framework.population.exceptions import PopulationError +from vivarium.framework.population.manager import PopulationManager, SimulantData class InitializingComponent(Component): @@ -43,78 +44,500 @@ def other_initializer(self, simulant_data: SimulantData) -> None: pass -def test_initializer_set_fail_attr() -> None: - component_set = InitializerComponentSet() - - with pytest.raises(AttributeError): - component_set.add(NonComponent().initializer, ["test_column"]) - - -def test_initializer_set_duplicate_component() -> None: - component_set = InitializerComponentSet() - component = InitializingComponent("test") - - component_set.add(component.initializer, ["test_column1"]) - with pytest.raises(PopulationError, match="multiple population initializers"): - component_set.add(component.other_initializer, ["test_column2"]) - - -def test_initializer_set_duplicate_columns() -> None: - component_set = InitializerComponentSet() - component1 = InitializingComponent("test1") - component2 = InitializingComponent("test2") - columns = ["test_column"] - - component_set.add(component1.initializer, columns) - with pytest.raises(PopulationError, match="both registered initializers"): - component_set.add(component2.initializer, columns) - - with pytest.raises(PopulationError, match="both registered initializers"): - component_set.add(component2.initializer, ["sneaky_column"] + columns) - - -def test_initializer_set_population_manager() -> None: - component_set = InitializerComponentSet() - population_manager = PopulationManager() - - component_set.add(population_manager.on_initialize_simulants, ["tracked"]) - - -def test_initializer_set() -> None: - component_set = InitializerComponentSet() - for i in range(10): - component = InitializingComponent(str(i)) - columns = [f"test_column_{i}_{j}" for j in range(5)] - component_set.add(component.initializer, columns) - - +@pytest.mark.parametrize("private_columns", [[], ["age", "sex"]]) +def test_setting_columns_with_get_view( + private_columns: list[str], mocker: MockerFixture +) -> None: + manager = PopulationManager() + component = mocker.Mock() + component.name = "test_component" + manager._private_column_metadata["test_component"] = private_columns + view = manager._get_view(component=component) + assert view.private_columns == private_columns + + +@pytest.mark.parametrize("attributes", ("all", PIE_COL_NAMES, ["pie", "cube"])) +@pytest.mark.parametrize("index", [None, pd.RangeIndex(0, len(PIE_RECORDS) // 2)]) +@pytest.mark.parametrize("query", [None, "pie == 'apple'"]) +def test_get_population( + attributes: Literal["all"] | list[str], + index: pd.Index[int] | None, + query: str, + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + kwargs: dict[str, Any] = {"attributes": attributes} + if index is not None: + kwargs["index"] = index + if query is not None: + kwargs["query"] = query + assert attributes == "all" or isinstance(attributes, list) + pop = pies_and_cubes_pop_mgr.get_population(**kwargs) + assert ( + set(pop.columns) == set(PIE_COL_NAMES + CUBE_COL_NAMES) + if attributes == "all" + else set(attributes) + ) + if query is not None: + assert (pop["pie"] == "apple").all() + + +def test_get_population_different_attribute_types() -> None: + """Test that get_population works with simple attributes, non-simple attributes, + and attribute pipelines that return dataframes instead of series'.""" + component1 = ColumnCreator() + component2 = AttributePipelineCreator() + sim = InteractiveContext(components=[component1, component2], setup=True) + pop = sim._population.get_population("all") + # We have columnar multi-index due to AttributePipelines that return dataframes + assert isinstance(pop.columns, pd.MultiIndex) + assert set(pop.columns) == { + ("test_column_1", ""), + ("test_column_2", ""), + ("test_column_3", ""), + ("attribute_generating_columns_4_5", "test_column_4"), + ("attribute_generating_columns_4_5", "test_column_5"), + ("attribute_generating_column_8", "test_column_8"), + ("test_attribute", ""), + ("attribute_generating_columns_6_7", "test_column_6"), + ("attribute_generating_columns_6_7", "test_column_7"), + } + value_cols = [col for col in pop.columns if col != ("simulant_step_size", "")] + expected = pd.Series([idx % 3 for idx in pop.index]) + for col in value_cols: + pd.testing.assert_series_equal(pop[col], expected, check_names=False) + + +class TestGetPopulationSqueezing: + """Tests for squeeze behavior on get_population with specific columns.""" + + @pytest.fixture(scope="class") + def sim(self) -> InteractiveContext: + return InteractiveContext(components=[ColumnCreator(), AttributePipelineCreator()]) + + def assert_squeezing( + self, + sim: InteractiveContext, + columns: list[str] | Literal["all"], + assert_fn: Any, + *assert_args: Any, + ) -> None: + unsqueezed = sim._population.get_population(columns, squeeze=False) + squeezed = sim._population.get_population(columns, squeeze=True) + assert_fn(unsqueezed, squeezed, *assert_args) + + def test_single_level_single_column_returns_series(self, sim: InteractiveContext) -> None: + self.assert_squeezing( + sim, ["test_column_1"], assert_squeezing_single_level_single_col + ) + + def test_single_level_multi_column_returns_dataframe( + self, sim: InteractiveContext + ) -> None: + self.assert_squeezing( + sim, ["test_column_1", "test_column_2"], assert_squeezing_single_level_multi_col + ) + + def test_multi_level_single_outer_single_inner_returns_series( + self, sim: InteractiveContext + ) -> None: + self.assert_squeezing( + sim, + ["attribute_generating_column_8"], + assert_squeezing_multi_level_single_outer_single_inner, + ) + + def test_multi_level_single_outer_multi_inner_returns_inner_dataframe( + self, sim: InteractiveContext + ) -> None: + self.assert_squeezing( + sim, + ["attribute_generating_columns_4_5"], + assert_squeezing_multi_level_single_outer_multi_inner, + ) + + def test_multi_level_multi_outer_returns_full_dataframe( + self, sim: InteractiveContext + ) -> None: + self.assert_squeezing( + sim, + ["test_column_1", "attribute_generating_columns_6_7"], + assert_squeezing_multi_level_multi_outer, + ) + + def test_all_columns_single_level_single_column_returns_series(self) -> None: + sim = InteractiveContext(components=[SingleColumnCreator()]) + self.assert_squeezing( + sim, "all", assert_squeezing_single_level_single_col, "test_column_1" + ) + + def test_all_columns_single_level_multi_column_returns_dataframe(self) -> None: + sim = InteractiveContext(components=[ColumnCreator()]) + self.assert_squeezing(sim, "all", assert_squeezing_single_level_multi_col) + + def test_all_columns_multi_level_single_outer_single_inner_returns_series(self) -> None: + sim = InteractiveContext(components=[MultiLevelSingleColumnCreator()]) + self.assert_squeezing( + sim, + "all", + assert_squeezing_multi_level_single_outer_single_inner, + ("some_attribute", "some_column"), + ) + + def test_all_columns_multi_level_single_outer_multi_inner_returns_inner_dataframe( + self, + ) -> None: + sim = InteractiveContext(components=[MultiLevelMultiColumnCreator()]) + sim._population._attribute_pipelines.pop("some_other_attribute") + self.assert_squeezing( + sim, "all", assert_squeezing_multi_level_single_outer_multi_inner + ) + + def test_all_columns_multi_level_multi_outer_returns_full_dataframe(self) -> None: + sim = InteractiveContext(components=[ColumnCreator(), AttributePipelineCreator()]) + self.assert_squeezing(sim, "all", assert_squeezing_multi_level_multi_outer) + + +@pytest.mark.parametrize("include_duplicates", [False, True]) @pytest.mark.parametrize( - "contains_tracked, query, expected_query", + "query", [ - (True, "", ""), - (True, "foo == True", "foo == True"), - (False, "", "tracked == True"), - (False, "foo == True", "foo == True and tracked == True"), + None, # default + "test_column_1 < 2", # query on a requested column + "test_column_2 < 2", # query on a non-requested column ], ) -def test_setting_query_with_get_view( - contains_tracked: bool, query: str, expected_query: str +def test_get_population_column_ordering(include_duplicates: bool, query: str | None) -> None: + def _extract_ordered_list(cols: list[str]) -> list[tuple[str, str]]: + col_mapping = { + "test_column_1": ("test_column_1", ""), + "attribute_generating_columns_4_5": [ + ("attribute_generating_columns_4_5", "test_column_4"), + ("attribute_generating_columns_4_5", "test_column_5"), + ], + "test_attribute": ("test_attribute", ""), + } + expected_cols = [] + for col in cols: + col_tuple = col_mapping[col] + if isinstance(col_tuple, list): + for item in col_tuple: + if item not in expected_cols: + expected_cols.append(item) + else: + if col_tuple not in expected_cols: + expected_cols.append(col_tuple) + return expected_cols + + def _check_col_ordering( + sim: InteractiveContext, kwargs: dict[str, str | list[str]] + ) -> None: + pop = sim._population.get_population(**kwargs) # type: ignore[call-overload] + expected_cols = _extract_ordered_list(cols) + assert isinstance(pop.columns, pd.MultiIndex) + returned_cols = pop.columns.tolist() + assert returned_cols == expected_cols + + component1 = ColumnCreator() + component2 = AttributePipelineCreator() + sim = InteractiveContext(components=[component1, component2], setup=True) + + cols = ["test_column_1", "attribute_generating_columns_4_5", "test_attribute"] + if include_duplicates: + cols.extend(cols) # duplicate the list + kwargs: dict[str, str | list[str]] = {} + kwargs["attributes"] = cols + if query is not None: + kwargs["query"] = query + _check_col_ordering(sim, kwargs) + # Now try reversing the order + # NOTE: we specifically do not parametrize this test to ensure that the two + # 'get_population' calls are happening on exactly the same population manager + cols.reverse() + _check_col_ordering(sim, kwargs) + + +@pytest.mark.parametrize( + "attributes", + ( + ["age", "sex"], + PIE_COL_NAMES + ["age", "sex"], + ["age", "sex"], + ["color", "count", "age"], + ), +) +def test_get_population_raises_missing_attributes( + attributes: list[str], pies_and_cubes_pop_mgr: PopulationManager ) -> None: - manager = PopulationManager() - columns = ["age", "sex"] - if contains_tracked: - columns.append("tracked") - view = manager._get_view(columns=columns, query=query) - assert view.query == expected_query + with pytest.raises(PopulationError, match="not in population state table"): + pies_and_cubes_pop_mgr.get_population(attributes) + + +def test_get_population_raises_bad_string(pies_and_cubes_pop_mgr: PopulationManager) -> None: + with pytest.raises(TypeError, match="Attributes must be a list of strings or 'all'"): + pies_and_cubes_pop_mgr.get_population("invalid_string") # type: ignore[call-overload] + + +def test__get_attributes_three_or_more_levels_not_implemented() -> None: + class BadAttributeCreator(Component): + def setup(self, builder: Builder) -> None: + builder.value.register_attribute_producer( + "animals", + lambda idx: pd.DataFrame( + { + ("cat", "size"): "teeny-tiny", + ("cat", "color"): "tuxedo", + ("dog", "size"): "huge", + ("dog", "color"): "spotted", + }, + index=idx, + ), + ) + + sim = InteractiveContext(components=[BadAttributeCreator()], setup=True) + with pytest.raises( + NotImplementedError, + match="Multi-level columns in attribute pipeline outputs are not supported.", + ): + sim._population.get_population(["animals"]) + + +def test_get_population_deduplicates_requested_columns( + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + pop = pies_and_cubes_pop_mgr.get_population(["pie", "pie", "pie"], squeeze=False) + assert set(pop.columns) == {"pie"} + + +def test_register_initializer(mocker: MockerFixture) -> None: + class ColumnCreator2(ColumnCreator): + @property + def name(self) -> str: + return "column_creator_2" + + class ColumnCreator3(ColumnCreator): + @property + def name(self) -> str: + return "column_creator_3" + + # The metadata for the manager should be empty because the fixture does not + # actually go through setup. + mgr = PopulationManager() + mock_register_attr = mocker.Mock() + mocker.patch.object(mgr, "_register_attribute_producer", mock_register_attr, create=True) + mock_resources = mocker.Mock() + mocker.patch.object(mgr, "resources", mock_resources, create=True) + mock_add_private_cols = mocker.Mock() + mocker.patch.object( + mgr.resources, "add_private_columns", mock_add_private_cols, create=True + ) + + assert mgr._private_column_metadata == {} + + component1 = ColumnCreator() + mocker.patch.object( + mgr, "_get_current_component_or_manager", return_value=component1, create=True + ) + mgr.register_initializer( + initializer=component1.initialize_test_columns, + columns=["foo", "bar"], + required_resources=["dep1", "dep2"], + ) + + component2 = ColumnCreator2() + mocker.patch.object( + mgr, "_get_current_component_or_manager", return_value=component2, create=True + ) + mgr.register_initializer( + initializer=component2.initialize_test_columns, + columns=None, + required_resources=["dep3", "dep4"], + ) + + component3 = ColumnCreator3() + mocker.patch.object( + mgr, "_get_current_component_or_manager", return_value=component3, create=True + ) + mgr.register_initializer( + initializer=component3.initialize_test_columns, columns="qux", required_resources=[] + ) + + # Check that register_attribute_producer was called appropriately + assert mock_register_attr.call_count == 3 + for column in ["foo", "bar", "qux"]: + mock_register_attr.assert_any_call( + column, source=[column], source_is_private_column=True + ) + + # Check the private column metadata + assert mgr._private_column_metadata == { + component1.name: ["foo", "bar"], + component2.name: [], + component3.name: ["qux"], + } + + # Check that resources.add_private_columns was called appropriately + assert mock_add_private_cols.call_count == 3 + mock_add_private_cols.assert_any_call( + columns=["foo", "bar"], + required_resources=["dep1", "dep2"], + initializer=component1.initialize_test_columns, + ) + mock_add_private_cols.assert_any_call( + columns=[], + required_resources=["dep3", "dep4"], + initializer=component2.initialize_test_columns, + ) + mock_add_private_cols.assert_any_call( + columns=["qux"], required_resources=[], initializer=component3.initialize_test_columns + ) + + +def test_register_initializer_duplicate_raises(mocker: MockerFixture) -> None: + component = ColumnCreator() + mgr = PopulationManager() + mocker.patch.object( + mgr, "_get_current_component_or_manager", return_value=component, create=True + ) + mocker.patch.object(mgr, "_register_attribute_producer", mocker.Mock(), create=True) + mock_resources = mocker.Mock() + mocker.patch.object(mgr, "resources", mock_resources, create=True) + + # First registration should succeed + mgr.register_initializer( + initializer=component.initialize_test_columns, + columns=["col_a"], + ) + + # Registering the same initializer again should raise + with pytest.raises(PopulationError, match="has already been registered"): + mgr.register_initializer( + initializer=component.initialize_test_columns, + columns=["col_b"], + ) @pytest.mark.parametrize( - "columns, expected_columns", [("age", ["age"]), (["age"], None), (["age", "sex"], None)] + "components, index, columns", + [ + ([ColumnCreator(), ColumnCreatorAndRequirer()], None, None), + ([ColumnCreator()], pd.Index([4, 8, 15, 16, 23, 42]), None), + ([ColumnCreator()], None, ["test_column_2"]), + ( + [ColumnCreator()], + pd.Index([4, 8, 15, 16, 23, 42]), + ["test_column_1", "test_column_3"], + ), + ], ) -def test_setting_columns_with_get_view( - columns: str | list[str], expected_columns: list[str] | None +def test_get_private_columns( + components: list[Component], index: pd.Index[int] | None, columns: list[str] | None ) -> None: - view_columns = expected_columns or columns - manager = PopulationManager() - view = manager._get_view(columns=columns, query="") - assert view._columns == view_columns + sim = InteractiveContext(components=components) + kwargs: dict[str, pd.Index[int] | list[str]] = {} + if index is not None: + kwargs["index"] = index + if columns is not None: + kwargs["columns"] = columns + for component in components: + private_columns = pd.DataFrame(sim._population.get_private_columns(component, **kwargs)) # type: ignore[arg-type] + if index is not None: + assert private_columns.index.equals(index) + else: + assert private_columns.index.equals(sim._population.get_population_index()) + if columns is not None: + assert list(private_columns.columns) == columns + else: + assert list(private_columns.columns) == component.private_columns + + +def test_get_private_columns_squeezing() -> None: + + # Single-level, single-column -> series + single_col_creator = SingleColumnCreator() + sim = InteractiveContext(components=[single_col_creator], setup=True) + unsqueezed = sim._population.get_private_columns( + single_col_creator, columns=["test_column_1"] + ) + squeezed = sim._population.get_private_columns( + single_col_creator, columns="test_column_1" + ) + assert_squeezing_single_level_single_col(unsqueezed, squeezed) + default = sim._population.get_private_columns(single_col_creator) + assert isinstance(default, pd.Series) and isinstance(squeezed, pd.Series) + assert default.equals(squeezed) + + # Single-level, multiple-column -> dataframe + col_creator = ColumnCreator() + sim = InteractiveContext(components=[col_creator], setup=True) + # There's no way to squeeze here. + df = sim._population.get_private_columns( + col_creator, columns=["test_column_1", "test_column_2", "test_column_3"] + ) + assert isinstance(df, pd.DataFrame) + assert not isinstance(df.columns, pd.MultiIndex) + default = sim._population.get_private_columns(col_creator) + assert isinstance(default, pd.DataFrame) + assert default.equals(df) + + +def test_get_private_columns_raises_on_initial_pop_creation() -> None: + mgr = PopulationManager() + mgr.creating_initial_population = True + with pytest.raises( + PopulationError, + match="Cannot get private columns during initial population creation", + ): + mgr.get_private_columns(ColumnCreator(), columns=["test_column_1"]) + + +def test_get_private_columns_raises_bad_column_request() -> None: + mgr = PopulationManager() + with pytest.raises( + PopulationError, + match="is requesting the following private columns to which it does not have access", + ): + mgr.get_private_columns(ColumnCreator(), columns=["foo"]) + + +def test_get_population_index() -> None: + component = AttributePipelineCreator() + sim = InteractiveContext(components=[component], setup=False) + with pytest.raises(PopulationError, match="Population has not been initialized."): + sim._population.get_population_index() + sim.setup() + private_cols = pd.DataFrame(sim._population._private_columns) + private_cols.index.equals(sim._population.get_population_index()) + + +def test_forget_to_create_columns() -> None: + class ColumnForgetter(ColumnCreator): + def initialize_test_columns(self, pop_data: SimulantData) -> None: + pass + + with pytest.raises(PopulationError, match="not actually created"): + InteractiveContext(components=[ColumnForgetter()]) + + +def test_create_already_existing_columns_fails() -> None: + class SameColumnCreator(ColumnCreator): + ... + + with pytest.raises( + PopulationError, + match="Component 'same_column_creator' is attempting to register private column 'test_column_1' but it is already registered by component 'column_creator'.", + ): + InteractiveContext(components=[ColumnCreator(), SameColumnCreator()]) + + +def test_register_tracked_query(mocker: MockerFixture) -> None: + mgr = PopulationManager() + assert mgr.tracked_queries == [] + mgr.register_tracked_query("foo == 'bar'") + assert mgr.tracked_queries == ["foo == 'bar'"] + mgr.register_tracked_query("cat != dog") + assert mgr.tracked_queries == ["foo == 'bar'", "cat != dog"] + # Check duplicates are ignored + mocker.patch.object(mgr, "logger", mocker.Mock(), create=True) + mgr.register_tracked_query("foo == 'bar'") + mgr.logger.warning.assert_called_once() # type: ignore[attr-defined] + assert mgr.tracked_queries == ["foo == 'bar'", "cat != dog"] diff --git a/tests/framework/population/test_population_view.py b/tests/framework/population/test_population_view.py index f191f0607..443ee6b46 100644 --- a/tests/framework/population/test_population_view.py +++ b/tests/framework/population/test_population_view.py @@ -1,64 +1,42 @@ from __future__ import annotations -import itertools -import math import random +import re from typing import Any import pandas as pd import pytest - +from pytest_mock import MockerFixture + +from tests.framework.population.conftest import ( + CUBE_COL_NAMES, + CUBE_DF, + PIE_COL_NAMES, + PIE_DF, + PIE_RECORDS, + CubeComponent, + PieComponent, +) +from tests.framework.population.helpers import ( + assert_squeezing_multi_level_single_outer_single_inner, + assert_squeezing_single_level_single_col, +) +from tests.helpers import AttributePipelineCreator, ColumnCreator, SingleColumnCreator +from vivarium import InteractiveContext +from vivarium.framework.engine import Builder +from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.population import PopulationError, PopulationManager, PopulationView ########################## # Mock data and fixtures # ########################## -COL_NAMES = ["color", "count", "pie", "pi", "tracked"] -COLORS = ["red", "green", "yellow"] -COUNTS = [10, 20, 30] -PIES = ["apple", "chocolate", "pecan"] -PIS = [math.pi**i for i in range(1, 4)] -TRACKED_STATUSES = [True, False] -RECORDS = [ - (color, count, pie, pi, ts) - for color, count, pie, pi, ts in itertools.product( - COLORS, COUNTS, PIES, PIS, TRACKED_STATUSES - ) -] -BASE_POPULATION = pd.DataFrame(data=RECORDS, columns=COL_NAMES) - -NEW_COL_NAMES = ["cube", "cube_string"] -CUBE = [i**3 for i in range(len(RECORDS))] -CUBE_STRING = [str(i**3) for i in range(len(RECORDS))] -NEW_ATTRIBUTES = pd.DataFrame( - zip(CUBE, CUBE_STRING), - columns=NEW_COL_NAMES, - index=BASE_POPULATION.index, -) - - -@pytest.fixture(scope="function") -def population_manager() -> PopulationManager: - class _PopulationManager(PopulationManager): - def __init__(self) -> None: - super().__init__() - self._population = pd.DataFrame( - data=RECORDS, - columns=COL_NAMES, - ) - - def _add_constraint(self, *args: Any, **kwargs: Any) -> None: - pass - - return _PopulationManager() - @pytest.fixture( params=[ - BASE_POPULATION.index, - BASE_POPULATION.index[::2], - BASE_POPULATION.index[:0], + PIE_DF.index, + PIE_DF.index[::2], + PIE_DF.index[:0], ] ) def update_index(request: pytest.FixtureRequest) -> pd.Index[int]: @@ -69,10 +47,10 @@ def update_index(request: pytest.FixtureRequest) -> pd.Index[int]: @pytest.fixture( params=[ - BASE_POPULATION.copy(), - BASE_POPULATION[COL_NAMES[:2]].copy(), - BASE_POPULATION[[COL_NAMES[0]]].copy(), - BASE_POPULATION[COL_NAMES[0]].copy(), + PIE_DF.copy(), + PIE_DF[PIE_COL_NAMES[1:2]].copy(), + PIE_DF[[PIE_COL_NAMES[0]]].copy(), + PIE_DF[PIE_COL_NAMES[0]].copy(), ] ) def population_update( @@ -85,11 +63,9 @@ def population_update( @pytest.fixture( params=[ - NEW_ATTRIBUTES.copy(), - NEW_ATTRIBUTES[[NEW_COL_NAMES[0]]].copy(), - NEW_ATTRIBUTES[NEW_COL_NAMES[0]].copy(), - pd.concat([BASE_POPULATION, NEW_ATTRIBUTES], axis=1), - pd.concat([BASE_POPULATION.iloc[:, 0], NEW_ATTRIBUTES.iloc[:, 0]], axis=1), + CUBE_DF.copy(), + CUBE_DF[[CUBE_COL_NAMES[0]]].copy(), + CUBE_DF[CUBE_COL_NAMES[0]].copy(), ] ) def population_update_new_cols( @@ -100,154 +76,437 @@ def population_update_new_cols( return update +@pytest.fixture( + params=[ + None, + "pie != 'pecan' and pi > 1000 and cube > 100000", + ] +) +def query(request: pytest.FixtureRequest) -> str | None: + assert isinstance(request.param, (str, type(None))) + return request.param + + +@pytest.fixture(params=["pie != 'apple'"]) +def tracked_query(request: pytest.FixtureRequest) -> str: + assert isinstance(request.param, str) + return request.param + + ################## # Initialization # ################## -def test_initialization(population_manager: PopulationManager) -> None: - pv = population_manager.get_view(COL_NAMES) +def test_initialization(pies_and_cubes_pop_mgr: PopulationManager) -> None: + component = PieComponent() + expected_private_columns = set(["pie", "pi"]) + pv = pies_and_cubes_pop_mgr.get_view(component) assert pv._id == 0 - assert pv.name == "population_view_0" - assert set(pv.columns) == set(COL_NAMES) - assert pv.query == "" - - # Failure here is lazy. The manager should give you back views for - # columns that don't exist since views are built during setup when - # we don't necessarily know all the columns yet. - cols = ["age", "sex", "tracked"] - pv = population_manager.get_view(cols) + assert pv.name == f"population_view_{pv._id}" + assert set(pv.private_columns) == expected_private_columns + + pv = pies_and_cubes_pop_mgr.get_view(component) assert pv._id == 1 - assert pv.name == "population_view_1" - assert set(pv.columns) == set(cols) - assert pv.query == "" - - col_subset = ["color", "count"] - pv = population_manager.get_view(col_subset) - assert pv._id == 2 - assert pv.name == "population_view_2" - assert set(pv.columns) == set(col_subset) - # View will filter to tracked by default if it's not requested as a column - assert pv.query == "tracked == True" - - q_string = "color == 'red'" - pv = population_manager.get_view(COL_NAMES, query=q_string) - assert pv._id == 3 - assert pv.name == "population_view_3" - assert set(pv.columns) == set(COL_NAMES) - assert pv.query == q_string + assert pv.name == f"population_view_{pv._id}" + assert set(pv.private_columns) == expected_private_columns -########################## -# PopulationView.subview # -########################## +################################# +# PopulationView.get_attributes # +################################# -@pytest.mark.parametrize( - "columns", - [["color", "count"], ["color"], ["color", "count", "tracked"]], - ids=["multiple columns", "single column", "including tracked"], -) -def test_subview_columns_list_input( - population_manager: PopulationManager, columns: list[str] -) -> None: - pv = population_manager.get_view(COL_NAMES) - sub_pv = pv.subview(columns) - assert set(sub_pv.columns) == set(columns) +def test_get_attributes(pies_and_cubes_pop_mgr: PopulationManager) -> None: + ######################## + # Full population view # + ######################## + component = PieComponent() + pv = pies_and_cubes_pop_mgr.get_view(component) + full_idx = pd.RangeIndex(0, len(PIE_RECORDS)) + + # Get full data set + pop_full = pv.get_attributes(full_idx, PIE_COL_NAMES) + assert set(pop_full.columns) == set(PIE_COL_NAMES) + assert pop_full.index.equals(full_idx) + + # Get data subset + pop = pv.get_attributes(full_idx, PIE_COL_NAMES, query=f"pie == 'apple'") + assert set(pop.columns) == set(PIE_COL_NAMES) + assert pop.index.equals(pop_full[pop_full["pie"] == "apple"].index) -def test_subview_columns_string_input(population_manager: PopulationManager) -> None: - pv = population_manager.get_view(COL_NAMES) - sub_pv = pv.subview("color") - assert set(sub_pv.columns) == {"color"} +def test_get_attributes_empty_idx(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + + pop = pv.get_attributes(pd.Index([]), PIE_COL_NAMES) + assert isinstance(pop, pd.DataFrame) + assert set(pop.columns) == set(PIE_COL_NAMES) + assert pop.empty -@pytest.mark.parametrize("columns", [["color", "count"], ["color", "count", "tracked"]]) -def test_subview_queries(population_manager: PopulationManager, columns: list[str]) -> None: - pv = population_manager.get_view(COL_NAMES, query="foo == 'red'") +def test_get_attributes_raises(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + index = pd.Index([]) - # get subview - sub_pv = pv.subview(columns) + with pytest.raises( + PopulationError, + match="Requested attribute\(s\) \{'foo'\} not in population state table.", + ): + pv.get_attributes(index, "foo") + + +@pytest.mark.parametrize("attribute", ["pie", ["pie"]]) +def test_get_attributes_skip_post_processor( + attribute: str | list[str], pies_and_cubes_pop_mgr: PopulationManager +) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_idx = pd.RangeIndex(0, len(PIE_RECORDS)) - expected_query = pv.query if "tracked" in columns else f"{pv.query} and tracked == True" - assert sub_pv.query == expected_query + key = attribute if isinstance(attribute, str) else attribute[0] + mocked_pie_pipeline = pies_and_cubes_pop_mgr._attribute_pipelines[key] + pv.get_attributes(full_idx, attribute, skip_post_processor=True) + mocked_pie_pipeline.assert_called_once_with(full_idx, skip_post_processor=True) # type: ignore[attr-defined] + + +def test_get_attributes_skip_post_processor_raises( + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_idx = pd.RangeIndex(0, len(PIE_RECORDS)) + + with pytest.raises( + ValueError, + match="When skip_post_processor is True, a single attribute must be requested.", + ): + pv.get_attributes(full_idx, ["pie", "pi"], skip_post_processor=True) @pytest.mark.parametrize( - "columns", - [["age", "sex"], COL_NAMES + ["age"], []], - ids=["columns not in pop_view", "one column not in pop_view", "no columns"], + "attribute, query", [("pie", "pie == 'apple'"), ("pie", "cube > 1000")] ) -def test_subview_bad_columns_input( - population_manager: PopulationManager, columns: list[str] +def test_get_attributes_skip_post_processor_with_query( + attribute: str, + query: str, + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + """Test that the index is reduced when a query is passed with skip_post_processor=True.""" + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_idx = pd.RangeIndex(0, len(PIE_RECORDS)) + + # Set up the mocked pipelines to return actual data from the private columns + # so that the query can be executed + def mock_pie_pipeline(idx: pd.Index[int], skip_post_processor: bool) -> pd.Series[Any]: + private_col_df = pies_and_cubes_pop_mgr._private_columns + assert isinstance(private_col_df, pd.DataFrame) + return private_col_df.loc[idx, "pie"] + + def mock_cube_pipeline(idx: pd.Index[int], skip_post_processor: bool) -> pd.Series[Any]: + private_col_df = pies_and_cubes_pop_mgr._private_columns + assert isinstance(private_col_df, pd.DataFrame) + return private_col_df.loc[idx, "cube"] + + pies_and_cubes_pop_mgr._attribute_pipelines["pie"].side_effect = mock_pie_pipeline # type: ignore[attr-defined] + pies_and_cubes_pop_mgr._attribute_pipelines["cube"].side_effect = mock_cube_pipeline # type: ignore[attr-defined] + + # Execute get_attributes with a query and skip_post_processor=True + # Query should filter the data + result = pv.get_attributes(full_idx, attribute, query=query, skip_post_processor=True) + + # The expected index should be the filtered index based on the query + expected_index = pd.concat([PIE_DF, CUBE_DF], axis=1).query(query).index + assert len(expected_index) < len(full_idx) + + # Assert that the returned data has the reduced index + assert result.index.equals(expected_index) + + # Verify that the pipeline was called with the reduced index, not the full index + pies_and_cubes_pop_mgr._attribute_pipelines[attribute].assert_called_once() # type: ignore[attr-defined] + call_args = pies_and_cubes_pop_mgr._attribute_pipelines[attribute].call_args # type: ignore[attr-defined] + assert call_args[0][0].equals(expected_index) + assert call_args[1] == {"skip_post_processor": True} + + +@pytest.mark.parametrize("register_tracked_query", [True, False]) +@pytest.mark.parametrize("include_untracked", [True, False]) +def test_get_attributes_combined_query( + register_tracked_query: bool, + include_untracked: bool, + update_index: pd.Index[int], + query: str | None, + tracked_query: str, + pies_and_cubes_pop_mgr: PopulationManager, ) -> None: - pv = population_manager.get_view(COL_NAMES) - with pytest.raises(PopulationError): - pv.subview(columns) + """Test that queries provided to the pop view and via get_attributes are combined correctly.""" + + if register_tracked_query: + pies_and_cubes_pop_mgr.register_tracked_query(tracked_query) + kwargs = _resolve_kwargs(include_untracked, query) + combined_query = _combine_queries( + include_untracked, + pies_and_cubes_pop_mgr.tracked_queries, + query, + ) + col_request = PIE_COL_NAMES.copy() + if combined_query and "cube" in combined_query: + col_request += ["cube"] -###################### -# PopulationView.get # -###################### + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + pop = pv.get_attributes(update_index, col_request, **kwargs) + assert isinstance(pop, pd.DataFrame) + expected_pop = _get_expected(update_index, combined_query) + if expected_pop.empty and not update_index.empty: + raise RuntimeError("Bad test setup: expected population is empty.") + if not update_index.empty and combined_query and expected_pop.index.equals(update_index): + raise RuntimeError("Bad test setup: no filtering occurred.") + assert pop.equals(expected_pop) -def test_get(population_manager: PopulationManager) -> None: - ######################## - # Full population view # - ######################## - pv = population_manager.get_view(COL_NAMES) - full_idx = pd.RangeIndex(0, len(RECORDS)) - # Get full data set - pop = pv.get(full_idx) - assert set(pop.columns) == set(COL_NAMES) - assert len(pop) == len(RECORDS) +def test_get_attributes_empty_list(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_index = pd.RangeIndex(0, len(PIE_RECORDS)) + no_attributes = pv.get_attributes(full_index, []) + assert no_attributes.empty + assert no_attributes.index.equals(full_index) - # Get data subset - pop = pv.get(full_idx, query=f"color == 'red'") - assert set(pop.columns) == set(COL_NAMES) - assert len(pop) == len(RECORDS) // len(COLORS) - ############################### - # View without tracked column # - ############################### - cols_without_tracked = COL_NAMES[:-1] - pv = population_manager.get_view(cols_without_tracked) +def test_get_attributes_query_removes_all(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_index = pd.RangeIndex(0, len(PIE_RECORDS)) + empty_pop = pv.get_attributes(full_index, PIE_COL_NAMES, "pi == 'oops'") + assert isinstance(empty_pop, pd.DataFrame) + assert empty_pop.equals(PIE_DF.iloc[0:0][PIE_COL_NAMES]) - # Get all tracked - pop = pv.get(full_idx) - assert set(pop.columns) == set(cols_without_tracked) - assert len(pop) == len(RECORDS) // 2 - # get subset without tracked - pop = pv.get(full_idx, query=f"color == 'red'") - assert set(pop.columns) == set(cols_without_tracked) - assert len(pop) == len(RECORDS) // (2 * len(COLORS)) +class TestGetAttributesReturnTypes: + class SomeComponent(ColumnCreator, AttributePipelineCreator): + """Class that creates multi-level column attributes and private columns.""" + def setup(self, builder: Builder) -> None: + ColumnCreator.setup(self, builder) + AttributePipelineCreator.setup(self, builder) -def test_get_empty_idx(population_manager: PopulationManager) -> None: - pv = population_manager.get_view(COL_NAMES) + @pytest.fixture(scope="class") + def simulation(self) -> InteractiveContext: + return InteractiveContext( + components=[TestGetAttributesReturnTypes.SomeComponent()], setup=True + ) - pop = pv.get(pd.Index([])) - assert isinstance(pop, pd.DataFrame) - assert set(pop.columns) == set(COL_NAMES) - assert pop.empty + @pytest.fixture(scope="class") + def population_view(self, simulation: InteractiveContext) -> PopulationView: + return simulation._population.get_view() + + @pytest.fixture(scope="class") + def index(self, simulation: InteractiveContext) -> pd.Index[int]: + return simulation._population.get_population_index() + + def test_single_level_single_column( + self, population_view: PopulationView, index: pd.Index[int] + ) -> None: + # Single-level, single-column -> series + unsqueezed = population_view.get_attributes(index, ["test_column_1"]) + squeezed = population_view.get_attributes(index, "test_column_1") + assert_squeezing_single_level_single_col(unsqueezed, squeezed) + + def test_single_level_multiple_columns( + self, population_view: PopulationView, index: pd.Index[int] + ) -> None: + # Single-level, multiple-column -> dataframe + # There's no way to request a squeezed dataframe here. + df = population_view.get_attributes(index, ["test_column_1", "test_column_2"]) + assert isinstance(df, pd.DataFrame) + assert not isinstance(df.columns, pd.MultiIndex) + + def test_multi_level_single_outer_single_inner( + self, population_view: PopulationView, index: pd.Index[int] + ) -> None: + # Multi-level, single outer, single inner -> series + unsqueezed = population_view.get_attributes(index, ["attribute_generating_column_8"]) + squeezed = population_view.get_attributes(index, "attribute_generating_column_8") + assert_squeezing_multi_level_single_outer_single_inner(unsqueezed, squeezed) + + def test_single_dataframe_attribute_raises( + self, population_view: PopulationView, index: pd.Index[int] + ) -> None: + with pytest.raises(ValueError, match="Expected a pandas Series to be returned"): + population_view.get_attributes(index, "attribute_generating_columns_4_5") + + def test_multi_level_multiple_outer( + self, population_view: PopulationView, index: pd.Index[int] + ) -> None: + # Multi-level, multiple outer -> full unsqueezed multi-level dataframe + # There's no way to request a squeezed dataframe here. + df = population_view.get_attributes( + index, ["test_column_1", "attribute_generating_columns_6_7"] + ) + assert isinstance(df, pd.DataFrame) + assert isinstance(df.columns, pd.MultiIndex) + @pytest.mark.parametrize( + "attribute", ["test_column_1", "attribute_generating_columns_6_7"] + ) + def test_get_attribute_frame( + self, population_view: PopulationView, index: pd.Index[int], attribute: str + ) -> None: + df = population_view.get_attribute_frame(index, attribute) + assert isinstance(df, pd.DataFrame) + assert not isinstance(df.columns, pd.MultiIndex) + + expected = population_view.get_attributes(index, [attribute]) + assert (df.values == expected.values).all().all() -def test_get_fail(population_manager: PopulationManager) -> None: - bad_pvs = [ - population_manager.get_view(["age", "sex"]), - population_manager.get_view(COL_NAMES + ["age", "sex"]), - population_manager.get_view(["age", "sex", "tracked"]), - population_manager.get_view(["age", "sex"]), - population_manager.get_view(["color", "count", "age"]), - ] - full_idx = pd.RangeIndex(0, len(RECORDS)) +###################################### +# PopulationView.get_private_columns # +###################################### - for pv in bad_pvs: - with pytest.raises(PopulationError, match="not in population table"): - pv.get(full_idx) + +@pytest.mark.parametrize("private_columns", [None, PIE_COL_NAMES[1:]]) +@pytest.mark.parametrize("register_tracked_query", [True, False]) +@pytest.mark.parametrize("include_untracked", [True, False]) +def test_get_private_columns( + private_columns: list[str] | None, + register_tracked_query: bool, + include_untracked: bool, + update_index: pd.Index[int], + query: str | None, + tracked_query: str, + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + if register_tracked_query: + pies_and_cubes_pop_mgr.register_tracked_query(tracked_query) + kwargs = _resolve_kwargs(include_untracked, query) + if private_columns is not None: + kwargs["private_columns"] = private_columns + + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + pop = pv.get_private_columns(update_index, **kwargs) + assert isinstance(pop, pd.DataFrame) + combined_query = _combine_queries( + include_untracked, + pies_and_cubes_pop_mgr.tracked_queries, + query, + ) + expected_pop = _get_expected(update_index, combined_query) + # We need to remove public columns that were used for filtering + if "cube" in expected_pop.columns: + expected_pop.drop("cube", axis=1, inplace=True) + if private_columns: + expected_pop = expected_pop[private_columns] + if expected_pop.empty and not update_index.empty: + raise RuntimeError("Bad test setup: expected population is empty.") + if not update_index.empty and combined_query and expected_pop.index.equals(update_index): + raise RuntimeError("Bad test setup: no filtering occurred.") + assert pop.equals(expected_pop) + + +def test_get_private_columns_raises(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + index = pd.Index([]) + + with pytest.raises( + PopulationError, + match=re.escape( + "is requesting the following private columns to which it does not have access" + ), + ): + pv.get_private_columns(index, private_columns=["pie", "pi", "foo"]) + + pv._component = None + with pytest.raises( + PopulationError, + match="This PopulationView is read-only, so it doesn't have access to get_private_columns().", + ): + pv.get_private_columns(index) + + +def test_get_private_columns_empty_list(pies_and_cubes_pop_mgr: PopulationManager) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_index = pd.RangeIndex(0, len(PIE_RECORDS)) + no_attributes = pv.get_private_columns(full_index, []) + assert isinstance(no_attributes, pd.DataFrame) + assert no_attributes.empty + assert no_attributes.index.equals(full_index) + assert no_attributes.equals(pd.DataFrame(index=full_index)) + + apples = pv.get_private_columns(full_index, [], query="pie == 'apple'") + assert isinstance(apples, pd.DataFrame) + apple_index = PIE_DF[PIE_DF["pie"] == "apple"].index + assert apples.equals(pd.DataFrame(index=apple_index)) + + +def test_get_private_columns_query_removes_all( + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + full_index = pd.RangeIndex(0, len(PIE_RECORDS)) + empty_pop = pv.get_private_columns(full_index, query="pi == 'oops'") + assert isinstance(empty_pop, pd.DataFrame) + assert empty_pop.equals(PIE_DF.iloc[0:0][PIE_COL_NAMES]) + + +def test_get_private_columns_squeezing() -> None: + + # Single-level, single-column -> series + single_col_creator = SingleColumnCreator() + sim = InteractiveContext(components=[single_col_creator], setup=True) + pv = sim._population.get_view(single_col_creator) + index = sim._population.get_population_index() + unsqueezed = pv.get_private_columns(index, ["test_column_1"]) + squeezed = pv.get_private_columns(index, "test_column_1") + assert_squeezing_single_level_single_col(unsqueezed, squeezed) + default = pv.get_private_columns(index) + assert isinstance(default, pd.Series) and isinstance(squeezed, pd.Series) + assert default.equals(squeezed) + + # Single-level, multiple-column -> dataframe + col_creator = ColumnCreator() + sim = InteractiveContext(components=[col_creator], setup=True) + pv = sim._population.get_view(col_creator) + index = sim._population.get_population_index() + # There's no way to squeeze here. + df = pv.get_private_columns(index, ["test_column_1", "test_column_2", "test_column_3"]) + assert isinstance(df, pd.DataFrame) + assert not isinstance(df.columns, pd.MultiIndex) + default = pv.get_private_columns(index) + assert isinstance(default, pd.DataFrame) + assert default.equals(df) + + +##################################### +# PopulationView.get_filtered_index # +##################################### + + +@pytest.mark.parametrize("register_tracked_query", [True, False]) +@pytest.mark.parametrize("include_untracked", [True, False]) +def test_get_filtered_index( + register_tracked_query: bool, + include_untracked: bool, + update_index: pd.Index[int], + query: str | None, + tracked_query: str, + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + if register_tracked_query: + pies_and_cubes_pop_mgr.register_tracked_query(tracked_query) + kwargs = _resolve_kwargs(include_untracked, query) + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + pop_idx = pv.get_filtered_index(update_index, **kwargs) + + combined_query = _combine_queries( + include_untracked, + pies_and_cubes_pop_mgr.tracked_queries, + query, + ) + expected_pop = _get_expected(update_index, combined_query) + if expected_pop.empty and not update_index.empty: + raise RuntimeError("Bad test setup: expected population is empty.") + if not update_index.empty and combined_query and expected_pop.index.equals(update_index): + raise RuntimeError("Bad test setup: no filtering occurred.") + assert pop_idx.equals(expected_pop.index) ################################# @@ -265,37 +524,35 @@ def test_full_population_view__coerce_to_dataframe( cols = [population_update.name] else: cols = list(population_update.columns) - coerced_df = PopulationView._coerce_to_dataframe(population_update, COL_NAMES) - assert BASE_POPULATION.loc[update_index, cols].equals(coerced_df) + coerced_df = PopulationView._coerce_to_dataframe(population_update, PIE_COL_NAMES) + assert PIE_DF.loc[update_index, cols].equals(coerced_df) def test_full_population_view__coerce_to_dataframe_fail( population_update_new_cols: pd.Series[Any] | pd.DataFrame, ) -> None: with pytest.raises(TypeError, match="must be a pandas Series or DataFrame"): - PopulationView._coerce_to_dataframe(BASE_POPULATION.iloc[:, 0].tolist(), COL_NAMES) # type: ignore[arg-type] + PopulationView._coerce_to_dataframe(PIE_DF.iloc[:, 0].tolist(), PIE_COL_NAMES) # type: ignore[arg-type] with pytest.raises(PopulationError, match="unnamed pandas series"): PopulationView._coerce_to_dataframe( - BASE_POPULATION.iloc[:, 0].rename(None), - COL_NAMES, + PIE_DF.iloc[:, 0].rename(None), + PIE_COL_NAMES, ) with pytest.raises(PopulationError, match="extra columns"): - PopulationView._coerce_to_dataframe(population_update_new_cols, COL_NAMES) + PopulationView._coerce_to_dataframe(population_update_new_cols, PIE_COL_NAMES) with pytest.raises(PopulationError, match="no columns"): - PopulationView._coerce_to_dataframe( - pd.DataFrame(index=BASE_POPULATION.index), COL_NAMES - ) + PopulationView._coerce_to_dataframe(pd.DataFrame(index=PIE_DF.index), PIE_COL_NAMES) def test_single_column_population_view__coerce_to_dataframe( update_index: pd.Index[int], ) -> None: - column = COL_NAMES[0] - update = BASE_POPULATION.loc[update_index].copy() - output = BASE_POPULATION.loc[update_index, [column]] + column = PIE_COL_NAMES[0] + update = PIE_DF.loc[update_index].copy() + output = PIE_DF.loc[update_index, [column]] passing_cases = [ update[[column]], # Single col df @@ -313,80 +570,18 @@ def test_single_column_population_view__coerce_to_dataframe_fail( ) -> None: with pytest.raises(TypeError, match="must be a pandas Series or DataFrame"): PopulationView._coerce_to_dataframe( - BASE_POPULATION.iloc[:, 0].tolist(), [COL_NAMES[0]] # type: ignore[arg-type] + PIE_DF.iloc[:, 0].tolist(), [PIE_COL_NAMES[0]] # type: ignore[arg-type] ) with pytest.raises(PopulationError, match="extra columns"): - PopulationView._coerce_to_dataframe(population_update_new_cols, [COL_NAMES[0]]) + PopulationView._coerce_to_dataframe(population_update_new_cols, [PIE_COL_NAMES[0]]) with pytest.raises(PopulationError, match="no columns"): PopulationView._coerce_to_dataframe( - pd.DataFrame(index=BASE_POPULATION.index), [COL_NAMES[0]] - ) - - -################################# -# PopulationView.update helpers # -################################################## -# PopulationView._ensure_coherent_initialization # -################################################## - - -def test__ensure_coherent_initialization_no_new_columns( - population_update: pd.Series[Any] | pd.DataFrame, - update_index: pd.Index[int], -) -> None: - if isinstance(population_update, pd.Series): - pytest.skip() - - # Missing population - if not update_index.empty: - with pytest.raises(PopulationError, match="missing updates"): - PopulationView._ensure_coherent_initialization( - population_update.loc[update_index[::2]], BASE_POPULATION.loc[update_index] - ) - - # No new columns - with pytest.raises(PopulationError, match="all provided columns"): - PopulationView._ensure_coherent_initialization( - population_update, - BASE_POPULATION.loc[update_index], + pd.DataFrame(index=PIE_DF.index), [PIE_COL_NAMES[0]] ) -def test__ensure_coherent_initialization_new_columns( - population_update_new_cols: pd.Series[Any] | pd.DataFrame, - update_index: pd.Index[int], -) -> None: - if isinstance(population_update_new_cols, pd.Series): - pytest.skip() - - # All new cols, should be good - PopulationView._ensure_coherent_initialization( - population_update_new_cols, - BASE_POPULATION.loc[update_index], - ) - - # Missing rows - if not update_index.equals(BASE_POPULATION.index): - with pytest.raises(PopulationError, match="missing updates"): - PopulationView._ensure_coherent_initialization( - population_update_new_cols, - BASE_POPULATION, - ) - - # Conflicting data in existing cols. - cols_overlap = [c for c in population_update_new_cols if c in COL_NAMES] - if not update_index.empty and cols_overlap: - update = population_update_new_cols.copy() - update[COL_NAMES[0]] = "bad_values" - with pytest.raises(PopulationError, match="conflicting"): - PopulationView._ensure_coherent_initialization( - update, - BASE_POPULATION.loc[update_index], - ) - - ################################# # PopulationView.update helpers # ######################################################### @@ -397,18 +592,20 @@ def test__ensure_coherent_initialization_new_columns( def test__format_update_and_check_preconditions_bad_args() -> None: with pytest.raises(AssertionError): PopulationView._format_update_and_check_preconditions( - BASE_POPULATION, - BASE_POPULATION, - COL_NAMES, + "foo", + PIE_DF, + PIE_DF, + PIE_COL_NAMES, creating_initial_population=True, adding_simulants=False, ) with pytest.raises(TypeError, match="must be a pandas Series or DataFrame"): PopulationView._format_update_and_check_preconditions( - BASE_POPULATION.iloc[:, 0].tolist(), # type: ignore[arg-type] - BASE_POPULATION, - COL_NAMES, + "foo", + PIE_DF.iloc[:, 0].tolist(), # type: ignore[arg-type] + PIE_DF, + PIE_COL_NAMES, True, True, ) @@ -419,18 +616,20 @@ def test__format_update_and_check_preconditions_coerce_failures( ) -> None: with pytest.raises(PopulationError, match="unnamed pandas series"): PopulationView._format_update_and_check_preconditions( - BASE_POPULATION.iloc[:, 0].rename(None), - BASE_POPULATION, - COL_NAMES, + "foo", + PIE_DF.iloc[:, 0].rename(None), + PIE_DF, + PIE_COL_NAMES, True, True, ) - for view_cols in [COL_NAMES, [COL_NAMES[0]]]: + for view_cols in [PIE_COL_NAMES, [PIE_COL_NAMES[0]]]: with pytest.raises(PopulationError, match="extra columns"): PopulationView._format_update_and_check_preconditions( + "foo", population_update_new_cols, - BASE_POPULATION, + PIE_DF, view_cols, True, True, @@ -438,8 +637,9 @@ def test__format_update_and_check_preconditions_coerce_failures( with pytest.raises(PopulationError, match="no columns"): PopulationView._format_update_and_check_preconditions( - pd.DataFrame(index=BASE_POPULATION.index), - BASE_POPULATION, + "foo", + pd.DataFrame(index=PIE_DF.index), + PIE_DF, view_cols, True, True, @@ -455,11 +655,15 @@ def test__format_update_and_check_preconditions_unknown_pop_fail( update = population_update.copy() update.index += 2 * update.index.max() - with pytest.raises(PopulationError, match=f"{len(update)} simulants"): + with pytest.raises( + PopulationError, + match="Population updates must have an index that is a subset of the current private data.", + ): PopulationView._format_update_and_check_preconditions( + "foo", update, - BASE_POPULATION, - COL_NAMES, + PIE_DF, + PIE_COL_NAMES, True, True, ) @@ -471,104 +675,48 @@ def test__format_update_and_check_preconditions_coherent_initialization_fail( ) -> None: # Missing population if not update_index.empty: - with pytest.raises(PopulationError, match="missing updates"): + with pytest.raises(PopulationError, match="Component 'foo' is missing updates"): PopulationView._format_update_and_check_preconditions( + "foo", population_update.loc[update_index[::2]], - BASE_POPULATION.loc[update_index], - COL_NAMES, + PIE_DF.loc[update_index], + PIE_COL_NAMES, True, True, ) - # No new columns - with pytest.raises(PopulationError, match="all provided columns"): - PopulationView._format_update_and_check_preconditions( - population_update, - BASE_POPULATION.loc[update_index], - COL_NAMES, - True, - True, - ) - def test__format_update_and_check_preconditions_coherent_initialization_fail_new_cols( population_update_new_cols: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: - if not update_index.equals(BASE_POPULATION.index): - with pytest.raises(PopulationError, match="missing updates"): + if not update_index.equals(PIE_DF.index): + with pytest.raises(PopulationError, match="Component 'foo' is missing updates"): PopulationView._format_update_and_check_preconditions( + "foo", population_update_new_cols, - BASE_POPULATION, - COL_NAMES + NEW_COL_NAMES, - True, - True, - ) - - # Conflicting data in existing cols. - cols_overlap = [c for c in population_update_new_cols if c in COL_NAMES] - if not update_index.empty and cols_overlap: - update = population_update_new_cols.copy() - update[COL_NAMES[0]] = "bad_values" - with pytest.raises(PopulationError, match="conflicting"): - PopulationView._format_update_and_check_preconditions( - update, - BASE_POPULATION.loc[update_index], - COL_NAMES + NEW_COL_NAMES, + PIE_DF, + PIE_COL_NAMES + CUBE_COL_NAMES, True, True, ) -def test__format_update_and_check_preconditions_new_columns_non_init( - population_update_new_cols: pd.Series[Any] | pd.DataFrame, - update_index: pd.Index[int], -) -> None: - for adding_simulants in [True, False]: - with pytest.raises(PopulationError, match="outside the initial population creation"): - PopulationView._format_update_and_check_preconditions( - population_update_new_cols, - BASE_POPULATION.loc[update_index], - COL_NAMES + NEW_COL_NAMES, - False, - adding_simulants, - ) - - -def test__format_update_and_check_preconditions_conflicting_non_init( - population_update: pd.Series[Any] | pd.DataFrame, - update_index: pd.Index[int], -) -> None: - update = population_update.copy() - if isinstance(update, pd.Series): - update[:] = "bad_value" - else: - update.loc[:, COL_NAMES[0]] = "bad_value" - if not update_index.empty: - with pytest.raises(PopulationError, match="conflicting"): - PopulationView._format_update_and_check_preconditions( - update, - BASE_POPULATION.loc[update_index], - COL_NAMES + NEW_COL_NAMES, - False, - True, - ) - - def test__format_update_and_check_preconditions_init_pass( population_update_new_cols: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: result = PopulationView._format_update_and_check_preconditions( + "foo", population_update_new_cols, - BASE_POPULATION.loc[update_index], - COL_NAMES + NEW_COL_NAMES, + PIE_DF.loc[update_index], + PIE_COL_NAMES + CUBE_COL_NAMES, True, True, ) update = PopulationView._coerce_to_dataframe( population_update_new_cols, - COL_NAMES + NEW_COL_NAMES, + PIE_COL_NAMES + CUBE_COL_NAMES, ) assert set(result.columns) == set(update) @@ -580,17 +728,18 @@ def test__format_update_and_check_preconditions_add_pass( population_update: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: - state_table = BASE_POPULATION.drop(update_index).reindex(BASE_POPULATION.index) + state_table = PIE_DF.drop(update_index).reindex(PIE_DF.index) result = PopulationView._format_update_and_check_preconditions( + "foo", population_update, state_table, - COL_NAMES + NEW_COL_NAMES, + PIE_COL_NAMES + CUBE_COL_NAMES, False, True, ) update = PopulationView._coerce_to_dataframe( population_update, - COL_NAMES + NEW_COL_NAMES, + PIE_COL_NAMES + CUBE_COL_NAMES, ) assert set(result.columns) == set(update) @@ -602,15 +751,16 @@ def test__format_update_and_check_preconditions_time_step_pass( population_update: pd.Series[Any] | pd.DataFrame, ) -> None: result = PopulationView._format_update_and_check_preconditions( + "foo", population_update, - BASE_POPULATION, - COL_NAMES + NEW_COL_NAMES, + PIE_DF, + PIE_COL_NAMES + CUBE_COL_NAMES, False, False, ) update = PopulationView._coerce_to_dataframe( population_update, - COL_NAMES + NEW_COL_NAMES, + PIE_COL_NAMES + CUBE_COL_NAMES, ) assert set(result.columns) == set(update) @@ -622,15 +772,16 @@ def test__format_update_and_check_preconditions_adding_simulants_replace_identic population_update: pd.Series[Any] | pd.DataFrame, ) -> None: result = PopulationView._format_update_and_check_preconditions( + "foo", population_update, - BASE_POPULATION, - COL_NAMES + NEW_COL_NAMES, + PIE_DF, + PIE_COL_NAMES + CUBE_COL_NAMES, False, True, ) update = PopulationView._coerce_to_dataframe( population_update, - COL_NAMES + NEW_COL_NAMES, + PIE_COL_NAMES + CUBE_COL_NAMES, ) assert set(result.columns) == set(update) @@ -649,14 +800,15 @@ def test__update_column_and_ensure_dtype() -> None: random.seed("test__update_column_and_ensure_dtype") for adding_simulants in [True, False]: - for update_index in [BASE_POPULATION.index, BASE_POPULATION.index[::2]]: - for col in BASE_POPULATION: + # Test full and partial column updates + for update_index in [PIE_DF.index, PIE_DF.index[::2]]: + for col in PIE_DF: update = pd.Series( - random.sample(BASE_POPULATION[col].tolist(), k=len(update_index)), + random.sample(PIE_DF[col].tolist(), k=len(update_index)), index=update_index, name=col, ) - existing = BASE_POPULATION[col].copy() + existing = PIE_DF[col].copy() new_values = PopulationView._update_column_and_ensure_dtype( update, @@ -674,16 +826,16 @@ def test__update_column_and_ensure_dtype() -> None: def test__update_column_and_ensure_dtype_unmatched_dtype() -> None: # This tests a very specific failure case as the code is # not robust to general dtype silliness. - update_index = BASE_POPULATION.index - col = "count" + update_index = PIE_DF.index + col = "pi" update = pd.Series( - random.sample(BASE_POPULATION[col].tolist(), k=len(update_index)), + random.sample(PIE_DF[col].tolist(), k=len(update_index)), index=update_index, name=col, ) - existing = BASE_POPULATION[col].copy() - # Count is an int, this coerces it to a float since there's no null type for ints. - existing[:] = None + existing = PIE_DF[col].copy() + # change the type + existing = existing.astype(str) # Should work fine when we're adding simulants new_values = PopulationView._update_column_and_ensure_dtype( @@ -694,7 +846,10 @@ def test__update_column_and_ensure_dtype_unmatched_dtype() -> None: assert new_values.loc[update_index].equals(update) # And be bad news otherwise. - with pytest.raises(PopulationError, match="corrupting"): + with pytest.raises( + PopulationError, + match="A component is corrupting the population table by modifying the dtype", + ): PopulationView._update_column_and_ensure_dtype( update, existing, @@ -702,153 +857,196 @@ def test__update_column_and_ensure_dtype_unmatched_dtype() -> None: ) +################################# +# PopulationView.update helpers # +################################################## +# PopulationView._build_query # +################################################## + + +def test__skip_tracked_query_if_initializing( + pies_and_cubes_pop_mgr: PopulationManager, +) -> None: + pies_and_cubes_pop_mgr.tracked_queries = ["one == 1"] + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + # lifecycle_states is not directly iterable so just look for constants manually + states = [state for state in dir(lifecycle_states) if not state.startswith("_")] + for state in states: + query = pv._build_query("", include_untracked=False) + if state in [lifecycle_states.INITIALIZATION, lifecycle_states.POPULATION_CREATION]: + # We DO include the untracked people here so make sure the query does + # NOT include the tracking query + assert query == "" + else: + assert query == "(one == 1)" + + ######################### # PopulationView.update # ######################### def test_population_view_update_format_fail( - population_manager: PopulationManager, + pies_and_cubes_pop_mgr: PopulationManager, population_update: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: - pv = population_manager.get_view(COL_NAMES) - - population_manager.creating_initial_population = True - population_manager.adding_simulants = True + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) + pies_and_cubes_pop_mgr.creating_initial_population = True + pies_and_cubes_pop_mgr.adding_simulants = True # Bad type with pytest.raises(TypeError): - pv.update(BASE_POPULATION.iloc[:, 0].tolist()) # type: ignore[arg-type] + pv.update(PIE_DF.iloc[:, 0].tolist()) # type: ignore[arg-type] # Unknown population index if not update_index.empty: update = population_update.copy() update.index += 2 * update.index.max() - with pytest.raises(PopulationError, match=f"{len(update)} simulants"): + with pytest.raises( + PopulationError, + match=f"{len(update)} simulants were provided in an update with no matching index in the existing table", + ): pv.update(update) - # Missing population - population_manager._population = BASE_POPULATION.loc[update_index] + # Missing an update + pies_and_cubes_pop_mgr._private_columns = PIE_DF.loc[update_index] if not update_index.empty: - with pytest.raises(PopulationError, match="missing updates"): + with pytest.raises( + PopulationError, match="Component 'pie_component' is missing updates for" + ): pv.update(population_update.loc[update_index[::2]]) - # No new columns - with pytest.raises(PopulationError, match="all provided columns"): - pv.update(population_update) - - population_manager.creating_initial_population = False - update = population_update.copy() - if isinstance(update, pd.Series): - update[:] = "bad_value" - else: - update.loc[:, COL_NAMES[0]] = "bad_value" - if not update_index.empty: - with pytest.raises(PopulationError, match="conflicting"): - pv.update(update) - def test_population_view_update_format_fail_new_cols( - population_manager: PopulationManager, + pies_and_cubes_pop_mgr: PopulationManager, population_update_new_cols: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: - pv = population_manager.get_view(COL_NAMES) - population_manager.creating_initial_population = True - population_manager.adding_simulants = True + pv_pies = pies_and_cubes_pop_mgr.get_view(PieComponent()) - with pytest.raises(PopulationError, match="unnamed pandas series"): - pv.update(BASE_POPULATION.iloc[:, 0].rename(None)) + pies_and_cubes_pop_mgr.creating_initial_population = True + pies_and_cubes_pop_mgr.adding_simulants = True - for view_cols in [COL_NAMES, [COL_NAMES[0]]]: - pv = population_manager.get_view(view_cols) + with pytest.raises(PopulationError, match="unnamed pandas series"): + pv_pies.update(PIE_DF.iloc[:, 0].rename(None)) - with pytest.raises(PopulationError, match="extra columns"): - pv.update(population_update_new_cols) + with pytest.raises(PopulationError, match="extra columns"): + pv_pies.update(population_update_new_cols) - with pytest.raises(PopulationError, match="no columns"): - pv.update(pd.DataFrame(index=BASE_POPULATION.index)) + with pytest.raises(PopulationError, match="no columns"): + pv_pies.update(pd.DataFrame(index=PIE_DF.index)) - pv = population_manager.get_view(COL_NAMES + NEW_COL_NAMES) - if not update_index.equals(BASE_POPULATION.index): + pv_cubes = pies_and_cubes_pop_mgr.get_view(CubeComponent()) + if not update_index.equals(CUBE_DF.index): with pytest.raises(PopulationError, match="missing updates"): - pv.update(population_update_new_cols) - - # Conflicting data in existing cols. - population_manager._population = BASE_POPULATION.loc[update_index] - cols_overlap = [c for c in population_update_new_cols if c in COL_NAMES] - if not update_index.empty and cols_overlap: - update = population_update_new_cols.copy() - update[COL_NAMES[0]] = "bad_values" - with pytest.raises(PopulationError, match="conflicting"): - pv.update(update) - - population_manager.creating_initial_population = False - for adding_simulants in [True, False]: - population_manager.adding_simulants = adding_simulants - with pytest.raises(PopulationError, match="outside the initial population creation"): - pv.update(population_update_new_cols) + pv_cubes.update(population_update_new_cols) def test_population_view_update_init( - population_manager: PopulationManager, + pies_and_cubes_pop_mgr: PopulationManager, population_update_new_cols: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: if isinstance(population_update_new_cols, pd.Series): pytest.skip() - pv = population_manager.get_view(COL_NAMES + NEW_COL_NAMES) + # Remove the cubes backing data to test that initialization works + pies_and_cubes_pop_mgr._private_columns = PIE_DF.loc[update_index] + + pv = pies_and_cubes_pop_mgr.get_view(CubeComponent()) + + pies_and_cubes_pop_mgr.creating_initial_population = True + pies_and_cubes_pop_mgr.adding_simulants = True - population_manager._population = BASE_POPULATION.loc[update_index] - population_manager.creating_initial_population = True - population_manager.adding_simulants = True pv.update(population_update_new_cols) for col in population_update_new_cols: - assert population_manager._population[col].equals(population_update_new_cols[col]) + assert pies_and_cubes_pop_mgr._private_columns[col].equals( + population_update_new_cols[col] + ) def test_population_view_update_add( - population_manager: PopulationManager, + pies_and_cubes_pop_mgr: PopulationManager, population_update: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: if isinstance(population_update, pd.Series): pytest.skip() - pv = population_manager.get_view(COL_NAMES + NEW_COL_NAMES) - - population_manager._population = BASE_POPULATION.loc[update_index] + pv_pies = pies_and_cubes_pop_mgr.get_view(PieComponent()) + pies_and_cubes_pop_mgr._private_columns = PIE_DF.loc[update_index] for col in population_update: - population_manager._population[col] = None - population_manager.creating_initial_population = False - population_manager.adding_simulants = True - pv.update(population_update) + pies_and_cubes_pop_mgr._private_columns[col] = None + pies_and_cubes_pop_mgr.creating_initial_population = False + pies_and_cubes_pop_mgr.adding_simulants = True + pv_pies.update(population_update) for col in population_update: if update_index.empty: - assert population_manager._population[col].empty + assert pies_and_cubes_pop_mgr._private_columns[col].empty else: - assert population_manager._population[col].equals(population_update[col]) + assert pies_and_cubes_pop_mgr._private_columns[col].equals(population_update[col]) def test_population_view_update_time_step( - population_manager: PopulationManager, + pies_and_cubes_pop_mgr: PopulationManager, population_update: pd.Series[Any] | pd.DataFrame, update_index: pd.Index[int], ) -> None: if isinstance(population_update, pd.Series): pytest.skip() - pv = population_manager.get_view(COL_NAMES + NEW_COL_NAMES) + pv = pies_and_cubes_pop_mgr.get_view(PieComponent()) - population_manager.creating_initial_population = False - population_manager.adding_simulants = False + pies_and_cubes_pop_mgr.creating_initial_population = False + pies_and_cubes_pop_mgr.adding_simulants = False pv.update(population_update) for col in population_update.columns: - pop = population_manager._population + pop = pies_and_cubes_pop_mgr._private_columns assert pop is not None assert pop.loc[update_index, col].equals(population_update[col]) + + +#################### +# Helper functions # +#################### + + +def _resolve_kwargs( + include_untracked: bool, + query: str | None, +) -> dict[str, Any]: + kwargs: dict[str, bool | str | list[str]] = {} + kwargs["include_untracked"] = include_untracked + if query is not None: + kwargs["query"] = query + + return kwargs + + +def _get_expected(update_index: pd.Index[int], combined_query: str | None) -> pd.DataFrame: + expected_pop = PIE_DF.loc[update_index] + if combined_query: + if "cube" in combined_query: + expected_pop = pd.concat( + [expected_pop, CUBE_DF.loc[update_index, "cube"]], axis=1 + ) + expected_pop = expected_pop.query(combined_query) + return expected_pop + + +def _combine_queries( + include_untracked: bool, + tracked_queries: list[str], + query: str | None, +) -> str: + combined_query_parts = [] + if not include_untracked and tracked_queries: + combined_query_parts += tracked_queries + if query is not None: + combined_query_parts.append(f"{query}") + combined_query = " and ".join(combined_query_parts) + return combined_query diff --git a/tests/framework/population/test_utilities.py b/tests/framework/population/test_utilities.py new file mode 100644 index 000000000..39e5fc9e2 --- /dev/null +++ b/tests/framework/population/test_utilities.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import pytest + +from vivarium.framework.population.utilities import ( + combine_queries, + extract_columns_from_query, +) + + +@pytest.mark.parametrize( + "query, expected", + [ + ("", set()), + ( + ( + # Basic + "is_alive == True and is_aged_out == False and size == 'large' and " + # No spaces + "answer==42 or " + "answer_str=='forty-two' or " + "43!=correct_answer and " + "duck!=goose or " + # Mixed operators and casing + "(10 < age < 20 OR sex == 'Female') IF tiger == 'hobbes' AND " + # Column-column comparisons and @constants + "(bar >= baz if @some_const <= 100) or " + # Names w/ 'and', 'or', or 'if' + "band == xplor or iffy == 'sketchy' and " + # Underscores + "some_col == True or " + "test_column_1 != 5 and " + # Casing + "Foo != Bar and " + # Special names requiring backticks + "`spaced column` == False or " + "`???` != 'unknown' and " + "`column(1)` < 50 or " + "`column[2]` < 50 or " + "`column{3}` < 50 or " + # Quotes + "`\"quz\"` == 'value' or " + 'nothing != "something" and ' + # in logic + "color in ['red', 'blue', 'green'] or " + "shape not in ['circle', 'square']" + ), + { + "is_alive", + "is_aged_out", + "size", + "answer", + "answer_str", + "correct_answer", + "duck", + "goose", + "age", + "sex", + "tiger", + "bar", + "baz", + "band", + "xplor", + "iffy", + "some_col", + "test_column_1", + "Foo", + "Bar", + "spaced column", + "???", + "column(1)", + "column[2]", + "column{3}", + '"quz"', + "nothing", + "color", + "shape", + }, + ), + ], +) +def test_extract_columns_from_query(query: str, expected: set[str]) -> None: + query_columns = extract_columns_from_query(query) + assert query_columns == expected + + +@pytest.mark.parametrize( + "queries, expected_query", + [ + [("", ""), ""], + [("is_alive == True", "age < 5"), "(is_alive == True) and (age < 5)"], + [ + ("is_alive == Trie or is_aged_out == False", "age < 5", "sex == 'Female'"), + "(is_alive == Trie or is_aged_out == False) and (age < 5) and (sex == 'Female')", + ], + ], +) +def test_combine_queries(queries: tuple[str, ...], expected_query: str) -> None: + combined = combine_queries(*queries) + assert combined == expected_query diff --git a/tests/framework/randomness/test_crn.py b/tests/framework/randomness/test_crn.py index 5cbf309cc..613e30005 100644 --- a/tests/framework/randomness/test_crn.py +++ b/tests/framework/randomness/test_crn.py @@ -12,6 +12,7 @@ import pytest from pandas.testing import assert_frame_equal +from tests.helpers import ColumnCreator from vivarium import Component from vivarium.framework.engine import Builder from vivarium.framework.event import Event @@ -34,6 +35,7 @@ def test_basic_repeatability(initializes_crn_attributes: bool) -> None: "clock": Callable[[], pd.Timestamp], "seed": str, "index_map": IndexMap, + "component": Component, "initializes_crn_attributes": bool, }, total=False, @@ -44,6 +46,7 @@ def test_basic_repeatability(initializes_crn_attributes: bool) -> None: "clock": lambda: pd.Timestamp("2020-01-01"), "seed": "abc", "index_map": index_map, + "component": ColumnCreator(), "initializes_crn_attributes": initializes_crn_attributes, } @@ -79,10 +82,6 @@ class BasePopulation(Component): def name(self) -> str: return "population" - @property - def columns_created(self) -> list[str]: - return ["crn_attr1", "crn_attr2", "other_attr1"] - def __init__(self, with_crn: bool, sims_to_add: Iterator[int] = cycle([0])) -> None: """ Parameters @@ -101,15 +100,11 @@ def __init__(self, with_crn: bool, sims_to_add: Iterator[int] = cycle([0])) -> N def setup(self, builder: Builder) -> None: self.register = builder.randomness.register_simulants self.randomness_init = builder.randomness.get_stream( - "crn_init", - initializes_crn_attributes=self.with_crn, + "crn_init", initializes_crn_attributes=self.with_crn ) self.randomness_other = builder.randomness.get_stream("other") self.simulant_creator = builder.population.get_simulant_creator() - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - pass - def on_time_step(self, event: Event) -> None: sims_to_add = next(self.sims_to_add) if sims_to_add > 0: @@ -119,7 +114,20 @@ def on_time_step(self, event: Event) -> None: class EntranceTimePopulation(BasePopulation): """Population that bases identity on entrance time and a random number""" - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def setup(self, builder: Builder) -> None: + super().setup(builder) + builder.population.register_initializer( + initializer=self.register_crn_attributes, + columns=["crn_attr1", "crn_attr2"], + required_resources=[self.randomness_init], + ) + builder.population.register_initializer( + initializer=self.register_other_attribute, + columns="other_attr1", + required_resources=[self.randomness_other], + ) + + def register_crn_attributes(self, pop_data: SimulantData) -> None: crn_attr = (1_000_000 * self.randomness_init.get_draw(index=pop_data.index)).astype( int ) @@ -131,12 +139,13 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: if self.with_crn: self.register(population) - population["other_attr1"] = self.randomness_other.get_draw( - pop_data.index, - additional_key="attr1", - ) self.population_view.update(population) + def register_other_attribute(self, pop_data: SimulantData) -> None: + attr1 = self.randomness_other.get_draw(pop_data.index, additional_key="attr1") + attr1.name = "other_attr1" + self.population_view.update(attr1) + class SequentialPopulation(BasePopulation): """ @@ -149,8 +158,14 @@ class SequentialPopulation(BasePopulation): def setup(self, builder: Builder) -> None: super().setup(builder) self.count = 0 + builder.population.register_initializer( + initializer=self.register_crn_attributes, columns=["crn_attr1", "crn_attr2"] + ) + builder.population.register_initializer( + initializer=self.register_other_attribute, columns="other_attr1" + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def register_crn_attributes(self, pop_data: SimulantData) -> None: new_people = len(pop_data.index) population = pd.DataFrame( @@ -164,13 +179,14 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: if self.with_crn: self.register(population) - population["other_attr1"] = self.randomness_other.get_draw( - pop_data.index, - additional_key="attr1", - ) self.population_view.update(population) self.count += new_people + def register_other_attribute(self, pop_data: SimulantData) -> None: + attr1 = self.randomness_other.get_draw(pop_data.index, additional_key="attr1") + attr1.name = "other_attr1" + self.population_view.update(attr1) + @pytest.mark.parametrize( "pop_class, with_crn, sims_to_add", @@ -254,8 +270,8 @@ def test_multi_sim_reproducibility_with_different_pop_growth( sim1.step() sim2.step() - pop1 = sim1.get_population().set_index(["crn_attr1", "crn_attr2"]).drop(columns="tracked") - pop2 = sim2.get_population().set_index(["crn_attr1", "crn_attr2"]).drop(columns="tracked") + pop1 = sim1.get_population().set_index(["crn_attr1", "crn_attr2"]) + pop2 = sim2.get_population().set_index(["crn_attr1", "crn_attr2"]) if with_crn: overlap = pop1.index.intersection(pop2.index) @@ -273,7 +289,20 @@ class UnBrokenPopulation(BasePopulation): This is now a regression testing class. """ - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def setup(self, builder: Builder) -> None: + super().setup(builder) + builder.population.register_initializer( + initializer=self.register_crn_attributes, + columns=["crn_attr1", "crn_attr2"], + required_resources=[self.randomness_init], + ) + builder.population.register_initializer( + initializer=self.register_other_attribute, + columns="other_attr1", + required_resources=[self.randomness_other], + ) + + def register_crn_attributes(self, pop_data: SimulantData) -> None: crn_attr = (1_000_000 * self.randomness_init.get_draw(index=pop_data.index)).astype( int ) @@ -285,12 +314,13 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: if self.with_crn: self.register(population) - population["other_attr1"] = self.randomness_other.get_draw( - pop_data.index, - additional_key="attr1", - ) self.population_view.update(population) + def register_other_attribute(self, pop_data: SimulantData) -> None: + attr1 = self.randomness_other.get_draw(pop_data.index, additional_key="attr1") + attr1.name = "other_attr1" + self.population_view.update(attr1) + @pytest.mark.parametrize( "with_crn, sims_to_add", diff --git a/tests/framework/randomness/test_manager.py b/tests/framework/randomness/test_manager.py index e10a40c25..399902f3c 100644 --- a/tests/framework/randomness/test_manager.py +++ b/tests/framework/randomness/test_manager.py @@ -1,14 +1,83 @@ +from typing import Literal + import pandas as pd import pytest from layered_config_tree import LayeredConfigTree +from pytest_mock import MockerFixture -from tests.helpers import ColumnCreator, ColumnRequirer -from vivarium import InteractiveContext +from tests.helpers import ColumnCreator +from vivarium import Component, InteractiveContext +from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.randomness.manager import RandomnessError, RandomnessManager from vivarium.framework.randomness.stream import get_hash +@pytest.mark.parametrize("initializes_crn_attributes", [True, False]) +def test_get_randomness_stream_calls_methods_correctly( + mocker: MockerFixture, initializes_crn_attributes: bool +) -> None: + """Test that get_randomness_stream orchestrates calls to helper methods correctly.""" + # Setup + manager = RandomnessManager() + test_component = Component() + test_decision_point = "test_decision" + test_rate_conversion: Literal["linear", "exponential"] = "linear" + + # Set up a mock RandomnessStream + mock_stream = mocker.Mock() + mock_stream.get_draw = mocker.Mock() + mock_stream.filter_for_probability = mocker.Mock() + mock_stream.filter_for_rate = mocker.Mock() + mock_stream.choice = mocker.Mock() + + # Inject mocks into the manager + manager._get_current_component = mocker.Mock(return_value=test_component) + manager._get_randomness_stream = mocker.Mock(return_value=mock_stream) # type: ignore[method-assign] + manager._add_resources = mocker.Mock() + manager._add_constraint = mocker.Mock() + manager._key_columns = ["age", "sex"] + + # Execute + result = manager.get_randomness_stream( + test_decision_point, initializes_crn_attributes, test_rate_conversion + ) + + # Assert _get_randomness_stream was called with correct arguments + manager._get_randomness_stream.assert_called_once_with( # type: ignore[attr-defined] + test_decision_point, + test_component, + initializes_crn_attributes, + test_rate_conversion, + ) + + # Assert _add_resources was called with correct arguments + expected_required_resources = [] if initializes_crn_attributes else ["age", "sex"] + manager._add_resources.assert_called_once_with( # type: ignore[attr-defined] + component=test_component, + resources=mock_stream, + required_resources=expected_required_resources, + ) + + # Assert _add_constraint was called for each stream method + assert manager._add_constraint.call_count == 4 # type: ignore[attr-defined] + restricted_states = [ + lifecycle_states.INITIALIZATION, + lifecycle_states.SETUP, + lifecycle_states.POST_SETUP, + ] + expected_calls = [ + mocker.call(mock_stream.get_draw, restrict_during=restricted_states), + mocker.call(mock_stream.filter_for_probability, restrict_during=restricted_states), + mocker.call(mock_stream.filter_for_rate, restrict_during=restricted_states), + mocker.call(mock_stream.choice, restrict_during=restricted_states), + ] + manager._add_constraint.assert_has_calls(expected_calls) # type: ignore[attr-defined] + + # Assert the stream is returned + assert result == mock_stream + + def mock_clock() -> pd.Timestamp: return pd.Timestamp("1/1/2005") @@ -18,6 +87,7 @@ def test_randomness_manager_get_randomness_stream() -> None: component = ColumnCreator() rm = RandomnessManager() + rm._get_current_component = lambda: component rm._add_constraint = lambda f, **kwargs: f rm._seed = seed rm._clock_ = mock_clock @@ -30,10 +100,9 @@ def test_randomness_manager_get_randomness_stream() -> None: assert stream.seed == seed assert stream.clock is mock_clock assert set(rm._decision_points.keys()) == {"test"} - assert stream.component == component with pytest.raises(RandomnessError): - rm._get_randomness_stream("test", ColumnRequirer()) + rm._get_randomness_stream("test", component) def test_randomness_manager_register_simulants() -> None: diff --git a/tests/framework/randomness/test_reproducibility.py b/tests/framework/randomness/test_reproducibility.py index 8ba03f1da..23d92a525 100644 --- a/tests/framework/randomness/test_reproducibility.py +++ b/tests/framework/randomness/test_reproducibility.py @@ -23,7 +23,7 @@ def test_reproducibility(tmp_path: Path, disease_model_spec: Path) -> None: check=True, ) - files = [file for file in results_dir.rglob("**/*.parquet")] + files = list(results_dir.rglob("**/*.parquet")) assert len(files) == 4 for filename in ["dead", "ylls"]: df_paths = [file for file in files if file.stem == filename] diff --git a/tests/framework/randomness/test_stream.py b/tests/framework/randomness/test_stream.py index 393329a33..ffb447ad3 100644 --- a/tests/framework/randomness/test_stream.py +++ b/tests/framework/randomness/test_stream.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any import numpy as np import pandas as pd import pytest from layered_config_tree import LayeredConfigTree +from pytest_mock import MockerFixture from scipy import stats from vivarium_testing_utils import FuzzyChecker @@ -21,9 +22,16 @@ @pytest.fixture -def randomness_stream() -> RandomnessStream: - dates = [pd.Timestamp(1991, 1, 1), pd.Timestamp(1990, 1, 1)] - randomness = RandomnessStream("test", dates.pop, 1, IndexMap()) +def randomness_stream(mocker: MockerFixture) -> RandomnessStream: + # Provide different dates to each fixture to ensure independence + dates = [ + pd.Timestamp(1991, 1, 1), + pd.Timestamp(1990, 1, 1), + pd.Timestamp(1989, 1, 1), + pd.Timestamp(1988, 1, 1), + pd.Timestamp(1987, 1, 1), + ] + randomness = RandomnessStream("test", dates.pop, 1, IndexMap(), mocker.Mock()) return randomness @@ -84,6 +92,56 @@ def test_filter_for_probability_multiple_probabilities( ) +def test_filter_for_probability_with_zeros_and_or_ones( + randomness_stream: RandomnessStream, index: pd.Index[int] +) -> None: + probabilities = pd.Series( + [0.0, 0.6, 1.0, 0.0, 0.0, 0.6, 1.0, 0.6, 0.0, 0.6] * (index.size // 10), index=index + ) + threshold_0_0 = probabilities.index[probabilities == 0.0] + threshold_1_0 = probabilities.index[probabilities == 1.0] + threshold_0_6 = probabilities.index.difference(threshold_0_0).difference(threshold_1_0) + sub_index = randomness_stream.filter_for_probability(index, probabilities) + # Nothing should be selected from the zero-probability group + assert len(sub_index.intersection(threshold_0_0)) == 0 + assert np.isclose( + len(sub_index.intersection(threshold_0_6)) / len(threshold_0_6), 0.6, rtol=0.1 + ) + assert len(sub_index.intersection(threshold_1_0)) == len(threshold_1_0) + + +def test_filter_for_probability_all_zeros_or_ones( + randomness_stream: RandomnessStream, index: pd.Index[int] +) -> None: + sub_index_zeros = randomness_stream.filter_for_probability(index, 0.0) + assert len(sub_index_zeros) == 0 + + sub_index_ones = randomness_stream.filter_for_probability(index, 1.0) + assert len(sub_index_ones) == len(index) + + +def test_filter_for_probability_types( + randomness_stream: RandomnessStream, index: pd.Index[int] +) -> None: + sub_index_float = randomness_stream.filter_for_probability(index, 0.42) + sub_index_list = randomness_stream.filter_for_probability(index, [0.42] * len(index)) + sub_index_tuple = randomness_stream.filter_for_probability(index, (0.42,) * len(index)) + sub_index_array = randomness_stream.filter_for_probability( + index, np.array([0.42] * len(index)) + ) + sub_index_series = randomness_stream.filter_for_probability( + index, pd.Series(0.42, index=index) + ) + for sub_index in [ + sub_index_float, + sub_index_list, + sub_index_tuple, + sub_index_array, + sub_index_series, + ]: + assert np.isclose(len(sub_index) / len(index), 0.42, rtol=0.1) + + @pytest.mark.parametrize( "rate, time_scaling_factor", [ @@ -210,10 +268,13 @@ def test_sample_from_distribution_bad_args( ], ) def test_sample_from_distribution_using_scipy( - index: pd.Index[int], distribution: stats.rv_continuous, params: dict[str, int] + mocker: MockerFixture, + index: pd.Index[int], + distribution: stats.rv_continuous, + params: dict[str, int], ) -> None: randomness_stream = RandomnessStream( - "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap() + "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap(), mocker.Mock() ) draws = randomness_stream.get_draw(index, "some_key") expected = distribution.ppf(draws, **params) @@ -227,7 +288,9 @@ def test_sample_from_distribution_using_scipy( assert np.allclose(sample, expected) -def test_sample_from_distribution_using_ppf(index: pd.Index[int]) -> None: +def test_sample_from_distribution_using_ppf( + mocker: MockerFixture, index: pd.Index[int] +) -> None: def silly_ppf(x: pd.Series[Any], **kwargs: Any) -> pd.Series[Any]: add = kwargs["add"] mult = kwargs["mult"] @@ -236,7 +299,7 @@ def silly_ppf(x: pd.Series[Any], **kwargs: Any) -> pd.Series[Any]: return output randomness_stream = RandomnessStream( - "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap() + "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap(), mocker.Mock() ) draws = randomness_stream.get_draw(index, "some_key") expected = 2 * (draws**2) + 1 @@ -284,10 +347,14 @@ def test_stream_rate_conversion_config( ], ) def test_filter_for_probability_error_with_null_values( - probs: float | list[float] | pd.Series[float], + probs: float | list[float] | pd.Series[float], mocker: MockerFixture ) -> None: randomness_stream = RandomnessStream( - "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap() + key="test", + clock=lambda: pd.Timestamp(2020, 1, 1), + seed=1, + index_map=IndexMap(), + component=mocker.Mock(), ) pop = pd.DataFrame({"age": [10, 11, 12, 13, 14], "id": [1, 2, 3, 4, 5]}).set_index("id") with pytest.raises(ValueError, match="Probabilities contain null values"): diff --git a/tests/framework/resource/test_interface.py b/tests/framework/resource/test_interface.py new file mode 100644 index 000000000..e440eb08b --- /dev/null +++ b/tests/framework/resource/test_interface.py @@ -0,0 +1,29 @@ +from pytest_mock import MockerFixture + +from vivarium.framework.resource.interface import ResourceInterface +from vivarium.framework.resource.manager import ResourceManager +from vivarium.framework.resource.resource import Column + + +def test_add_private_columns(mocker: MockerFixture) -> None: + mgr = ResourceManager() + interface = ResourceInterface(mgr) + mocker.patch.object( + mgr, + "_get_current_component_or_manager", + return_value=mocker.MagicMock(), + create=True, + ) + interface.add_private_columns( + initializer=lambda pop_data: None, + columns=["private_col_1", "private_col_2"], + required_resources=[], + ) + resource_map = mgr._resource_group_map + resource_ids = ["column.private_col_1", "column.private_col_2"] + assert set(resource_map.keys()) == set(resource_ids) + for resource_id in resource_ids: + resource = resource_map[resource_id].resources[resource_id] + assert isinstance(resource, Column) + assert resource.name == resource_id.split(".", 1)[1] + assert resource.resource_type == "column" diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index 8a2a1ff5c..8b2ea5f9a 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -5,21 +5,21 @@ from typing import Any import pytest -import pytest_mock +from pytest_mock import MockerFixture -from tests.helpers import ColumnCreator, ColumnCreatorAndRequirer, ColumnRequirer +from tests.helpers import ColumnCreator, ColumnCreatorAndRequirer from vivarium import Component from vivarium.framework.population import SimulantData from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap -from vivarium.framework.resource import ResourceManager +from vivarium.framework.resource import Resource, ResourceManager from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.resource import Column, NullResource -from vivarium.framework.values import Pipeline, ValueModifier, ValueSource +from vivarium.framework.values import AttributePipeline, Pipeline, ValueModifier, ValueSource @pytest.fixture -def manager(mocker: pytest_mock.MockFixture) -> ResourceManager: +def manager(mocker: MockerFixture) -> ResourceManager: manager = ResourceManager() manager.logger = mocker.Mock() return manager @@ -32,18 +32,82 @@ def resource_producers() -> dict[int, ResourceProducer]: @pytest.fixture def manager_with_resources( - manager: ResourceManager, resource_producers: dict[int, ResourceProducer] + manager: ResourceManager, + resource_producers: dict[int, ResourceProducer], + mocker: MockerFixture, ) -> ResourceManager: stream = RandomnessStream( - "B", lambda: datetime.now(), 1, IndexMap(), resource_producers[1] + key="B", + clock=lambda: datetime.now(), + seed=1, + index_map=IndexMap(), + component=resource_producers[1], ) pipeline = Pipeline("C", resource_producers[2]) + A_component = resource_producers[0] + attribute_A = AttributePipeline("A", A_component) + mocker.patch.object(A_component, "initialize_A", create=True) + D_component = resource_producers[3] + attribute_D = AttributePipeline("D", D_component) + mocker.patch.object(D_component, "initialize_D", create=True) + null_resource_component = resource_producers[4] + mocker.patch.object(null_resource_component, "initialize_nothing", create=True) + + manager.add_resources( + component=D_component, + initializer=None, + resources=attribute_D, + required_resources=[stream, pipeline], + ) + # Add the private column resource + manager.add_resources( + component=D_component, + initializer=D_component.initialize_D, + resources=[Column("D", D_component)], + required_resources=[stream, pipeline], + ) + + stream_component = stream.component + assert isinstance(stream_component, Component) + manager.add_resources( + component=stream_component, + initializer=None, + resources=stream, + required_resources=["A"], + ) + + pipeline_component = pipeline.component + assert isinstance(pipeline_component, Component) + manager.add_resources( + component=pipeline_component, + initializer=None, + resources=pipeline, + required_resources=["A"], + ) + + manager.add_resources( + component=A_component, initializer=None, resources=attribute_A, required_resources=[] + ) + # Add the private column resource + manager.add_resources( + component=A_component, + initializer=A_component.initialize_A, + resources=[Column("A", A_component)], + required_resources=[], + ) + + manager.add_resources( + component=null_resource_component, + initializer=null_resource_component.initialize_nothing, + resources=[], + required_resources=[stream], + ) + + # Call each resource group's on_post_setup to finalize dependencies + attribute_pipelines = {"A": attribute_A, "D": attribute_D} + for rg in manager._resource_group_map.values(): + rg.set_required_resources(attribute_pipelines) - manager.add_resources(resource_producers[3], ["D"], [stream, pipeline]) - manager.add_resources(stream.component, [stream], ["A"]) - manager.add_resources(pipeline.component, [pipeline], ["A"]) - manager.add_resources(resource_producers[0], ["A"], []) - manager.add_resources(resource_producers[4], [], [stream]) return manager @@ -63,71 +127,133 @@ def __init__(self, name: str): super().__init__() self._name = name - def on_initialize_simulants(self, _simulant_data: SimulantData) -> None: + def initialize_A(self, pop_data: SimulantData) -> None: + pass + + def initialize_D(self, pop_data: SimulantData) -> None: + pass + + def initialize_nothing(self, pop_data: SimulantData) -> None: pass @pytest.mark.parametrize( - "resource_class, init_args, type_string, is_initializer", + "resource_class, init_args, type_string", [ - (Pipeline, ["foo"], "value", False), - (ValueSource, [Pipeline("foo"), lambda: 1], "value_source", False), - (ValueModifier, [Pipeline("foo"), lambda: 1], "value_modifier", False), - (Column, ["foo"], "column", True), - (NullResource, [1], "null", True), + (Pipeline, ["foo"], "value"), + (AttributePipeline, ["foo"], "attribute"), + (ValueSource, [Pipeline("foo"), lambda: 1], "value_source"), + (ValueModifier, [Pipeline("foo"), lambda: 1], "value_modifier"), + (Column, ["foo"], "column"), + (NullResource, [1], "null"), ], ids=lambda x: [x.__name__ if isinstance(x, type) else x], ) +@pytest.mark.parametrize("initialized", [True, False]) def test_resource_manager_get_resource_group( resource_class: type, init_args: list[Any], type_string: str, - is_initializer: bool, + initialized: bool, manager: ResourceManager, ) -> None: component = ColumnCreator() group = manager._get_resource_group( - component, [resource_class(*init_args, component=component)], [] + component=component, + initializer=component.initialize_test_columns if initialized else None, + resources=[resource_class(*init_args, component=component)], + required_resources=[], ) assert group.type == type_string assert group.names == [r.resource_id for r in group.resources.values()] - assert not group.dependencies - assert group.is_initialized == is_initializer - assert group.initializer == component.on_initialize_simulants + assert not group.required_resources + assert group.is_initialized == initialized + assert group.initializer == (component.initialize_test_columns if initialized else None) def test_resource_manager_get_resource_group_null(manager: ResourceManager) -> None: component_1 = ColumnCreator() component_2 = ColumnCreatorAndRequirer() - group_1 = manager._get_resource_group(component_1, [], []) - group_2 = manager._get_resource_group(component_2, [], []) + group_1 = manager._get_resource_group( + component=component_1, + initializer=component_1.initialize_test_columns, + resources=[], + required_resources=[], + ) + group_2 = manager._get_resource_group( + component=component_2, + initializer=component_2.initialize_test_column_4, + resources=[], + required_resources=[], + ) assert group_1.type == "null" assert group_1.names == ["null.0"] - assert group_1.initializer == component_1.on_initialize_simulants - assert not group_1.dependencies + assert group_1.initializer == component_1.initialize_test_columns + assert not group_1.required_resources assert group_2.type == "null" assert group_2.names == ["null.1"] - assert group_2.initializer == component_2.on_initialize_simulants - assert not group_2.dependencies + assert group_2.initializer == component_2.initialize_test_column_4 + assert not group_2.required_resources + + +def test_get_resource_group_multiple_initializers(manager: ResourceManager) -> None: + class SomeComponent(Component): + def initializer_1(self, pop_data: SimulantData) -> None: + pass + + def initializer_2(self, pop_data: SimulantData) -> None: + pass + + component = SomeComponent() + + group = manager._get_resource_group( + component=component, + initializer=component.initializer_1, + resources=[Column("foo", component), Column("bar", component)], + required_resources=[], + ) + + assert group.type == "column" + assert group.names == ["column.foo", "column.bar"] + assert group.initializer == component.initializer_1 + + # Create another group with the same resources but a different initializer + group2 = manager._get_resource_group( + component=component, + initializer=component.initializer_2, + resources=Resource("test", "baz", component), + required_resources=[], + ) + + assert group2.type == "test" + assert group2.names == ["test.baz"] + assert group2.initializer == component.initializer_2 def test_add_resource_wrong_component(manager: ResourceManager) -> None: resource = Pipeline("foo", ColumnCreatorAndRequirer()) - error_message = "All initialized resources must have the component 'column_creator'." + error_message = "All initialized resources in this resource group must have the component 'column_creator'." + component = ColumnCreator() with pytest.raises(ResourceError, match=error_message): - manager.add_resources(ColumnCreator(), [resource], []) + manager.add_resources( + component=component, + initializer=component.initialize_test_columns, + resources=resource, + required_resources=[], + ) @pytest.mark.parametrize( "resource_type, resource_creator", [ - ("column", lambda name, component: name), + ("column", lambda name, component: Column(name, component)), ("value", lambda name, component: Pipeline(name, component)), + ("attribute", lambda name, component: AttributePipeline(name, component)), ], ) def test_resource_manager_add_same_resource_twice( @@ -139,34 +265,73 @@ def test_resource_manager_add_same_resource_twice( c2 = ColumnCreatorAndRequirer() r1 = [resource_creator(str(i), c1) for i in range(5)] r2 = [resource_creator(str(i), c2) for i in range(5, 10)] + [resource_creator("1", c2)] - - manager.add_resources(c1, r1, []) + manager.add_resources( + component=c1, + initializer=c1.initialize_test_columns, + resources=r1, + required_resources=[], + ) error_message = ( f"Component '{c2.name}' is attempting to register resource" f" '{resource_type}.1' but it is already registered by '{c1.name}'." ) with pytest.raises(ResourceError, match=error_message): - manager.add_resources(c2, r2, []) + manager.add_resources( + component=c2, + initializer=c2.initialize_test_column_4, + resources=r2, + required_resources=[], + ) def test_resource_manager_sorted_nodes_two_node_cycle( - manager: ResourceManager, randomness_stream: RandomnessStream + manager: ResourceManager, randomness_stream: RandomnessStream, mocker: MockerFixture ) -> None: - manager.add_resources(ColumnCreatorAndRequirer(), ["c_1"], [randomness_stream]) - manager.add_resources(randomness_stream.component, [randomness_stream], ["c_1"]) + component = ColumnCreatorAndRequirer() + column = Column("c_1", mocker.Mock()) + manager.add_resources( + component=component, + initializer=component.initialize_test_column_4, + resources=[column], + required_resources=[randomness_stream], + ) + manager.add_resources( + component=randomness_stream.component, + initializer=None, + resources=randomness_stream, + required_resources=[column], + ) with pytest.raises(ResourceError, match="The resource pool contains at least one cycle"): _ = manager.sorted_nodes def test_resource_manager_sorted_nodes_three_node_cycle( - manager: ResourceManager, randomness_stream: RandomnessStream + manager: ResourceManager, + randomness_stream: RandomnessStream, + mocker: MockerFixture, ) -> None: - pipeline = Pipeline("some_pipeline", ColumnRequirer()) - - manager.add_resources(ColumnCreatorAndRequirer(), ["c_1"], [randomness_stream]) - manager.add_resources(pipeline.component, [pipeline], ["c_1"]) - manager.add_resources(randomness_stream.component, [randomness_stream], [pipeline]) + pipeline = Pipeline("some_pipeline", mocker.Mock()) + component = ColumnCreatorAndRequirer() + column = Column("c_1", component) + manager.add_resources( + component=component, + initializer=component.initialize_test_column_4, + resources=[column], + required_resources=[randomness_stream], + ) + manager.add_resources( + component=pipeline.component, + initializer=None, + resources=pipeline, + required_resources=[column], + ) + manager.add_resources( + component=randomness_stream.component, + initializer=None, + resources=randomness_stream, + required_resources=[pipeline], + ) with pytest.raises(ResourceError, match="The resource pool contains at least one cycle"): _ = manager.sorted_nodes @@ -175,36 +340,53 @@ def test_resource_manager_sorted_nodes_three_node_cycle( def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> None: component = ColumnCreator() for i in range(10): - manager.add_resources(component, [f"c_{i}"], [f"c_{i % 10}"]) + resource = Resource("test", f"resource{i}", component) + dependency = Resource("test", f"resource{(i + 1) % 10}", component) + manager.add_resources( + component=component, + initializer=None, + resources=resource, + required_resources=[dependency], + ) - with pytest.raises(ResourceError, match="cycle"): + with pytest.raises(ResourceError, match="The resource pool contains at least one cycle"): _ = manager.sorted_nodes -def test_large_dependency_chain(manager: ResourceManager) -> None: +def test_large_dependency_chain(manager: ResourceManager, mocker: MockerFixture) -> None: component = ColumnCreator() for i in range(9, 0, -1): - manager.add_resources(component, [f"c_{i}"], [f"c_{i - 1}"]) - manager.add_resources(component, ["c_0"], []) + manager.add_resources( + component=component, + initializer=component.initialize_test_columns, + resources=AttributePipeline(f"c_{i}", component), + required_resources=[AttributePipeline(f"c_{i-1}", mocker.Mock())], + ) + manager.add_resources( + component=component, + initializer=component.initialize_test_columns, + resources=AttributePipeline("c_0", component), + required_resources=[], + ) for i, resource in enumerate(manager.sorted_nodes): - assert str(resource) == f"(column.c_{i})" + assert str(resource) == f"(attribute.c_{i})" def test_resource_manager_sorted_nodes_acyclic( manager_with_resources: ResourceManager, ) -> None: - n = [str(node) for node in manager_with_resources.sorted_nodes] + nodes = [str(node) for node in manager_with_resources.sorted_nodes] - assert n.index("(column.A)") < n.index("(stream.B)") - assert n.index("(column.A)") < n.index("(value.C)") - assert n.index("(column.A)") < n.index("(column.D)") + assert nodes.index("(attribute.A)") < nodes.index("(stream.B)") + assert nodes.index("(attribute.A)") < nodes.index("(value.C)") + assert nodes.index("(attribute.A)") < nodes.index("(attribute.D)") - assert n.index("(stream.B)") < n.index("(column.D)") - assert n.index("(value.C)") < n.index("(column.D)") + assert nodes.index("(stream.B)") < nodes.index("(attribute.D)") + assert nodes.index("(value.C)") < nodes.index("(attribute.D)") - assert n.index("(stream.B)") < n.index(f"(null.0)") + assert nodes.index("(stream.B)") < nodes.index(f"(null.0)") def test_get_population_initializers( @@ -213,6 +395,6 @@ def test_get_population_initializers( initializers = manager_with_resources.get_population_initializers() assert len(initializers) == 3 - assert initializers[0] == resource_producers[0].on_initialize_simulants - assert resource_producers[3].on_initialize_simulants in initializers - assert resource_producers[4].on_initialize_simulants in initializers + assert initializers[0] == resource_producers[0].initialize_A + assert resource_producers[3].initialize_D in initializers + assert resource_producers[4].initialize_nothing in initializers diff --git a/tests/framework/resource/test_resource.py b/tests/framework/resource/test_resource.py index 9f06d37da..459769562 100644 --- a/tests/framework/resource/test_resource.py +++ b/tests/framework/resource/test_resource.py @@ -1,30 +1,7 @@ -from datetime import datetime - -import pytest - from tests.helpers import ColumnCreator -from vivarium.framework.randomness import RandomnessStream -from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.resource import Resource -from vivarium.framework.resource.resource import Column, NullResource -from vivarium.framework.values import Pipeline, ValueModifier, ValueSource def test_resource_id() -> None: resource = Resource("value_source", "test", ColumnCreator()) assert resource.resource_id == "value_source.test" - - -@pytest.mark.parametrize( - "resource, is_initialized", - [ - (Pipeline("foo"), False), - (ValueSource(Pipeline("bar"), lambda: 1, ColumnCreator()), False), - (ValueModifier(Pipeline("baz"), lambda: 1, ColumnCreator()), False), - (Column("foo", ColumnCreator()), True), - (RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), False), - (NullResource(0, ColumnCreator()), True), - ], -) -def test_resource_is_initialized(resource: Resource, is_initialized: bool) -> None: - assert resource.is_initialized == is_initialized diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index 3f7975aff..9126c7012 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -1,34 +1,49 @@ from datetime import datetime import pytest +from pytest_mock import MockerFixture -from tests.helpers import ColumnCreator, ColumnRequirer +from tests.helpers import ColumnCreator +from vivarium.framework.population.manager import SimulantData from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup from vivarium.framework.resource.resource import Column, NullResource, Resource -from vivarium.framework.values import Pipeline, ValueModifier, ValueSource +from vivarium.framework.values import AttributePipeline, Pipeline, ValueModifier, ValueSource -def test_resource_group() -> None: +@pytest.mark.parametrize("resource_type", ["column", "resource"]) +def test_resource_group(resource_type: str, mocker: MockerFixture) -> None: component = ColumnCreator() - resources = [Column(str(i), component) for i in range(5)] - r_dependencies = [ - Column("an_interesting_column", None), + resources: list[Column] | Resource + if resource_type == "column": + resources = [Column(f"resource_{i}", component) for i in range(5)] + else: + resources = Resource("test", "some_resource", component) + r_required_resources = [ + AttributePipeline("an_interesting_attribute", None), Pipeline("baz"), - RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), + RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap(), mocker.Mock()), ValueSource(Pipeline("foo"), lambda: 1, None), ] - rg = ResourceGroup(resources, r_dependencies) + rg = ResourceGroup( + initialized_resources=resources, + required_resources=r_required_resources, + initializer=component.initialize_test_columns, + ) assert rg.component == component - assert rg.type == "column" - assert rg.names == [f"column.{i}" for i in range(5)] - assert rg.initializer == component.on_initialize_simulants - assert rg.dependencies == [ - "column.an_interesting_column", + assert rg.type == "column" if resource_type == "column" else "test" + assert ( + rg.names == [f"column.{res.name}" for res in resources] + if isinstance(resources, list) + else ["test.some_resource"] + ) + assert rg.initializer == component.initialize_test_columns + assert rg.required_resources == [ + "attribute.an_interesting_attribute", "value.baz", "stream.bar", "value_source.foo", @@ -37,37 +52,49 @@ def test_resource_group() -> None: @pytest.mark.parametrize( - "resource, has_initializer", + "resource", [ - (Pipeline("foo"), False), - (ValueSource(Pipeline("bar"), lambda: 1, ColumnCreator()), False), - (ValueModifier(Pipeline("baz"), lambda: 1, ColumnCreator()), False), - (Column("foo", ColumnCreator()), True), - (RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), False), - (NullResource(0, ColumnCreator()), True), + Pipeline("foo", ColumnCreator()), + AttributePipeline("foo", ColumnCreator()), + ValueSource(Pipeline("bar", ColumnCreator()), lambda: 1, ColumnCreator()), + ValueModifier(Pipeline("baz", ColumnCreator()), lambda: 1, ColumnCreator()), + Column("foo", ColumnCreator()), + RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap(), ColumnCreator()), + NullResource(0, ColumnCreator()), ], ) -def test_resource_group_is_initializer(resource: Resource, has_initializer: bool) -> None: - rg = ResourceGroup([resource], [Column("bar", None)]) - assert rg.is_initialized == has_initializer +@pytest.mark.parametrize("initialized", [True, False]) +def test_resource_group_is_initializer(resource: Resource, initialized: bool) -> None: + def some_initializer(pop_data: SimulantData) -> None: + pass + + rg = ResourceGroup( + resource, [Resource("test", "bar", None)], some_initializer if initialized else None + ) + assert rg.is_initialized == initialized def test_resource_group_with_no_resources() -> None: with pytest.raises(ResourceError, match="must have at least one resource"): - _ = ResourceGroup([], [Column("foo", None)]) + _ = ResourceGroup([], [Resource("test", "foo", None)], None) def test_resource_group_with_multiple_components() -> None: + # This test is not terribly relevant since ResourceGroup now only accepts a list + # of Columns or a single Resource. We keep it around in case that changes. resources = [ - ValueModifier(Pipeline("foo"), lambda: 1, ColumnCreator()), - ValueSource(Pipeline("bar"), lambda: 2, ColumnRequirer()), + Column("foo", ColumnCreator()), + Column("bar", ColumnCreator()), ] with pytest.raises(ResourceError, match="resources must have the same component"): - _ = ResourceGroup(resources, []) + _ = ResourceGroup(resources, [], None) def test_resource_group_with_multiple_resource_types() -> None: + + # This test is not terribly relevant since ResourceGroup now only accepts a list + # of Columns or a single Resource. We keep it around in case that changes. component = ColumnCreator() resources = [ ValueModifier(Pipeline("foo"), lambda: 1, component), @@ -75,4 +102,29 @@ def test_resource_group_with_multiple_resource_types() -> None: ] with pytest.raises(ResourceError, match="resources must be of the same type"): - _ = ResourceGroup(resources, []) + _ = ResourceGroup(resources, [], None) # type: ignore[arg-type] + + +def test_set_required_resources(mocker: MockerFixture) -> None: + some_attribute = AttributePipeline("some_attribute", mocker.Mock()) + some_other_attribute = AttributePipeline("some_other_attribute", mocker.Mock()) + resource = Resource("test", "some_resource", mocker.Mock()) + required_resources: list[AttributePipeline | str] = [ + some_attribute, + "some_other_attribute", + ] + + rg = ResourceGroup( + initialized_resources=resource, + required_resources=required_resources, + initializer=None, + ) + assert rg._required_resources == [some_attribute, "some_other_attribute"] + # Mock the attribute pipelines dict + attribute_pipelines = { + "some_attribute": some_attribute, + "some_other_attribute": some_other_attribute, + } + rg.set_required_resources(attribute_pipelines) + # Check that the 'some_other_attribute' string has been replaced by the pipeline + assert rg._required_resources == [some_attribute, some_other_attribute] diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index bb6dea638..61da2704a 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -12,7 +12,6 @@ from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.observer import Observer from vivarium.framework.results.stratification import Stratification -from vivarium.framework.values import Pipeline from vivarium.types import ScalarMapper, VectorMapper NAME = "hogwarts_house" @@ -29,13 +28,12 @@ BIN_LABELS = ["meh", "somewhat", "very", "extra"] BIN_SILLY_BIN_EDGES = [0, 20, 40, 60, 90] -COL_NAMES = ["house", "familiar", "power_level", "tracked"] +COL_NAMES = ["house", "familiar", "power_level"] FAMILIARS = ["owl", "cat", "gecko", "banana_slug", "unladen_swallow"] POWER_LEVELS = [20, 40, 60, 80] POWER_LEVEL_BIN_EDGES = [0, 25, 50, 75, 100] POWER_LEVEL_GROUP_LABELS = ["low", "medium", "high", "very high"] -TRACKED_STATUSES = [True, False] -RECORDS = list(itertools.product(HOUSE_CATEGORIES, FAMILIARS, POWER_LEVELS, TRACKED_STATUSES)) +RECORDS = list(itertools.product(HOUSE_CATEGORIES, FAMILIARS, POWER_LEVELS)) BASE_POPULATION = pd.DataFrame(data=RECORDS, columns=COL_NAMES) HARRY_POTTER_CONFIG = { @@ -58,39 +56,36 @@ class Hogwarts(Component): - @property - def columns_created(self) -> list[str]: - return [ - "student_id", - "student_house", - "familiar", - "power_level", - "house_points", - "quidditch_wins", - "exam_score", - "spell_power", - "potion_power", - ] - def setup(self, builder: Builder) -> None: - self.grade = builder.value.register_value_producer( + builder.value.register_attribute_producer( "grade", - source=lambda index: self.population_view.get(index)["exam_score"].map( + source=lambda index: self.population_view.get_attributes(index, "exam_score").map( lambda x: x // 10 ), - requires_columns=["exam_score"], + required_resources=["exam_score"], ) - self.double_power = builder.value.register_value_producer( + builder.value.register_attribute_producer( "double_power", - source=lambda index: self.population_view.get(index)["power_level"] * 2, - requires_columns=["power_level"], + source=lambda index: self.population_view.get_attributes(index, "power_level") + * 2, + required_resources=["power_level"], + ) + builder.population.register_initializer( + initializer=self.initialize_hogwarts, + columns=[ + "student_id", + "student_house", + "familiar", + "power_level", + "house_points", + "quidditch_wins", + "exam_score", + "spell_power", + "potion_power", + ], ) - def grade_source(self, index: pd.Index[int]) -> pd.Series[str]: - pass_mask = self.population_view.get(index)["exam_score"] > 6 - return pass_mask.map({True: "pass", False: "fail"}) - - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_hogwarts(self, pop_data: SimulantData) -> None: size = len(pop_data.index) initialization_data = pd.DataFrame( { @@ -110,7 +105,15 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: self.population_view.update(initialization_data) def on_time_step(self, pop_data: Event) -> None: - update = self.population_view.get(pop_data.index) + update = self.population_view.get_private_columns( + pop_data.index, + [ + "student_house", + "power_level", + "familiar", + "exam_score", + ], + ) update["house_points"] = 0 update["quidditch_wins"] = 0 # House points are stratified by 'student_house' and 'power_level_group'. @@ -141,9 +144,7 @@ def register_observations(self, builder: Builder) -> None: name="house_points", aggregator_sources=["house_points"], aggregator=lambda df: df.sum(), - requires_columns=[ - "house_points", - ], + requires_attributes=["house_points"], results_formatter=results_formatter, ) @@ -154,12 +155,10 @@ class FullyFilteredHousePointsObserver(Observer): def register_observations(self, builder: Builder) -> None: builder.results.register_adding_observation( name="house_points", - pop_filter="tracked==True & power_level=='one billion'", + pop_filter="power_level=='one billion'", aggregator_sources=["house_points"], aggregator=lambda df: df.sum(), - requires_columns=[ - "house_points", - ], + requires_attributes=["house_points"], ) @@ -173,9 +172,7 @@ def register_observations(self, builder: Builder) -> None: aggregator=lambda df: df.sum(), excluded_stratifications=["student_house", "power_level_group"], additional_stratifications=["familiar"], - requires_columns=[ - "quidditch_wins", - ], + requires_attributes=["quidditch_wins"], results_formatter=results_formatter, ) @@ -189,9 +186,7 @@ def register_observations(self, builder: Builder) -> None: aggregator_sources=["quidditch_wins"], aggregator=lambda df: df.sum(), excluded_stratifications=["student_house", "power_level_group"], - requires_columns=[ - "quidditch_wins", - ], + requires_attributes=["quidditch_wins"], results_formatter=results_formatter, ) @@ -221,7 +216,7 @@ class ExamScoreObserver(Observer): def register_observations(self, builder: Builder) -> None: builder.results.register_concatenating_observation( name="exam_score", - requires_columns=["student_id", "student_house", "exam_score"], + requires_attributes=["student_id", "student_house", "exam_score"], ) @@ -231,8 +226,8 @@ class CatBombObserver(Observer): def register_observations(self, builder: Builder) -> None: builder.results.register_stratified_observation( name="cat_bomb", - pop_filter="familiar=='cat' and tracked==True", - requires_columns=["familiar"], + pop_filter="familiar=='cat'", + requires_attributes=["familiar"], results_updater=self.update_cats, excluded_stratifications=["power_level_group"], aggregator_sources=["student_house"], @@ -259,7 +254,7 @@ def __init__(self) -> None: def register_observations(self, builder: Builder) -> None: builder.results.register_unstratified_observation( name="valedictorian", - requires_columns=["event_time", "student_id", "exam_score"], + requires_attributes=["event_time", "student_id", "exam_score"], results_gatherer=self.choose_valedictorian, # type: ignore [arg-type] results_updater=self.update_valedictorian, ) @@ -288,10 +283,10 @@ def setup(self, builder: Builder) -> None: builder.results.register_stratification( name="student_house", categories=list(STUDENT_HOUSES), - requires_columns=["student_house"], + requires_attributes=["student_house"], ) builder.results.register_stratification( - name="familiar", categories=FAMILIARS, requires_columns=["familiar"] + name="familiar", categories=FAMILIARS, requires_attributes=["familiar"] ) builder.results.register_binned_stratification( "power_level", @@ -344,8 +339,7 @@ def sorting_hat_bad_mapping(simulant_row: pd.Series[str]) -> str: def verify_stratification_added( stratifications: dict[str, Stratification], name: str, - requires_columns: list[str], - requires_values: list[Pipeline], + requires_attributes: list[str], categories: list[str], excluded_categories: list[str], mapper: VectorMapper | ScalarMapper, @@ -355,8 +349,7 @@ def verify_stratification_added( stratification = stratifications.get(name) if not stratification: return False - expected_value_names = sorted(pipeline.name for pipeline in requires_values) - actual_value_names = sorted(pipeline.name for pipeline in stratification.requires_values) + return ( stratification.name == name and sorted(stratification.categories) @@ -364,13 +357,5 @@ def verify_stratification_added( and sorted(stratification.excluded_categories) == sorted(excluded_categories) and stratification.mapper == mapper and stratification.is_vectorized == is_vectorized - and sorted(stratification.requires_columns) == sorted(requires_columns) - and actual_value_names == expected_value_names + and sorted(stratification.requires_attributes) == sorted(requires_attributes) ) - - -# Mock for get_value call for Pipelines, returns a str instead of a Pipeline -def mock_get_value(self: Builder, name: str) -> Pipeline: - if not isinstance(name, str): - raise TypeError("Passed a non-string type to mock get_value(), check your pipelines.") - return Pipeline(name) diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index 632b3ea13..376511880 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -4,7 +4,6 @@ from collections.abc import Callable from datetime import timedelta from typing import Any -from unittest.mock import Mock import numpy as np import pandas as pd @@ -28,9 +27,9 @@ from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.interface import PopulationFilter from vivarium.framework.results.observation import AddingObservation, ConcatenatingObservation from vivarium.framework.results.stratification import Stratification, get_mapped_col_name -from vivarium.framework.values import Pipeline from vivarium.types import ScalarMapper, VectorMapper @@ -63,14 +62,12 @@ def test_add_stratification_mappers( ) -> None: ctx = ResultsContext() mocker.patch.object(ctx, "excluded_categories", {}) - pipeline = Pipeline("grade") assert NAME not in ctx.stratifications ctx.add_stratification( name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[pipeline], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=None, mapper=mapper, @@ -79,8 +76,7 @@ def test_add_stratification_mappers( assert verify_stratification_added( stratifications=ctx.stratifications, name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[pipeline], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=[], mapper=mapper, @@ -118,8 +114,7 @@ def test_add_stratification_excluded_categories( ctx.add_stratification( name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=excluded_categories, mapper=sorting_hat_vectorized, @@ -129,8 +124,7 @@ def test_add_stratification_excluded_categories( assert verify_stratification_added( stratifications=ctx.stratifications, name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=excluded_categories, mapper=sorting_hat_vectorized, @@ -185,8 +179,7 @@ def test_add_stratification_raises( # Register a stratification to test against duplicate stratifications ctx.add_stratification( name="duplicate_name", - requires_columns=["foo"], - requires_values=[], + requires_attributes=["foo"], categories=["bar"], excluded_categories=None, mapper=sorting_hat_serial, @@ -195,8 +188,7 @@ def test_add_stratification_raises( with pytest.raises(ValueError, match=re.escape(msg_match)): ctx.add_stratification( name=name, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, categories=categories, excluded_categories=excluded_categories, mapper=sorting_hat_vectorized, @@ -209,27 +201,27 @@ def test_add_stratification_raises( [ { "name": "living_person_time", - "pop_filter": 'alive == "alive" and undead == False', - "requires_columns": ["alive", "undead"], + "population_filter": PopulationFilter("is_alive == True and undead == False"), + "requires_attributes": ["is_alive", "undead"], "when": lifecycle_states.COLLECT_METRICS, }, { "name": "undead_person_time", - "pop_filter": "undead == True", - "requires_columns": ["undead"], + "population_filter": PopulationFilter("undead == True"), + "requires_attributes": ["undead"], "when": lifecycle_states.TIME_STEP_PREPARE, }, ], ids=["valid_on_collect_metrics", "valid_on_time_step__prepare"], ) -def test_register_observation(kwargs: Any) -> None: +def test_register_observation(kwargs: dict[str, Any]) -> None: ctx = ResultsContext() assert len(ctx.grouped_observations) == 0 kwargs["results_formatter"] = lambda: None kwargs["stratifications"] = tuple() kwargs["aggregator_sources"] = [] kwargs["aggregator"] = len - kwargs["requires_values"] = [] + kwargs["requires_attributes"] = [] ctx.register_observation( observation_type=AddingObservation, **kwargs, @@ -242,10 +234,9 @@ def test_register_observation_duplicate_name_raises() -> None: ctx.register_observation( observation_type=AddingObservation, name="some-observation-name", - pop_filter="some-pop-filter", + population_filter=PopulationFilter("some-pop-filter"), when="some-when", - requires_columns=[], - requires_values=[], + requires_attributes=[], results_formatter=lambda df: df, stratifications=(), aggregator_sources=[], @@ -258,10 +249,9 @@ def test_register_observation_duplicate_name_raises() -> None: ctx.register_observation( observation_type=ConcatenatingObservation, name="some-observation-name", - pop_filter="some-other-pop-filter", + population_filter=PopulationFilter("some-other-pop-filter"), when="some-other-when", - requires_columns=[], - requires_values=[], + requires_attributes=[], stratifications=None, ) @@ -290,10 +280,12 @@ def test_adding_observation_gather_results( aggregator: Callable[..., int | float], stratifications: list[str], event: Event, + mocker: MockerFixture, ) -> None: """Test cases where every stratification is in gather_results. Checks for existence and correctness of results""" ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -304,8 +296,7 @@ def test_adding_observation_gather_results( if "house" in stratifications: ctx.add_stratification( name="house", - requires_columns=["house"], - requires_values=[], + requires_attributes=["house"], categories=HOUSE_CATEGORIES, excluded_categories=None, mapper=None, @@ -314,20 +305,17 @@ def test_adding_observation_gather_results( if "familiar" in stratifications: ctx.add_stratification( name="familiar", - requires_columns=["familiar"], - requires_values=[], + requires_attributes=["familiar"], categories=FAMILIARS, excluded_categories=None, mapper=None, is_vectorized=True, ) - pop_filter = "tracked==True" observation = ctx.register_observation( observation_type=AddingObservation, name="foo", - pop_filter=pop_filter, - requires_columns=aggregator_sources, - requires_values=[], + population_filter=PopulationFilter(), + requires_attributes=aggregator_sources, aggregator_sources=aggregator_sources, aggregator=aggregator, stratifications=tuple(stratifications), @@ -335,8 +323,7 @@ def test_adding_observation_gather_results( results_formatter=lambda: None, ) - filtered_pop = population.query(pop_filter) - groups = filtered_pop.groupby(stratifications) + groups = population.groupby(stratifications) if aggregator == sum: power_level_sums = groups[aggregator_sources].sum().squeeze() assert len(power_level_sums.unique()) == 1 @@ -362,9 +349,12 @@ def test_adding_observation_gather_results( assert i == 1 -def test_concatenating_observation_gather_results(event: Event) -> None: +def test_concatenating_observation_gather_results( + event: Event, mocker: MockerFixture +) -> None: ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -376,20 +366,19 @@ def test_concatenating_observation_gather_results(event: Event) -> None: ) lifecycle_state = lifecycle_states.COLLECT_METRICS - pop_filter = "house=='hufflepuff'" + population_filter = PopulationFilter(query="house=='hufflepuff'") included_cols = ["familiar", "house"] observation = ctx.register_observation( observation_type=ConcatenatingObservation, name="foo", - pop_filter=pop_filter, + population_filter=population_filter, when=lifecycle_state, - requires_columns=included_cols, - requires_values=[], + requires_attributes=included_cols, results_formatter=lambda _, __: pd.DataFrame(), stratifications=None, ) - filtered_pop = population.query(pop_filter) + filtered_pop = population.query(population_filter.query) i = 0 for result, _measure, _updater in ctx.gather_results( @@ -436,10 +425,12 @@ def test_gather_results_partial_stratifications_in_results( aggregator: Callable[..., int | float], stratifications: list[str], event: Event, + mocker: MockerFixture, ) -> None: """Test cases where not all stratifications are observed for gather_results. This looks for existence of unobserved stratifications and ensures their values are 0""" ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -450,8 +441,7 @@ def test_gather_results_partial_stratifications_in_results( if "house" in stratifications: ctx.add_stratification( name="house", - requires_columns=["house"], - requires_values=[], + requires_attributes=["house"], categories=HOUSE_CATEGORIES, excluded_categories=None, mapper=None, @@ -463,8 +453,7 @@ def test_gather_results_partial_stratifications_in_results( if "familiar" in stratifications: ctx.add_stratification( name="familiar", - requires_columns=["familiar"], - requires_values=[], + requires_attributes=["familiar"], categories=FAMILIARS, excluded_categories=None, mapper=None, @@ -477,9 +466,8 @@ def test_gather_results_partial_stratifications_in_results( observation = ctx.register_observation( observation_type=AddingObservation, name=name, - pop_filter="tracked==True", - requires_columns=aggregator_sources, - requires_values=[], + population_filter=PopulationFilter(), + requires_attributes=aggregator_sources, aggregator_sources=aggregator_sources, aggregator=aggregator, stratifications=tuple(stratifications), @@ -496,11 +484,12 @@ def test_gather_results_partial_stratifications_in_results( assert (unladen_results[VALUE_COLUMN] == 0).all() -def test_gather_results_with_empty_pop_filter(event: Event) -> None: +def test_gather_results_with_empty_pop_filter(event: Event, mocker: MockerFixture) -> None: """Test case where pop_filter filters to an empty population. gather_results should return None. """ ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -509,9 +498,8 @@ def test_gather_results_with_empty_pop_filter(event: Event) -> None: observation = ctx.register_observation( observation_type=AddingObservation, name="wizard_count", - pop_filter="house == 'durmstrang'", - requires_columns=["house"], - requires_values=[], + population_filter=PopulationFilter("house == 'durmstrang'"), + requires_attributes=["house"], aggregator_sources=[], aggregator=len, stratifications=tuple(), @@ -525,9 +513,10 @@ def test_gather_results_with_empty_pop_filter(event: Event) -> None: assert not result -def test_gather_results_with_no_stratifications(event: Event) -> None: +def test_gather_results_with_no_stratifications(event: Event, mocker: MockerFixture) -> None: """Test case where we have no stratifications. gather_results should return one value.""" ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -536,9 +525,8 @@ def test_gather_results_with_no_stratifications(event: Event) -> None: observation = ctx.register_observation( observation_type=AddingObservation, name="wizard_count", - pop_filter="", - requires_columns=[], - requires_values=[], + population_filter=PopulationFilter(), + requires_attributes=[], aggregator_sources=None, aggregator=len, stratifications=tuple(), @@ -560,10 +548,11 @@ def test_gather_results_with_no_stratifications(event: Event) -> None: ) -def test_bad_aggregator_stratification(event: Event) -> None: +def test_bad_aggregator_stratification(event: Event, mocker: MockerFixture) -> None: """Test if an exception gets raised when a stratification that doesn't exist is attempted to be used, as expected.""" ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value="", create=True) # Generate population DataFrame population = BASE_POPULATION.copy() @@ -572,8 +561,7 @@ def test_bad_aggregator_stratification(event: Event) -> None: # Set up stratifications ctx.add_stratification( name="house", - requires_columns=["house"], - requires_values=[], + requires_attributes=["house"], categories=HOUSE_CATEGORIES, excluded_categories=None, mapper=None, @@ -581,8 +569,7 @@ def test_bad_aggregator_stratification(event: Event) -> None: ) ctx.add_stratification( name="familiar", - requires_columns=["familiar"], - requires_values=[], + requires_attributes=["familiar"], categories=FAMILIARS, excluded_categories=None, mapper=None, @@ -591,9 +578,8 @@ def test_bad_aggregator_stratification(event: Event) -> None: observation = ctx.register_observation( observation_type=AddingObservation, name="this_shouldnt_work", - pop_filter="", - requires_columns=[], - requires_values=[], + population_filter=PopulationFilter(), + requires_attributes=[], aggregator_sources=[], aggregator=sum, stratifications=("house", "height"), # `height` is not a stratification @@ -623,9 +609,8 @@ def test_get_observations( ctx = ResultsContext() register_observation_kwargs = { "observation_type": AddingObservation, - "pop_filter": "", - "requires_columns": [], - "requires_values": [], + "population_filter": PopulationFilter(), + "requires_attributes": [], "results_formatter": lambda: None, "stratifications": (), "aggregator_sources": None, @@ -652,7 +637,6 @@ def test_get_observations( assert [obs.name for obs in ctx.get_observations(event)] == expected_observations -@pytest.mark.parametrize("resource_type", ["columns", "values"]) @pytest.mark.parametrize( "observation_names, stratification_names, expected_resources", [ @@ -670,18 +654,21 @@ def test_get_observations( "neither", ], ) -def test_get_required_resources( - resource_type: str, +@pytest.mark.parametrize("include_untracked", [True, False]) +def test_get_required_attributes( observation_names: list[str], stratification_names: list[str], expected_resources: set[str], + include_untracked: bool, + mocker: MockerFixture, ) -> None: ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value='foo == "bar"', create=True) all_observations = {} register_observation_kwargs = { "observation_type": AddingObservation, - "pop_filter": "", + "population_filter": PopulationFilter(include_untracked=include_untracked), "when": lifecycle_states.COLLECT_METRICS, "results_formatter": lambda: None, "stratifications": (), @@ -689,32 +676,19 @@ def test_get_required_resources( "aggregator": len, } - def get_required_resources_kwargs( - resource_type: str, resources: list[str] - ) -> dict[str, list[str] | list[Pipeline]]: - if resource_type == "columns": - return {"requires_columns": resources, "requires_values": []} - elif resource_type == "values": - return { - "requires_values": [Pipeline(r) for r in resources], - "requires_columns": [], - } - else: - raise ValueError(f"Unknown resource_type: {resource_type}") - all_observations["obs1"] = ctx.register_observation( name="obs1", - **get_required_resources_kwargs(resource_type, ["x", "y"]), # type: ignore[arg-type] + requires_attributes=["x", "y"], **register_observation_kwargs, # type: ignore[arg-type] ) all_observations["obs2"] = ctx.register_observation( name="obs2", - **get_required_resources_kwargs(resource_type, ["y", "z"]), # type: ignore[arg-type] + requires_attributes=["y", "z"], **register_observation_kwargs, # type: ignore[arg-type] ) all_observations["obs3"] = ctx.register_observation( name="obs3", - **get_required_resources_kwargs(resource_type, ["w"]), # type: ignore[arg-type] + requires_attributes=["w"], **register_observation_kwargs, # type: ignore[arg-type] ) @@ -727,28 +701,54 @@ def get_required_resources_kwargs( } all_stratifications["strat1"] = Stratification( name="strat1", - **get_required_resources_kwargs(resource_type, ["x", "y"]), # type: ignore[arg-type] + requires_attributes=["x", "y"], **stratification_kwargs, # type: ignore[arg-type] ) all_stratifications["strat2"] = Stratification( name="strat2", - **get_required_resources_kwargs(resource_type, ["x", "v"]), # type: ignore[arg-type] + requires_attributes=["x", "v"], **stratification_kwargs, # type: ignore[arg-type] ) observations = [all_observations[name] for name in observation_names] stratifications = [all_stratifications[name] for name in stratification_names] - if resource_type == "columns": - actual_columns = ctx.get_required_columns(observations, stratifications) - assert set(actual_columns) == {"tracked"} | expected_resources - elif resource_type == "values": - actual_columns = [ - p.name for p in ctx.get_required_values(observations, stratifications) - ] - assert set(actual_columns) == expected_resources - else: - raise ValueError(f"Unknown resource_type: {resource_type}") + actual_columns = ctx.get_required_attributes(observations, stratifications) + if observations and not include_untracked: + expected_resources = expected_resources.union({"foo"}) + assert set(actual_columns) == expected_resources + + +def test_get_required_attributes_columns_from_query(mocker: MockerFixture) -> None: + """Tests that columns used in queries are regardless of requires_attributes.""" + + ctx = ResultsContext() + mocker.patch.object( + ctx, "get_tracked_query", return_value="pet == 'cat' and lives < 9", create=True + ) + + observation = ctx.register_observation( + observation_type=AddingObservation, + name="obs_with_query", + population_filter=PopulationFilter( + "color in ['black', 'white'] or name == 'Garfield'" + ), + when=lifecycle_states.COLLECT_METRICS, + requires_attributes=["foo", "bar"], # does NOT include any of the query columns + results_formatter=lambda: None, + stratifications=(), + aggregator_sources=None, + aggregator=len, + ) + + assert set(ctx.get_required_attributes([observation], [])) == { + "pet", + "lives", + "color", + "name", + "foo", + "bar", + } @pytest.mark.parametrize( @@ -756,16 +756,26 @@ def get_required_resources_kwargs( ['familiar=="cat"', 'familiar=="spaghetti_yeti"', ""], ids=["pop_filter", "pop_filter_empties_dataframe", "no_pop_filter"], ) -def test__filter_population(pop_filter: str) -> None: +@pytest.mark.parametrize("include_untracked", [True, False]) +@pytest.mark.parametrize("tracked_query", ["house=='hufflepuff'", "house=='whitehouse'", ""]) +def test__filter_population( + pop_filter: str, include_untracked: bool, tracked_query: str, mocker: MockerFixture +) -> None: population = BASE_POPULATION.copy() + ctx = ResultsContext() + mocker.patch.object(ctx, "get_tracked_query", return_value=tracked_query, create=True) - filtered_pop = ResultsContext()._filter_population( - population=population, pop_filter=pop_filter + filtered_pop = ctx._filter_population( + population=population, + population_filter=PopulationFilter(pop_filter, include_untracked=include_untracked), ) expected = population.copy() if pop_filter: familiar = pop_filter.split("==")[1].strip('"') expected = expected[expected["familiar"] == familiar] + if not include_untracked and tracked_query: + house = tracked_query.split("==")[1].strip("'") + expected = expected[expected["house"] == house] assert filtered_pop.equals(expected) @@ -781,7 +791,6 @@ def test__filter_population(pop_filter: str) -> None: def test__drop_na_stratifications(stratifications: tuple[str, ...]) -> None: population = BASE_POPULATION.copy() population["new_col1"] = "new_value1" - population.loc[population["tracked"] == True, "new_col1"] = np.nan population["new_col2"] = "new_value2" population.loc[population["new_col1"].notna(), "new_col2"] = np.nan # Add on the post-stratified columns diff --git a/tests/framework/results/test_interface.py b/tests/framework/results/test_interface.py index d4ee351cb..4037d16eb 100644 --- a/tests/framework/results/test_interface.py +++ b/tests/framework/results/test_interface.py @@ -12,10 +12,10 @@ from tests.framework.results.helpers import BASE_POPULATION, FAMILIARS from tests.framework.results.helpers import HOUSE_CATEGORIES as HOUSES -from tests.framework.results.helpers import mock_get_value from vivarium.framework.event import Event from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.results import ResultsInterface, ResultsManager +from vivarium.framework.results.interface import PopulationFilter from vivarium.framework.results.observation import ( ConcatenatingObservation, StratifiedObservation, @@ -37,9 +37,8 @@ def _silly_mapper(some_series: pd.Series[str]) -> str: return "this was pointless" builder = mocker.Mock() - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) + # Set up mock builder with mocked get_attribute call for Pipelines + mocker.patch.object(builder, "value.get_attribute") mgr = ResultsManager() mgr.setup(builder) interface = ResultsInterface(mgr) @@ -53,8 +52,7 @@ def _silly_mapper(some_series: pd.Series[str]) -> str: excluded_categories=["some-unwanted-category"], mapper=_silly_mapper, is_vectorized=False, - requires_columns=["some-column", "some-other-column"], - requires_values=["some-value", "some-other-value"], + requires_attributes=["some-column", "some-other-column"], ) # Check stratification registration @@ -62,42 +60,33 @@ def _silly_mapper(some_series: pd.Series[str]) -> str: assert len(stratifications) == 1 stratification = stratifications["some-name"] - pipeline_names = [pipeline.name for pipeline in stratification.requires_values] assert stratification.name == "some-name" - assert stratification.requires_columns == ["some-column", "some-other-column"] - assert pipeline_names == ["some-value", "some-other-value"] + assert stratification.requires_attributes == ["some-column", "some-other-column"] assert stratification.categories == ["some-category", "some-other-category"] assert stratification.excluded_categories == ["some-unwanted-category"] assert stratification.mapper == _silly_mapper assert stratification.is_vectorized is False -@pytest.mark.parametrize( - "target, target_type", [("some-column", "column"), ("some-value", "value")] -) -def test_register_binned_stratification( - target: str, target_type: str, mocker: MockerFixture -) -> None: +def test_register_binned_stratification(mocker: MockerFixture) -> None: mgr = ResultsManager() mgr.logger = logger builder = mocker.Mock() - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) + mocker.patch.object(builder, "value.get_attribute") mgr.setup(builder) - # mgr._results_context.setup(builder) # Check pre-registration stratifications and manager required columns/values assert len(mgr._results_context.stratifications) == 0 + target = "some-attribute" mgr.register_binned_stratification( target=target, binned_column="new-binned-column", bin_edges=[1, 2, 3], labels=["1_to_2", "2_to_3"], excluded_categories=["2_to_3"], - target_type=target_type, some_kwarg="some-kwarg", some_other_kwarg="some-other-kwarg", ) @@ -105,13 +94,10 @@ def test_register_binned_stratification( # Check stratification registration stratifications = mgr._results_context.stratifications assert len(stratifications) == 1 - expected_column_names = [target] if target_type == "column" else [] - expected_value_names = [target] if target_type == "value" else [] stratification = stratifications["new-binned-column"] assert stratification.name == "new-binned-column" - assert stratification.requires_columns == expected_column_names - assert [value.name for value in stratification.requires_values] == expected_value_names + assert stratification.requires_attributes == [target] assert stratification.categories == ["1_to_2"] assert stratification.excluded_categories == ["2_to_3"] # Cannot access the mapper because it's in local scope, so check __repr__ @@ -156,9 +142,6 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None: interface = ResultsInterface(mgr) builder = mocker.Mock() builder.configuration.stratification.default = ["default-stratification", "exclude-this"] - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) mgr.setup(builder) for strat in [ "default-stratification", @@ -171,7 +154,7 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None: categories=["a", "b", "c"], excluded_categories=[], is_vectorized=True, - requires_columns=[strat], + requires_attributes=[strat], ) assert len(interface._manager._results_context.grouped_observations) == 0 @@ -180,8 +163,7 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None: name="some-name", pop_filter="some-filter", when="some-when", - requires_columns=["some-column", "some-other-column"], - requires_values=["some-value", "some-other-value"], + requires_attributes=["some-column", "some-other-column"], results_updater=lambda _, __: pd.DataFrame(), additional_stratifications=["some-stratification", "some-other-stratification"], excluded_stratifications=["exclude-this"], @@ -195,10 +177,11 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None: grouped_observations = interface._manager._results_context.grouped_observations assert len(grouped_observations) == 1 - filter = list(grouped_observations["some-when"].keys())[0] - stratifications = list(grouped_observations["some-when"][filter])[0] - observations = grouped_observations["some-when"][filter][stratifications] - assert filter == "some-filter" + filter_info = list(grouped_observations["some-when"].keys())[0] + stratifications = list(grouped_observations["some-when"][filter_info])[0] + observations = grouped_observations["some-when"][filter_info][stratifications] + assert filter_info.query == "some-filter" + assert not filter_info.include_untracked assert isinstance(stratifications, tuple) # for mypy in following set(stratifications) assert set(stratifications) == { "default-stratification", @@ -209,7 +192,7 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None: for observation in [observations_dict["some-name"], observations[0]]: assert observation.name == "some-name" - assert observation.pop_filter == "some-filter" + assert observation.population_filter.query == "some-filter" assert observation.when == "some-when" assert observation.results_gatherer is not None assert observation.results_updater is not None @@ -225,31 +208,28 @@ def test_register_unstratified_observation(mocker: MockerFixture) -> None: mgr = ResultsManager() interface = ResultsInterface(mgr) builder = mocker.Mock() - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) mgr.setup(builder) assert len(interface._manager._results_context.grouped_observations) == 0 interface.register_unstratified_observation( name="some-name", pop_filter="some-filter", when="some-when", - requires_columns=["some-column", "some-other-column"], - requires_values=["some-value", "some-other-value"], + requires_attributes=["some-column", "some-other-column"], results_gatherer=lambda _: pd.DataFrame(), results_updater=lambda _, __: pd.DataFrame(), ) grouped_observations = interface._manager._results_context.grouped_observations assert len(grouped_observations) == 1 - filter = list(grouped_observations["some-when"].keys())[0] - stratifications = list(grouped_observations["some-when"][filter])[0] - observations = grouped_observations["some-when"][filter][stratifications] - assert filter == "some-filter" + filter_info = list(grouped_observations["some-when"].keys())[0] + stratifications = list(grouped_observations["some-when"][filter_info])[0] + observations = grouped_observations["some-when"][filter_info][stratifications] + assert filter_info.query == "some-filter" + assert not filter_info.include_untracked assert stratifications is None assert len(observations) == 1 obs = observations[0] assert obs.name == "some-name" - assert obs.pop_filter == "some-filter" + assert obs.population_filter.query == "some-filter" assert obs.when == "some-when" assert obs.results_gatherer is not None assert obs.results_updater is not None @@ -258,19 +238,18 @@ def test_register_unstratified_observation(mocker: MockerFixture) -> None: @pytest.mark.parametrize( ( - "name, pop_filter, aggregator_columns, aggregator, requires_columns, requires_values," + "name, pop_filter, aggregator_columns, aggregator, requires_attributes," " additional_stratifications, excluded_stratifications, when" ), [ ( "living_person_time", - 'alive == "alive" and undead == False', + "is_alive == True and undead == False", [], _silly_aggregator, [], [], [], - [], lifecycle_states.TIME_STEP_CLEANUP, ), ( @@ -281,7 +260,6 @@ def test_register_unstratified_observation(mocker: MockerFixture) -> None: [], [], [], - [], lifecycle_states.TIME_STEP_PREPARE, ), ( @@ -289,7 +267,6 @@ def test_register_unstratified_observation(mocker: MockerFixture) -> None: "undead == True", [], _silly_aggregator, - [], ["fake_pipeline", "another_fake_pipeline"], [], [], @@ -304,8 +281,7 @@ def test_register_adding_observation( pop_filter: str, aggregator_columns: list[str], aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]], - requires_columns: list[str], - requires_values: list[str], + requires_attributes: list[str], additional_stratifications: list[str], excluded_stratifications: list[str], when: str, @@ -314,9 +290,6 @@ def test_register_adding_observation( interface = ResultsInterface(mgr) builder = mocker.Mock() builder.configuration.stratification.default = [] - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) mgr.setup(builder) assert len(interface._manager._results_context.grouped_observations) == 0 interface.register_adding_observation( @@ -327,8 +300,7 @@ def test_register_adding_observation( excluded_stratifications=excluded_stratifications, aggregator_sources=aggregator_columns, aggregator=aggregator, - requires_columns=requires_columns, - requires_values=requires_values, + requires_attributes=requires_attributes, ) assert len(interface._manager._results_context.grouped_observations) == 1 @@ -347,11 +319,12 @@ def test_register_multiple_adding_observations(mocker: MockerFixture) -> None: aggregator=_silly_aggregator, ) # Test observation gets added - assert len(interface._manager._results_context.grouped_observations) == 1 + grouped_observations = interface._manager._results_context.grouped_observations + assert len(grouped_observations) == 1 assert ( - interface._manager._results_context.grouped_observations[ - lifecycle_states.TIME_STEP_CLEANUP - ]["tracked==True"][()][0].name + grouped_observations[lifecycle_states.TIME_STEP_CLEANUP][PopulationFilter()][()][ + 0 + ].name == "living_person_time" ) @@ -362,23 +335,23 @@ def test_register_multiple_adding_observations(mocker: MockerFixture) -> None: aggregator=_silly_aggregator, ) # Test new observation gets added - assert len(interface._manager._results_context.grouped_observations) == 2 + grouped_observations = interface._manager._results_context.grouped_observations + assert len(grouped_observations) == 2 assert ( - interface._manager._results_context.grouped_observations[ - lifecycle_states.TIME_STEP_CLEANUP - ]["tracked==True"][()][0].name + grouped_observations[lifecycle_states.TIME_STEP_CLEANUP][PopulationFilter()][()][ + 0 + ].name == "living_person_time" ) assert ( - interface._manager._results_context.grouped_observations[ - lifecycle_states.TIME_STEP_PREPARE - ]["undead==True"][()][0].name + grouped_observations[lifecycle_states.TIME_STEP_PREPARE][ + PopulationFilter("undead==True") + ][()][0].name == "undead_person_time" ) -@pytest.mark.parametrize("resource_type", ["value", "column"]) -def test_unhashable_pipeline(mocker: MockerFixture, resource_type: str) -> None: +def test_unhashable_pipeline(mocker: MockerFixture) -> None: mgr = ResultsManager() interface = ResultsInterface(mgr) builder = mocker.Mock() @@ -386,13 +359,12 @@ def test_unhashable_pipeline(mocker: MockerFixture, resource_type: str) -> None: mgr.setup(builder) assert len(interface._manager._results_context.grouped_observations) == 0 - with pytest.raises(TypeError, match=f"All required {resource_type}s must be strings"): + with pytest.raises(TypeError, match=f"All required attributes must be strings"): interface.register_adding_observation( name="living_person_time", - pop_filter='alive == "alive" and undead == False', + pop_filter="is_alive == True and undead == False", when=lifecycle_states.TIME_STEP_CLEANUP, - requires_columns=[["bad", "unhashable", "thing"]] if resource_type == "column" else [], # type: ignore[list-item] - requires_values=[["bad", "unhashable", "thing"]] if resource_type == "value" else [], # type: ignore[list-item] + requires_attributes=[["bad", "unhashable", "thing"]], # type: ignore[list-item] additional_stratifications=[], excluded_stratifications=[], aggregator_sources=[], @@ -419,18 +391,17 @@ def test_register_adding_observation_when_options(when: str, mocker: MockerFixtu ) mgr.setup(builder) mgr.population_view = mocker.Mock() - mgr.population_view.subview.return_value = mgr.population_view # type: ignore[attr-defined] - mgr.population_view.get.return_value = BASE_POPULATION.copy() # type: ignore[attr-defined] + mgr.population_view.get_attributes.return_value = BASE_POPULATION.copy() # type: ignore[attr-defined] # register stratifications results_interface.register_stratification( - name="house", categories=HOUSES, is_vectorized=True, requires_columns=["house"] + name="house", categories=HOUSES, is_vectorized=True, requires_attributes=["house"] ) results_interface.register_stratification( name="familiar", categories=FAMILIARS, is_vectorized=True, - requires_columns=["familiar"], + requires_attributes=["familiar"], ) aggregator_map = { @@ -450,7 +421,7 @@ def test_register_adding_observation_when_options(when: str, mocker: MockerFixtu when=phase, additional_stratifications=["house", "familiar"], aggregator=aggregator, - requires_columns=["house", "familiar"], + requires_attributes=["house", "familiar"], ) for mock_aggregator in aggregator_map.values(): @@ -466,6 +437,7 @@ def test_register_adding_observation_when_options(when: str, mocker: MockerFixtu ) # Run on_post_setup to initialize the raw_results attribute with 0s and set stratifications mgr.on_post_setup(event) + mgr._results_context.get_tracked_query = mocker.Mock(return_value="") mgr.gather_results(event) for phase, aggregator in aggregator_map.items(): @@ -480,38 +452,35 @@ def test_register_concatenating_observation(mocker: MockerFixture) -> None: interface = ResultsInterface(mgr) builder = mocker.Mock() builder.configuration.stratification.default = [] - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) + # Set up mock builder with mocked get_attribute call for Pipelines + mocker.patch.object(builder, "value.get_attribute") mgr.setup(builder) assert len(interface._manager._results_context.grouped_observations) == 0 interface.register_concatenating_observation( name="some-name", pop_filter="some-filter", when="some-when", - requires_columns=["some-column", "some-other-column"], - requires_values=["some-value", "some-other-value"], + requires_attributes=["some-column", "some-other-column"], results_formatter=lambda _, __: pd.DataFrame(), ) grouped_observations = interface._manager._results_context.grouped_observations assert len(grouped_observations) == 1 - filter = list(grouped_observations["some-when"].keys())[0] - stratifications = list(grouped_observations["some-when"][filter])[0] - observations = grouped_observations["some-when"][filter][stratifications] - assert filter == "some-filter" + filter_info = list(grouped_observations["some-when"].keys())[0] + stratifications = list(grouped_observations["some-when"][filter_info])[0] + observations = grouped_observations["some-when"][filter_info][stratifications] + assert filter_info.query == "some-filter" + assert not filter_info.include_untracked assert stratifications is None assert len(observations) == 1 obs = observations[0] assert obs.name == "some-name" - assert obs.pop_filter == "some-filter" + assert obs.population_filter.query == "some-filter" assert obs.when == "some-when" assert isinstance(obs, ConcatenatingObservation) - assert obs.included_columns == [ + assert obs.requires_attributes == [ "event_time", "some-column", "some-other-column", - "some-value", - "some-other-value", ] assert obs.results_gatherer is not None assert obs.results_updater is not None diff --git a/tests/framework/results/test_manager.py b/tests/framework/results/test_manager.py index 9f7c3621c..729a18337 100644 --- a/tests/framework/results/test_manager.py +++ b/tests/framework/results/test_manager.py @@ -1,5 +1,4 @@ import re -from types import MethodType import numpy as np import pandas as pd @@ -33,19 +32,20 @@ NoStratificationsQuidditchWinsObserver, QuidditchWinsObserver, ValedictorianObserver, - mock_get_value, sorting_hat_serial, sorting_hat_vectorized, verify_stratification_added, ) +from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.interface import PopulationFilter from vivarium.framework.results.manager import ResultsManager from vivarium.framework.results.observation import AddingObservation, Observation +from vivarium.framework.results.observer import Observer from vivarium.framework.results.stratification import Stratification, get_mapped_col_name -from vivarium.framework.values import Pipeline from vivarium.interface.interactive import InteractiveContext from vivarium.types import ScalarMapper, VectorMapper @@ -107,7 +107,7 @@ def test__get_stratifications( ], ids=["vectorized_mapper", "non-vectorized_mapper", "excluded_categories"], ) -def test_register_stratification_no_pipelines( +def test_register_stratification( mocker: pytest_mock.MockFixture, excluded_categories: list[str], mapper: VectorMapper | ScalarMapper, @@ -125,14 +125,12 @@ def test_register_stratification_no_pipelines( excluded_categories=excluded_categories, mapper=mapper, is_vectorized=is_vectorized, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, ) assert verify_stratification_added( stratifications=mgr._results_context.stratifications, name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=excluded_categories, mapper=mapper, @@ -140,91 +138,6 @@ def test_register_stratification_no_pipelines( ) -@pytest.mark.parametrize( - "mapper, is_vectorized", - [ - (sorting_hat_vectorized, True), - (sorting_hat_serial, False), - ], - ids=["vectorized_mapper", "non-vectorized_mapper"], -) -def test_register_stratification_with_pipelines( - mocker: pytest_mock.MockFixture, mapper: VectorMapper | ScalarMapper, is_vectorized: bool -) -> None: - mgr = ResultsManager() - builder = mocker.Mock() - builder.configuration.stratification = LayeredConfigTree( - {"default": [], "excluded_categories": {}} - ) - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) - mgr.setup(builder) - mgr.register_stratification( - name=NAME, - categories=HOUSE_CATEGORIES, - excluded_categories=None, - mapper=mapper, - is_vectorized=is_vectorized, - requires_columns=[], - requires_values=NAME_COLUMNS, - ) - - assert verify_stratification_added( - stratifications=mgr._results_context.stratifications, - name=NAME, - requires_columns=[], - requires_values=[Pipeline(name) for name in NAME_COLUMNS], - categories=HOUSE_CATEGORIES, - excluded_categories=[], - mapper=mapper, - is_vectorized=is_vectorized, - ) - - -@pytest.mark.parametrize( - "mapper, is_vectorized", - [ - (sorting_hat_vectorized, True), - (sorting_hat_serial, False), - ], - ids=["vectorized_mapper", "non-vectorized_mapper"], -) -def test_register_stratification_with_column_and_pipelines( - mocker: pytest_mock.MockFixture, mapper: VectorMapper | ScalarMapper, is_vectorized: bool -) -> None: - mgr = ResultsManager() - builder = mocker.Mock() - builder.configuration.stratification = LayeredConfigTree( - {"default": [], "excluded_categories": {}} - ) - # Set up mock builder with mocked get_value call for Pipelines - mocker.patch.object(builder, "value.get_value") - builder.value.get_value = MethodType(mock_get_value, builder) - mgr.setup(builder) - mocked_column_name = "silly_column" - mgr.register_stratification( - name=NAME, - categories=HOUSE_CATEGORIES, - excluded_categories=None, - mapper=mapper, - is_vectorized=is_vectorized, - requires_columns=[mocked_column_name], - requires_values=NAME_COLUMNS, - ) - - assert verify_stratification_added( - stratifications=mgr._results_context.stratifications, - name=NAME, - requires_columns=[mocked_column_name], - requires_values=[Pipeline(name) for name in NAME_COLUMNS], - categories=HOUSE_CATEGORIES, - excluded_categories=[], - mapper=mapper, - is_vectorized=is_vectorized, - ) - - ############################################## # Tests for `register_binned_stratification` # ############################################## @@ -265,7 +178,6 @@ def test_binned_stratification_mapper() -> None: bin_edges=BIN_SILLY_BIN_EDGES, labels=BIN_LABELS, excluded_categories=None, - target_type="column", ) strat = mgr._results_context.stratifications[BIN_BINNED_COLUMN] data = pd.DataFrame([-np.inf] + BIN_SILLY_BIN_EDGES + [np.inf]) @@ -305,11 +217,10 @@ def test_add_observation_nop_stratifications( mgr.register_observation( observation_type=AddingObservation, name="name", - pop_filter='alive == "alive"', + population_filter=PopulationFilter("is_alive == True"), aggregator_sources=[], aggregator=lambda: None, - requires_columns=[], - requires_values=[], + requires_attributes=[], additional_stratifications=additional, excluded_stratifications=excluded, when=lifecycle_states.COLLECT_METRICS, @@ -467,7 +378,6 @@ def test_gather_results_with_no_observations(mocker: pytest_mock.MockerFixture) mgr.gather_results(event) mgr._results_context.get_observations.assert_called_once_with(event) # type: ignore[attr-defined] - mgr.population_view.subview.assert_not_called() # type: ignore[attr-defined] mgr._results_context.gather_results.assert_not_called() # type: ignore[attr-defined] @@ -490,7 +400,6 @@ def test_gather_results_with_empty_index(mocker: pytest_mock.MockerFixture) -> N mgr.gather_results(event) mgr._results_context.get_observations.assert_called_once_with(event) # type: ignore[attr-defined] - mgr.population_view.subview.assert_not_called() # type: ignore[attr-defined] mgr._results_context.gather_results.assert_not_called() # type: ignore[attr-defined] @@ -514,6 +423,46 @@ def test_gather_results_with_different_stratifications_and_to_observes() -> None ).all() +def test_gather_results_different_include_untracked_observations() -> None: + class SimulantCountObserver(Observer): + def register_observations(self, builder: Builder) -> None: + builder.results.register_unstratified_observation( + name="simulant_counter", + requires_attributes=["student_house"], + results_gatherer=lambda df: df, # type: ignore [arg-type, return-value] + results_updater=lambda _existing_df, new_df: new_df.groupby("student_house") + .size() + .reset_index(name=VALUE_COLUMN), + ) + builder.results.register_unstratified_observation( + name="simulant_counter_include_untracked", + include_untracked=True, + requires_attributes=["student_house"], + results_gatherer=lambda df: df, # type: ignore [arg-type, return-value] + results_updater=lambda _existing_df, new_df: new_df.groupby("student_house") + .size() + .reset_index(name=VALUE_COLUMN), + ) + + components = [ + Hogwarts(), + SimulantCountObserver(), + ] + sim = InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) + pop_mgr = sim._population + pop_mgr.tracked_queries = ['student_house != "slytherin"'] + sim.step() + mgr = sim._results + results = mgr.get_results() + exclude_untracked = results["simulant_counter"] + include_untracked = results["simulant_counter_include_untracked"] + assert "slytherin" not in exclude_untracked["student_house"].values + assert "slytherin" in include_untracked["student_house"].values + assert exclude_untracked.equals( + include_untracked[include_untracked["student_house"] != "slytherin"] + ) + + @pytest.fixture(scope="module") def prepare_population_sim() -> InteractiveContext: return InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=[Hogwarts()]) @@ -590,10 +539,9 @@ def test_prepare_population( observations: list[Observation] = [ AddingObservation( name=f"test_observation_{i}", - pop_filter="", + population_filter=PopulationFilter(), when=lifecycle_states.COLLECT_METRICS, - requires_columns=columns, - requires_values=[prepare_population_sim.get_value(value) for value in values], + requires_attributes=columns + values, results_formatter=lambda *_: pd.DataFrame(), aggregator_sources=[], aggregator=lambda *_: pd.Series(), @@ -605,8 +553,7 @@ def test_prepare_population( name=f"strat_{i}", categories=["a", "b", "c"], excluded_categories=[], - requires_columns=columns, - requires_values=[prepare_population_sim.get_value(value) for value in values], + requires_attributes=columns + values, mapper=lambda x: pd.Series("a", index=x.index), is_vectorized=True, ) @@ -615,7 +562,7 @@ def test_prepare_population( event = Event( name=lifecycle_states.COLLECT_METRICS, - index=prepare_population_sim.get_population().index, + index=prepare_population_sim.get_population_index(), user_data={ "train": "Hogwarts Express", "headmaster": "Albus Dumbledore", @@ -627,7 +574,7 @@ def test_prepare_population( population = mgr._prepare_population(event, observations, stratifications) - assert set(population.columns) == set(["tracked"] + expected_columns) + assert set(population.columns) == set(expected_columns) if "current_time" in expected_columns: assert (population["current_time"] == prepare_population_sim._clock.time).all() if "event_time" in expected_columns: @@ -639,10 +586,79 @@ def test_prepare_population( for strat in stratifications: assert ( population[get_mapped_col_name(strat.name)] - == strat.stratify(population[strat._sources]) + == strat.stratify(population[strat.requires_attributes]) ).all() +def test_prepare_population_all_untracked( + prepare_population_sim: InteractiveContext, mocker: pytest_mock.MockerFixture +) -> None: + mgr = prepare_population_sim._results + observation1 = AddingObservation( + name="familiar", + population_filter=PopulationFilter(include_untracked=True), # allow untracked + when=lifecycle_states.COLLECT_METRICS, + requires_attributes=["familiar"], + results_formatter=lambda *_: pd.DataFrame(), + aggregator_sources=[], + aggregator=lambda *_: pd.Series(), + ) + observation2 = AddingObservation( + name="house_points", + population_filter=PopulationFilter(), + when=lifecycle_states.COLLECT_METRICS, + requires_attributes=["house_points"], + results_formatter=lambda *_: pd.DataFrame(), + aggregator_sources=[], + aggregator=lambda *_: pd.Series(), + ) + + index = prepare_population_sim.get_population_index() + event = Event( + name=lifecycle_states.COLLECT_METRICS, + index=index, + user_data={}, + time=prepare_population_sim._clock.time + prepare_population_sim._clock.step_size, # type: ignore [operator] + step_size=prepare_population_sim._clock.step_size, + ) + + # Add an untracking query + pop_mgr = prepare_population_sim._population + pop_mgr.tracked_queries = ['student_house != "slytherin"'] + # Change lifecycle phase to ensure tracked queries are applied appropriately + mocker.patch.object(pop_mgr, "get_current_state", lambda: "on_time_step") + + # Check that the exclusion is not applied since one of the observers allows untracked + private_columns = pop_mgr._private_columns + assert isinstance(private_columns, pd.DataFrame) + population = mgr._prepare_population( + event, observations=[observation1, observation2], stratifications=[] + ) + # Check that 'student_house' is included since it is needed to apply the tracking + # query in observation2 + assert set(population.columns) == {"student_house", "familiar", "house_points"} + assert population.equals(private_columns[population.columns]) + assert "slytherin" in population["student_house"].values + + # Now set both observers to exclude untracked + observation3 = AddingObservation( + # identical to observation1 exclude excluding untracked + name="familiar", + population_filter=PopulationFilter(), + when=lifecycle_states.COLLECT_METRICS, + requires_attributes=["familiar"], + results_formatter=lambda *_: pd.DataFrame(), + aggregator_sources=[], + aggregator=lambda *_: pd.Series(), + ) + population = mgr._prepare_population( + event, observations=[observation3, observation2], stratifications=[] + ) + slytherin_mask = private_columns["student_house"] == "slytherin" + expected = private_columns.loc[~slytherin_mask, list(population.columns)] + assert population.equals(expected) + + def test_stratified_observation_results() -> None: components = [ Hogwarts(), @@ -652,7 +668,11 @@ def test_stratified_observation_results() -> None: sim = InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) assert (sim.get_results()["cat_bomb"]["value"] == 0.0).all() sim.step() - num_familiars = sim.get_population().groupby(["familiar", "student_house"]).apply(len) + num_familiars = ( + sim.get_population(["familiar", "student_house"]) + .groupby(["familiar", "student_house"]) + .apply(len) + ) expected = num_familiars.loc["cat"] ** 1.0 expected.name = "value" expected = expected.sort_values().reset_index() @@ -663,7 +683,11 @@ def test_stratified_observation_results() -> None: sim.get_results()["cat_bomb"].sort_values("value").reset_index(drop=True) ) sim.step() - num_familiars = sim.get_population().groupby(["familiar", "student_house"]).apply(len) + num_familiars = ( + sim.get_population(["familiar", "student_house"]) + .groupby(["familiar", "student_house"]) + .apply(len) + ) expected = num_familiars.loc["cat"] ** 2.0 expected.name = "value" expected = expected.sort_values().reset_index() @@ -764,12 +788,18 @@ def _check_quidditch_wins(pop: pd.DataFrame, step_number: int) -> None: ] sim = InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) sim.step() - pop = sim.get_population() + pop = sim.get_population( + ["house_points", "quidditch_wins", "student_house", "power_level", "familiar"] + ) + assert isinstance(pop, pd.DataFrame) _check_house_points(pop, step_number=1) _check_quidditch_wins(pop, step_number=1) sim.step() - pop = sim.get_population() + pop = sim.get_population( + ["house_points", "quidditch_wins", "student_house", "power_level", "familiar"] + ) + assert isinstance(pop, pd.DataFrame) _check_house_points(pop, step_number=2) _check_quidditch_wins(pop, step_number=2) _assert_standard_index(sim.get_results()["house_points"]) @@ -816,13 +846,13 @@ def test_update__raw_results_no_stratifications() -> None: components = [Hogwarts(), NoStratificationsQuidditchWinsObserver()] sim = InteractiveContext(configuration=HARRY_POTTER_CONFIG, components=components) sim.step() - pop = sim.get_population() + wins = sim.get_population("quidditch_wins") raw_results = sim._results._raw_results["no_stratifications_quidditch_wins"] - assert raw_results.loc["all"][VALUE_COLUMN] == pop["quidditch_wins"].sum() + assert raw_results.loc["all"][VALUE_COLUMN] == wins.sum() sim.step() - pop = sim.get_population() + wins = sim.get_population("quidditch_wins") raw_results = sim._results._raw_results["no_stratifications_quidditch_wins"] - assert raw_results.loc["all"][VALUE_COLUMN] == pop["quidditch_wins"].sum() * 2 + assert raw_results.loc["all"][VALUE_COLUMN] == wins.sum() * 2 def test_update__raw_results_extra_columns() -> None: diff --git a/tests/framework/results/test_observation.py b/tests/framework/results/test_observation.py index ce349d493..d9304351f 100644 --- a/tests/framework/results/test_observation.py +++ b/tests/framework/results/test_observation.py @@ -10,6 +10,7 @@ from tests.framework.results.helpers import BASE_POPULATION, FAMILIARS, HOUSE_CATEGORIES from vivarium.framework.results import VALUE_COLUMN from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.interface import PopulationFilter from vivarium.framework.results.observation import ( AddingObservation, ConcatenatingObservation, @@ -23,10 +24,9 @@ def stratified_observation() -> StratifiedObservation: return StratifiedObservation( name="stratified_observation_name", - pop_filter="", + population_filter=PopulationFilter(), when="whenevs", - requires_columns=[], - requires_values=[], + requires_attributes=[], results_updater=lambda _, __: pd.DataFrame(), results_formatter=lambda _, __: pd.DataFrame(), aggregator_sources=None, @@ -38,10 +38,9 @@ def stratified_observation() -> StratifiedObservation: def concatenating_observation() -> ConcatenatingObservation: return ConcatenatingObservation( name="concatenating_observation_name", - pop_filter="", + population_filter=PopulationFilter(), when="whenevs", - requires_columns=["some-col", "some-other-col"], - requires_values=[], + requires_attributes=["some-col", "some-other-col"], results_formatter=lambda _, __: pd.DataFrame(), ) @@ -70,9 +69,9 @@ def test_is_stratified(observation_type: type[Observation], is_stratified: bool) ((), ["power_level"], len), ((), [], len), # Multiple-column dataframe return - (("familiar",), ["power_level", "tracked"], sum), - (("familiar", "house"), ["power_level", "tracked"], sum), - ((), ["power_level", "tracked"], sum), + (("familiar",), ["power_level"], sum), + (("familiar", "house"), ["power_level"], sum), + ((), ["power_level"], sum), ], ) def test_stratified_observation__aggregate( @@ -113,8 +112,8 @@ def test_stratified_observation__aggregate( assert len(aggregates.values) == 1 assert aggregates.values[0] == len(BASE_POPULATION) else: # sum aggregator - assert aggregates.shape[1] == 2 - expected = BASE_POPULATION[["power_level", "tracked"]].sum() / groups.ngroups + assert aggregates.shape[1] == 1 + expected = BASE_POPULATION[["power_level"]].sum() / groups.ngroups if stratifications: stratification_idx = ( set(itertools.product(*(FAMILIARS, HOUSE_CATEGORIES))) @@ -127,7 +126,7 @@ def test_stratified_observation__aggregate( assert final.equals(expected) else: assert len(aggregates.values) == 1 - for col in ["power_level", "tracked"]: + for col in ["power_level"]: assert aggregates.loc["all", col] == expected[col] @@ -237,10 +236,9 @@ def test_adding_observation_results_updater(new_observations: pd.DataFrame) -> N existing_results = pd.DataFrame({"value": [0.0, 0.0]}) obs = AddingObservation( name="adding_observation_name", - pop_filter="", + population_filter=PopulationFilter(), when="whenevs", - requires_columns=[], - requires_values=[], + requires_attributes=[], results_formatter=lambda _, __: pd.DataFrame(), aggregator_sources=None, aggregator=lambda _: 0.0, diff --git a/tests/framework/results/test_stratification.py b/tests/framework/results/test_stratification.py index 54c2564dc..d2b0d60e8 100644 --- a/tests/framework/results/test_stratification.py +++ b/tests/framework/results/test_stratification.py @@ -24,7 +24,6 @@ get_mapped_col_name, get_original_col_name, ) -from vivarium.framework.values import Pipeline @pytest.mark.parametrize( @@ -47,8 +46,7 @@ def test_stratification( ) -> None: my_stratification = Stratification( name=NAME, - requires_columns=NAME_COLUMNS, - requires_values=[], + requires_attributes=NAME_COLUMNS, categories=HOUSE_CATEGORIES, excluded_categories=[], mapper=mapper, @@ -60,50 +58,45 @@ def test_stratification( @pytest.mark.parametrize( - "requires_columns, requires_values, categories, mapper, msg_match", + "requires_attributes, categories, mapper, msg_match", [ ( - [], [], HOUSE_CATEGORIES, None, ( - f"No mapper but 0 stratification sources are provided for stratification {NAME}. " - "The list of sources must be of length 1 if no mapper is provided." + f"No mapper but 0 required attributes are provided for stratification {NAME}. " + "The list of required attributes must be of length 1 if no mapper is provided." ), ), ( NAME_COLUMNS, - [], HOUSE_CATEGORIES, None, ( - f"No mapper but {len(NAME_COLUMNS)} stratification sources are provided for " - f"stratification {NAME}. The list of sources must be of length 1 if no mapper is " - "provided." + f"No mapper but {len(NAME_COLUMNS)} required attributes are provided for " + f"stratification {NAME}. The list of required attributes must be of " + "length 1 if no mapper is provided." ), ), ( - ["house"], - ["grade"], + ["house", "grade"], HOUSE_CATEGORIES, None, ( - f"No mapper but 2 stratification sources are provided for stratification {NAME}. " - "The list of sources must be of length 1 if no mapper is provided." + f"No mapper but 2 required attributes are provided for stratification {NAME}. " + "The list of required attributes must be of length 1 if no mapper is provided." ), ), ( - [], [], HOUSE_CATEGORIES, sorting_hat_vectorized, - "The sources argument must be non-empty.", + "The requires_attributes argument must be non-empty.", ), ( NAME_COLUMNS, [], - [], sorting_hat_vectorized, "The categories argument must be non-empty.", ), @@ -117,19 +110,17 @@ def test_stratification( ], ) def test_stratification_init_raises( - requires_columns: list[str], - requires_values: list[str], + requires_attributes: list[str], categories: list[str], mapper: Callable[[pd.DataFrame], pd.Series[str]] | Callable[[pd.Series[str]], str], msg_match: str, ) -> None: - pipelines = [Pipeline(name) for name in requires_values] with pytest.raises(ValueError, match=re.escape(msg_match)): - Stratification(NAME, requires_columns, pipelines, categories, [], mapper, True) + Stratification(NAME, requires_attributes, categories, [], mapper, True) @pytest.mark.parametrize( - "requires_columns, mapper, is_vectorized, expected_exception, error_match", + "requires_attributes, mapper, is_vectorized, expected_exception, error_match", [ ( NAME_COLUMNS, @@ -176,7 +167,7 @@ def test_stratification_init_raises( ], ) def test_stratification_call_raises( - requires_columns: list[str], + requires_attributes: list[str], mapper: Callable[[pd.DataFrame], pd.Series[str]] | Callable[[pd.Series[str]], str], is_vectorized: bool, expected_exception: type[Exception], @@ -184,8 +175,7 @@ def test_stratification_call_raises( ) -> None: my_stratification = Stratification( name=NAME, - requires_columns=requires_columns, - requires_values=[], + requires_attributes=requires_attributes, categories=HOUSE_CATEGORIES, excluded_categories=[], mapper=mapper, diff --git a/tests/framework/test_configuration.py b/tests/framework/test_configuration.py index 642f078f7..5e8b2c4b5 100644 --- a/tests/framework/test_configuration.py +++ b/tests/framework/test_configuration.py @@ -23,7 +23,7 @@ def test_get_default_specification_user_config( default_spec = _get_default_specification() - assert expand_user_mock.called_once_with("~/vivarium.yaml") + expand_user_mock.assert_called_once_with() with test_user_config.open() as f: data = {"configuration": yaml.full_load(f)} @@ -44,7 +44,7 @@ def test_get_default_specification_no_user_config( default_spec = _get_default_specification() - assert expand_user_mock.called_once_with("~/vivarium.yaml") + expand_user_mock.assert_called_once_with() data: dict[str, dict[Any, Any]] = {"components": {}, "configuration": {}} data.update(DEFAULT_PLUGINS) @@ -82,7 +82,7 @@ def test_build_simulation_configuration( config = build_simulation_configuration() - assert expand_user_mock.called_once_with("~/vivarium.yaml") + expand_user_mock.assert_called_once_with() with test_user_config.open() as f: data = yaml.full_load(f) diff --git a/tests/framework/test_engine.py b/tests/framework/test_engine.py index ce5c697da..564056b22 100644 --- a/tests/framework/test_engine.py +++ b/tests/framework/test_engine.py @@ -24,8 +24,14 @@ NoStratificationsQuidditchWinsObserver, QuidditchWinsObserver, ) -from tests.helpers import Listener, MockComponentA, MockComponentB -from vivarium import Component +from tests.helpers import ( + AttributePipelineCreator, + ColumnCreator, + Listener, + MockComponentA, + MockComponentB, +) +from vivarium import Component, InteractiveContext from vivarium.framework.artifact import ArtifactInterface, ArtifactManager from vivarium.framework.components import ( ComponentConfigError, @@ -269,9 +275,8 @@ def test_SimulationContext_initialize_simulants( sim.setup() pop_size = sim.configuration.population.population_size current_time = sim._clock.time - assert sim._population.get_population(True).empty sim.initialize_simulants() - pop = sim._population.get_population(True) + pop = sim._population.get_population("all") assert len(pop) == pop_size assert sim._clock.time == current_time @@ -525,6 +530,22 @@ def test_SimulationContext_load_from_backup( assert isinstance(sim_backup, SimulationContext) +def test_private_columns_get_registered() -> None: + component1 = ColumnCreator() + component2 = AttributePipelineCreator() + sim = InteractiveContext(components=[component1, component2], setup=False) + assert sim._population._private_column_metadata == {} + sim.setup() + metadata = sim._population._private_column_metadata + assert metadata == { + component1.name: ["test_column_1", "test_column_2", "test_column_3"], + # The datetime clock does not have private columns but does register an initializer + "datetime_clock": [], + } + # Check that there are indeed other attributes registered besides via column_created + len(sim.get_population().columns) > 3 + + #################### # HELPER FUNCTIONS # #################### diff --git a/tests/framework/test_state_machine.py b/tests/framework/test_state_machine.py index f2d4b2c2d..5060394bb 100644 --- a/tests/framework/test_state_machine.py +++ b/tests/framework/test_state_machine.py @@ -9,20 +9,20 @@ from tests.helpers import ColumnCreator from vivarium import InteractiveContext from vivarium.framework.configuration import build_simulation_configuration +from vivarium.framework.engine import Builder from vivarium.framework.population import SimulantData -from vivarium.framework.resource import Resource from vivarium.framework.state_machine import Machine, State, Transition from vivarium.types import ClockTime, DataInput def test_initialize_allowing_self_transition() -> None: - self_transitions = State("self-transitions", allow_self_transition=True) + self_transitions = State("self-transitions") no_self_transitions = State("no-self-transitions", allow_self_transition=False) undefined_self_transitions = State("self-transitions") - assert self_transitions.transition_set.allow_null_transition - assert not no_self_transitions.transition_set.allow_null_transition - assert not undefined_self_transitions.transition_set.allow_null_transition + assert self_transitions.transition_set.allow_self_transition + assert not no_self_transitions.transition_set.allow_self_transition + assert undefined_self_transitions.transition_set.allow_self_transition def test_initialize_with_initial_state() -> None: @@ -30,7 +30,7 @@ def test_initialize_with_initial_state() -> None: other_state = State("other") machine = Machine("state", states=[start_state, other_state], initial_state=start_state) simulation = InteractiveContext(components=[machine]) - assert simulation.get_population()["state"].unique() == ["start"] + assert all(simulation.get_population("state") == "start") @pytest.mark.parametrize("weights_type", ["artifact", "callable", "scalar"]) @@ -66,8 +66,8 @@ def initialization_weights(key: str) -> DataInput: ) simulation.setup() - state = simulation.get_population()["state"] - assert np.all(simulation.get_population().state != "start") + state = simulation.get_population("state") + assert np.all(simulation.get_population("state") != "start") assert round((state == "a").mean(), 1) == 0.2 assert round((state == "b").mean(), 1) == 0.8 @@ -93,17 +93,7 @@ def mock_load(key: str) -> pd.DataFrame: {"population": {"population_size": 10000}, "randomness ": {"key_columns": []}} ) - class TestMachine(Machine): - @property - def initialization_requirements(self) -> list[str | Resource]: - # FIXME - MIC-5408: We shouldn't need to specify the columns in the - # lookup tables here, since the component can't know what will be - # specified by the states or the configuration. - return ["test_column_1"] - - def initialization_weights( - key: str, - ) -> DataInput: + def initialization_weights(key: str) -> DataInput: weights = { "artifact": key, "callable": lambda _: state_weights[key], @@ -114,7 +104,7 @@ def initialization_weights( state_a = State("a", initialization_weights=initialization_weights("state_a.weights")) state_b = State("b", initialization_weights=initialization_weights("state_b.weights")) - machine = TestMachine("state", states=[state_a, state_b]) + machine = Machine("state", states=[state_a, state_b]) simulation = InteractiveContext( components=[machine, ColumnCreator()], configuration=config, setup=False ) @@ -123,7 +113,7 @@ def initialization_weights( ) simulation.setup() - pop = simulation.get_population()[["state", "test_column_1"]] + pop = simulation.get_population(["state", "test_column_1"]) state_a_weights = state_weights["state_a.weights"] state_b_weights = state_weights["state_b.weights"] for i in range(3): @@ -135,13 +125,16 @@ def initialization_weights( def test_error_if_initialize_with_both_initial_state_and_initialization_weights() -> None: start_state = State("start") other_state = State("other", initialization_weights=lambda _: 0.8) + machine = Machine("state", states=[start_state, other_state], initial_state=start_state) + with pytest.raises(ValueError, match="Cannot specify both"): - Machine("state", states=[start_state, other_state], initial_state=start_state) + InteractiveContext(components=[machine]) def test_error_if_initialize_with_neither_initial_state_nor_initialization_weights() -> None: + machine = Machine("state", states=[State("a"), State("b")]) with pytest.raises(ValueError, match="Must specify either"): - Machine("state", states=[State("a"), State("b")]) + InteractiveContext(components=[machine]) @pytest.mark.parametrize("population_size", [1, 100]) @@ -155,8 +148,8 @@ def test_transition( "randomness": {"key_columns": []}, } ) + start_state = State("start", allow_self_transition=False) done_state = State("done") - start_state = State("start") if use_transition_arg: start_state.add_transition(Transition(start_state, done_state)) else: @@ -164,18 +157,18 @@ def test_transition( machine = Machine("state", states=[start_state, done_state], initial_state=start_state) simulation = InteractiveContext(components=[machine], configuration=base_config) - assert np.all(simulation.get_population().state == "start") + assert np.all(simulation.get_population("state") == "start") simulation.step() - assert np.all(simulation.get_population().state == "done") + assert np.all(simulation.get_population("state") == "done") def test_no_null_transition(base_config: LayeredConfigTree) -> None: base_config.update( {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} ) - a_state = State("a") - b_state = State("b") - start_state = State("start") + a_state = State("a", allow_self_transition=False) + b_state = State("b", allow_self_transition=False) + start_state = State("start", allow_self_transition=False) start_state.add_transition( output_state=a_state, probability_function=lambda index: pd.Series(0.4, index=index) ) @@ -187,12 +180,12 @@ def test_no_null_transition(base_config: LayeredConfigTree) -> None: ) simulation = InteractiveContext(components=[machine], configuration=base_config) - assert np.all(simulation.get_population().state == "start") + assert np.all(simulation.get_population("state") == "start") simulation.step() - state = simulation.get_population()["state"] - assert np.all(simulation.get_population().state != "start") + state = simulation.get_population("state") + assert np.all(state != "start") assert round((state == "a").mean(), 1) == 0.4 assert round((state == "b").mean(), 1) == 0.6 @@ -201,8 +194,8 @@ def test_null_transition(base_config: LayeredConfigTree) -> None: base_config.update( {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} ) - a_state = State("a") - start_state = State("start", allow_self_transition=True) + a_state = State("a", allow_self_transition=False) + start_state = State("start") start_state.add_transition( output_state=a_state, probability_function=lambda index: pd.Series(0.4, index=index) ) @@ -211,25 +204,27 @@ def test_null_transition(base_config: LayeredConfigTree) -> None: simulation = InteractiveContext(components=[machine], configuration=base_config) simulation.step() - state = simulation.get_population()["state"] + state = simulation.get_population("state") assert round((state == "a").mean(), 1) == 0.4 def test_side_effects() -> None: class CountingState(State): - @property - def columns_created(self) -> list[str]: - return ["count"] + def setup(self, builder: Builder) -> None: + super().setup(builder) + builder.population.register_initializer( + initializer=self.initialize_count, columns="count" + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_count(self, pop_data: SimulantData) -> None: self.population_view.update(pd.Series(0, index=pop_data.index, name="count")) def transition_side_effect(self, index: pd.Index[int], _: ClockTime) -> None: - pop = self.population_view.get(index) - self.population_view.update(pop["count"] + 1) + pop = self.population_view.get_attributes(index, "count") + self.population_view.update(pop + 1) - counting_state = CountingState("counting") - start_state = State("start") + counting_state = CountingState("counting", allow_self_transition=False) + start_state = State("start", allow_self_transition=False) start_state.add_transition(output_state=counting_state) counting_state.add_transition(output_state=start_state) @@ -237,16 +232,16 @@ def transition_side_effect(self, index: pd.Index[int], _: ClockTime) -> None: "state", states=[start_state, counting_state], initial_state=start_state ) simulation = InteractiveContext(components=[machine]) - assert np.all(simulation.get_population()["count"] == 0) + assert np.all(simulation.get_population("count") == 0) # transitioning to counting state simulation.step() - assert np.all(simulation.get_population()["count"] == 1) + assert np.all(simulation.get_population("count") == 1) # transitioning back to start state simulation.step() - assert np.all(simulation.get_population()["count"] == 1) + assert np.all(simulation.get_population("count") == 1) # transitioning to counting state again simulation.step() - assert np.all(simulation.get_population()["count"] == 2) + assert np.all(simulation.get_population("count") == 2) diff --git a/tests/framework/test_time.py b/tests/framework/test_time.py index 2bb85f7a4..c93a764b3 100644 --- a/tests/framework/test_time.py +++ b/tests/framework/test_time.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Any, Generator +from typing import Any from unittest.mock import MagicMock import numpy as np @@ -14,11 +14,10 @@ from vivarium.component import Component from vivarium.framework.engine import Builder, SimulationContext from vivarium.framework.event import Event -from vivarium.framework.results.observer import Observer -from vivarium.framework.time import SimulationClock, get_time_stamp +from vivarium.framework.time.manager import SimulationClock, get_time_stamp from vivarium.framework.utilities import from_yearly from vivarium.framework.values import ValuesManager, rescale_post_processor -from vivarium.types import ClockStepSize, ClockTime, NumberLike +from vivarium.types import ClockStepSize @pytest.fixture @@ -42,10 +41,9 @@ def components() -> list[Component | Listener]: def validate_step_column_is_pipeline(sim: SimulationContext) -> None: """Ensure that the pipeline and column step sizes are aligned""" - step_pipeline = sim._values.get_value("simulant_step_size")(sim.get_population().index) - assert sim._population._population is not None - step_column = sim._population._population.step_size - assert np.all(step_pipeline == step_column) + step_pipeline = sim._values.get_value("simulant_step_size")(sim.get_population_index()) + assert sim._clock._individual_clocks is not None + assert np.all(step_pipeline == sim._clock._individual_clocks["step_size"]) def validate_index_aligned( @@ -75,7 +73,7 @@ def take_step(sim: SimulationContext) -> ClockStepSize: def get_full_pop_index(sim: SimulationContext) -> pd.Index[int]: - return sim.get_population().index + return sim.get_population_index() def get_index_by_parity(index: pd.Index[int], parity: str) -> pd.Index[int]: @@ -88,7 +86,7 @@ def get_index_by_parity(index: pd.Index[int], parity: str) -> pd.Index[int]: def get_pop_by_parity(sim: SimulationContext, parity: str) -> pd.DataFrame: - pop = sim.get_population() + pop = pd.DataFrame(sim.get_population()) return pop.loc[get_index_by_parity(pop.index, parity)] @@ -175,34 +173,20 @@ def __init__( modified_simulants: str = "all", ) -> None: super().__init__(name, step_modifier_even, step_modifier_odd, modified_simulants) - self.ts_pipeline_value = None + self.rate_pipeline = f"test_rate_{self.name}" def setup(self, builder: Builder) -> None: super().setup(builder) - self.rate_pipeline = builder.value.register_value_producer( - f"test_rate_{self.name}", + builder.value.register_attribute_producer( + self.rate_pipeline, source=lambda idx: pd.Series(1.75, index=idx), preferred_post_processor=rescale_post_processor, ) def on_time_step(self, event: Event) -> None: - self.ts_pipeline_value = self.rate_pipeline(event.index) - - -class StepModifierWithUntracking(StepModifierWithRatePipeline): - """Add an event step that untracks/tracks even simulants every timestep""" - - @property - def columns_required(self) -> list[str]: - return ["tracked"] - - def on_time_step(self, event: Event) -> None: - super().on_time_step(event) - evens = self.population_view.get(event.index).loc[ - get_index_by_parity(event.index, "evens") - ] - evens["tracked"] = False - self.population_view.update(evens) + self.ts_pipeline_value = self.population_view.get_attributes( + event.index, self.rate_pipeline + ) class StepModifierWithMovement(StepModifierWithRatePipeline): @@ -235,9 +219,9 @@ def test_basic_iteration( full_pop_index = get_full_pop_index(sim) assert sim._clock.time == get_time_stamp(sim.configuration.time.start) assert sim._clock.step_size == pd.Timedelta(days=1) - ## Ensure that we don't have a pop view (and by extension, don't vary clocks) - ## If no components modify the step size. - assert bool(sim._clock._individual_clocks) == varied_step_size + # Ensure that we don't have a pop view (and by extension, don't vary clocks) + # if no components modify the step size. + assert (sim._clock._individual_clocks is not None) == varied_step_size for _ in range(2): # After initialization, all simulants should be aligned to event times @@ -271,8 +255,8 @@ def test_empty_active_pop( ## Force a next event time update without updating step sizes. ## This ensures (against the current implementation) that we will have a timestep ## that has no simulants aligned. Check that we do the minimum timestep update. - assert sim._population._population is not None - sim._population._population.next_event_time += pd.Timedelta(days=1) + assert sim._clock._individual_clocks is not None + sim._clock._individual_clocks["next_event_time"] += pd.Timedelta(days=1) ## First Step validate_step_column_is_pipeline(sim) take_step_and_validate( @@ -471,27 +455,6 @@ def test_multiple_modifiers(base_config: LayeredConfigTree) -> None: ) -def test_untracked_simulants(base_config: LayeredConfigTree) -> None: - """Test that untracked simulants are always included in event indices, and are - basically treated the same as any other simulant.""" - base_config.update({"configuration": {"time": {"standard_step_size": 7}}}) - listener = Listener("listener") - step_modifier_component = StepModifierWithUntracking("step_modifier", 3) - sim = SimulationContext( - base_config, - [step_modifier_component, listener], - ) - - sim.setup() - sim.initialize_simulants() - full_pop_index = get_full_pop_index(sim) - - for _ in range(2): - take_step_and_validate(sim, listener, full_pop_index, expected_step_size_days=3) - assert step_modifier_component.ts_pipeline_value is not None - assert step_modifier_component.ts_pipeline_value.index.equals(full_pop_index) - - def test_move_simulants_to_end(base_config: LayeredConfigTree) -> None: """Ensure that we move simulants' next event time to the end of the simulation, if they are even.""" base_config.update({"configuration": {"time": {"standard_step_size": 7}}}) @@ -512,7 +475,7 @@ def test_move_simulants_to_end(base_config: LayeredConfigTree) -> None: assert step_modifier_component.ts_pipeline_value.index.equals(full_pop_index) assert np.all( sim._clock.simulant_next_event_times(evens) - == sim._clock.stop_time + sim._clock.minimum_step_size + == sim._clock.stop_time + sim._clock.minimum_step_size # type: ignore [operator] ) for _ in range(2): @@ -539,8 +502,8 @@ def test_step_size_post_processor(builder: MagicMock) -> None: ## Add modifier that sets the step size to 9 for all simulants clock.register_step_modifier(lambda idx: pd.Series(pd.Timedelta(days=9), index=idx)) value = clock._step_size_pipeline(index) - evens = value.iloc[lambda x: x.index % 2 == 0] - odds = value.iloc[lambda x: x.index % 2 == 1] + evens = value[value.index % 2 == 0] + odds = value[value.index % 2 == 1] ## The second modifier shouldn't have an effect assert np.all(evens == pd.Timedelta(days=6)) diff --git a/tests/framework/test_values.py b/tests/framework/test_values.py index f14e99a00..ae4c762c0 100644 --- a/tests/framework/test_values.py +++ b/tests/framework/test_values.py @@ -1,14 +1,21 @@ from __future__ import annotations +import re from collections.abc import Callable +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd import pytest -from pytest_mock import MockFixture +from pytest_mock import MockerFixture, MockFixture +from tests.helpers import ColumnCreator +from vivarium import Component +from vivarium.framework.lifecycle import lifecycle_states +from vivarium.framework.resource import Column, Resource from vivarium.framework.utilities import from_yearly from vivarium.framework.values import ( + AttributePipeline, DynamicValueError, Pipeline, ValuesManager, @@ -16,6 +23,106 @@ rescale_post_processor, union_post_processor, ) +from vivarium.framework.values.pipeline import ( + AttributesValueSource, + PrivateColumnValueSource, + ValueSource, +) +from vivarium.interface import InteractiveContext +from vivarium.types import NumberLike + +if TYPE_CHECKING: + from vivarium.framework.engine import Builder + from vivarium.framework.population import SimulantData + from vivarium.framework.values import AttributePostProcessor + + +INDEX = pd.Index([4, 8, 15, 16, 23, 42]) + + +def test_configure_pipeline_calls_methods_correctly(mocker: MockerFixture) -> None: + """Test that _configure_pipeline orchestrates calls to helper methods correctly.""" + # Setup + manager = ValuesManager() + test_component = Component() + test_pipeline = mocker.Mock() + test_required_resources = ["resource1", "resource2"] + test_combiner = mocker.Mock() + test_post_processor = mocker.Mock() + + # Inject mocks into the manager + manager._get_current_component = mocker.Mock(return_value=test_component) + manager._add_resources = mocker.Mock() + manager._add_constraint = mocker.Mock() + + # Execute + manager._configure_pipeline( + test_pipeline, + lambda idx: pd.Series(1, index=idx), + test_required_resources, + test_combiner, + test_post_processor, + source_is_private_column=False, + ) + + # Assert pipeline.set_attributes was called + test_pipeline.set_attributes.assert_called_once() + call_args = test_pipeline.set_attributes.call_args + assert call_args[1]["component"] == test_component + assert isinstance(call_args[1]["source"], ValueSource) + assert call_args[1]["combiner"] == test_combiner + assert call_args[1]["post_processor"] == test_post_processor + assert call_args[1]["manager"] == manager + + # Assert _add_resources was called with correct arguments + manager._add_resources.assert_called_once_with( # type: ignore[attr-defined] + component=test_pipeline.component, + resources=test_pipeline.source, + required_resources=test_pipeline.source.required_resources, + ) + + # Assert _add_constraint was called with correct arguments + manager._add_constraint.assert_called_once() # type: ignore[attr-defined] + call_args = manager._add_constraint.call_args # type: ignore[attr-defined] + assert call_args[0][0] == test_pipeline._call + assert call_args[1]["restrict_during"] == [ + lifecycle_states.INITIALIZATION, + lifecycle_states.SETUP, + lifecycle_states.POST_SETUP, + ] + + +def test_configure_modifier_calls_methods_correctly(mocker: MockerFixture) -> None: + """Test that _configure_modifier orchestrates calls to helper methods correctly.""" + # Setup + manager = ValuesManager() + test_component = Component() + test_pipeline = mocker.Mock() + test_modifier = lambda idx, val: val + 1 + test_required_resources = ["resource1", "resource2"] + + # Set up a mock value modifier + mock_value_modifier = mocker.Mock() + mock_value_modifier.name = "test_modifier" + test_pipeline.get_value_modifier.return_value = mock_value_modifier + + # Inject mocks into the manager + manager._get_current_component = mocker.Mock(return_value=test_component) + manager._add_resources = mocker.Mock() + manager.logger = mocker.Mock() + + # Execute + manager._configure_modifier(test_pipeline, test_modifier, test_required_resources) + + # Assert pipeline.get_value_modifier was called with correct arguments + test_pipeline.get_value_modifier.assert_called_once_with(test_modifier, test_component) + + # Assert _add_resources was called with correct arguments + manager._add_resources.assert_called_once_with( # type: ignore[attr-defined] + component=test_component, + resources=mock_value_modifier, + required_resources=test_required_resources, + ) @pytest.fixture @@ -50,7 +157,7 @@ def manager_with_step_size( return manager -def test_replace_combiner(manager: ValuesManager) -> None: +def test_replace_combiner(manager: ValuesManager, mocker: MockFixture) -> None: value = manager.register_value_producer("test", source=lambda: 1) assert value() == 1 @@ -64,21 +171,25 @@ def test_replace_combiner(manager: ValuesManager) -> None: def test_joint_value(manager: ValuesManager) -> None: # This is the normal configuration for PAF and disability weight type values - index = pd.Index(range(10)) - value = manager.register_value_producer( + manager.register_attribute_producer( "test", source=lambda idx: [pd.Series(0.0, index=idx)], preferred_combiner=list_combiner, - preferred_post_processor=union_post_processor, # type: ignore [arg-type] + preferred_post_processor=union_post_processor, ) - assert np.all(value(index) == 0) + value = manager.get_attribute_pipelines()["test"] + assert np.all(value(INDEX) == 0) - manager.register_value_modifier("test", modifier=lambda idx: pd.Series(0.5, index=idx)) - assert np.all(value(index) == 0.5) + manager.register_attribute_modifier( + "test", modifier=lambda idx: pd.Series(0.5, index=idx) + ) + assert np.all(value(INDEX) == 0.5) - manager.register_value_modifier("test", modifier=lambda idx: pd.Series(0.5, index=idx)) - assert np.all(value(index) == 0.75) + manager.register_attribute_modifier( + "test", modifier=lambda idx: pd.Series(0.5, index=idx) + ) + assert np.all(value(INDEX) == 0.75) def test_contains(manager: ValuesManager) -> None: @@ -95,45 +206,724 @@ def test_contains(manager: ValuesManager) -> None: def test_returned_series_name(manager: ValuesManager) -> None: value = manager.register_value_producer( - "test", - source=lambda idx: pd.Series(0.0, index=idx), + "test", source=lambda idx: pd.Series(0.0, index=idx) ) - assert value(pd.Index(range(10))).name == "test" + assert value(INDEX).name == "test" @pytest.mark.parametrize("manager_with_step_size", ["static_step"], indirect=True) def test_rescale_post_processor_static(manager_with_step_size: ValuesManager) -> None: - index = pd.Index(range(10)) - pipeline = manager_with_step_size.register_value_producer( + manager_with_step_size.register_attribute_producer( "test", source=lambda idx: pd.Series(0.75, index=idx), preferred_post_processor=rescale_post_processor, ) - assert np.all(pipeline(index) == from_yearly(0.75, pd.Timedelta(days=6))) + pipeline = manager_with_step_size.get_attribute_pipelines()["test"] + assert np.all(pipeline(INDEX) == from_yearly(0.75, pd.Timedelta(days=6))) @pytest.mark.parametrize("manager_with_step_size", ["variable_step"], indirect=True) def test_rescale_post_processor_variable(manager_with_step_size: ValuesManager) -> None: - index = pd.Index(range(10)) - pipeline = manager_with_step_size.register_value_producer( + manager_with_step_size.register_attribute_producer( "test", source=lambda idx: pd.Series(0.5, index=idx), preferred_post_processor=rescale_post_processor, ) - value = pipeline(index) - evens = value.iloc[lambda x: x.index % 2 == 0] - odds = value.iloc[lambda x: x.index % 2 == 1] + pipeline = manager_with_step_size.get_attribute_pipelines()["test"] + value = pipeline(INDEX) + evens = value[INDEX % 2 == 0] + odds = value[INDEX % 2 == 1] assert np.all(evens == from_yearly(0.5, pd.Timedelta(days=3))) assert np.all(odds == from_yearly(0.5, pd.Timedelta(days=5))) -def test_unsourced_pipeline() -> None: - pipeline = Pipeline("some_name") - assert pipeline.source.resource_id == "missing_value_source.some_name" +@pytest.mark.parametrize("manager_with_step_size", ["static_step"], indirect=True) +@pytest.mark.parametrize( + "source, expected", + [ + ( + lambda idx: pd.Series(0.75, index=idx), + pd.Series(from_yearly(0.75, pd.Timedelta(days=6)), index=INDEX), + ), + ( + lambda idx: 0.75, + pd.Series(from_yearly(0.75, pd.Timedelta(days=6)), index=INDEX), + ), + ( + lambda idx: pd.Series(10, index=idx), + pd.Series(from_yearly(10, pd.Timedelta(days=6)), index=INDEX), + ), + ( + lambda idx: np.array([0.75] * len(idx)), + pd.Series(from_yearly(0.75, pd.Timedelta(days=6)), index=INDEX), + ), + ( + lambda idx: np.array([[0.75, 0.1, 0.04]] * len(idx)), + pd.DataFrame( + { + 0: from_yearly(0.75, pd.Timedelta(days=6)), + 1: from_yearly(0.1, pd.Timedelta(days=6)), + 2: from_yearly(0.04, pd.Timedelta(days=6)), + }, + index=INDEX, + ), + ), + (lambda idx: np.array([[[0.75], [0.1], [0.04]]] * len(idx)), None), # should raise + ], +) +def test_rescale_post_processor_types( + source: Callable[[pd.Index[int]], pd.Series[float] | pd.Series[int] | pd.DataFrame], + expected: pd.Series[int] | pd.Series[float] | pd.DataFrame | None, + manager_with_step_size: ValuesManager, +) -> None: + + manager_with_step_size.register_attribute_producer( + "test", + source=source, + preferred_post_processor=rescale_post_processor, + ) + pipeline = manager_with_step_size.get_attribute_pipelines()["test"] + if expected is not None: + attributes = pipeline(INDEX) + if isinstance(expected, pd.DataFrame): + assert isinstance(attributes, pd.DataFrame) + pd.testing.assert_frame_equal(attributes, expected) + else: + assert isinstance(expected, pd.Series) + assert attributes.equals(expected) + else: + with pytest.raises( + DynamicValueError, + match=re.escape( + "Numpy arrays with 3 dimensions are not supported. Only 1D and 2D arrays are allowed." + ), + ): + pipeline(INDEX) + + +# Tests for union_post_processor + + +def test_union_post_processor_not_list(manager: ValuesManager) -> None: + """Test that union_post_processor raises an error when value is not a list.""" + with pytest.raises( + DynamicValueError, + match=re.escape("The union post processor requires a list of values."), + ): + union_post_processor(INDEX, 0.5, manager) # type: ignore[arg-type] + + +@pytest.mark.parametrize("invalid_value", [[0.5, "string"], [pd.Series([0.5]), None]]) +def test_union_post_processor_invalid_element_type( + invalid_value: list[Any], manager: ValuesManager +) -> None: + """Test that union_post_processor raises an error for invalid element types.""" + with pytest.raises( + DynamicValueError, + match=re.escape( + "The union post processor only supports numeric types, " + "pandas Series/DataFrames, and numpy ndarrays." + ), + ): + union_post_processor(INDEX, invalid_value, manager) + + +def test_union_post_processor_3d_array(manager: ValuesManager) -> None: + """Test that union_post_processor raises an error for 3D numpy arrays.""" + value: list[NumberLike] = [np.array([[[0.5], [0.3]], [[0.2], [0.1]]])] + with pytest.raises( + DynamicValueError, + match=re.escape( + "Numpy arrays with 3 dimensions are not supported. Only 1D and 2D arrays are allowed." + ), + ): + union_post_processor(INDEX, value, manager) + + +@pytest.mark.parametrize( + "value, expected_type", + [ + ([0.5], pd.Series), + # 1D numpy array + ([np.array([0.3, 0.4, 0.5, 0.6, 0.7, 0.8])], pd.Series), + # 2D numpy array + ( + [ + np.array( + [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7]] + ) + ], + pd.DataFrame, + ), + # pandas Series + ([pd.Series([0.2, 0.3, 0.4, 0.5, 0.6, 0.7], index=INDEX)], pd.Series), + # pandas DataFrame + ( + [ + pd.DataFrame( + { + "a": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "b": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + }, + index=INDEX, + ) + ], + pd.DataFrame, + ), + ], +) +def test_union_post_processor_single_element( + value: list[NumberLike], + expected_type: type[pd.Series[Any] | pd.DataFrame], + manager: ValuesManager, +) -> None: + """Test that union_post_processor returns the single element correctly formatted.""" + result = union_post_processor(INDEX, value, manager) + if expected_type is pd.DataFrame: + assert isinstance(result, pd.DataFrame) + pd.testing.assert_frame_equal(result, pd.DataFrame(value[0], index=INDEX)) # type: ignore[arg-type] + else: + assert isinstance(result, pd.Series) + pd.testing.assert_series_equal(result, pd.Series(value[0], index=INDEX)) + + +@pytest.mark.parametrize( + "value, expected_value", + [ + # Two scalars: 1 - (1-0.5)*(1-0.3) = 1 - 0.35 = 0.65 + ([0.5, 0.3], pd.Series(0.65, index=INDEX[:2])), + # Three scalars: 1 - (1-0.5)*(1-0.3)*(1-0.2) = 1 - 0.28 = 0.72 + ([0.5, 0.3, 0.2], pd.Series(0.72, index=INDEX[:2])), + # Multiple 1D arrays + ( + [np.array([0.1, 0.2]), np.array([0.2, 0.3])], + ( + pd.Series( + [1 - (1 - 0.1) * (1 - 0.2), 1 - (1 - 0.2) * (1 - 0.3)], index=INDEX[:2] + ) + ), + ), + # Multiple Series + ( + [pd.Series([0.1, 0.2], index=INDEX[:2]), pd.Series([0.3, 0.4], index=INDEX[:2])], + pd.Series( + [1 - (1 - 0.1) * (1 - 0.3), 1 - (1 - 0.2) * (1 - 0.4)], index=INDEX[:2] + ), + ), + # Multiple DataFrames + ( + [ + pd.DataFrame({"a": [0.1, 0.2], "b": [0.5, 0.6]}, index=INDEX[:2]), + pd.DataFrame({"a": [0.3, 0.4], "b": [0.7, 0.8]}, index=INDEX[:2]), + ], + pd.DataFrame( + { + "a": [1 - (1 - 0.1) * (1 - 0.3), 1 - (1 - 0.2) * (1 - 0.4)], + "b": [1 - (1 - 0.5) * (1 - 0.7), 1 - (1 - 0.6) * (1 - 0.8)], + }, + index=INDEX[:2], + ), + ), + ], +) +def test_union_post_processor_multiple_same_type( + value: list[NumberLike], + expected_value: pd.Series[float] | pd.DataFrame, + manager: ValuesManager, +) -> None: + """Test union_post_processor with multiple elements of the same type.""" + index = INDEX[:2] + result = union_post_processor(index, value, manager) + + if isinstance(expected_value, pd.DataFrame): + # DataFrame result + assert isinstance(result, pd.DataFrame) + pd.testing.assert_frame_equal(result, expected_value) + else: + # Series result + assert isinstance(result, pd.Series) + pd.testing.assert_series_equal(result, expected_value) + + +@pytest.mark.parametrize( + "value, expected", + [ + # Scalar and 1D array: 1 - (1-0.5)*(1-[0.2, 0.3]) = [0.6, 0.65] + ( + [0.5, np.array([0.2, 0.3])], + pd.Series( + [1 - (1 - 0.5) * (1 - 0.2), 1 - (1 - 0.5) * (1 - 0.3)], index=INDEX[:2] + ), + ), + # Scalar and Series + ( + [0.4, pd.Series([0.3, 0.2], index=INDEX[:2])], + pd.Series( + [1 - (1 - 0.4) * (1 - 0.3), 1 - (1 - 0.4) * (1 - 0.2)], index=INDEX[:2] + ), + ), + # 1D array and Series + ( + [np.array([0.1, 0.2]), pd.Series([0.3, 0.4], index=INDEX[:2])], + pd.Series( + [1 - (1 - 0.1) * (1 - 0.3), 1 - (1 - 0.2) * (1 - 0.4)], index=INDEX[:2] + ), + ), + ], +) +def test_union_post_processor_mixed_types_1d( + value: list[NumberLike], expected: pd.Series[float], manager: ValuesManager +) -> None: + """Test union_post_processor with mixed types that result in 1D output.""" + result = union_post_processor(INDEX[:2], value, manager) + assert isinstance(result, pd.Series) + pd.testing.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "value, expected", + [ + # Scalar and 2D array + ( + [0.5, np.array([[0.2, 0.3], [0.4, 0.5]])], + pd.DataFrame( + { + 0: [1 - (1 - 0.5) * (1 - 0.2), 1 - (1 - 0.5) * (1 - 0.4)], + 1: [1 - (1 - 0.5) * (1 - 0.3), 1 - (1 - 0.5) * (1 - 0.5)], + }, + index=INDEX[:2], + ), + ), + # 2D array and DataFrame + ( + [ + np.array([[0.1, 0.2], [0.3, 0.4]]), + pd.DataFrame({"a": [0.5, 0.6], "b": [0.7, 0.8]}, index=INDEX[:2]), + ], + pd.DataFrame( + { + "a": [1 - (1 - 0.1) * (1 - 0.5), 1 - (1 - 0.3) * (1 - 0.6)], + "b": [1 - (1 - 0.2) * (1 - 0.7), 1 - (1 - 0.4) * (1 - 0.8)], + }, + index=INDEX[:2], + ), + ), + ], +) +def test_union_post_processor_mixed_types_2d( + value: list[NumberLike], expected: pd.DataFrame, manager: ValuesManager +) -> None: + """Test union_post_processor with mixed types that result in 2D output.""" + result = union_post_processor(INDEX[:2], value, manager) + assert isinstance(result, pd.DataFrame) + pd.testing.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("pipeline_type", [Pipeline, AttributePipeline]) +def test_unsourced_pipeline(pipeline_type: Pipeline) -> None: + pipeline = pipeline_type("some_name") + value_type = "attribute" if isinstance(pipeline, AttributePipeline) else "value" + assert pipeline.source.resource_id == f"missing_{value_type}_source.some_name" with pytest.raises( DynamicValueError, match=f"The dynamic value pipeline for {pipeline.name} has no source.", ): - pipeline() + pipeline(index=INDEX) + + +def test_attribute_pipeline_creation() -> None: + """Test that AttributePipeline can be created and has correct attributes.""" + pipeline = AttributePipeline("test_attribute") + assert pipeline.name == "test_attribute" + assert pipeline.resource_type == "attribute" + assert isinstance(pipeline.source, ValueSource) + assert pipeline.source.resource_id == "missing_attribute_source.test_attribute" + + +def test_attribute_pipeline_register_producer(manager: ValuesManager) -> None: + """Test registering an attribute producer through ValuesManager.""" + # Create a simple attribute source + def age_source(index: pd.Index[int]) -> pd.DataFrame: + return pd.DataFrame( + { + "age": [25.0, 30.0, 35.0, 40.0, 45.0][: len(index)], + "birth_year": [1999, 1994, 1989, 1984, 1979][: len(index)], + }, + index=index, + ) + + # Register the attribute producer + manager.register_attribute_producer("age", source=age_source) + pipeline = manager.get_attribute_pipelines()["age"] + + # Verify it returns an AttributePipeline + assert isinstance(pipeline, AttributePipeline) + assert pipeline.name == "age" + + # Test calling the pipeline + index = pd.Index([0, 1, 2]) + result = pipeline(index) + + assert isinstance(result, pd.DataFrame) + assert result.index.equals(index) + assert list(result.columns) == ["age", "birth_year"] + assert all(result["age"] == [25.0, 30.0, 35.0]) + assert all(result["birth_year"] == [1999, 1994, 1989]) + + +@pytest.mark.parametrize("use_postprocessor", [True, False]) +def test_attribute_pipeline_usage(use_postprocessor: bool, manager: ValuesManager) -> None: + + # Create initialized dataframe + data = pd.DataFrame({"col1": [0.0] * (max(INDEX) + 5), "col2": [0.0] * (max(INDEX) + 5)}) + + def attribute_source(index: pd.Index[int]) -> pd.DataFrame: + df = data.loc[index].copy() + df["col1"] = 1.0 + df["col2"] = 2.0 + return df + + def attribute_post_processor( + index: pd.Index[int], value: pd.DataFrame, manager: ValuesManager + ) -> pd.DataFrame: + return value * 10 + + def attribute_modifier1(index: pd.Index[int], value: pd.DataFrame) -> pd.DataFrame: + """modify col1 only""" + df = value.copy() + df["col1"] += 1.0 + return df + + def attribute_modifier2(index: pd.Index[int], value: pd.DataFrame) -> pd.DataFrame: + """modify col2 only""" + df = value.copy() + df["col2"] += 2.0 + return df + + manager.register_attribute_producer( + "test_attribute", + source=attribute_source, + preferred_post_processor=attribute_post_processor if use_postprocessor else None, + ) + pipeline = manager.get_attribute_pipelines()["test_attribute"] + manager.register_attribute_modifier("test_attribute", modifier=attribute_modifier1) + manager.register_attribute_modifier("test_attribute", modifier=attribute_modifier2) + + result = pipeline(INDEX) + + assert isinstance(result, pd.DataFrame) + assert result.index.equals(INDEX) + assert set(result.columns) == {"col1", "col2"} + assert all(result["col1"] == (20 if use_postprocessor else 2.0)) + assert all(result["col2"] == (40 if use_postprocessor else 4.0)) + + +def test_attribute_pipeline_raises_returns_different_index(manager: ValuesManager) -> None: + """Test than an error is raised when the index returned is different than was passed in.""" + + def bad_attribute_source(index: pd.Index[int]) -> pd.DataFrame: + index += 1 + return pd.DataFrame( + {"col1": [1.0] * len(index), "col2": [2.0] * len(index)}, index=index + ) + + manager.register_attribute_producer("test_attribute", source=bad_attribute_source) + pipeline = manager.get_attribute_pipelines()["test_attribute"] + + with pytest.raises( + DynamicValueError, + match=f"The dynamic attribute pipeline for {pipeline.name} returned a series " + "or dataframe with a different index than was passed in.", + ): + pipeline(INDEX) + + +def test_attribute_pipeline_return_types(manager: ValuesManager) -> None: + def series_attribute_source(index: pd.Index[int]) -> pd.Series[float]: + return pd.Series([1.0] * len(index), index=index) + + def dataframe_attribute_source(index: pd.Index[int]) -> pd.DataFrame: + return pd.DataFrame({"col1": [1.0] * len(index)}, index=index) + + def str_attribute_source(index: pd.Index[int]) -> str: + return "foo" + + manager.register_attribute_producer( + "test_series_attribute", source=series_attribute_source + ) + series_pipeline = manager.get_attribute_pipelines()["test_series_attribute"] + manager.register_attribute_producer( + "test_dataframe_attribute", source=dataframe_attribute_source + ) + dataframe_pipeline = manager.get_attribute_pipelines()["test_dataframe_attribute"] + manager.register_attribute_producer( + "test_series_attribute_with_str_source", + source=str_attribute_source, + preferred_post_processor=lambda idx, val, mgr: pd.Series(val, index=idx), + ) + series_pipeline_with_str_source = manager.get_attribute_pipelines()[ + "test_series_attribute_with_str_source" + ] + + assert isinstance(series_pipeline(INDEX), pd.Series) + assert series_pipeline(INDEX).index.equals(INDEX) + + assert isinstance(dataframe_pipeline(INDEX), pd.DataFrame) + assert dataframe_pipeline(INDEX).index.equals(INDEX) + + assert isinstance(series_pipeline_with_str_source(INDEX), pd.Series) + assert series_pipeline_with_str_source(INDEX).index.equals(INDEX) + + # Register the string source w/ no post-processors, i.e. calling will return str + manager.register_attribute_producer("test_bad_attribute", source=str_attribute_source) + bad_pipeline = manager.get_attribute_pipelines()["test_bad_attribute"] + + with pytest.raises( + DynamicValueError, + match=( + f"The dynamic attribute pipeline for {bad_pipeline.name} returned a {type('foo')} " + "but pd.Series' or pd.DataFrames are expected for attribute pipelines." + ), + ): + bad_pipeline(INDEX) + + +@pytest.mark.parametrize("skip_post_processor", [True, False]) +def test_attribute_pipeline_with_post_processor( + skip_post_processor: bool, manager: ValuesManager +) -> None: + """Test that AttributePipeline works with AttributePostProcessor.""" + + # Create a source that returns a DataFrame + def attribute_source(index: pd.Index[int]) -> pd.DataFrame: + return pd.DataFrame({"value": [10.0] * len(index)}, index=index) + + # Create a post-processor that doubles values + def double_post_processor( + index: pd.Index[int], value: pd.DataFrame, manager: ValuesManager + ) -> pd.DataFrame: + result = value.copy() + result["value"] = result["value"] * 2 + return result + + manager.register_attribute_producer( + "test_attribute", + source=attribute_source, + preferred_post_processor=double_post_processor, + ) + pipeline = manager.get_attribute_pipelines()["test_attribute"] + + result = pipeline(INDEX, skip_post_processor=skip_post_processor) + + # Verify post-processor was applied + assert isinstance(result, pd.DataFrame) + assert result.index.equals(INDEX) + assert all(result["value"] == (20.0 if not skip_post_processor else 10.0)) + + +def test_get_attribute(manager: ValuesManager) -> None: + """Test that ValuesManager.get_attribute returns AttributePipeline.""" + + # Test getting an attribute that doesn't exist yet + pipeline = manager.get_attribute("test_attribute") + assert isinstance(pipeline, AttributePipeline) + assert pipeline.name == "test_attribute" + + # Test getting the same attribute again returns the same pipeline + pipeline2 = manager.get_attribute("test_attribute") + assert pipeline is pipeline2 + + +def test_duplicate_names_raise(manager: ValuesManager) -> None: + """Tests that we raise if we try to register a value and attribute producer with the same name.""" + name = "test1" + manager.register_value_producer(name, source=lambda: 1) + with pytest.raises( + DynamicValueError, + match=re.escape(f"'{name}' is already registered as a value pipeline."), + ): + manager.register_attribute_producer(name, source=lambda idx: pd.DataFrame()) + + # switch order + name = "test2" + manager.register_attribute_producer(name, source=lambda idx: pd.DataFrame()) + with pytest.raises( + DynamicValueError, + match=re.escape(f"'{name}' is already registered as an attribute pipeline."), + ): + manager.register_value_producer(name, source=lambda: 1) + + +@pytest.mark.parametrize( + "source, expected_return", + [ + (lambda idx: pd.Series(1.0, index=idx), pd.Series(1.0, index=INDEX)), + (["attr1", "attr2"], pd.DataFrame({"attr1": [10.0], "attr2": [20.0]}, index=INDEX)), + (["attr2"], pd.Series(20.0, index=INDEX, name="attr2")), + ], +) +def test_source_callable( + source: pd.Series[float] | list[str] | int, + expected_return: pd.Series[float] | pd.DataFrame | None, +) -> None: + """Test that the source is correctly converted to a callable if needed.""" + + class SomeComponent(Component): + def setup(self, builder: Builder) -> None: + builder.value.register_attribute_producer( + "some-attribute", + source=source, # type: ignore [arg-type] # we are testing invalid types too + ) + builder.population.register_initializer( + initializer=self.initialize_attr1_attr2, columns=["attr1", "attr2"] + ) + + def initialize_attr1_attr2(self, pop_data: SimulantData) -> None: + update = pd.DataFrame({"attr1": [10.0], "attr2": [20.0]}, index=pop_data.index) + self.population_view.update(update) + + sim = InteractiveContext(components=[SomeComponent()]) + attribute = sim.get_population("some-attribute") + assert type(attribute) == type(expected_return) + if isinstance(expected_return, pd.DataFrame) and isinstance(attribute, pd.DataFrame): + pd.testing.assert_frame_equal(attribute.loc[INDEX, :], expected_return) + elif isinstance(expected_return, pd.Series) and isinstance(attribute, pd.Series): + assert attribute[INDEX].equals(expected_return) + + +@pytest.mark.parametrize( + "source, post_processor, is_private_column, expected_is_simple", + [ + (["col1"], None, True, True), + (["col1"], None, False, False), + (lambda idx: pd.DataFrame({"col1": [1.0] * len(idx)}), None, False, False), + (["col1"], lambda idx, val, mgr: val * 2, True, False), + ], +) +def test_attribute_pipeline_is_simple( + source: list[str] | Callable[[pd.Index[int]], pd.DataFrame], + post_processor: AttributePostProcessor | None, + is_private_column: bool, + expected_is_simple: bool, + manager: ValuesManager, +) -> None: + """Test the is_simple property of AttributePipeline.""" + manager.register_attribute_producer( + "test_attribute", + source=source, + preferred_post_processor=post_processor, + source_is_private_column=is_private_column, + ) + pipeline = manager.get_attribute_pipelines()["test_attribute"] + assert pipeline.is_simple == expected_is_simple + manager.register_attribute_modifier("test_attribute", modifier=lambda idx, val: val + 1) + assert pipeline.is_simple is False + + +class TestConfigurePipeline: + """Test class for _configure_pipeline resource handling.""" + + @pytest.fixture + def component(self) -> Component: + return ColumnCreator() + + @pytest.fixture + def manager(self, mocker: MockerFixture, component: Component) -> ValuesManager: + manager = ValuesManager() + manager._add_resources = mocker.Mock() + manager._add_constraint = mocker.Mock() + manager.logger = mocker.Mock() + manager._get_current_component = lambda: component + return manager + + @pytest.fixture + def pipeline(self) -> AttributePipeline: + return AttributePipeline("test_pipeline") + + @pytest.fixture + def required_resources(self) -> list[Resource]: + return [Resource("test", "resource_1", None)] + + @staticmethod + def callable_source(idx: pd.Index[int]) -> pd.Series[float]: + return pd.Series(1.0, index=idx) + + def test__configure_pipeline_with_callable_source( + self, + manager: ValuesManager, + pipeline: AttributePipeline, + required_resources: list[Resource], + ) -> None: + """Test that _configure_pipeline handles callable source correctly.""" + manager._configure_pipeline( + pipeline=pipeline, + source=self.callable_source, + required_resources=required_resources, + ) + # Check that pipeline.set_attributes was called correctly + assert isinstance(pipeline.source, ValueSource) + assert pipeline.source._source == self.callable_source + assert pipeline.source.required_resources == required_resources + + def test__configure_pipeline_with_private_column_source( + self, + manager: ValuesManager, + pipeline: AttributePipeline, + component: Component, + required_resources: list[Resource], + ) -> None: + """Test that _configure_pipeline handles private column source correctly.""" + manager._configure_pipeline( + pipeline=pipeline, + source=["col1"], + source_is_private_column=True, + required_resources=required_resources, + ) + assert isinstance(pipeline.source, PrivateColumnValueSource) + assert pipeline.source.column.name == "col1" + assert pipeline.source.required_resources == [ + Column("col1", component), + *required_resources, + ] + + def test__configure_pipeline_with_attribute_column_source( + self, + manager: ValuesManager, + pipeline: AttributePipeline, + required_resources: list[Resource], + ) -> None: + """Test that _configure_pipeline handles attribute column source correctly.""" + manager._configure_pipeline( + pipeline=pipeline, + source=["col1", "col2"], + required_resources=required_resources, + ) + # Check that pipeline.set_attributes was called correctly + assert isinstance(pipeline.source, AttributesValueSource) + assert pipeline.source.attributes == ["col1", "col2"] + assert pipeline.source.required_resources == ["col1", "col2", *required_resources] + + @pytest.mark.parametrize( + "source, error_msg", + [ + (callable_source, "Got `source` type"), + (["col1", "col2"], "Got 2 names instead."), + ], + ) + def test__configure_pipeline_raises( + self, + mocker: MockerFixture, + source: Callable[[pd.Index[int]], pd.Series[float]] | list[str], + error_msg: str, + ) -> None: + manager = ValuesManager() + manager._get_current_component = mocker.Mock() + pipeline = AttributePipeline("test_callable") + with pytest.raises(ValueError, match=error_msg): + manager._configure_pipeline( + pipeline=pipeline, + source=source, + source_is_private_column=True, + ) diff --git a/tests/helpers.py b/tests/helpers.py index a24236fee..c354317c2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -3,14 +3,12 @@ from typing import Any import pandas as pd -from layered_config_tree import ConfigurationError from vivarium import Component, Observer from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.lifecycle import lifecycle_states from vivarium.framework.population import SimulantData -from vivarium.framework.resource import Resource from vivarium.manager import Manager @@ -96,6 +94,9 @@ def setup(self, builder: Builder) -> None: def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.name == other.name + def __hash__(self) -> int: + return super().__hash__() + class Listener(MockComponentB): def __init__(self, *args: Any, name: str = "test_listener"): @@ -138,23 +139,114 @@ def on_simulation_end(self, event: Event) -> None: class ColumnCreator(Component): - @property - def columns_created(self) -> list[str]: - return ["test_column_1", "test_column_2", "test_column_3"] - def setup(self, builder: Builder) -> None: - builder.value.register_value_producer("pipeline_1", lambda x: x) + builder.population.register_initializer( + initializer=self.initialize_test_columns, + columns=["test_column_1", "test_column_2", "test_column_3"], + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_test_columns(self, pop_data: SimulantData) -> None: self.population_view.update(self.get_initial_state(pop_data.index)) def get_initial_state(self, index: pd.Index[int]) -> pd.DataFrame: return pd.DataFrame( - {column: [i % 3 for i in index] for column in self.columns_created}, index=index + { + column: [i % 3 for i in index] + for column in ["test_column_1", "test_column_2", "test_column_3"] + }, + index=index, + ) + + +class SingleColumnCreator(ColumnCreator): + def setup(self, builder: Builder) -> None: + builder.population.register_initializer( + initializer=self.initialize_test_columns, columns=["test_column_1"] + ) + + def get_initial_state(self, index: pd.Index[int]) -> pd.DataFrame: + return pd.DataFrame( + {"test_column_1": [i % 3 for i in index]}, + index=index, + ) + + +class MultiLevelSingleColumnCreator(Component): + def setup(self, builder: Builder) -> None: + builder.value.register_attribute_producer( + "some_attribute", + lambda idx: pd.DataFrame({"some_column": [i % 3 for i in idx]}, index=idx), + ) + + +class MultiLevelMultiColumnCreator(Component): + def setup(self, builder: Builder) -> None: + builder.value.register_attribute_producer( + "some_attribute", + lambda idx: pd.DataFrame( + {"column_1": [i % 3 for i in idx], "column_2": [i % 3 for i in idx]}, + index=idx, + ), + ) + builder.value.register_attribute_producer( + "some_other_attribute", + lambda idx: pd.DataFrame({"column_3": [i % 3 for i in idx]}, index=idx), + ) + + +class AttributePipelineCreator(Component): + """A helper class to register different types of attribute pipelines. + + It does NOT create any private columns; use the ColumnCreator class for that. + + """ + + def setup(self, builder: Builder) -> None: + + # Simple attributes + builder.value.register_attribute_producer( + "attribute_generating_columns_4_5", + lambda idx: pd.DataFrame( + { + "test_column_4": [i % 3 for i in idx], + "test_column_5": [i % 3 for i in idx], + }, + index=idx, + ), + ) + builder.value.register_attribute_producer( + "attribute_generating_column_8", + lambda idx: pd.DataFrame({"test_column_8": [i % 3 for i in idx]}, index=idx), + ) + + # Non-simple attributes + # For this test, we make them non-simple by registering a modifer that doesn't actually modify anything + builder.value.register_attribute_producer( + "test_attribute", + lambda idx: pd.Series([i % 3 for i in idx], index=idx), + ) + builder.value.register_attribute_producer( + "attribute_generating_columns_6_7", + lambda idx: pd.DataFrame( + { + "test_column_6": [i % 3 for i in idx], + "test_column_7": [i % 3 for i in idx], + }, + index=idx, + ), + ) + builder.value.register_attribute_modifier( + "test_attribute", + lambda index, series: series, + ) + builder.value.register_attribute_modifier( + "attribute_generating_columns_6_7", + lambda index, df: df, ) class LookupCreator(ColumnCreator): + CONFIGURATION_DEFAULTS = { "lookup_creator": { "data_sources": { @@ -162,26 +254,26 @@ class LookupCreator(ColumnCreator): "favorite_scalar": 0.4, "favorite_color": "simulants.favorite_color", "favorite_number": "simulants.favorite_number", + "favorite_list": [9, 4], "baking_time": "self::load_baking_time", "cooling_time": "tests.framework.components.test_component::load_cooling_time", }, - "alternate_data_sources": { - "favorite_list": [9, 4], - }, } } - def build_all_lookup_tables(self, builder: "Builder") -> None: - super().build_all_lookup_tables(builder) - if not self.configuration: - raise ConfigurationError( - "Configuration not set. This may break tests using the lookup table creator helper." - ) - self.lookup_tables["favorite_list"] = self.build_lookup_table( - builder, - self.configuration["alternate_data_sources"]["favorite_list"], - ["column_1", "column_2"], + def setup(self, builder: Builder) -> None: + super().setup(builder) + self.favorite_team_table = self.build_lookup_table(builder, "favorite_team") + self.favorite_scalar_table = self.build_lookup_table( + builder, "favorite_scalar", value_columns="scalar" ) + self.favorite_color_table = self.build_lookup_table(builder, "favorite_color") + self.favorite_number_table = self.build_lookup_table(builder, "favorite_number") + self.favorite_list_table = self.build_lookup_table( + builder, "favorite_list", value_columns=["column_1", "column_2"] + ) + self.baking_time_table = self.build_lookup_table(builder, "baking_time") + self.cooling_time_table = self.build_lookup_table(builder, "cooling_time") @staticmethod def load_baking_time(_builder: Builder) -> float: @@ -189,15 +281,47 @@ def load_baking_time(_builder: Builder) -> float: class SingleLookupCreator(ColumnCreator): - pass + CONFIGURATION_DEFAULTS = { + "single_lookup_creator": { + "data_sources": {"favorite_color": "simulants.favorite_color"} + } + } + + def setup(self, builder: Builder) -> None: + super().setup(builder) + self.favorite_color_table = self.build_lookup_table(builder, "favorite_color") class OrderedColumnsLookupCreator(Component): + VALUE_COLUMNS = ["one", "two", "three", "four", "five", "six", "seven"] + ORDERED_COLUMNS = pd.DataFrame( + [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]], + columns=VALUE_COLUMNS, + ) + @property - def columns_created(self) -> list[str]: - return ["foo", "bar"] + def configuration_defaults(self) -> dict[str, Any]: + return { + self.name: { + "data_sources": { + "categorical": self._get_ordered_columns_categorical(), + "interpolated": self._get_ordered_columns_interpolated(), + }, + } + } + + def setup(self, builder: Builder) -> None: + self.categorical_table = self.build_lookup_table( + builder, "categorical", value_columns=self.VALUE_COLUMNS + ) + self.interpolated_table = self.build_lookup_table( + builder, "interpolated", value_columns=self.VALUE_COLUMNS + ) + builder.population.register_initializer( + initializer=self.initialize_foo_bar, columns=["foo", "bar"] + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_foo_bar(self, pop_data: SimulantData) -> None: initialization_data = pd.DataFrame( { "foo": "key1", @@ -207,76 +331,34 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: ) self.population_view.update(initialization_data) - def build_all_lookup_tables(self, builder: "Builder") -> None: - value_columns = ["one", "two", "three", "four", "five", "six", "seven"] - ordered_columns = pd.DataFrame( - [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]], - columns=value_columns, - ) - ordered_columns_categorical = ordered_columns.copy() - ordered_columns_categorical["foo"] = ["key1", "key2"] - ordered_columns_interpolated = ordered_columns.copy() - ordered_columns_interpolated["foo"] = ["key1", "key1"] - ordered_columns_interpolated["bar_start"] = [10, 20] - ordered_columns_interpolated["bar_end"] = [20, 30] - self.lookup_tables["ordered_columns_categorical"] = self.build_lookup_table( - builder, - ordered_columns_categorical, - value_columns, - ) - self.lookup_tables["ordered_columns_interpolated"] = self.build_lookup_table( - builder, - ordered_columns_interpolated, - value_columns, - ) - + def _get_ordered_columns_categorical(self) -> pd.DataFrame: + df = self.ORDERED_COLUMNS.copy() + df["foo"] = ["key1", "key2"] + return df -class ColumnRequirer(Component): - @property - def columns_required(self) -> list[str]: - return ["test_column_1", "test_column_2"] + def _get_ordered_columns_interpolated(self) -> pd.DataFrame: + df = self.ORDERED_COLUMNS.copy() + df["foo"] = ["key1", "key1"] + df["bar_start"] = [10, 20] + df["bar_end"] = [20, 30] + return df class ColumnCreatorAndRequirer(Component): - @property - def columns_required(self) -> list[str]: - return ["test_column_1", "test_column_2"] - - @property - def columns_created(self) -> list[str]: - return ["test_column_4"] - - @property - def initialization_requirements(self) -> list[str | Resource]: - return ["test_column_2", self.pipeline, self.randomness] - def setup(self, builder: Builder) -> None: self.pipeline = builder.value.get_value("pipeline_1") self.randomness = builder.randomness.get_stream("stream_1") + builder.population.register_initializer( + initializer=self.initialize_test_column_4, + columns="test_column_4", + required_resources=["test_column_2", self.pipeline, self.randomness], + ) - def on_initialize_simulants(self, pop_data: SimulantData) -> None: + def initialize_test_column_4(self, pop_data: SimulantData) -> None: initialization_data = pd.DataFrame({"test_column_4": 8}, index=pop_data.index) self.population_view.update(initialization_data) -class ColumnCreatorAndAllRequirer(ColumnCreatorAndRequirer): - @property - def columns_required(self) -> list[str]: - return [] - - -class AllColumnsRequirer(Component): - @property - def columns_required(self) -> list[str]: - return [] - - -class FilteredPopulationView(ColumnRequirer): - @property - def population_view_query(self) -> str: - return "test_column_1 == 5" - - class NoPopulationView(Component): pass diff --git a/tests/interface/test_interactive.py b/tests/interface/test_interactive.py index 180b68363..232eaf6c3 100644 --- a/tests/interface/test_interactive.py +++ b/tests/interface/test_interactive.py @@ -1,6 +1,18 @@ import pandas as pd import pytest +from tests.framework.population.helpers import ( + assert_squeezing_multi_level_single_outer_multi_inner, + assert_squeezing_multi_level_single_outer_single_inner, + assert_squeezing_single_level_single_col, +) +from tests.helpers import ( + AttributePipelineCreator, + ColumnCreator, + MultiLevelMultiColumnCreator, + MultiLevelSingleColumnCreator, + SingleColumnCreator, +) from vivarium import InteractiveContext from vivarium.framework.values import Pipeline @@ -25,3 +37,79 @@ def test_run_for_duration() -> None: sim.run_for("5 days") assert sim._clock.time == initial_time + pd.Timedelta("15 days") # type: ignore[operator] + + +def test_get_attribute_names() -> None: + sim = InteractiveContext( + components=[MultiLevelMultiColumnCreator(), AttributePipelineCreator()] + ) + assert set(sim.get_attribute_names()) == set( + [ + # MultiLevelMultiColumnCreator attributes + "some_attribute", + "some_other_attribute", + # AttributePipelineCreator attributes + "attribute_generating_columns_4_5", + "attribute_generating_column_8", + "test_attribute", + "attribute_generating_columns_6_7", + ] + ) + # Make sure there's nothing unexpected compared to the actual population df + assert set(sim.get_attribute_names()) == set( + sim.get_population().columns.get_level_values(0) + ) + + +def test_get_population_squeezing() -> None: + + # Single-level, single-column -> series + sim = InteractiveContext(components=[SingleColumnCreator()]) + unsqueezed = sim.get_population(["test_column_1"]) + squeezed = sim.get_population("test_column_1") + assert_squeezing_single_level_single_col(unsqueezed, squeezed, "test_column_1") + default = sim.get_population() + assert isinstance(default, pd.Series) + assert isinstance(squeezed, pd.Series) + assert default.equals(squeezed) + + # Single-level, multiple-column -> dataframe + component = ColumnCreator() + sim = InteractiveContext(components=[component], setup=True) + # There's no way to request a squeezed dataframe here. + df = sim.get_population(["test_column_1", "test_column_2", "test_column_3"]) + assert isinstance(df, pd.DataFrame) + assert not isinstance(df.columns, pd.MultiIndex) + default = sim.get_population() + assert default.equals(df) # type: ignore[arg-type] + + # Multi-level, single outer, single inner -> series + sim = InteractiveContext(components=[MultiLevelSingleColumnCreator()], setup=True) + unsqueezed = sim.get_population(["some_attribute"]) + squeezed = sim.get_population("some_attribute") + assert_squeezing_multi_level_single_outer_single_inner( + unsqueezed, squeezed, ("some_attribute", "some_column") + ) + default = sim.get_population() + assert isinstance(default, pd.Series) + assert isinstance(squeezed, pd.Series) + assert default.equals(squeezed) + + # Multi-level, single outer, multiple inner -> inner dataframe + sim = InteractiveContext(components=[MultiLevelMultiColumnCreator()], setup=True) + sim._population._attribute_pipelines.pop("some_other_attribute") + unsqueezed = sim.get_population(["some_attribute"]) + squeezed = sim.get_population("some_attribute") + assert_squeezing_multi_level_single_outer_multi_inner(unsqueezed, squeezed) + default = sim.get_population() + assert isinstance(default, pd.DataFrame) + assert default.equals(squeezed) + + # Multi-level, multiple outer -> full unsqueezed multi-level dataframe + sim = InteractiveContext(components=[MultiLevelMultiColumnCreator()], setup=True) + # There's no way to request a squeezed dataframe here. + df = sim.get_population(["some_attribute", "some_other_attribute"]) + assert isinstance(df, pd.DataFrame) + assert isinstance(df.columns, pd.MultiIndex) + default = sim.get_population() + assert default.equals(df) # type: ignore[arg-type]