@@ -114,12 +114,6 @@ def __init__(self, observations, alternatives, chosen_alternatives=None,
114114 raise ValueError ("Cannot sample without replacement with sample_size {} "
115115 "and n_alts {}" .format (sample_size , alternatives .shape [0 ]))
116116
117- if (observations .index .name == None ):
118- observations .index .name = 'obs_id'
119-
120- if (alternatives .index .name == None ):
121- alternatives .index .name = 'alt_id'
122-
123117 # TO DO - check that dfs have unique indexes
124118 # TO DO - check that chosen_alternatives correspond correctly to other dfs
125119 # TO DO - same with weights (could join onto other tables and then split off)
@@ -130,14 +124,25 @@ def __init__(self, observations, alternatives, chosen_alternatives=None,
130124 observations = observations .drop (chosen_alternatives .name , axis = 'columns' )
131125 chosen_alternatives .name = '_' + alternatives .index .name # avoids conflicts
132126
133- # Check for duplicate column names
134- obs_cols = list (observations .columns ) + list (observations .index .names )
135- alt_cols = list (alternatives .columns ) + list (alternatives .index .names )
136- dupes = set (obs_cols ) & set (alt_cols )
127+ # Allow missing obs and alts, to support .from_df() constructor
128+ if (observations is not None ):
129+
130+ # Provide default names for observation and alternatives id's
131+
132+ if (observations .index .name == None ):
133+ observations .index .name = 'obs_id'
134+
135+ if (alternatives .index .name == None ):
136+ alternatives .index .name = 'alt_id'
137+
138+ # Check for duplicate column names
139+ obs_cols = list (observations .columns ) + list (observations .index .names )
140+ alt_cols = list (alternatives .columns ) + list (alternatives .index .names )
141+ dupes = set (obs_cols ) & set (alt_cols )
137142
138- if len (dupes ) > 0 :
139- raise ValueError ("Both input tables contain column {}. Please ensure "
140- "column names are unique before merging" .format (dupes ))
143+ if len (dupes ) > 0 :
144+ raise ValueError ("Both input tables contain column {}. Please ensure "
145+ "column names are unique before merging" .format (dupes ))
141146
142147 # Normalize weights to a pd.Series
143148 if (weights is not None ) & isinstance (weights , str ):
@@ -172,17 +177,48 @@ def __init__(self, observations, alternatives, chosen_alternatives=None,
172177 self .weights_2d = weights_2d
173178
174179 # Build choice table...
180+ # Allow missing obs and alts, to support .from_df() constructor
181+ if (observations is not None ):
175182
176- if (len (observations ) == 0 ) or (len (alternatives ) == 0 ):
177- self ._merged_table = pd .DataFrame ()
183+ if (len (observations ) == 0 ) or (len (alternatives ) == 0 ):
184+ self ._merged_table = pd .DataFrame ()
178185
179- elif (sample_size is None ):
180- self ._merged_table = self ._build_table_without_sampling ()
186+ elif (sample_size is None ):
187+ self ._merged_table = self ._build_table_without_sampling ()
181188
182- else :
183- self ._merged_table = self ._build_table ()
189+ else :
190+ self ._merged_table = self ._build_table ()
184191
185192
193+ @classmethod
194+ def from_df (cls , df ):
195+ """
196+ Create a MergedChoiceTable instance from a pre-generated DataFrame.
197+
198+ Each chooser's rows should be contiguous. If applicable, the chosen alternative
199+ should be listed first. This ordering is used by MergedChoiceTable.to_frame(),
200+ and appears to be an undocumented requirement of the legacy MNL code.
201+
202+ Parameters
203+ ----------
204+ df : pandas.DataFrame
205+ Table with a two-level MultiIndex where the first level corresponds to the
206+ index of the observations and the second to the index of the alternatives.
207+ May include a binary column named 'chosen' indicating observed choices.
208+
209+ Returns
210+ -------
211+ MergedChoiceTable
212+
213+ """
214+ obj = cls (observations = None , alternatives = None )
215+ obj ._merged_table = df
216+
217+ # TO DO: sort the dataframe so that rows are automatically in a consistent order
218+
219+ return obj
220+
221+
186222 def _merge_interaction_terms (self , df ):
187223 """
188224 Merges interaction terms (if they exist) onto the input DataFrame.
@@ -436,7 +472,7 @@ def observation_id_col(self):
436472 str
437473
438474 """
439- return self .observations .index .name
475+ return self ._merged_table .index .names [ 0 ]
440476
441477
442478 @property
@@ -450,7 +486,7 @@ def alternative_id_col(self):
450486 str
451487
452488 """
453- return self .alternatives .index .name
489+ return self ._merged_table .index .names [ 1 ]
454490
455491
456492 @property
@@ -464,7 +500,7 @@ def choice_col(self):
464500 str or None
465501
466502 """
467- if (self . chosen_alternatives is not None ):
503+ if ('chosen' in self . _merged_table . columns ):
468504 return 'chosen'
469505
470506 else :
0 commit comments