Skip to content

Commit f10e768

Browse files
authored
Merge pull request #40 from SpatLyu/dev
export surd utility
2 parents 6cbd6ee + f7e0b0c commit f10e768

3 files changed

Lines changed: 194 additions & 1 deletion

File tree

R/RcppExports.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ RcppContTE <- function(mat, target, agent, lag_p = 3L, lag_q = 3L, k = 3L, alg =
6565
.Call(`_infoxtr_RcppContTE`, mat, target, agent, lag_p, lag_q, k, alg, base, normalize, lag_single)
6666
}
6767

68+
RcppDiscSURD <- function(mat, max_order = 3L, threads = 1L, base = 2.0, normalize = TRUE) {
69+
.Call(`_infoxtr_RcppDiscSURD`, mat, max_order, threads, base, normalize)
70+
}
71+
72+
RcppContSURD <- function(mat, max_order = 3L, k = 3L, alg = 0L, threads = 1L, base = 2.0, normalize = TRUE) {
73+
.Call(`_infoxtr_RcppContSURD`, mat, max_order, k, alg, threads, base, normalize)
74+
}
75+
6876
RcppGenLatticeLag <- function(mat, nb, lag = 1L) {
6977
.Call(`_infoxtr_RcppGenLatticeLag`, mat, nb, lag)
7078
}

src/InfotheoExps.cpp

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ double RcppDiscTE(SEXP mat,
440440
bool normalize = false,
441441
bool lag_single = false)
442442
{
443-
infoxtr::infotheo::Matrix m = pat_r2std(mat,false);
443+
infoxtr::infotheo::Matrix m = pat_r2std(mat, false);
444444

445445
std::vector<size_t> tg = Rcpp::as<std::vector<size_t>>(target);
446446
std::vector<size_t> ag = Rcpp::as<std::vector<size_t>>(agent);
@@ -513,3 +513,156 @@ double RcppContTE(const Rcpp::NumericMatrix& mat,
513513
static_cast<size_t>(std::abs(alg)),
514514
base, normalize, lag_single);
515515
}
516+
517+
// Wrapper function to preform SURD decomposition for discrete data
518+
// [[Rcpp::export(rng = false)]]
519+
Rcpp::List RcppDiscSURD(SEXP mat,
520+
int max_order = 3,
521+
int threads = 1,
522+
double base = 2.0,
523+
bool normalize = true)
524+
{
525+
infoxtr::surd::DiscMat m = pat_r2std(mat, false);
526+
527+
infoxtr::surd::SURDRes res = infoxtr::surd::surd(
528+
m, static_cast<size_t>(std::abs(max_order)),
529+
static_cast<size_t>(std::abs(threads)), base, normalize);
530+
531+
const size_t k = res.size();
532+
533+
Rcpp::NumericVector values(k);
534+
Rcpp::CharacterVector types(k);
535+
Rcpp::CharacterVector names(k);
536+
537+
for (size_t i = 0; i < k; ++i)
538+
{
539+
values[i] = res.values[i];
540+
541+
// variable name
542+
if (res.types[i] == 3)
543+
{
544+
// InfoLeak uses all sources
545+
std::string nm = "InfoLeak";
546+
names[i] = nm;
547+
types[i] = "InfoLeak";
548+
continue;
549+
}
550+
551+
const auto& vars = res.var_indices[i];
552+
553+
std::string nm;
554+
555+
for (size_t j = 0; j < vars.size(); ++j)
556+
{
557+
if (j > 0)
558+
nm += "_";
559+
560+
nm += "V";
561+
nm += std::to_string(vars[j]);
562+
}
563+
564+
names[i] = nm;
565+
566+
switch (res.types[i])
567+
{
568+
case 0:
569+
types[i] = "R";
570+
break;
571+
case 1:
572+
types[i] = "U";
573+
break;
574+
case 2:
575+
types[i] = "S";
576+
break;
577+
default:
578+
types[i] = "Unknown";
579+
}
580+
}
581+
582+
// values.attr("names") = names;
583+
584+
return Rcpp::List::create(
585+
Rcpp::Named("vars") = names,
586+
Rcpp::Named("types") = types,
587+
Rcpp::Named("values") = values
588+
);
589+
}
590+
591+
// Wrapper function to preform SURD decomposition for continuous data
592+
// [[Rcpp::export(rng = false)]]
593+
Rcpp::List RcppContSURD(SEXP mat,
594+
int max_order = 3,
595+
int k = 3,
596+
int alg = 0,
597+
int threads = 1,
598+
double base = 2.0,
599+
bool normalize = true)
600+
{
601+
std::vector<std::vector<double>> m = mat_r2std(mat, false);
602+
603+
infoxtr::surd::SURDRes res = infoxtr::surd::surd(
604+
m, static_cast<size_t>(std::abs(max_order)),
605+
static_cast<size_t>(std::abs(k)),
606+
static_cast<size_t>(std::abs(alg)),
607+
static_cast<size_t>(std::abs(threads)),
608+
base, normalize);
609+
610+
const size_t n_vals = res.size();
611+
612+
Rcpp::NumericVector values(n_vals);
613+
Rcpp::CharacterVector types(n_vals);
614+
Rcpp::CharacterVector names(n_vals);
615+
616+
for (size_t i = 0; i < n_vals; ++i)
617+
{
618+
values[i] = res.values[i];
619+
620+
// variable name
621+
if (res.types[i] == 3)
622+
{
623+
// InfoLeak uses all sources
624+
std::string nm = "InfoLeak";
625+
names[i] = nm;
626+
types[i] = "InfoLeak";
627+
continue;
628+
}
629+
630+
const auto& vars = res.var_indices[i];
631+
632+
std::string nm;
633+
634+
for (size_t j = 0; j < vars.size(); ++j)
635+
{
636+
if (j > 0)
637+
nm += "_";
638+
639+
nm += "V";
640+
nm += std::to_string(vars[j]);
641+
}
642+
643+
names[i] = nm;
644+
645+
switch (res.types[i])
646+
{
647+
case 0:
648+
types[i] = "R";
649+
break;
650+
case 1:
651+
types[i] = "U";
652+
break;
653+
case 2:
654+
types[i] = "S";
655+
break;
656+
default:
657+
types[i] = "Unknown";
658+
}
659+
}
660+
661+
// values.attr("names") = names;
662+
663+
return Rcpp::List::create(
664+
Rcpp::Named("vars") = names,
665+
Rcpp::Named("types") = types,
666+
Rcpp::Named("values") = values
667+
);
668+
}

src/RcppExports.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,36 @@ BEGIN_RCPP
253253
return rcpp_result_gen;
254254
END_RCPP
255255
}
256+
// RcppDiscSURD
257+
Rcpp::List RcppDiscSURD(SEXP mat, int max_order, int threads, double base, bool normalize);
258+
RcppExport SEXP _infoxtr_RcppDiscSURD(SEXP matSEXP, SEXP max_orderSEXP, SEXP threadsSEXP, SEXP baseSEXP, SEXP normalizeSEXP) {
259+
BEGIN_RCPP
260+
Rcpp::RObject rcpp_result_gen;
261+
Rcpp::traits::input_parameter< SEXP >::type mat(matSEXP);
262+
Rcpp::traits::input_parameter< int >::type max_order(max_orderSEXP);
263+
Rcpp::traits::input_parameter< int >::type threads(threadsSEXP);
264+
Rcpp::traits::input_parameter< double >::type base(baseSEXP);
265+
Rcpp::traits::input_parameter< bool >::type normalize(normalizeSEXP);
266+
rcpp_result_gen = Rcpp::wrap(RcppDiscSURD(mat, max_order, threads, base, normalize));
267+
return rcpp_result_gen;
268+
END_RCPP
269+
}
270+
// RcppContSURD
271+
Rcpp::List RcppContSURD(SEXP mat, int max_order, int k, int alg, int threads, double base, bool normalize);
272+
RcppExport SEXP _infoxtr_RcppContSURD(SEXP matSEXP, SEXP max_orderSEXP, SEXP kSEXP, SEXP algSEXP, SEXP threadsSEXP, SEXP baseSEXP, SEXP normalizeSEXP) {
273+
BEGIN_RCPP
274+
Rcpp::RObject rcpp_result_gen;
275+
Rcpp::traits::input_parameter< SEXP >::type mat(matSEXP);
276+
Rcpp::traits::input_parameter< int >::type max_order(max_orderSEXP);
277+
Rcpp::traits::input_parameter< int >::type k(kSEXP);
278+
Rcpp::traits::input_parameter< int >::type alg(algSEXP);
279+
Rcpp::traits::input_parameter< int >::type threads(threadsSEXP);
280+
Rcpp::traits::input_parameter< double >::type base(baseSEXP);
281+
Rcpp::traits::input_parameter< bool >::type normalize(normalizeSEXP);
282+
rcpp_result_gen = Rcpp::wrap(RcppContSURD(mat, max_order, k, alg, threads, base, normalize));
283+
return rcpp_result_gen;
284+
END_RCPP
285+
}
256286
// RcppGenLatticeLag
257287
Rcpp::NumericMatrix RcppGenLatticeLag(const Rcpp::NumericMatrix& mat, const Rcpp::List& nb, int lag);
258288
RcppExport SEXP _infoxtr_RcppGenLatticeLag(SEXP matSEXP, SEXP nbSEXP, SEXP lagSEXP) {
@@ -362,6 +392,8 @@ static const R_CallMethodDef CallEntries[] = {
362392
{"_infoxtr_RcppContCMI", (DL_FUNC) &_infoxtr_RcppContCMI, 8},
363393
{"_infoxtr_RcppDiscTE", (DL_FUNC) &_infoxtr_RcppDiscTE, 9},
364394
{"_infoxtr_RcppContTE", (DL_FUNC) &_infoxtr_RcppContTE, 10},
395+
{"_infoxtr_RcppDiscSURD", (DL_FUNC) &_infoxtr_RcppDiscSURD, 5},
396+
{"_infoxtr_RcppContSURD", (DL_FUNC) &_infoxtr_RcppContSURD, 7},
365397
{"_infoxtr_RcppGenLatticeLag", (DL_FUNC) &_infoxtr_RcppGenLatticeLag, 3},
366398
{"_infoxtr_RcppGenGridLag", (DL_FUNC) &_infoxtr_RcppGenGridLag, 3},
367399
{"_infoxtr_RcppGenTSLag", (DL_FUNC) &_infoxtr_RcppGenTSLag, 2},

0 commit comments

Comments
 (0)