Skip to content

Commit 518b882

Browse files
authored
Merge pull request #776 from ixcat/issue-666
datajoint/table.py: smarter dataframe conversion (#666)
2 parents d54c250 + 18beaf2 commit 518b882

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

datajoint/table.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
193193
"""
194194

195195
if isinstance(rows, pandas.DataFrame):
196-
rows = rows.to_records()
196+
# drop 'extra' synthetic index for 1-field index case -
197+
# frames with more advanced indices should be prepared by user.
198+
rows = rows.reset_index(
199+
drop=len(rows.index.names) == 1 and not rows.index.names[0]
200+
).to_records(index=False)
197201

198202
# prohibit direct inserts into auto-populated tables
199203
if not allow_direct_insert and not getattr(self, '_allow_insert', True): # allow_insert is only used in AutoPopulate

tests/test_relation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def test_insert_select(self):
103103
'real_id', 'date_of_birth', 'subject_notes', subject_id='subject_id+1000', species='"human"'))
104104
assert_equal(len(self.subject), 2*original_length)
105105

106-
def test_insert_pandas(self):
106+
def test_insert_pandas_roundtrip(self):
107+
''' ensure fetched frames can be inserted '''
107108
schema.TTest2.delete()
108109
n = len(schema.TTest())
109110
assert_true(n > 0)
@@ -113,6 +114,20 @@ def test_insert_pandas(self):
113114
schema.TTest2.insert(df)
114115
assert_equal(len(schema.TTest2()), n)
115116

117+
def test_insert_pandas_userframe(self):
118+
'''
119+
ensure simple user-created frames (1 field, non-custom index)
120+
can be inserted without extra index adjustment
121+
'''
122+
schema.TTest2.delete()
123+
n = len(schema.TTest())
124+
assert_true(n > 0)
125+
df = pandas.DataFrame(schema.TTest.fetch())
126+
assert_true(isinstance(df, pandas.DataFrame))
127+
assert_equal(len(df), n)
128+
schema.TTest2.insert(df)
129+
assert_equal(len(schema.TTest2()), n)
130+
116131
@raises(dj.DataJointError)
117132
def test_insert_select_ignore_extra_fields0(self):
118133
""" need ignore extra fields for insert select """

0 commit comments

Comments
 (0)