@@ -275,21 +275,26 @@ def test_stream_rate_conversion_config(
275275 assert sim ._randomness ._rate_conversion_type == rate_conversion
276276
277277
278- def test_filter_for_probability_error_with_null_values () -> None :
278+ @pytest .mark .parametrize (
279+ "probs" ,
280+ [
281+ [0.3 , np .nan , 0.5 , np .nan , 0.7 ],
282+ np .nan ,
283+ pd .Series ([0.2 , 0.3 , np .nan , 0.4 , 0.5 ]),
284+ ],
285+ )
286+ def test_filter_for_probability_error_with_null_values (
287+ probs : float | list [float ] | pd .Series [float ],
288+ ) -> None :
279289 randomness_stream = RandomnessStream (
280290 "test" , lambda : pd .Timestamp (2020 , 1 , 1 ), 1 , IndexMap ()
281291 )
282292 pop = pd .DataFrame ({"age" : [10 , 11 , 12 , 13 , 14 ], "id" : [1 , 2 , 3 , 4 , 5 ]}).set_index ("id" )
283293 with pytest .raises (ValueError , match = "Probabilities contain null values" ):
284- randomness_stream .filter_for_probability (pop , [ 0.3 , np . nan , 0.5 , np . nan , 0.7 ] )
294+ randomness_stream .filter_for_probability (pop , probs )
285295
286296 with pytest .raises (ValueError , match = "Probabilities contain null values" ):
287- randomness_stream .filter_for_probability (pop , np . nan )
297+ randomness_stream .filter_for_probability (pop , probs )
288298
289299 with pytest .raises (ValueError , match = "Probabilities contain null values" ):
290- randomness_stream .filter_for_probability (
291- pop , pd .Series ([0.2 , 0.3 , np .nan , 0.4 , 0.5 ], index = pop .index )
292- )
293-
294- # Doesn't raise
295- randomness_stream .filter_for_probability (pop , [0.3 , 0.4 , 0.5 , 0.6 , 0.7 ])
300+ randomness_stream .filter_for_probability (pop , probs )
0 commit comments