@@ -440,7 +440,7 @@ double RcppDiscTE(SEXP mat,
440440 bool normalize = false ,
441441 bool lag_single = false )
442442{
443- infoxtr::infotheo::Matrix m = pat_r2std (mat,false );
443+ infoxtr::infotheo::Matrix m = pat_r2std (mat, false );
444444
445445 std::vector<size_t > tg = Rcpp::as<std::vector<size_t >>(target);
446446 std::vector<size_t > ag = Rcpp::as<std::vector<size_t >>(agent);
@@ -513,3 +513,156 @@ double RcppContTE(const Rcpp::NumericMatrix& mat,
513513 static_cast <size_t >(std::abs (alg)),
514514 base, normalize, lag_single);
515515}
516+
517+ // Wrapper function to preform SURD decomposition for discrete data
518+ // [[Rcpp::export(rng = false)]]
519+ Rcpp::List RcppDiscSURD (SEXP mat,
520+ int max_order = 3 ,
521+ int threads = 1 ,
522+ double base = 2.0 ,
523+ bool normalize = true )
524+ {
525+ infoxtr::surd::DiscMat m = pat_r2std (mat, false );
526+
527+ infoxtr::surd::SURDRes res = infoxtr::surd::surd (
528+ m, static_cast <size_t >(std::abs (max_order)),
529+ static_cast <size_t >(std::abs (threads)), base, normalize);
530+
531+ const size_t k = res.size ();
532+
533+ Rcpp::NumericVector values (k);
534+ Rcpp::CharacterVector types (k);
535+ Rcpp::CharacterVector names (k);
536+
537+ for (size_t i = 0 ; i < k; ++i)
538+ {
539+ values[i] = res.values [i];
540+
541+ // variable name
542+ if (res.types [i] == 3 )
543+ {
544+ // InfoLeak uses all sources
545+ std::string nm = " InfoLeak" ;
546+ names[i] = nm;
547+ types[i] = " InfoLeak" ;
548+ continue ;
549+ }
550+
551+ const auto & vars = res.var_indices [i];
552+
553+ std::string nm;
554+
555+ for (size_t j = 0 ; j < vars.size (); ++j)
556+ {
557+ if (j > 0 )
558+ nm += " _" ;
559+
560+ nm += " V" ;
561+ nm += std::to_string (vars[j]);
562+ }
563+
564+ names[i] = nm;
565+
566+ switch (res.types [i])
567+ {
568+ case 0 :
569+ types[i] = " R" ;
570+ break ;
571+ case 1 :
572+ types[i] = " U" ;
573+ break ;
574+ case 2 :
575+ types[i] = " S" ;
576+ break ;
577+ default :
578+ types[i] = " Unknown" ;
579+ }
580+ }
581+
582+ // values.attr("names") = names;
583+
584+ return Rcpp::List::create (
585+ Rcpp::Named (" vars" ) = names,
586+ Rcpp::Named (" types" ) = types,
587+ Rcpp::Named (" values" ) = values
588+ );
589+ }
590+
591+ // Wrapper function to preform SURD decomposition for continuous data
592+ // [[Rcpp::export(rng = false)]]
593+ Rcpp::List RcppContSURD (SEXP mat,
594+ int max_order = 3 ,
595+ int k = 3 ,
596+ int alg = 0 ,
597+ int threads = 1 ,
598+ double base = 2.0 ,
599+ bool normalize = true )
600+ {
601+ std::vector<std::vector<double >> m = mat_r2std (mat, false );
602+
603+ infoxtr::surd::SURDRes res = infoxtr::surd::surd (
604+ m, static_cast <size_t >(std::abs (max_order)),
605+ static_cast <size_t >(std::abs (k)),
606+ static_cast <size_t >(std::abs (alg)),
607+ static_cast <size_t >(std::abs (threads)),
608+ base, normalize);
609+
610+ const size_t n_vals = res.size ();
611+
612+ Rcpp::NumericVector values (n_vals);
613+ Rcpp::CharacterVector types (n_vals);
614+ Rcpp::CharacterVector names (n_vals);
615+
616+ for (size_t i = 0 ; i < n_vals; ++i)
617+ {
618+ values[i] = res.values [i];
619+
620+ // variable name
621+ if (res.types [i] == 3 )
622+ {
623+ // InfoLeak uses all sources
624+ std::string nm = " InfoLeak" ;
625+ names[i] = nm;
626+ types[i] = " InfoLeak" ;
627+ continue ;
628+ }
629+
630+ const auto & vars = res.var_indices [i];
631+
632+ std::string nm;
633+
634+ for (size_t j = 0 ; j < vars.size (); ++j)
635+ {
636+ if (j > 0 )
637+ nm += " _" ;
638+
639+ nm += " V" ;
640+ nm += std::to_string (vars[j]);
641+ }
642+
643+ names[i] = nm;
644+
645+ switch (res.types [i])
646+ {
647+ case 0 :
648+ types[i] = " R" ;
649+ break ;
650+ case 1 :
651+ types[i] = " U" ;
652+ break ;
653+ case 2 :
654+ types[i] = " S" ;
655+ break ;
656+ default :
657+ types[i] = " Unknown" ;
658+ }
659+ }
660+
661+ // values.attr("names") = names;
662+
663+ return Rcpp::List::create (
664+ Rcpp::Named (" vars" ) = names,
665+ Rcpp::Named (" types" ) = types,
666+ Rcpp::Named (" values" ) = values
667+ );
668+ }
0 commit comments