Skip to content

Commit aa54734

Browse files
authored
Merge pull request #534 from dimitri-yatsenko/master
Add pandas support and support order_by "KEY" (issues #459, #537, #538, #541)
2 parents 19d2ff9 + f6bfe97 commit aa54734

File tree

13 files changed

+135
-35
lines changed

13 files changed

+135
-35
lines changed

datajoint/declare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def compile_attribute(line, in_key, foreign_key_sql):
272272
match = {k: v.strip() for k, v in match.items()}
273273
match['nullable'] = match['default'].lower() == 'null'
274274
acceptable_datatype_pattern = r'^(time|date|year|enum|(var)?char|float|double|decimal|' \
275-
r'(tiny|small|medium|big)?int|' \
275+
r'(tiny|small|medium|big)?int|bool(ean)?|' \
276276
r'(tiny|small|medium|long)?blob|external|attach)'
277277
if re.match(acceptable_datatype_pattern, match['type']) is None:
278278
raise DataJointError('DataJoint does not support datatype "{type}"'.format(**match))

datajoint/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
'lost connection': 2013,
1515
}
1616

17+
1718
def is_connection_error(e):
1819
"""
1920
Checks if error e pertains to a connection issue
@@ -22,7 +23,6 @@ def is_connection_error(e):
2223
(isinstance(e, err.OperationalError) and e.args[0] in operation_error_codes.values())
2324

2425

25-
2626
class DataJointError(Exception):
2727
"""
2828
Base class for errors specific to DataJoint internal operation.

datajoint/expression.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import datetime
88
import decimal
9+
import pandas
910
from .settings import config
1011
from .errors import DataJointError
1112
from .fetch import Fetch, Fetch1
@@ -117,7 +118,7 @@ def _make_condition(self, arg):
117118
"""
118119
Translate the input arg into the equivalent SQL condition (a string)
119120
:param arg: any valid restriction object.
120-
:return: an SQL condition string. It may also be a boolean that is intended to be treated as a string.
121+
:return: an SQL condition string or a boolean value.
121122
"""
122123
def prep_value(v):
123124
return str(v) if isinstance(v, (datetime.date, datetime.datetime, datetime.time, decimal.Decimal)) else v
@@ -176,6 +177,10 @@ def prep_value(v):
176177
not_="not " if negate else "",
177178
subquery=arg.make_sql(common_attributes)))
178179

180+
# restrict by pandas.DataFrames
181+
if isinstance(arg, pandas.DataFrame):
182+
arg = arg.to_records() # convert to np.recarray
183+
179184
# if iterable (but not a string, a QueryExpression, or an AndList), treat as an OrList
180185
try:
181186
or_list = [self._make_condition(q) for q in arg]
@@ -289,8 +294,7 @@ def restrict(self, restriction):
289294
rel.restrict(restriction) is equivalent to rel = rel & restriction or rel &= restriction
290295
rel.restrict(Not(restriction)) is equivalent to rel = rel - restriction or rel -= restriction
291296
The primary key of the result is unaffected.
292-
Successive restrictions are combined using the logical AND.
293-
The AndList class is provided to play the role of successive restrictions.
297+
Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b))
294298
Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists
295299
(logical disjunction of conditions)
296300
Inverse restriction is accomplished by either using the subtraction operator or the Not class.
@@ -342,6 +346,26 @@ def fetch1(self):
342346
def fetch(self):
343347
return Fetch(self)
344348

349+
def head(self, limit=25, **fetch_kwargs):
350+
"""
351+
shortcut to fetch the first few entries from query expression.
352+
Equivalent to fetch(order_by="KEY", limit=25)
353+
:param limit: number of entries
354+
:param fetch_kwargs: kwargs for fetch
355+
:return: query result
356+
"""
357+
return self.fetch(order_by="KEY", limit=limit, **fetch_kwargs)
358+
359+
def tail(self, limit=25, **fetch_kwargs):
360+
"""
361+
shortcut to fetch the last few entries from query expression.
362+
Equivalent to fetch(order_by="KEY DESC", limit=25)[::-1]
363+
:param limit: number of entries
364+
:param fetch_kwargs: kwargs for fetch
365+
:return: query result
366+
"""
367+
return self.fetch(order_by="KEY DESC", limit=limit, **fetch_kwargs)[::-1]
368+
345369
def attributes_in_restriction(self):
346370
"""
347371
:return: list of attributes that are probably used in the restriction.
@@ -365,7 +389,7 @@ def preview(self, limit=None, width=None):
365389
limit = config['display.limit']
366390
if width is None:
367391
width = config['display.width']
368-
tuples = rel.fetch(limit=limit+1)
392+
tuples = rel.fetch(limit=limit+1, format="array")
369393
has_more = len(tuples) > limit
370394
tuples = tuples[:limit]
371395
columns = heading.names
@@ -378,13 +402,13 @@ def preview(self, limit=None, width=None):
378402
'\n'.join(' '.join(templates[f] % (tup[f] if f in tup.dtype.names else '=BLOB=')
379403
for f in columns) for tup in tuples) +
380404
('\n ...\n' if has_more else '\n') +
381-
(' (%d tuples)\n' % len(rel) if config['display.show_tuple_count'] else ''))
405+
(' (Total: %d)\n' % len(rel) if config['display.show_tuple_count'] else ''))
382406

383407
def _repr_html_(self):
384408
heading = self.heading
385409
rel = self.proj(*heading.non_blobs)
386410
info = heading.table_info
387-
tuples = rel.fetch(limit=config['display.limit']+1)
411+
tuples = rel.fetch(limit=config['display.limit']+1, format='array')
388412
has_more = len(tuples) > config['display.limit']
389413
tuples = tuples[0:config['display.limit']]
390414

@@ -464,7 +488,7 @@ def _repr_html_(self):
464488
['\n'.join(['<td>%s</td>' % (tup[name] if name in tup.dtype.names else '=BLOB=')
465489
for name in heading.names])
466490
for tup in tuples]),
467-
count=('<p>%d tuples</p>' % len(rel)) if config['display.show_tuple_count'] else '')
491+
count=('<p>Total: %d</p>' % len(rel)) if config['display.show_tuple_count'] else '')
468492

469493
def make_sql(self, select_fields=None):
470494
return 'SELECT {fields} FROM {from_}{where}'.format(

datajoint/fetch.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from collections import OrderedDict
22
from functools import partial
3+
import warnings
4+
import pandas
5+
import re
36
import numpy as np
47
from .blob import unpack
58
from .errors import DataJointError
6-
import warnings
9+
from .settings import config
710

811

912
class key:
@@ -24,6 +27,16 @@ def to_dicts(recarray):
2427
yield dict(zip(recarray.dtype.names, rec.tolist()))
2528

2629

30+
def _flatten_attribute_list(primary_key, attr):
31+
for a in attr:
32+
if re.match(r'^\s*KEY\s*(ASC\s*)?$', a):
33+
yield from primary_key
34+
elif re.match(r'^\s*KEY\s*DESC\s*$', a):
35+
yield from (q + ' DESC' for q in primary_key)
36+
else:
37+
yield a
38+
39+
2740
class Fetch:
2841
"""
2942
A fetch object that handles retrieving elements from the table expression.
@@ -33,36 +46,59 @@ class Fetch:
3346
def __init__(self, expression):
3447
self._expression = expression
3548

36-
def __call__(self, *attrs, offset=None, limit=None, order_by=None, as_dict=False, squeeze=False):
49+
def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None, as_dict=False, squeeze=False):
3750
"""
3851
Fetches the expression results from the database into an np.array or list of dictionaries and unpacks blob attributes.
3952
4053
:param attrs: zero or more attributes to fetch. If not provided, the call will return
4154
all attributes of this relation. If provided, returns tuples with an entry for each attribute.
4255
:param offset: the number of tuples to skip in the returned result
4356
:param limit: the maximum number of tuples to return
44-
:param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None.
57+
:param order_by: a single attribute or the list of attributes to order the results.
58+
No ordering should be assumed if order_by=None.
59+
To reverse the order, add DESC to the attribute name or names: e.g. ("age DESC", "frequency")
60+
To order by primary key, use "KEY" or "KEY DESC"
61+
:param format: Effective when as_dict=False and when attrs is empty
62+
None: default from config['fetch_format'] or 'array' if not configured
63+
"array": use numpy.key_array
64+
"frame": output pandas.DataFrame. .
4565
:param as_dict: returns a list of dictionaries instead of a record array
4666
:param squeeze: if True, remove extra dimensions from arrays
4767
:return: the contents of the relation in the form of a structured numpy.array or a dict list
4868
"""
4969

50-
# if 'order_by' passed in a string, make into list
51-
if isinstance(order_by, str):
52-
order_by = [order_by]
70+
if order_by is not None:
71+
# if 'order_by' passed in a string, make into list
72+
if isinstance(order_by, str):
73+
order_by = [order_by]
74+
# expand "KEY" or "KEY DESC"
75+
order_by = list(_flatten_attribute_list(self._expression.primary_key, order_by))
5376

5477
# if attrs are specified then as_dict cannot be true
5578
if attrs and as_dict:
5679
raise DataJointError('Cannot specify attributes to return when as_dict=True. '
57-
'Use proj() to select attributes or set as_dict=False')
80+
'Use '
81+
'proj() to select attributes or set as_dict=False')
82+
# format should not be specified with attrs or is_dict=True
83+
if format is not None and (as_dict or attrs):
84+
raise DataJointError('Cannot specify output format when as_dict=True or '
85+
'when attributes are selected to be fetched separately.')
86+
87+
if format not in {None, "array", "frame"}:
88+
raise DataJointError('Fetch output format must be in {{"array", "frame"}} but "{}" was given'.format(format))
89+
90+
if not (attrs or as_dict) and format is None:
91+
format = config['fetch_format'] # default to array
92+
if format not in {"array", "frame"}:
93+
raise DataJointError('Invalid entry "{}" in datajoint.config["fetch_format"]: use "array" or "frame"'.format(format))
5894

5995
if limit is None and offset is not None:
6096
warnings.warn('Offset set, but no limit. Setting limit to a large number. '
6197
'Consider setting a limit explicitly.')
6298
limit = 2 * len(self._expression)
6399

64100
if not attrs:
65-
# fetch all attributes
101+
# fetch all attributes as a numpy.record_array or pandas.DataFrame
66102
cur = self._expression.cursor(as_dict=as_dict, limit=limit, offset=offset, order_by=order_by)
67103
heading = self._expression.heading
68104
if as_dict:
@@ -78,6 +114,8 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, as_dict=False
78114
ret[name] = list(map(external_table.get, ret[name]))
79115
elif heading[name].is_blob:
80116
ret[name] = list(map(partial(unpack, squeeze=squeeze), ret[name]))
117+
if format == "frame":
118+
ret = pandas.DataFrame(ret).set_index(heading.primary_key)
81119
else: # if list of attributes provided
82120
attributes = [a for a in attrs if not is_key(a)]
83121
result = self._expression.proj(*attributes).fetch(

datajoint/jobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from decimal import Decimal
21
from .hash import key_hash
32
import os
43
import platform

datajoint/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
'connection.charset': '', # pymysql uses '' as default
3838
'loglevel': 'INFO',
3939
'safemode': True,
40+
'fetch_format': 'array',
4041
'display.limit': 12,
4142
'display.width': 14,
4243
'display.show_tuple_count': True

datajoint/table.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import platform
55
import numpy as np
6+
import pandas
67
import pymysql
78
import logging
89
import warnings
@@ -146,13 +147,12 @@ def insert1(self, row, **kwargs):
146147
"""
147148
self.insert((row,), **kwargs)
148149

149-
def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False,
150-
allow_direct_insert=None):
150+
def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, allow_direct_insert=None):
151151
"""
152152
Insert a collection of rows.
153153
154-
:param rows: An iterable where an element is a numpy record, a dict-like object, or an ordered sequence.
155-
rows may also be another relation with the same heading.
154+
:param rows: An iterable where an element is a numpy record, a dict-like object, a pandas.DataFrame, a sequence,
155+
or a query expression with the same heading as table self.
156156
:param replace: If True, replaces the existing tuple.
157157
:param skip_duplicates: If True, silently skip duplicate inserts.
158158
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
@@ -164,9 +164,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
164164
>>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
165165
"""
166166

167-
if ignore_errors:
168-
warnings.warn('Use of `ignore_errors` in `insert` and `insert1` is deprecated. Use try...except... '
169-
'to explicitly handle any errors', stacklevel=2)
167+
if isinstance(rows, pandas.DataFrame):
168+
rows = rows.to_records()
170169

171170
# prohibit direct inserts into auto-populated tables
172171
if not (allow_direct_insert or getattr(self, '_allow_insert', True)): # _allow_insert is only present in AutoPopulate

datajoint/user_tables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
# attributes that trigger instantiation of user classes
1414
supported_class_attrs = {
15-
'key_source', 'describe', 'populate', 'progress', 'primary_key', 'proj', 'aggr', 'heading', 'fetch', 'fetch1',
15+
'key_source', 'describe', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
16+
'fetch', 'fetch1','head', 'tail',
1617
'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
1718

1819

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ numpy
22
pymysql>=0.7.2
33
pyparsing
44
ipython
5+
pandas
56
tqdm
67
networkx
78
pydot

tests/schema_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class C(dj.Part):
5858
value :float # normally distributed variables according to parameters in B
5959
"""
6060

61-
def _make_tuples(self, key):
61+
def make(self, key):
6262
random.seed(str(key))
6363
sub = B.C()
6464
for i in range(4):
@@ -113,7 +113,7 @@ class F(dj.Part):
113113
-> B.C
114114
"""
115115

116-
def _make_tuples(self, key):
116+
def make(self, key):
117117
random.seed(str(key))
118118
self.insert1(dict(key, **random.choice(list(L().fetch('KEY')))))
119119
sub = E.F()

0 commit comments

Comments
 (0)