@@ -249,10 +249,13 @@ def entropy(self, membership=None, idx=None):
249249 mbs = np .array (self .predict (what = "samples" , prob = False , idx = idx )).squeeze ()
250250 dmbs , dmembership = {}, {}
251251 [dmbs .setdefault (mbs [i ], set ()).add (i ) for i in range (len (mbs ))]
252- [dmembership .setdefault (membership [i ], set ()).add (i )
253- for i in range (len (membership ))]
254- return - 1. / (n * log (len (dmembership ), 2 )) * sum (sum (len (dmbs [k ].intersection (dmembership [j ])) *
255- log (len (dmbs [k ].intersection (dmembership [j ])) / float (len (dmbs [k ])), 2 ) for j in dmembership ) for k in dmbs )
252+ [dmembership .setdefault (membership [i ], set ()).add (i ) for i in range (len (membership ))]
253+ entropy = 0.
254+ for k in dmbs :
255+ for j in dmembership :
256+ entropy += len (dmbs [k ].intersection (dmembership [j ])) * np .log2 (len (dmbs [k ].intersection (dmembership [j ])) / float (len (dmbs [k ])))
257+ entropy *= - 1. / (n * np .log2 (len (dmembership )))
258+ return entropy
256259
257260 def predict (self , what = 'samples' , prob = False , idx = None ):
258261 """
@@ -386,9 +389,8 @@ def purity(self, membership=None, idx=None):
386389 mbs = np .array (self .predict (what = "samples" , prob = False , idx = idx )).squeeze ()
387390 dmbs , dmembership = {}, {}
388391 [dmbs .setdefault (mbs [i ], set ()).add (i ) for i in range (len (mbs ))]
389- [dmembership .setdefault (membership [i ], set ()).add (i )
390- for i in range (len (membership ))]
391- return 1. / n * sum (max (len (dmbs [k ].intersection (dmembership [j ])) for j in dmembership ) for k in dmbs )
392+ [dmembership .setdefault (membership [i ], set ()).add (i ) for i in range (len (membership ))]
393+ return 1. / n * sum (np .max ([len (dmbs [k ].intersection (dmembership [j ])) for j in dmembership ]) for k in dmbs )
392394
393395 def rss (self , idx = None ):
394396 """
0 commit comments