@@ -140,7 +140,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
140140 _sentinel = f"__clinops_pos_{ uuid .uuid4 ().hex } __"
141141 try :
142142 df [_sentinel ] = np .arange (len (df ))
143- df = self ._fill_with_gap_mask (df , numeric_cols , forward = True )
143+ df = self ._fill_with_gap_mask (df , numeric_cols , self . time_col , forward = True )
144144 df = df .sort_values (_sentinel ).reset_index (drop = True )
145145 finally :
146146 df = df .drop (columns = [_sentinel ], errors = "ignore" )
@@ -158,7 +158,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
158158 _sentinel = f"__clinops_pos_{ uuid .uuid4 ().hex } __"
159159 try :
160160 df [_sentinel ] = np .arange (len (df ))
161- df = self ._fill_with_gap_mask (df , numeric_cols , forward = False )
161+ df = self ._fill_with_gap_mask (df , numeric_cols , self . time_col , forward = False )
162162 df = df .sort_values (_sentinel ).reset_index (drop = True )
163163 finally :
164164 df = df .drop (columns = [_sentinel ], errors = "ignore" )
@@ -201,7 +201,7 @@ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
201201 # ------------------------------------------------------------------
202202
203203 def _fill_with_gap_mask (
204- self , df : pd .DataFrame , numeric_cols : list [str ], forward : bool
204+ self , df : pd .DataFrame , numeric_cols : list [str ], time_col : str , forward : bool
205205 ) -> pd .DataFrame :
206206 """
207207 Apply ffill/bfill with gap masking.
@@ -210,11 +210,10 @@ def _fill_with_gap_mask(
210210 entity group, preventing values from propagating across entity
211211 boundaries (e.g. across patients or admissions).
212212 """
213- assert self .time_col is not None # callers guard this
214213 if self .id_col and self .id_col in df .columns :
215214 parts = []
216215 for _ , grp in df .groupby (self .id_col , sort = False ):
217- grp = grp .sort_values (self . time_col )
216+ grp = grp .sort_values (time_col )
218217 original_nulls = grp [numeric_cols ].isna ()
219218 grp [numeric_cols ] = (
220219 grp [numeric_cols ].ffill () if forward else grp [numeric_cols ].bfill ()
@@ -225,7 +224,7 @@ def _fill_with_gap_mask(
225224 parts .append (grp )
226225 return pd .concat (parts )
227226 else :
228- df = df .sort_values (self . time_col )
227+ df = df .sort_values (time_col )
229228 original_nulls = df [numeric_cols ].isna ()
230229 df [numeric_cols ] = df [numeric_cols ].ffill () if forward else df [numeric_cols ].bfill ()
231230 return self ._mask_large_gaps (
0 commit comments