Skip to content

Commit 83acdc5

Browse files
Merge pull request #79 from wenyushi/master
pull request
2 parents 1346e11 + 8960da3 commit 83acdc5

File tree

5 files changed

+76
-18
lines changed

5 files changed

+76
-18
lines changed
1.9 MB
Binary file not shown.
5.2 MB
Binary file not shown.

dlpy/tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ def test_imagescaler1(self):
867867
self.assertTrue(l1.type == 'input')
868868
self.assertTrue(l1.config['offsets'] == [0., 0., 0.])
869869
self.assertTrue(l1.config['scale'] == 1.)
870-
870+
871871
def test_imagescaler2(self):
872872
# test export model with imagescaler
873873
try:

dlpy/tests/test_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_camelcase_to_underscore(self):
7373
underscore = camelcase_to_underscore('includeBias')
7474
self.assertTrue('include_bias' == underscore)
7575

76-
def test_camelcase_to_underscore(self):
76+
def test_underscore_to_camelcase(self):
7777
underscore = underscore_to_camelcase('include_bias')
7878
self.assertTrue('includeBias' == underscore)
7979

@@ -100,6 +100,16 @@ def test_create_object_detection_table(self):
100100

101101
create_object_detection_table(self.s, data_path=self.data_dir+'dlpy_obj_det_test',
102102
coord_type='coco', output='output')
103+
self.assertTrue(self.s.fetch('output', fetchvars='_nObjects_').Fetch['_nObjects_'].tolist() == [3.0]*10)
104+
105+
def test_create_object_detection_table_2(self):
106+
# make sure that txt files are already in self.data_dir + 'dlpy_obj_det_test', otherwise the test will fail.
107+
if self.data_dir is None:
108+
unittest.TestCase.skipTest(self, "DLPY_DATA_DIR is not set in the environment variables")
109+
create_object_detection_table(self.s, data_path = self.data_dir + 'dlpy_obj_det_test',
110+
coord_type = 'yolo',
111+
output = 'output')
112+
self.assertTrue(self.s.fetch('output', fetchvars='_nObjects_').Fetch['_nObjects_'].tolist() == [3.0]*10)
103113

104114
def test_get_anchors(self):
105115
if platform.system().startswith('Win'):
@@ -119,3 +129,8 @@ def test_get_anchors(self):
119129
data_path=self.data_dir+'dlpy_obj_det_test')
120130

121131
get_anchors(self.s, coord_type='yolo', data='output')
132+
133+
def test_get_txt_annotation(self):
134+
if self.data_dir_local is None:
135+
unittest.TestCase.skipTest(self, "DLPY_DATA_DIR_LOCAL is not set in the environment variables")
136+
get_txt_annotation(self.data_dir_local+'dlpy_obj_det_test', 'yolo', 416)

dlpy/utils.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import xml.etree.ElementTree as ET
3333
from swat.cas.table import CASTable
3434
from PIL import Image
35+
import warnings
3536

3637

3738
def random_name(name='ImageData', length=6):
@@ -1025,6 +1026,48 @@ def convert_txt_to_xml(path):
10251026
os.chdir(cwd)
10261027

10271028

1029+
def get_txt_annotation(local_path, coord_type, image_size = 416, label_files = None):
1030+
'''
1031+
Parse object detection annotation files based on Pascal VOC format and save as txt files.
1032+
1033+
Parameters
1034+
----------
1035+
local_path : string
1036+
Local_path points to the directory where xml files are stored.
1037+
The generated txt files will be stored under the directory.
1038+
coord_type : string
1039+
Specifies the type of coordinate to convert into.
1040+
'yolo' specifies x, y, width and height, x, y is the center
1041+
location of the object in the grid cell. x, y, are between 0
1042+
and 1 and are relative to that grid cell. x, y = 0,0 corresponds
1043+
to the top left pixel of the grid cell.
1044+
'coco' specifies xmin, ymin, xmax, ymax that are borders of a
1045+
bounding boxes.
1046+
The values are relative to parameter image_size.
1047+
Valid Values: yolo, coco
1048+
image_size : integer, optional
1049+
Specifies the size of images to resize.
1050+
Default: 416
1051+
label_files : list, optional
1052+
Specifies the list of filename with XML extension under local_path to be parsed.
1053+
If label_files is not specified, all of XML files under local_path will be parsed .
1054+
Default: None
1055+
1056+
'''
1057+
1058+
cwd = os.getcwd()
1059+
os.chdir(local_path)
1060+
# if label_files = None, that means we call it directly and parse annotation files.
1061+
if label_files is None:
1062+
label_files = os.listdir(local_path)
1063+
label_files = [x for x in label_files if x.endswith('.xml')]
1064+
if len(label_files) == 0:
1065+
raise DLPyError('Can not find any xml file under data_path')
1066+
for idx, filename in enumerate(label_files):
1067+
convert_xml_annotation(filename, coord_type, image_size)
1068+
os.chdir(cwd)
1069+
1070+
10281071
def create_object_detection_table(conn, data_path, coord_type, output,
10291072
local_path=None, image_size=416):
10301073
'''
@@ -1053,8 +1096,9 @@ def create_object_detection_table(conn, data_path, coord_type, output,
10531096
local_path : string, optional
10541097
Local_path and data_path point to the same location.
10551098
The parameter local_path will be optional (default=None) if the
1056-
Python client has the same OS as CAS server. Otherwise, the path that
1057-
depends on the Python client OS needs to be specified.
1099+
Python client has the same OS as CAS server or annotation files
1100+
in TXT format are placed in data_path.
1101+
Otherwise, the path that depends on the Python client OS needs to be specified.
10581102
For example:
10591103
Windows client with linux CAS server:
10601104
data_path=/path/to/data/path
@@ -1077,10 +1121,11 @@ def create_object_detection_table(conn, data_path, coord_type, output,
10771121
unix_type = server_type.startswith("lin") or server_type.startswith("osx")
10781122
# check if local and server are same type of OS
10791123
# in different os
1124+
need_to_parse = True
10801125
if (unix_type and local_os_type.startswith('Win')) or not (unix_type or local_os_type.startswith('Win')):
10811126
if local_path is None:
1082-
raise ValueError('local_path must be specified when your server is on {} OS and local '
1083-
'python is on {} OS'.format(server_type.split('.')[0].capitalize(), local_os_type))
1127+
print('The txt files in data_path are used as annotation files.')
1128+
need_to_parse = False
10841129
else:
10851130
local_path = data_path
10861131

@@ -1142,24 +1187,21 @@ def create_object_detection_table(conn, data_path, coord_type, output,
11421187
# find all of annotation files under the directory
11431188
a = conn.fileinfo(caslib = caslib, allfiles = True)
11441189
label_files = conn.fileinfo(caslib = caslib, allfiles = True).FileInfo['Name'].values
1145-
label_files = [x for x in label_files if x.endswith('.xml') or x.endswith('.json')]
1146-
if len(label_files) == 0:
1147-
raise ValueError('Can not find any annotation file under data_path')
1190+
# label_files = [x for x in label_files if x.endswith('.xml') or x.endswith('.json')]
11481191

1192+
# if client and server are on different type of operation system, we assume user parse xml files and put
1193+
# txt files in data_path folder. So skip get_txt_annotation()
11491194
# parse xml or json files and create txt files
1150-
cwd = os.getcwd()
1151-
os.chdir(local_path)
1152-
for idx, filename in enumerate(label_files):
1153-
if filename.endswith('.xml'):
1154-
convert_xml_annotation(filename, coord_type, image_size)
1155-
# elif filename.endswith('.json'):
1156-
# convert_json_annotation(filename)
1157-
os.chdir(cwd)
1195+
if need_to_parse:
1196+
get_txt_annotation(local_path, coord_type, image_size, label_files)
1197+
11581198
label_tbl_name = random_name('obj_det')
11591199
# load all of txt files into cas server
11601200
label_files = conn.fileinfo(caslib = caslib, allfiles = True).FileInfo['Name'].values
11611201
label_files = [x for x in label_files if x.endswith('.txt')]
1162-
idjoin_format_length = len(max(label_files, key=len)) - 4 # 4 is lenght of '.txt'
1202+
if len(label_files) == 0:
1203+
raise DLPyError('Can not find any txt file under data_path.')
1204+
idjoin_format_length = len(max(label_files, key=len)) - 4 # 4 is length of '.txt'
11631205
with sw.option_context(print_messages = False):
11641206
for idx, filename in enumerate(label_files):
11651207
tbl_name = '{}_{}'.format(label_tbl_name, idx)
@@ -1183,6 +1225,7 @@ def create_object_detection_table(conn, data_path, coord_type, output,
11831225
'''.format(output, string_input_tbl_name)
11841226
conn.runcode(code = fmt_code, _messagelevel = 'error')
11851227
cls_col_format_length = conn.columninfo(output).ColumnInfo.loc[0][3]
1228+
cls_col_format_length = cls_col_format_length if cls_col_format_length >= len('NoObject') else len('NoObject')
11861229

11871230
conn.altertable(name = output, columns = [dict(name = 'Var1', rename = var_name[0]),
11881231
dict(name = 'Var2', rename = var_name[1]),

0 commit comments

Comments
 (0)