@@ -245,6 +245,93 @@ TEST_CASE("conditional_sampling") {
245245 }
246246 }
247247 }
248+ SECTION (" count_conditional" ) {
249+ const auto comp_val = 0.001 ;
250+
251+ // Test count conditional sampling - count grid points where condition is met
252+ params.clear ();
253+ params.set (" grid_name" , grid->name ());
254+ params.set <std::string>(" input_field" , " count" );
255+ params.set <std::string>(" condition_field" , " qc" );
256+ params.set <std::string>(" condition_operator" , " gt" );
257+ params.set <std::string>(" condition_value" , std::to_string (comp_val));
258+
259+ // Set time for qc and randomize its values
260+ qc11.get_header ().get_tracking ().update_time_stamp (t0);
261+ qc12.get_header ().get_tracking ().update_time_stamp (t0);
262+ qc21.get_header ().get_tracking ().update_time_stamp (t0);
263+ randomize (qc11, engine, pdf);
264+ randomize (qc12, engine, pdf);
265+ randomize (qc21, engine, pdf);
266+
267+ // Create and set up the diagnostic for count
268+ auto count_diag11 = diag_factory.create (" ConditionalSampling" , comm, params);
269+ auto count_diag12 = diag_factory.create (" ConditionalSampling" , comm, params);
270+ auto count_diag21 = diag_factory.create (" ConditionalSampling" , comm, params);
271+ count_diag11->set_grids (gm);
272+ count_diag12->set_grids (gm);
273+ count_diag21->set_grids (gm);
274+
275+ // Set the fields for each diagnostic
276+ count_diag11->set_required_field (qc11);
277+ count_diag11->initialize (t0, RunType::Initial);
278+ count_diag11->compute_diagnostic ();
279+ auto count_diag11_f = count_diag11->get_diagnostic ();
280+ count_diag11_f.sync_to_host ();
281+ auto count_diag11_v = count_diag11_f.get_view <const Real *, Host>();
282+
283+ count_diag12->set_required_field (qc12);
284+ count_diag12->initialize (t0, RunType::Initial);
285+ count_diag12->compute_diagnostic ();
286+ auto count_diag12_f = count_diag12->get_diagnostic ();
287+ count_diag12_f.sync_to_host ();
288+ auto count_diag12_v = count_diag12_f.get_view <const Real *, Host>();
289+
290+ count_diag21->set_required_field (qc21);
291+ count_diag21->initialize (t0, RunType::Initial);
292+ count_diag21->compute_diagnostic ();
293+ auto count_diag21_f = count_diag21->get_diagnostic ();
294+ count_diag21_f.sync_to_host ();
295+ auto count_diag21_v = count_diag21_f.get_view <const Real **, Host>();
296+
297+ auto qc11_v = qc11.get_view <const Real *, Host>();
298+ auto qc12_v = qc12.get_view <const Real *, Host>();
299+ auto qc21_v = qc21.get_view <const Real **, Host>();
300+
301+ // Check the results - count should be 1.0 where condition is met, fill_value otherwise
302+ for (int ilev = 0 ; ilev < nlevs; ++ilev) {
303+ // check count for qc12
304+ if (qc12_v (ilev) > comp_val) {
305+ REQUIRE (count_diag12_v (ilev) == 1.0 );
306+ } else {
307+ REQUIRE (count_diag12_v (ilev) == fill_value);
308+ }
309+ }
310+
311+ for (int icol = 0 ; icol < ngcols; ++icol) {
312+ // Check count for qc11
313+ if (qc11_v (icol) > comp_val) {
314+ REQUIRE (count_diag11_v (icol) == 1.0 );
315+ } else {
316+ REQUIRE (count_diag11_v (icol) == fill_value);
317+ }
318+
319+ for (int ilev = 0 ; ilev < nlevs; ++ilev) {
320+ // check count for qc21
321+ if (qc21_v (icol, ilev) > comp_val) {
322+ REQUIRE (count_diag21_v (icol, ilev) == 1.0 );
323+ } else {
324+ REQUIRE (count_diag21_v (icol, ilev) == fill_value);
325+ }
326+ // check count again, but the negative
327+ if (qc21_v (icol, ilev) <= comp_val) {
328+ REQUIRE_FALSE (count_diag21_v (icol, ilev) == 1.0 );
329+ } else {
330+ REQUIRE_FALSE (count_diag21_v (icol, ilev) == fill_value);
331+ }
332+ }
333+ }
334+ }
248335}
249336
250337} // namespace scream
0 commit comments