@@ -42,6 +42,10 @@ class OutputsSettings:
42
42
Flag to plot convergence data every N iterations
43
43
If None, no plots will be saved.
44
44
Note that you can not plot convergence data without saving data (and not more frequently than these saves!)
45
+ * plot_sourcewise : bool
46
+ Flag to plot source based multidimensional parameters such as mixing_matrix for each source.
47
+ Otherwise they will be plotted according to the other dimension such as feature.
48
+ Default=False
45
49
* overwrite_logs_folder : bool
46
50
Flag to remove all previous logs if existing (default False)
47
51
@@ -61,6 +65,7 @@ def __init__(self, settings):
61
65
self .print_periodicity = None
62
66
self .plot_periodicity = None
63
67
self .save_periodicity = 50
68
+ self .plot_sourcewise = False
64
69
self .nb_of_patients_to_plot = 5
65
70
66
71
self .root_path = None
@@ -71,6 +76,8 @@ def __init__(self, settings):
71
76
self ._set_print_periodicity (settings )
72
77
self ._set_save_periodicity (settings )
73
78
self ._set_plot_periodicity (settings )
79
+ self ._set_nb_of_patients_to_plot (settings )
80
+ self ._set_plot_sourcewise (settings )
74
81
75
82
# only create folders if the user want to save data or plots and provided a valid path!
76
83
self ._create_root_folder (settings )
@@ -97,6 +104,12 @@ def _set_param_as_int_or_ignore(self, settings: dict, param: str):
97
104
# Update the attribute of self in-place
98
105
setattr (self , param , val )
99
106
107
+ def _set_plot_sourcewise (self , settings : dict ):
108
+ setattr (self , "plot_sourcewise" , settings ["plot_sourcewise" ])
109
+
110
+ def _set_nb_of_patients_to_plot (self , settings : dict ):
111
+ self ._set_param_as_int_or_ignore (settings , "nb_of_patients_to_plot" )
112
+
100
113
def _set_print_periodicity (self , settings : dict ):
101
114
self ._set_param_as_int_or_ignore (settings , "print_periodicity" )
102
115
@@ -516,6 +529,8 @@ def set_logs(self, path: Optional[Union[str, Path]] = None, **kwargs):
516
529
Note that:
517
530
* it should be a multiple of save_periodicity
518
531
* setting a too low value (frequent) we seriously slow down you fit
532
+ * plot_sourcewise : bool, optional, default False
533
+ Set this to True to plot the source-based parameters sourcewise.
519
534
* overwrite_logs_folder: bool, optional, default False
520
535
Set it to ``True`` to overwrite the content of the folder in ``path``.
521
536
* nb_of_patients_to_plot: int, optional default 5
@@ -538,6 +553,7 @@ def set_logs(self, path: Optional[Union[str, Path]] = None, **kwargs):
538
553
"print_periodicity" : None ,
539
554
"save_periodicity" : 10 ,
540
555
"plot_periodicity" : 50 ,
556
+ "plot_sourcewise" : False ,
541
557
"overwrite_logs_folder" : False ,
542
558
"nb_of_patients_to_plot" : 5 ,
543
559
}
@@ -548,6 +564,7 @@ def set_logs(self, path: Optional[Union[str, Path]] = None, **kwargs):
548
564
"plot_periodicity" ,
549
565
"save_periodicity" ,
550
566
"nb_of_patients_to_plot" ,
567
+ "plot_sourcewise" ,
551
568
):
552
569
if v is not None and not isinstance (v , int ):
553
570
raise LeaspyAlgoInputError (
0 commit comments