-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmainfile.py
148 lines (115 loc) · 4.65 KB
/
mainfile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
"""
PROJECT : Clustering data using von Mises-Fisher Distribution
Reference: http://jmlr.csail.mit.edu/papers/v6/banerjee05a.html
GUIDE :
Prof. Anand A Joshi - [email protected]
TEAM:
Bhavana Ganesh - [email protected]
Mahathi Vatsal Salopanthula - [email protected]
Sayali Ghume - [email protected]
Contact any of the members for queries and bugs.
"""
import vonmisesGenerate as vmg
import sphericalclustering as snn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import itertools
import time
import multiprocessing as mp
from multiprocessing import Pool
def clust (params):
""" Returns the cluster labels and exracted raw data
This function extracts data from file whose path is the parameter to the
function.Samples from von mises-fisher distribution with each row in the
data set to be the mean vector and clusters the data using
Spherical Kmeans clustering
The given function ''clust'' must have the following signature::
output = clust(params) where,
fpath = ['.../data.npz']
par2 = [no_clusters]
params = zip(fpath, par2)
where params is file address and no_of clusters zipped
Dependencies::
The function requires the below mentioned libraries to be installed
1. Numpy
2. Mpmath
The function requires the module ''vonmisesGenerate.py'' and
''sphericalclustering.py'' to be in Path
"""
fileadd, no_clusters = params
dat = np.load(fileadd)
brain = dat.f.data #Comment while debugging
#brain2 = dat.f.data #Uncomment while debugging
#brain = brain2[1:3,:] #Uncomment while debugging
del dat
H = len(brain)
W = len(brain[0])
data = np.zeros([H,W])
#generating random samples using the VMF distribution
for i in range (0,H):
mu = np.ndarray.tolist(brain[i,:])
val = np.linalg.norm(mu)
data[i,:] = vmg.randVMF(1,mu/val,1)
#Clustering using the sphericalclustering function
clusters = snn.sphericalknn(data,no_clusters)
return (clusters,brain)
def clustplot (clusters,brain,no_clusters):
""" Plots the clusters in 3D
This function creates two plots one displaying the clusters and another
a 3D plot displaying the first 3 dimensions of the extarcted data with
markers colored based on clustering
The given function ''clustplot'' must have the following signature::
clustplot(clusters,data)
Dependencies::
The function requires the below mentioned libraries to be installed
1. Matplotlib
2. mpl_toolkits
3. itertools
"""
H = len(brain)
W = len(brain[0])
#displaying the clusters
plt.plot(clusters, marker = '.', linestyle = '')
#Plot first 3 colums of data
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
colors = itertools.cycle(["r", "b", "g", "k", "m", "y"])
marker = itertools.cycle(["o", "*", "+", "x", "D", "s", "^"])
for i in range (0,no_clusters):
a = np.zeros([1,W])
j=0
for p in range (0,H):
if clusters[p] == i:
a = np.vstack((a,brain[p,:]))
j=j+1
a = np.delete(a,(0),axis=0)
ax.scatter(a[:,147],a[:,148],a[:,149], c=next(colors), marker=next(marker))
plt.show()
################################ MAIN FUNCTION#############################
if __name__ == '__main__':
""" Processes multiple data in parallel using multiprocessing library
The main function can process multiple data at the same time by threading
the process and then plots the first dataset.
"""
start_time = time.clock() #Start clock
#Define Parameters
fpath = ['D:/Codes/vonmises/data.npz']
no_clusters = 200
par2 = [no_clusters]
params = zip(fpath, par2) #Zip the file address and number of clusters to be formed
pool = Pool()
'''
Pool is a convenient means of parallelizing the execution of a function
across multiple input values, distributing the input data across processes
(data parallelism)
Source: https://docs.python.org/2/library/multiprocessing.html
'''
#print(mp.current_process())
out = pool.map(clust,params)
#plotting the first dataset
label = (out[0])[0]
brain = (out[0])[1]
clustplot(label,brain,no_clusters)
print((time.clock() - start_time)/60,"minutes") #Print Elapsed time