@@ -39,71 +39,56 @@ def samples_transfer_ratio_func(system_metric_info: dict):
3939
4040 """
4141
42- info = system_metric_info .get (SystemMetricType .SAMPLES_TRANSFER_RATIO .value )
42+ info = system_metric_info .get (
43+ SystemMetricType .SAMPLES_TRANSFER_RATIO .value )
4344 inference_num = 0
4445 transfer_num = 0
4546 for inference_data , transfer_data in info :
4647 inference_num += len (inference_data )
4748 transfer_num += len (transfer_data )
4849 return round (float (transfer_num ) / (inference_num + 1 ), 4 )
4950
51+
5052def compute (key , matrix ):
5153 """
5254 Compute BWT and FWT scores for a given matrix.
5355 """
54- # pylint: disable=C0103
55- # pylint: disable=C0301
56- # pylint: disable=C0303
57- # pylint: disable=R0912
56+ print (
57+ f"compute function: key={ key } , matrix={ matrix } , type(matrix)={ type (matrix )} " )
5858
59- print (f"compute function: key={ key } , matrix={ matrix } , type(matrix)={ type (matrix )} " )
60-
6159 length = len (matrix )
6260 accuracy = 0.0
63- BWT_score = 0.0
64- FWT_score = 0.0
61+ bwt_score = 0.0
62+ fwt_score = 0.0
6563 flag = True
6664
67- if key == 'all' :
68- for i in range (length - 1 , 0 , - 1 ):
69- sum_before_i = sum (item ['accuracy' ] for item in matrix [i ][:i ])
70- sum_after_i = sum (item ['accuracy' ] for item in matrix [i ][- (length - i - 1 ):])
71- if i == 0 :
72- seen_class_accuracy = 0.0
73- else :
74- seen_class_accuracy = sum_before_i / i
75- if length - 1 - i == 0 :
76- unseen_class_accuracy = 0.0
77- else :
78- unseen_class_accuracy = sum_after_i / (length - 1 - i )
79- print (f"round { i } : unseen class accuracy is { unseen_class_accuracy } , seen class accuracy is { seen_class_accuracy } " )
80-
8165 for row in matrix :
8266 if not isinstance (row , list ) or len (row ) != length - 1 :
8367 flag = False
8468 break
8569
8670 if not flag :
87- BWT_score = np .nan
88- FWT_score = np .nan
89- return BWT_score , FWT_score
71+ bwt_score = np .nan
72+ fwt_score = np .nan
73+ return bwt_score , fwt_score
9074
9175 for i in range (length - 1 ):
9276 for j in range (length - 1 ):
9377 if 'accuracy' in matrix [i + 1 ][j ] and 'accuracy' in matrix [i ][j ]:
9478 accuracy += matrix [i + 1 ][j ]['accuracy' ]
95- BWT_score += matrix [i + 1 ][j ]['accuracy' ] - matrix [i ][j ]['accuracy' ]
96-
79+ bwt_score += matrix [i + 1 ][j ]['accuracy' ] - \
80+ matrix [i ][j ]['accuracy' ]
81+
9782 for i in range (0 , length - 1 ):
9883 if 'accuracy' in matrix [i ][i ] and 'accuracy' in matrix [0 ][i ]:
99- FWT_score += matrix [i ][i ]['accuracy' ] - matrix [0 ][i ]['accuracy' ]
84+ fwt_score += matrix [i ][i ]['accuracy' ] - matrix [0 ][i ]['accuracy' ]
10085
10186 accuracy = accuracy / ((length - 1 ) * (length - 1 ))
102- BWT_score = BWT_score / ((length - 1 ) * (length - 1 ))
103- FWT_score = FWT_score / (length - 1 )
87+ bwt_score = bwt_score / ((length - 1 ) * (length - 1 ))
88+ fwt_score = fwt_score / (length - 1 )
10489
105- print (f"{ key } BWT_score: { BWT_score } " )
106- print (f"{ key } FWT_score: { FWT_score } " )
90+ print (f"{ key } BWT_score: { bwt_score } " )
91+ print (f"{ key } FWT_score: { fwt_score } " )
10792
10893 my_matrix = []
10994 for i in range (length - 1 ):
@@ -112,48 +97,53 @@ def compute(key, matrix):
11297 if 'accuracy' in matrix [i + 1 ][j ]:
11398 my_matrix [i ].append (matrix [i + 1 ][j ]['accuracy' ])
11499
115- return my_matrix , BWT_score , FWT_score
100+ return my_matrix , bwt_score , fwt_score
101+
116102
117103def bwt_func (system_metric_info : dict ):
118104 """
119105 compute BWT
120106 """
121107 # pylint: disable=C0103
122108 # pylint: disable=W0632
123- info = system_metric_info .get (SystemMetricType .Matrix .value )
109+ info = system_metric_info .get (SystemMetricType .MATRIX .value )
124110 _ , BWT_score , _ = compute ("all" , info ["all" ])
125111 return BWT_score
126112
113+
127114def fwt_func (system_metric_info : dict ):
128115 """
129116 compute FWT
130117 """
131118 # pylint: disable=C0103
132119 # pylint: disable=W0632
133- info = system_metric_info .get (SystemMetricType .Matrix .value )
120+ info = system_metric_info .get (SystemMetricType .MATRIX .value )
134121 _ , _ , FWT_score = compute ("all" , info ["all" ])
135122 return FWT_score
136123
124+
137125def matrix_func (system_metric_info : dict ):
138126 """
139127 compute FWT
140128 """
141129 # pylint: disable=C0103
142130 # pylint: disable=W0632
143- info = system_metric_info .get (SystemMetricType .Matrix .value )
131+ info = system_metric_info .get (SystemMetricType .MATRIX .value )
144132 my_dict = {}
145133 for key in info .keys ():
146134 my_matrix , _ , _ = compute (key , info [key ])
147135 my_dict [key ] = my_matrix
148136 return my_dict
149137
138+
150139def task_avg_acc_func (system_metric_info : dict ):
151140 """
152- compute Task_Avg_Acc
141+ compute task average accuracy
153142 """
154- info = system_metric_info .get (SystemMetricType .Task_Avg_Acc .value )
143+ info = system_metric_info .get (SystemMetricType .TASK_AVG_ACC .value )
155144 return info ["accuracy" ]
156145
146+
157147def get_metric_func (metric_dict : dict ):
158148 """
159149 get metric func by metric info
@@ -175,9 +165,11 @@ def get_metric_func(metric_dict: dict):
175165 if url :
176166 try :
177167 load_module (url )
178- metric_func = ClassFactory .get_cls (type_name = ClassType .GENERAL , t_cls_name = name )
168+ metric_func = ClassFactory .get_cls (
169+ type_name = ClassType .GENERAL , t_cls_name = name )
179170 return name , metric_func
180171 except Exception as err :
181- raise RuntimeError (f"get metric func(url={ url } ) failed, error: { err } ." ) from err
172+ raise RuntimeError (
173+ f"get metric func(url={ url } ) failed, error: { err } ." ) from err
182174
183175 return name , getattr (sys .modules [__name__ ], str .lower (name ) + "_func" )
0 commit comments