@@ -21,8 +21,9 @@ def __init__(self, operand):
2121
2222class AndList (list ):
2323 """
24- A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed.
25- All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR).
24+ A list of conditions to by applied to a query expression by logical conjunction: the
25+ conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
26+ applied by logical disjunction (OR).
2627
2728 Example:
2829 expr2 = expr & dj.AndList((cond1, cond2, cond3))
@@ -49,14 +50,16 @@ def assert_join_compatibility(expr1, expr2):
4950 the matching attributes in the two expressions must be in the primary key of one or the
5051 other expression.
5152 Raises an exception if not compatible.
53+
5254 :param expr1: A QueryExpression object
5355 :param expr2: A QueryExpression object
5456 """
5557 from .expression import QueryExpression , U
5658
5759 for rel in (expr1 , expr2 ):
5860 if not isinstance (rel , (U , QueryExpression )):
59- raise DataJointError ('Object %r is not a QueryExpression and cannot be joined.' % rel )
61+ raise DataJointError (
62+ 'Object %r is not a QueryExpression and cannot be joined.' % rel )
6063 if not isinstance (expr1 , U ) and not isinstance (expr2 , U ): # dj.U is always compatible
6164 try :
6265 raise DataJointError (
@@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2):
7073def make_condition (query_expression , condition , columns ):
7174 """
7275 Translate the input condition into the equivalent SQL condition (a string)
76+
7377 :param query_expression: a dj.QueryExpression object to apply condition
7478 :param condition: any valid restriction object.
75- :param columns: a set passed by reference to collect all column names used in the condition.
79+ :param columns: a set passed by reference to collect all column names used in the
80+ condition.
7681 :return: an SQL condition string or a boolean value.
7782 """
7883 from .expression import QueryExpression , Aggregation , U
@@ -102,12 +107,13 @@ def prep_value(k, v):
102107 # restrict by string
103108 if isinstance (condition , str ):
104109 columns .update (extract_column_names (condition ))
105- return template % condition .strip ().replace ("%" , "%%" ) # escape % in strings , see issue #376
110+ return template % condition .strip ().replace ("%" , "%%" ) # escape %, see issue #376
106111
107112 # restrict by AndList
108113 if isinstance (condition , AndList ):
109114 # omit all conditions that evaluate to True
110- items = [item for item in (make_condition (query_expression , cond , columns ) for cond in condition )
115+ items = [item for item in (make_condition (query_expression , cond , columns )
116+ for cond in condition )
111117 if item is not True ]
112118 if any (item is False for item in items ):
113119 return negate # if any item is False, the whole thing is False
@@ -123,18 +129,21 @@ def prep_value(k, v):
123129 if isinstance (condition , bool ):
124130 return negate != condition
125131
126- # restrict by a mapping such as a dict -- convert to an AndList of string equality conditions
132+ # restrict by a mapping/ dict -- convert to an AndList of string equality conditions
127133 if isinstance (condition , collections .abc .Mapping ):
128134 common_attributes = set (condition ).intersection (query_expression .heading .names )
129135 if not common_attributes :
130136 return not negate # no matching attributes -> evaluates to True
131137 columns .update (common_attributes )
132138 return template % ('(' + ') AND (' .join (
133- '`%s`=%s' % (k , prep_value (k , condition [k ])) for k in common_attributes ) + ')' )
139+ '`%s`%s' % (k , ' IS NULL' if condition [k ] is None
140+ else f'={ prep_value (k , condition [k ])} ' )
141+ for k in common_attributes ) + ')' )
134142
135143 # restrict by a numpy record -- convert to an AndList of string equality conditions
136144 if isinstance (condition , numpy .void ):
137- common_attributes = set (condition .dtype .fields ).intersection (query_expression .heading .names )
145+ common_attributes = set (condition .dtype .fields ).intersection (
146+ query_expression .heading .names )
138147 if not common_attributes :
139148 return not negate # no matching attributes -> evaluate to True
140149 columns .update (common_attributes )
@@ -154,7 +163,8 @@ def prep_value(k, v):
154163 if isinstance (condition , QueryExpression ):
155164 if check_compatibility :
156165 assert_join_compatibility (query_expression , condition )
157- common_attributes = [q for q in condition .heading .names if q in query_expression .heading .names ]
166+ common_attributes = [q for q in condition .heading .names
167+ if q in query_expression .heading .names ]
158168 columns .update (common_attributes )
159169 if isinstance (condition , Aggregation ):
160170 condition = condition .make_subquery ()
@@ -176,15 +186,17 @@ def prep_value(k, v):
176186 except TypeError :
177187 raise DataJointError ('Invalid restriction type %r' % condition )
178188 else :
179- or_list = [item for item in or_list if item is not False ] # ignore all False conditions
180- if any (item is True for item in or_list ): # if any item is True, the whole thing is True
189+ or_list = [item for item in or_list if item is not False ] # ignore False conditions
190+ if any (item is True for item in or_list ): # if any item is True, entirely True
181191 return not negate
182- return template % ('(%s)' % ' OR ' .join (or_list )) if or_list else negate # an empty or list is False
192+ return template % ('(%s)' % ' OR ' .join (or_list )) if or_list else negate
183193
184194
185195def extract_column_names (sql_expression ):
186196 """
187- extract all presumed column names from an sql expression such as the WHERE clause, for example.
197+ extract all presumed column names from an sql expression such as the WHERE clause,
198+ for example.
199+
188200 :param sql_expression: a string containing an SQL expression
189201 :return: set of extracted column names
190202 This may be MySQL-specific for now.
@@ -206,5 +218,8 @@ def extract_column_names(sql_expression):
206218 s = re .sub (r"(\b[a-z][a-z_0-9]*)\(" , "(" , s )
207219 remaining_tokens = set (re .findall (r"\b[a-z][a-z_0-9]*\b" , s ))
208220 # update result removing reserved words
209- result .update (remaining_tokens - {"is" , "in" , "between" , "like" , "and" , "or" , "null" , "not" })
221+ result .update (remaining_tokens - {"is" , "in" , "between" , "like" , "and" , "or" , "null" ,
222+ "not" , "interval" , "second" , "minute" , "hour" , "day" ,
223+ "month" , "week" , "year"
224+ })
210225 return result
0 commit comments