-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFN2_Train_label_classifier.py
More file actions
171 lines (131 loc) · 5.09 KB
/
FN2_Train_label_classifier.py
File metadata and controls
171 lines (131 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
******************************************************************************************
Written by Jim Rowe (U of Sheffield)
Started: 2026-02-19
@BotanicalJim
james.rowe at sheffield.ac.uk
******************************************************************************************
"""
from ij import IJ
from ij.measure import ResultsTable
from fiji.util.gui import GenericDialogPlus
from java.awt import GridLayout, Font, Color
from ij import WindowManager as WM
from datetime import datetime
import os
import json
from jarray import array
from java.util import ArrayList, Random
from weka.core import SerializationHelper, Attribute, Instances, DenseInstance
from weka.classifiers.trees import RandomForest
def errorDialog(message):
"""Outputs a given error for end users"""
gd = GenericDialogPlus("Error")
gd.addMessage(message)
gd.showDialog()
return
def folderSelectDialog():
"""Select training file or folder of training files"""
imps = WM.getImageTitles()
gd = GenericDialogPlus("Select training file or folder of training files")
gd.addDirectoryOrFileField("Choose path", "")
gd.showDialog()
if gd.wasCanceled():
IJ.exit()
return gd.getNextString()
def chooseColumns(columns, headingList):
gd = GenericDialogPlus("Choose parameters for training")
for i in range(len(headingList)):
gd.addCheckbox(headingList[i],columns[i])
gd.setLayout(GridLayout(0,3))
gd.showDialog()
checkBoxes=gd.getCheckboxes()
choices=[]
for i in range(len(checkBoxes)):
boolchoice=gd.getNextBoolean()
if boolchoice == True:
choices.append(i)
columns[i]=boolchoice
if gd.wasCanceled():
IJ.exit()
return choices, columns
def filteredResultsTable(rtc, columnChoices):
rtc2=ResultsTable()
for j in range(rtc.size()):
rtc2.addRow()
for i in columnChoices:
rtc2.addValue(rtc.getColumnHeading(i), rtc.getColumn(i)[j])
return rtc2
def concatResultsTable(rt1, rtc):
rtcheadingList=rtc.getColumnHeadings().split()
#morphResults.show("huh")
rt1headingsList = rt1.getColumnHeadings().split()
#combine both data tables and display the training data file
for j in range(rtc.size()):
rt1.addRow()
for i in range(len(rtcheadingList)):
rt1.addValue(rtcheadingList[i], rtc.getColumn(i)[j])
#print(j)
return rt1
def convertTableToInstances(rt2):
rt2headingsList = rt2.getColumnHeadings().split()
n_attributes = len(rt2headingsList)
classList= [str(i) for i in range(12)]
attributes = ArrayList([Attribute(i) for i in rt2headingsList[:-1]])
attributes.add(Attribute("class", classList))
samples=[]
sample=[]
for i in range(rt2.size()):
sample=[]
for j in rt2headingsList:
sample.append(rt2.getValue(j, i))
samples.append(array(sample,'d'))
print len(attributes)
print n_attributes
input_data = Instances("training", attributes, len(samples))
input_data.setClassIndex(len(attributes) -1) # the last one is the class
for vector in samples:
input_data.add(DenseInstance(1.0, vector))
return input_data, attributes
try:
from net.haesleinhuepf.clijx import CLIJx
#from net.haesleinhuepf.clij2 import CLIJx
except:
errorDialog("""This plugin requires clij2 to function.
To install please follow these instructions:
1. Click Help>Update> Manage update sites
2. Make sure the "clij" and "clij2" update sites are selected.
3. Click Close> Apply changes.
4. Close and reopen ImageJ""")
clij2 = CLIJx.getInstance()
# *****************************body of code starts****************************************
if __name__ == "__main__":
clij2.clear()
filePath=folderSelectDialog()
rt=ResultsTable()
if os.path.isfile(filePath):
rt1=ResultsTable.open(filePath)
else:
fileList=os.listdir(filePath)
for i in fileList:
if i[-4:]==".csv":
rt1=ResultsTable.open(filePath+ "/" +i)
concatResultsTable(rt, rt1)
rtheadingsList = rt.getColumnHeadings().split()
columns = [False, False, False, False, False, False, False, True, True, True, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, False, False, False, False, True, True, True, False, False, False, False, True, True, True, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
#columns = [1]*len(rtheadingsList)
columnChoices, columns= chooseColumns(columns, rtheadingsList)
print columns
rt2= filteredResultsTable(rt, columnChoices)
#rt2.show('filtered')
training_data, attributes= convertTableToInstances(rt2)
date= datetime.now().strftime("%Y-%m-%d-%H.%M")
#clij2.trainWekaFromTable(rt2, 'LABEL_ID', filePath + "/" + date + " classifier.model", 500, 10,10)
classifier = RandomForest()
classifier.setNumIterations(200)
classifier.setNumFeatures(5)
classifier.setMaxDepth(1)
classifier.buildClassifier(training_data)
SerializationHelper.write(filePath + "/" + date + " classifier.model", classifier)
with open(filePath+"/"+date+' classifier_fileheaders.json', 'w') as f:
json.dump(columns, f)