Skip to content

Commit 129dd91

Browse files
committed
added a few more comments
1 parent 81c6edb commit 129dd91

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

dim_reduction/hdr.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def dimReduction(img, Parameters=None):
3838
Parameters = dimReductionParameters()
3939

4040
type = Parameters.type # Type of Hierarchy
41-
showH = 1 # Parameters.showH # Set to 1 to show clustering, 0 otherwise
41+
showH = Parameters.showH # Set to 1 to show clustering, 0 otherwise
4242
maxNumClusters = Parameters.numBands
4343
NumCenters = Parameters.NumCenters
4444

45-
InputData = np.reshape(img, (numRows * numCols, numDims), order='F')
45+
InputData = np.reshape(img, (numRows * numCols, numDims))
4646
_, KLDivergencesList, _ = computeKLDivergencesBetweenBands(InputData, NumCenters);
4747

4848
Hierarchy = sch.linkage(KLDivergencesList, type)
@@ -58,18 +58,15 @@ def dimReduction(img, Parameters=None):
5858
for i in range(1, maxNumClusters+1):
5959
mergedData[i-1, :] = np.mean(InputData[:, band_clusters == i], 1)
6060

61-
mergedData = np.reshape(mergedData.T, (numRows, numCols, maxNumClusters), order='F')
61+
mergedData = np.reshape(mergedData.T, (numRows, numCols, maxNumClusters))
6262

6363
return mergedData
6464

6565

6666
def computeKLDivergencesBetweenBands(InputData, NumCenters):
6767

68-
# TESTED (keeping in mind that MATLAB and python reshape are different)
6968
DataList = InputData / InputData.max(1).max(0)
70-
# print('Datalist data: ', DataList[1,1])
7169

72-
# TESTED
7370
# compute the histograms
7471
Centers = np.arange(1/(2*NumCenters), 1 + 1/NumCenters, 1/NumCenters)
7572

@@ -78,19 +75,17 @@ def computeKLDivergencesBetweenBands(InputData, NumCenters):
7875
for count in range(DataList.shape[0]):
7976
hists[:, count], t = np.histogram(DataList.T[:, count], Centers)
8077

78+
# Add an epsilon term to the histograms
8179
hists = hists + np.spacing(1)
8280

8381
# compute KL Divergence
8482
lim = InputData.shape[1]
8583
KLDivergences = np.zeros((lim, lim))
86-
8784
for i in range(DataList.shape[1]):
8885
for j in range(DataList.shape[1]):
8986
KLDivergences[i, j] = (hists[i, :] * np.log(hists[i, :] / hists[j, :])).sum() \
9087
+ (hists[j, :] * np.log(hists[j, :] / hists[j, :])).sum()
9188

92-
plt.subplot(132)
93-
plt.plot(KLDivergences)
9489
temp = KLDivergences - np.diag(np.diag(KLDivergences))
9590
KLDivergencesList = pdist(temp)
9691

0 commit comments

Comments
 (0)