22from typing import Any , Callable , Dict
33
44import pandas as pd
5+ from loguru import logger
56from vivarium import ConfigTree
67from vivarium .framework .randomness import RandomnessStream
78
@@ -53,6 +54,7 @@ class ColumnNoiseType:
5354 noise_function : Callable [[pd .Series , ConfigTree , RandomnessStream , Any ], pd .Series ]
5455 row_noise_level : float = 0.01
5556 token_noise_level : float = 0.1
57+ noise_level_scaling_function : Callable [[str ], float ] = lambda x : 1.0
5658 additional_parameters : Dict [str , Any ] = None
5759
5860 def __call__ (
@@ -62,18 +64,27 @@ def __call__(
6264 randomness_stream : RandomnessStream ,
6365 additional_key : Any ,
6466 ) -> pd .Series :
65- # TODO: this is a temporary hack to account for all string columns having been made categorical
66- # We should record expected output dtype in the columns data structure
67- if column .dtype .name == "category" :
68- column = column .astype (str )
69- else :
70- column = column .copy ()
71- noise_level = configuration .row_noise_level
67+ column = column .copy ()
68+ noise_level = configuration .row_noise_level * self .noise_level_scaling_function (
69+ column .name
70+ )
7271 to_noise_idx = get_index_to_noise (
7372 column , noise_level , randomness_stream , f"{ self .name } _{ additional_key } "
7473 )
74+ if to_noise_idx .empty :
75+ logger .debug (
76+ f"No cells chosen to noise for noise function { self .name } on column { column .name } . "
77+ "This is likely due to a combination of the configuration noise levels and the input data."
78+ )
79+ return column
7580 noised_data = self .noise_function (
7681 column .loc [to_noise_idx ], configuration , randomness_stream , additional_key
7782 )
83+
84+ # Coerce noised column dtype back to original column's if it has changed
85+ if noised_data .dtype .name != column .dtype .name :
86+ noised_data = noised_data .astype (column .dtype )
87+
7888 column .loc [to_noise_idx ] = noised_data
89+
7990 return column
0 commit comments