77Array2D = NDArray [np .float64 ]
88
99def calculate_scaling_factors (X_sparse : Array1D | Array2D ) -> Array1D :
10+ """
11+ Calculate row-based scaling factors from the sparse vote matrix.
12+
13+ (Outside estimator so available for re-use.)
14+ """
1015 # This allows function to work for 2D (full vote_matrix) and 1D (participant_votes).
1116 # It essentially nest an 1D matrix in a 2D one.
1217 X_sparse = np .atleast_2d (X_sparse )
@@ -29,14 +34,16 @@ def calculate_scaling_factors(X_sparse: Array1D | Array2D) -> Array1D:
2934
3035class SparsityAwareScaler (BaseEstimator , TransformerMixin ):
3136 """
32- Scale projected points (participant/ statements) based on sparsity of vote
37+ Scale projected points (participant or statements) based on sparsity of vote
3338 matrix, to account for any small number of votes by a participant and
3439 prevent those participants from bunching up in the center.
3540
3641 Attributes:
42+ capture_step (str | int | None): Name or index of the capture step in the pipeline.
3743 X_sparse (np.ndarray | None): A sparse array with shape (n_features,)
3844 """
39- def __init__ (self , X_sparse : Optional [Array1D | Array2D ] = None ):
45+ def __init__ (self , capture_step : Optional [str | int ] = None , X_sparse : Optional [Array1D | Array2D ] = None ):
46+ self .capture_step = capture_step
4047 self .X_sparse = X_sparse
4148
4249 # See: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Tags.html#sklearn.utils.Tags
@@ -57,10 +64,53 @@ def inverse_transform(self, X):
5764 scaling_factors = self ._calculate_scaling_factors ()
5865 return X / scaling_factors [:, np .newaxis ]
5966
60- def _calculate_scaling_factors (self ):
61- if self .X_sparse is None :
67+
68+ def _get_pipeline_step (self , step ):
69+ """
70+ Fetch the parent pipeline when available via PatchedPipeline usage.
71+ """
72+ parent = getattr (self , "_parent_pipeline" , None )
73+ if parent is None :
74+ raise RuntimeError (
75+ f"{ self .__class__ .__name__ } cannot resolve `capture_step={ step } ` "
76+ "because it is not being used inside a `PatchedPipeline`. "
77+ "Either use a `PatchedPipeline` or pass `X_sparse` directly."
78+ )
79+ if isinstance (step , str ):
80+ return parent .named_steps [step ]
81+ elif isinstance (step , int ):
82+ return parent .steps [step ][1 ]
83+ else :
84+ raise ValueError ("`capture_step` must be a string (name) or int (index)." )
85+
86+ def _resolve_X_sparse (self ):
87+ """
88+ Resolve X_sparse (a sparse vote matrix) from argument or prior capture step.
89+ """
90+ if self .X_sparse is not None :
91+ return self .X_sparse
92+
93+ capture = self ._get_pipeline_step (self .capture_step )
94+ if not hasattr (capture , "X_captured_" ):
6295 raise AttributeError (
63- "Missing `X_sparse`. Pass `X_sparse` when initializing SparsityAwareScaler."
96+ f"Step '{ self .capture_step } ' does not contain `.X_captured_`. "
97+ f"Did you run `fit/transform` on the pipeline?"
6498 )
99+ return capture .X_captured_
65100
66- return calculate_scaling_factors (X_sparse = self .X_sparse )
101+ def _calculate_scaling_factors (self ):
102+ X_sparse = self ._resolve_X_sparse ()
103+ return calculate_scaling_factors (X_sparse = X_sparse )
104+
105+ class SparsityAwareCapturer (BaseEstimator , TransformerMixin ):
106+ """
107+ A passthrough transformer that captures and stores the X it receives in
108+ `self.X_captured_`. Useful in pipelines where a later step needs access to
109+ this intermediate result.
110+ """
111+ def fit (self , X , y = None ):
112+ return self
113+
114+ def transform (self , X ):
115+ self .X_captured_ = X # Store the actual input value
116+ return X
0 commit comments