1212from precon ._validation import _handle_axis , _list_convert
1313from precon .index_methods import calculate_index
1414from precon .helpers import (
15- flip ,
15+ axis_slice ,
1616 axis_vals_as_frame ,
17+ flip ,
1718 subset_shared_axis ,
1819)
1920from precon .weights import reindex_weights_to_indices
@@ -69,13 +70,13 @@ def impute_base_prices(
6970 axis = _handle_axis (axis )
7071
7172 # Subset the metadata axis to match those of indices, for quicker
72- # handling of function when applied by groupby.
73+ # handling of function when applied by groupby.
7374 to_impute = subset_shared_axis (to_impute , prices , flip (axis ))
7475 if weights is not None :
7576 weights = subset_shared_axis (weights , prices , flip (axis ))
7677 if adjustments is not None :
7778 adjustments = subset_shared_axis (adjustments , prices , flip (axis ))
78-
79+
7980 # Ensure the weights are in the same shape as the prices and
8081 # exclude the prices to impute from the imputation index
8182 # calculation by setting weights to zero.
@@ -84,7 +85,12 @@ def impute_base_prices(
8485 weights = weights .mask (to_impute , 0 )
8586
8687 # Get the base prices to start with from given base period.
87- start_prices = get_base_prices (prices , base_period , axis = axis , ffill_shift = False )
88+ start_prices = get_base_prices (
89+ prices ,
90+ base_period ,
91+ axis = axis ,
92+ fill_shift = False ,
93+ )
8894 base_prices = start_prices .copy ()
8995
9096 if not shift_imputed_values :
@@ -153,27 +159,22 @@ def impute_base_prices(
153159
154160
155161def get_base_prices (
156- prices : pd .DataFrame ,
157- base_period : Union [int , Sequence [int ]] = 1 ,
158- axis : pd ._typing .Axis = 0 ,
159- ffill_shift : bool = True ,
160- ) -> pd .DataFrame :
161- """Return prices at base month with optional ffill and shift.
162-
163- Default behaviour is to fill forward values within the year and
164- shift one period, since base prices usually start being used in
165- the following period up to the next base period. Will return
166- NaNs in non-base month if ffill=False.
162+ prices : pd .DataFrame ,
163+ base_period : Union [int , Sequence [int ]] = 1 ,
164+ axis : pd ._typing .Axis = 0 ,
165+ fill_shift : bool = True ,
166+ ) -> pd .DataFrame :
167+ """Return base prices with optional fill and shift.
167168
168169 Parameters
169170 ----------
170171 prices : DataFrame
171- base_period : int, or list of ints
172- The base periods to select base prices from.
172+ base_period : int, or list of ints, defaults to 1
173+ Base period/s to select base prices from.
173174 axis : {0, 1} int, defaults to 0
174175 Fill and shift direction.
175- ffill_shift : bool, defaults to True
176- Switch to forward fill values within the year and shift by one
176+ fill_shift : bool, defaults to True
177+ Switch to forward fill base prices within year and shift one
177178 period.
178179
179180 Returns
@@ -183,35 +184,37 @@ def get_base_prices(
183184
184185 Notes
185186 -----
186- The base prices are forward filled within each year so that base
187- prices are not filled when prices have stopped being collected.
188- When shifting, base prices are shifted by one period. So for a base
189- period of Jan (int=1) base prices are shifted on to the Feb-Jan+1
190- time delta in which they apply. A base price is needed for the Jan
191- period at the start of the series, so the function fills the
192- shifted values with the unshifted values to achieve this
193-
194- TODO: Make this work for any base period.
195-
187+ When using fill_shift, the base prices are forward filled within
188+ each year so that base prices are not filled when prices have
189+ stopped being collected. Base prices are also shifted by one period,
190+ so for a base period of Jan (int=1) base prices are shifted on to
191+ the Feb-Jan+1 time delta in which they apply.
192+
196193 """
197194 base_period = _list_convert (base_period )
198-
195+
199196 # Only prices in the base periods are not NaN.
200197 months = axis_vals_as_frame (prices , axis , converter = lambda x : x .month )
201198 base_prices = prices .where (months .isin (base_period ))
202199
203- if ffill_shift :
204- # Fill base prices forward within the year and shift one.
205- return base_price_fill_shift (base_prices , axis )
206-
207- return base_prices
200+ # Ensure the prices in the first period are taken as base prices
201+ # even if not a period given by base_period parameter.
202+ first_period = axis_slice (0 , axis )
208203
204+ if not base_prices .iloc [first_period ].isna ().all ():
205+ base_prices .iloc [first_period ] = prices .iloc [first_period ]
209206
210- def base_price_fill_shift (
207+ if fill_shift :
208+ return ffill_shift (base_prices , axis )
209+ else :
210+ return base_prices
211+
212+
213+ def ffill_shift (
211214 base_prices : pd .DataFrame ,
212215 axis : int = 0
213216) -> pd .DataFrame :
214- """Fill forward base prices and shift one period.
217+ """Fill forward base prices within year and shift one period.
215218
216219 Parameters
217220 ----------
@@ -233,8 +236,6 @@ def base_price_fill_shift(
233236 period at the start of the series, so the function fills the
234237 shifted values with the unshifted values to achieve this.
235238
236- TODO: Make this work for any base period.
237-
238239 """
239240 return (
240241 base_prices .groupby (lambda x : x .year , axis = axis )
0 commit comments