Skip to content

Commit 21d2acc

Browse files
authored
Merge pull request #61 from UDST/large_mnl_sim_w_interactions
latest updates to segmented MNL
2 parents 9baf3e4 + 561aaa2 commit 21d2acc

File tree

5 files changed

+119
-12
lines changed

5 files changed

+119
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='urbansim_templates',
5-
version='0.1.dev16',
5+
version='0.1.dev17',
66
description='UrbanSim extension for managing model steps',
77
author='UrbanSim Inc.',
88
author_email='[email protected]',

urbansim_templates/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = __version__ = '0.1.dev16'
1+
version = __version__ = '0.1.dev17'

urbansim_templates/models/large_multinomial_logit.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,34 @@ def fit(self, mct=None):
446446
self.mergedchoicetable = mct
447447

448448

449-
def run(self):
449+
def run(self, chooser_batch_size=None, interaction_terms=None):
450450
"""
451451
Run the model step: simulate choices and use them to update an Orca column.
452452
453453
The simulated choices are saved to the class object for diagnostics ('choices').
454454
If choices are unconstrained, the probabilities of sampled alternatives are saved
455455
as well ('probabilities').
456+
457+
Parameters
458+
----------
459+
chooser_batch_size : int
460+
This parameter gets passed to
461+
choicemodels.tools.simulation.iterative_lottery_choices and is a temporary
462+
workaround for dealing with memory issues that arise from generating massive
463+
merged choice tables for simulations that involve large numbers of choosers,
464+
large numbers of alternatives, and large numbers of predictors. It allows the
465+
user to specify a batch size for simulating choices one chunk at a time.
466+
467+
interaction_terms : pandas.Series, pandas.DataFrame, or list of either, optional
468+
Additional column(s) of interaction terms whose values depend on the combination
469+
of observation and alternative, to be merged onto the final data table. If passed
470+
as a Series or DataFrame, it should include a two-level MultiIndex. One level's
471+
name and values should match an index or column from the observations table, and
472+
the other should match an index or column from the alternatives table.
473+
474+
Returns
475+
-------
476+
None
456477
457478
"""
458479
obs = self._get_df(tables=self.out_choosers, fallback_tables=self.choosers,
@@ -465,15 +486,18 @@ def run(self):
465486
fitted_parameters = self.fitted_parameters)
466487

467488
def mct(obs, alts):
468-
return MergedChoiceTable(obs, alts, sample_size=self.alt_sample_size)
489+
return MergedChoiceTable(
490+
obs, alts, sample_size=self.alt_sample_size,
491+
interaction_terms=interaction_terms)
469492

470493
def probs(mct):
471494
return model.probabilities(mct)
472495

473496
if (self.constrained_choices == True):
474497
choices = iterative_lottery_choices(obs, alts, mct_callable=mct,
475498
probs_callable=probs, alt_capacity=self.alt_capacity,
476-
chooser_size=self.chooser_size, max_iter=self.max_iter)
499+
chooser_size=self.chooser_size, max_iter=self.max_iter,
500+
chooser_batch_size=chooser_batch_size)
477501

478502
else:
479503
probs = probs(mct(obs, alts))

urbansim_templates/models/segmented_large_multinomial_logit.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from ..utils import update_name
1313
from .. import modelmanager
1414
from . import LargeMultinomialLogitStep
15+
from .shared import TemplateStep
1516

1617

1718
@modelmanager.template
18-
class SegmentedLargeMultinomialLogitStep():
19+
class SegmentedLargeMultinomialLogitStep(TemplateStep):
1920
"""
2021
This template automatically generates a set of LargeMultinomialLogitStep submodels
2122
corresponding to "segments" or categories of choosers. The submodels can be directly
@@ -132,13 +133,15 @@ def get_segmentation_column(self):
132133
# TO DO - this doesn't filter for columns in the model expression; is there
133134
# centralized functionality for this merge that we should be using instead?
134135

135-
obs = orca.get_table(self.defaults.choosers).to_frame()
136-
obs = apply_filter_query(obs, self.defaults.chooser_filters)
137-
138-
alts = orca.get_table(self.defaults.alternatives).to_frame()
139-
alts = apply_filter_query(alts, self.defaults.alt_filters)
136+
obs = self._get_df(
137+
tables=self.defaults.choosers,
138+
filters=self.defaults.chooser_filters)
139+
140+
alts = self._get_df(
141+
tables=self.defaults.alternatives,
142+
filters=self.defaults.alt_filters)
140143

141-
df = pd.merge(obs, alts, how='inner',
144+
df = pd.merge(obs, alts, how='inner',
142145
left_on=self.defaults.choice_column, right_index=True)
143146

144147
return df[self.segmentation_column]
@@ -222,6 +225,8 @@ def fit_all(self):
222225
self.build_submodels()
223226

224227
for k, m in self.submodels.items():
228+
print(' SEGMENT: {0} = {1} '.format(
229+
self.segmentation_column, str(k)).center(70, '#'))
225230
m.fit()
226231

227232
self.name = update_name(self.template, self.name)

urbansim_templates/tests/test_segmented_large_multinomial_logit.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,51 @@ def orca_session():
3232
orca.add_table('alts', alts)
3333

3434

35+
@pytest.fixture
36+
def orca_session_alts_as_list():
37+
"""
38+
Set up a clean Orca session with a couple of data tables.
39+
40+
"""
41+
d1 = {'oid': np.arange(100),
42+
'group': np.random.choice(['A', 'B', 'C'], size=100),
43+
'int_group': np.random.choice([3, 4], size=100),
44+
'obsval': np.random.random(100),
45+
'choice': np.random.choice(np.arange(20), size=100)}
46+
47+
d2 = {'aid': np.arange(20),
48+
'altval': np.random.random(20)}
49+
50+
d3 = {'aid': np.arange(20),
51+
'altval_2': np.random.random(20)}
52+
53+
obs = pd.DataFrame(d1).set_index('oid')
54+
orca.add_table('obs', obs)
55+
56+
d2_df = pd.DataFrame(d2).set_index('aid')
57+
orca.add_table('d2', d2_df)
58+
59+
d3_df = pd.DataFrame(d3).set_index('aid')
60+
orca.add_table('d3', d3_df)
61+
62+
orca.broadcast('d3', 'd2', cast_index=True, onto_index=True)
63+
64+
65+
@pytest.fixture
66+
def m_alts_as_list(orca_session_alts_as_list):
67+
"""
68+
Set up a partially configured model step with multiple
69+
tables of alternatives
70+
"""
71+
m = SegmentedLargeMultinomialLogitStep()
72+
m.defaults.choosers = 'obs'
73+
m.defaults.alternatives = ['d2', 'd3']
74+
m.defaults.choice_column = 'choice'
75+
m.defaults.model_expression = 'obsval + altval + altval_2'
76+
m.segmentation_column = 'group'
77+
return m
78+
79+
3580
@pytest.fixture
3681
def m(orca_session):
3782
"""
@@ -55,6 +100,25 @@ def test_template_validity():
55100
assert validate_template(SegmentedLargeMultinomialLogitStep)
56101

57102

103+
def test_basic_operation(m):
104+
"""
105+
Test basic operation of the template.
106+
107+
"""
108+
m.fit_all()
109+
m.to_dict()
110+
assert len(m.submodels) == 3
111+
112+
def test_basic_operation_alts_as_list(m_alts_as_list):
113+
"""
114+
Test basic operation of the template.
115+
116+
"""
117+
m = m_alts_as_list
118+
m.fit_all()
119+
m.to_dict()
120+
assert len(m.submodels) == 3
121+
58122
def test_basic_operation(m):
59123
"""
60124
Test basic operation of the template.
@@ -103,6 +167,20 @@ def test_alternative_filters(m):
103167
assert len1 == len2
104168

105169

170+
def test_alternative_filters_for_alts_as_list(m_alts_as_list):
171+
"""
172+
Test that the default alternative filters generate the correct data subset.
173+
174+
"""
175+
m = m_alts_as_list
176+
m.defaults.alt_filters = 'altval_2 < 0.5'
177+
178+
m.build_submodels()
179+
for k, v in m.submodels.items():
180+
alts = v._get_df(tables = v.alternatives, filters = v.alt_filters)
181+
assert alts['altval_2'].max() < 0.5
182+
183+
106184
def test_submodel_filters(m):
107185
"""
108186
Test that submodel filters generate the correct data subset.

0 commit comments

Comments
 (0)