11#include < chrono>
22#include < cmath>
33#include < cstring>
4+ #include < functional>
45#include < future>
56#include < iostream>
67#include < map>
4041
4142namespace tiledbpy {
4243
43- using namespace std ;
4444using namespace tiledb ;
4545namespace py = pybind11;
4646using namespace pybind11 ::literals;
@@ -297,18 +297,260 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {
297297 return count;
298298}
299299
300+ class PyAgg {
301+
302+ using ByteBuffer = py::array_t <uint8_t >;
303+ using AggToBufferMap = std::map<std::string, ByteBuffer>;
304+ using AttrToAggsMap = std::map<std::string, AggToBufferMap>;
305+
306+ private:
307+ Context ctx_;
308+ std::shared_ptr<tiledb::Array> array_;
309+ std::shared_ptr<tiledb::Query> query_;
310+ AttrToAggsMap result_buffers_;
311+ AttrToAggsMap validity_buffers_;
312+
313+ py::dict original_input_;
314+ std::vector<std::string> attrs_;
315+
316+ public:
317+ PyAgg () = delete ;
318+
319+ PyAgg (const Context &ctx, py::object py_array, py::object py_layout,
320+ py::dict attr_to_aggs_input)
321+ : ctx_(ctx), original_input_(attr_to_aggs_input) {
322+ tiledb_array_t *c_array_ = (py::capsule)py_array.attr (" __capsule__" )();
323+
324+ // We never own this pointer; pass own=false
325+ array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false );
326+ query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);
327+
328+ bool issparse = array_->schema ().array_type () == TILEDB_SPARSE;
329+ tiledb_layout_t layout = (tiledb_layout_t )py_layout.cast <int32_t >();
330+ if (!issparse && layout == TILEDB_UNORDERED) {
331+ TPY_ERROR_LOC (" TILEDB_UNORDERED read is not supported for dense arrays" )
332+ }
333+ query_->set_layout (layout);
334+
335+ // Iterate through the requested attributes
336+ for (auto attr_to_aggs : attr_to_aggs_input) {
337+ auto attr_name = attr_to_aggs.first .cast <std::string>();
338+ auto aggs = attr_to_aggs.second .cast <std::vector<std::string>>();
339+
340+ tiledb::Attribute attr = array_->schema ().attribute (attr_name);
341+ attrs_.push_back (attr_name);
342+
343+ // For non-nullable attributes, applying max and min to the empty set is
344+ // undefined. To check for this, we need to also run the count aggregate
345+ // to make sure count != 0
346+ bool requested_max =
347+ std::find (aggs.begin (), aggs.end (), " max" ) != aggs.end ();
348+ bool requested_min =
349+ std::find (aggs.begin (), aggs.end (), " min" ) != aggs.end ();
350+ if (!attr.nullable () && (requested_max || requested_min)) {
351+ // If the user already also requested count, then we don't need to
352+ // request it again
353+ if (std::find (aggs.begin (), aggs.end (), " count" ) == aggs.end ()) {
354+ aggs.push_back (" count" );
355+ }
356+ }
357+
358+ // Iterate through the aggreate operations to apply on the given attribute
359+ for (auto agg_name : aggs) {
360+ _apply_agg_operator_to_attr (agg_name, attr_name);
361+
362+ // Set the result data buffers
363+ auto *res_buf = &result_buffers_[attr_name][agg_name];
364+ if (" count" == agg_name || " null_count" == agg_name ||
365+ " mean" == agg_name) {
366+ // count and null_count use uint64 and mean uses float64
367+ *res_buf = py::array (py::dtype (" uint8" ), 8 );
368+ } else {
369+ // max, min, and sum use the dtype of the attribute
370+ py::dtype dt (tiledb_dtype (attr.type (), attr.cell_size ()));
371+ *res_buf = py::array (py::dtype (" uint8" ), dt.itemsize ());
372+ }
373+ query_->set_data_buffer (attr_name + agg_name, (void *)res_buf->data (),
374+ 1 );
375+
376+ if (attr.nullable ()) {
377+ // For nullable attributes, if the input set for the aggregation
378+ // contains all NULL values, we will not get an aggregate value back
379+ // as this operation is undefined. We need to check the validity
380+ // buffer beforehand to see if we had a valid result
381+ if (!(" count" == agg_name || " null_count" == agg_name)) {
382+ auto *val_buf = &validity_buffers_[attr.name ()][agg_name];
383+ *val_buf = py::array (py::dtype (" uint8" ), 1 );
384+ query_->set_validity_buffer (attr_name + agg_name,
385+ (uint8_t *)val_buf->data (), 1 );
386+ }
387+ }
388+ }
389+ }
390+ }
391+
392+ void _apply_agg_operator_to_attr (const std::string &op_label,
393+ const std::string &attr_name) {
394+ using AggregateFunc =
395+ std::function<ChannelOperation (const Query &, const std::string &)>;
396+
397+ std::unordered_map<std::string, AggregateFunc> label_to_agg_func = {
398+ {" sum" , QueryExperimental::create_unary_aggregate<SumOperator>},
399+ {" min" , QueryExperimental::create_unary_aggregate<MinOperator>},
400+ {" max" , QueryExperimental::create_unary_aggregate<MaxOperator>},
401+ {" mean" , QueryExperimental::create_unary_aggregate<MeanOperator>},
402+ {" null_count" ,
403+ QueryExperimental::create_unary_aggregate<NullCountOperator>},
404+ };
405+
406+ QueryChannel default_channel =
407+ QueryExperimental::get_default_channel (*query_);
408+
409+ if (label_to_agg_func.find (op_label) != label_to_agg_func.end ()) {
410+ AggregateFunc create_unary_aggregate = label_to_agg_func.at (op_label);
411+ ChannelOperation op = create_unary_aggregate (*query_, attr_name);
412+ default_channel.apply_aggregate (attr_name + op_label, op);
413+ } else if (" count" == op_label) {
414+ default_channel.apply_aggregate (attr_name + op_label, CountOperation ());
415+ } else {
416+ TPY_ERROR_LOC (" Invalid channel operation " + op_label +
417+ " passed to apply_aggregate." );
418+ }
419+ }
420+
421+ py::dict get_aggregate () {
422+ query_->submit ();
423+
424+ // Cast the results to the correct dtype and output this as a Python dict
425+ py::dict output;
426+ for (auto attr_to_agg : original_input_) {
427+ // Be clear in our variable names for strings as py::dict uses py::str
428+ // keys whereas std::map uses std::string keys
429+ std::string attr_cpp_name = attr_to_agg.first .cast <string>();
430+
431+ py::str attr_py_name (attr_cpp_name);
432+ output[attr_py_name] = py::dict ();
433+
434+ tiledb::Attribute attr = array_->schema ().attribute (attr_cpp_name);
435+
436+ for (auto agg_py_name : original_input_[attr_py_name]) {
437+ std::string agg_cpp_name = agg_py_name.cast <string>();
438+
439+ if (_is_invalid (attr, agg_cpp_name)) {
440+ output[attr_py_name][agg_py_name] =
441+ _is_integer_dtype (attr) ? py::none () : py::cast (NAN);
442+ } else {
443+ output[attr_py_name][agg_py_name] = _set_result (attr, agg_cpp_name);
444+ }
445+ }
446+ }
447+ return output;
448+ }
449+
450+ bool _is_invalid (tiledb::Attribute attr, std::string agg_name) {
451+ if (attr.nullable ()) {
452+ if (" count" == agg_name || " null_count" == agg_name)
453+ return false ;
454+
455+ // For nullable attributes, check if the validity buffer returned false
456+ const void *val_buf = validity_buffers_[attr.name ()][agg_name].data ();
457+ return *((uint8_t *)(val_buf)) == 0 ;
458+ } else {
459+ // For non-nullable attributes, max and min are undefined for the empty
460+ // set, so we must check the count == 0
461+ if (" max" == agg_name || " min" == agg_name) {
462+ const void *count_buf = result_buffers_[attr.name ()][" count" ].data ();
463+ return *((uint64_t *)(count_buf)) == 0 ;
464+ }
465+ return false ;
466+ }
467+ }
468+
469+ bool _is_integer_dtype (tiledb::Attribute attr) {
470+ switch (attr.type ()) {
471+ case TILEDB_INT8:
472+ case TILEDB_INT16:
473+ case TILEDB_UINT8:
474+ case TILEDB_INT32:
475+ case TILEDB_INT64:
476+ case TILEDB_UINT16:
477+ case TILEDB_UINT32:
478+ case TILEDB_UINT64:
479+ return true ;
480+ default :
481+ return false ;
482+ }
483+ }
484+
485+ py::object _set_result (tiledb::Attribute attr, std::string agg_name) {
486+ const void *agg_buf = result_buffers_[attr.name ()][agg_name].data ();
487+
488+ if (" mean" == agg_name)
489+ return py::cast (*((double *)agg_buf));
490+
491+ if (" count" == agg_name || " null_count" == agg_name)
492+ return py::cast (*((uint64_t *)agg_buf));
493+
494+ switch (attr.type ()) {
495+ case TILEDB_FLOAT32:
496+ return py::cast (" sum" == agg_name ? *((double *)agg_buf)
497+ : *((float *)agg_buf));
498+ case TILEDB_FLOAT64:
499+ return py::cast (*((double *)agg_buf));
500+ case TILEDB_INT8:
501+ return py::cast (*((int8_t *)agg_buf));
502+ case TILEDB_UINT8:
503+ return py::cast (*((uint8_t *)agg_buf));
504+ case TILEDB_INT16:
505+ return py::cast (*((int16_t *)agg_buf));
506+ case TILEDB_UINT16:
507+ return py::cast (*((uint16_t *)agg_buf));
508+ case TILEDB_UINT32:
509+ return py::cast (*((uint32_t *)agg_buf));
510+ case TILEDB_INT32:
511+ return py::cast (*((int32_t *)agg_buf));
512+ case TILEDB_INT64:
513+ return py::cast (*((int64_t *)agg_buf));
514+ case TILEDB_UINT64:
515+ return py::cast (*((uint64_t *)agg_buf));
516+ default :
517+ TPY_ERROR_LOC (
518+ " [_cast_agg_result] Invalid tiledb dtype for aggregation result" )
519+ }
520+ }
521+
522+ void set_subarray (py::object py_subarray) {
523+ query_->set_subarray (*py_subarray.cast <tiledb::Subarray *>());
524+ }
525+
526+ void set_cond (py::object cond) {
527+ py::object init_pyqc = cond.attr (" init_query_condition" );
528+
529+ try {
530+ init_pyqc (array_->uri (), attrs_, ctx_);
531+ } catch (tiledb::TileDBError &e) {
532+ TPY_ERROR_LOC (e.what ());
533+ } catch (py::error_already_set &e) {
534+ TPY_ERROR_LOC (e.what ());
535+ }
536+ auto pyqc = (cond.attr (" c_obj" )).cast <PyQueryCondition>();
537+ auto qc = pyqc.ptr ().get ();
538+ query_->set_condition (*qc);
539+ }
540+ };
541+
300542class PyQuery {
301543
302544private:
303545 Context ctx_;
304- shared_ptr<tiledb::Domain> domain_;
305- shared_ptr<tiledb::ArraySchema> array_schema_;
306- shared_ptr<tiledb::Array> array_;
307- shared_ptr<tiledb::Query> query_;
546+ std:: shared_ptr<tiledb::Domain> domain_;
547+ std:: shared_ptr<tiledb::ArraySchema> array_schema_;
548+ std:: shared_ptr<tiledb::Array> array_;
549+ std:: shared_ptr<tiledb::Query> query_;
308550 std::vector<std::string> attrs_;
309551 std::vector<std::string> dims_;
310- map<string, BufferInfo> buffers_;
311- vector<string> buffers_order_;
552+ std:: map<std:: string, BufferInfo> buffers_;
553+ std:: vector<std:: string> buffers_order_;
312554
313555 bool deduplicate_ = true ;
314556 bool use_arrow_ = false ;
@@ -320,9 +562,7 @@ class PyQuery {
320562 tiledb_layout_t layout_ = TILEDB_ROW_MAJOR;
321563
322564 // label buffer list
323- std::unordered_map<string, uint64_t > label_input_buffer_data_;
324-
325- std::string uri_;
565+ unordered_map<string, uint64_t > label_input_buffer_data_;
326566
327567public:
328568 tiledb_ctx_t *c_ctx_;
@@ -347,15 +587,11 @@ class PyQuery {
347587 tiledb_array_t *c_array_ = (py::capsule)array.attr (" __capsule__" )();
348588
349589 // we never own this pointer, pass own=false
350- array_ = std::shared_ptr<tiledb::Array>(new Array (ctx_, c_array_, false ));
351-
352- array_schema_ =
353- std::shared_ptr<tiledb::ArraySchema>(new ArraySchema (array_->schema ()));
590+ array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false );
354591
355- domain_ =
356- std::shared_ptr<tiledb::Domain>(new Domain (array_schema_->domain ()));
592+ array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema ());
357593
358- uri_ = array. attr ( " uri " ). cast < std::string>( );
594+ domain_ = std::make_shared<tiledb::Domain>(array_schema_-> domain () );
359595
360596 bool issparse = array_->schema ().array_type () == TILEDB_SPARSE;
361597
@@ -398,8 +634,7 @@ class PyQuery {
398634 }
399635 }
400636
401- query_ =
402- std::shared_ptr<tiledb::Query>(new Query (ctx_, *array_, query_mode));
637+ query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
403638 // [](Query* p){} /* note: no deleter*/);
404639
405640 if (query_mode == TILEDB_READ) {
@@ -424,8 +659,7 @@ class PyQuery {
424659 }
425660
426661 void set_subarray (py::object py_subarray) {
427- tiledb::Subarray *subarray = py_subarray.cast <tiledb::Subarray *>();
428- query_->set_subarray (*subarray);
662+ query_->set_subarray (*py_subarray.cast <tiledb::Subarray *>());
429663 }
430664
431665#if defined(TILEDB_SERIALIZATION)
@@ -456,7 +690,7 @@ class PyQuery {
456690 py::object init_pyqc = cond.attr (" init_query_condition" );
457691
458692 try {
459- init_pyqc (uri_ , attrs_, ctx_);
693+ init_pyqc (array_-> uri () , attrs_, ctx_);
460694 } catch (tiledb::TileDBError &e) {
461695 TPY_ERROR_LOC (e.what ());
462696 } catch (py::error_already_set &e) {
@@ -1538,6 +1772,18 @@ void init_core(py::module &m) {
15381772 &PyQuery::_test_alloc_max_bytes)
15391773 .def_readonly (" retries" , &PyQuery::retries_);
15401774
1775+ py::class_<PyAgg>(m, " PyAgg" )
1776+ .def (py::init<const Context &, py::object, py::object, py::dict>(),
1777+ " ctx" _a, " py_array" _a, " py_layout" _a, " attr_to_aggs_input" _a)
1778+ .def (" set_subarray" , &PyAgg::set_subarray)
1779+ .def (" set_cond" , &PyAgg::set_cond)
1780+ .def (" get_aggregate" , &PyAgg::get_aggregate);
1781+
1782+ py::class_<PAPair>(m, " PAPair" )
1783+ .def (py::init ())
1784+ .def (" get_array" , &PAPair::get_array)
1785+ .def (" get_schema" , &PAPair::get_schema);
1786+
15411787 m.def (" array_to_buffer" , &convert_np);
15421788
15431789 m.def (" init_stats" , &init_stats);
@@ -1548,11 +1794,6 @@ void init_core(py::module &m) {
15481794 m.def (" get_stats" , &get_stats);
15491795 m.def (" use_stats" , &use_stats);
15501796
1551- py::class_<PAPair>(m, " PAPair" )
1552- .def (py::init ())
1553- .def (" get_array" , &PAPair::get_array)
1554- .def (" get_schema" , &PAPair::get_schema);
1555-
15561797 /*
15571798 We need to make sure C++ TileDBError is translated to a correctly-typed py
15581799 error. Note that using py::exception(..., "TileDBError") creates a new
0 commit comments