@@ -117,6 +117,61 @@ def compute_aucroc(self, model_outputs, sample_weights=None):
117117 metric = update if metric is None else metric .merge (update )
118118 return metric .compute ()
119119
120+ def test_mse_empty (self ):
121+ """Tests the `empty` method of the `MSE` class."""
122+ m = metrax .MSE .empty ()
123+ self .assertEqual (m .total , jnp .array (0 , jnp .float32 ))
124+ self .assertEqual (m .count , jnp .array (0 , jnp .int32 ))
125+
126+ def test_rmse_empty (self ):
127+ """Tests the `empty` method of the `RMSE` class."""
128+ m = metrax .RMSE .empty ()
129+ self .assertEqual (m .total , jnp .array (0 , jnp .float32 ))
130+ self .assertEqual (m .count , jnp .array (0 , jnp .int32 ))
131+
132+ def test_rsquared_empty (self ):
133+ """Tests the `empty` method of the `RSQUARED` class."""
134+ m = metrax .RSQUARED .empty ()
135+ self .assertEqual (m .total , jnp .array (0 , jnp .float32 ))
136+ self .assertEqual (m .count , jnp .array (0 , jnp .float32 ))
137+ self .assertEqual (m .sum_of_squared_error , jnp .array (0 , jnp .float32 ))
138+ self .assertEqual (m .sum_of_squared_label , jnp .array (0 , jnp .float32 ))
139+
140+ def test_precision_empty (self ):
141+ """Tests the `empty` method of the `Precision` class."""
142+ m = metrax .Precision .empty ()
143+ self .assertEqual (m .true_positives , jnp .array (0 , jnp .float32 ))
144+ self .assertEqual (m .false_positives , jnp .array (0 , jnp .float32 ))
145+
146+ def test_recall_empty (self ):
147+ """Tests the `empty` method of the `Recall` class."""
148+ m = metrax .Recall .empty ()
149+ self .assertEqual (m .true_positives , jnp .array (0 , jnp .float32 ))
150+ self .assertEqual (m .false_negatives , jnp .array (0 , jnp .float32 ))
151+
152+ def test_aucpr_empty (self ):
153+ """Tests the `empty` method of the `AUCPR` class."""
154+ m = metrax .AUCPR .empty ()
155+ self .assertEqual (m .true_positives , jnp .array (0 , jnp .float32 ))
156+ self .assertEqual (m .false_positives , jnp .array (0 , jnp .float32 ))
157+ self .assertEqual (m .false_negatives , jnp .array (0 , jnp .float32 ))
158+ self .assertEqual (m .num_thresholds , 0 )
159+
160+ def test_aucroc_empty (self ):
161+ """Tests the `empty` method of the `AUCROC` class."""
162+ m = metrax .AUCROC .empty ()
163+ self .assertEqual (m .true_positives , jnp .array (0 , jnp .float32 ))
164+ self .assertEqual (m .true_negatives , jnp .array (0 , jnp .float32 ))
165+ self .assertEqual (m .false_positives , jnp .array (0 , jnp .float32 ))
166+ self .assertEqual (m .false_negatives , jnp .array (0 , jnp .float32 ))
167+ self .assertEqual (m .num_thresholds , 0 )
168+
169+ def test_perplexity_empty (self ):
170+ """Tests the `empty` method of the `Perplexity` class."""
171+ m = metrax .Perplexity .empty ()
172+ self .assertEqual (m .aggregate_crossentropy , jnp .array (0 , jnp .float32 ))
173+ self .assertEqual (m .num_samples , jnp .array (0 , jnp .float32 ))
174+
120175 @parameterized .named_parameters (
121176 ('basic' , OUTPUT_LABELS , OUTPUT_PREDS , 0.5 ),
122177 ('high_threshold' , OUTPUT_LABELS , OUTPUT_PREDS , 0.7 ),
0 commit comments