@@ -129,7 +129,7 @@ class UnitScaler(_Transformer[DType]):
129
129
130
130
lower_vectorized : f64 = field (init = False )
131
131
upper_vectorized : f64 = field (init = False )
132
- _scale_vec_to_int : DType = field (init = False )
132
+ _size : DType = field (init = False )
133
133
134
134
def __post_init__ (self ) -> None :
135
135
if self .lower_value >= self .upper_value :
@@ -147,7 +147,7 @@ def __post_init__(self) -> None:
147
147
self .lower_vectorized = f64 (0 )
148
148
self .upper_vectorized = f64 (1 )
149
149
150
- self ._scale_vec_to_int = self .upper_value - self .lower_value
150
+ self ._size = self .upper_value - self .lower_value
151
151
152
152
def to_value (self , vector : Array [f64 ]) -> Array [DType ]:
153
153
"""Transform a value from the unit interval to the range."""
@@ -169,9 +169,7 @@ def _unsafe_to_value_single(self, vector: f64) -> f64:
169
169
scaled = vector * (_u - _l ) + _l
170
170
return np .exp (scaled ) # type: ignore
171
171
172
- _l = self .lower_value
173
- _u = self .upper_value
174
- return vector * (_u - _l ) + _l # type: ignore
172
+ return vector * self ._size + _l # type: ignore
175
173
176
174
def _unsafe_to_value (self , vector : Array [f64 ]) -> Array [f64 ]:
177
175
# NOTE: Unsafe as it does not check boundaries, clip or integer'ness
@@ -203,7 +201,7 @@ def vectorize_size(self, size: f64) -> f64:
203
201
/ (np .log (self .upper_value ) - np .log (self .lower_value )),
204
202
)
205
203
206
- return f64 (size / ( self .upper_value - self . lower_value ) )
204
+ return f64 (size / self ._size )
207
205
208
206
def legal_value (self , value : Array [Any ]) -> Mask :
209
207
# If we have a non numeric dtype, we have to unfortunatly go through but by bit
@@ -253,7 +251,10 @@ def legal_vector_single(self, vector: np.number) -> bool:
253
251
254
252
if not self .log :
255
253
inbounds = bool (self .lower_vectorized <= vector <= self .upper_vectorized )
256
- scaled = vector * self ._scale_vec_to_int
254
+ scaled = vector * self ._size
255
+ print ("SCALED" , scaled )
256
+ print ("is_close" , is_close_to_integer_single (scaled , atol = ATOL ))
257
+ print ("inbounds" , inbounds )
257
258
return bool (is_close_to_integer_single (scaled , atol = ATOL ) and inbounds )
258
259
259
260
value = self ._unsafe_to_value_single (vector ) # type: ignore
0 commit comments