Skip to content

Commit 208c789

Browse files
authored
Merge pull request #419 from zfj1/universal_set
Add support for universal sets and unions
2 parents 77dec54 + 0987b3c commit 208c789

File tree

5 files changed

+264
-9
lines changed

5 files changed

+264
-9
lines changed

+dj/+internal/GeneralRelvar.m

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,12 @@ function clip(self)
171171

172172
function n = count(self)
173173
% COUNT - the number of tuples in the relation.
174-
[~, sql_] = self.compile(3);
175-
n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_));
174+
[header_, sql_] = self.compile(3);
175+
if header_.distinct
176+
n = self.conn.query(sprintf('SELECT count(*) as n FROM (SELECT %s FROM %s) as `counted`', header_.sql, sql_));
177+
else
178+
n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_));
179+
end
176180
n = double(n.n);
177181
end
178182

@@ -632,6 +636,8 @@ function restrict(self, varargin)
632636
% get reference to the connection object from the first table
633637
if strcmp(self.operator, 'table')
634638
conn = self.operands{1}.schema.conn;
639+
elseif strcmp(self.operator, 'U') && isempty(self.operands)
640+
conn = dj.conn;
635641
else
636642
conn = self.operands{1}.getConn;
637643
end
@@ -665,9 +671,58 @@ function restrict(self, varargin)
665671
% apply relational operators recursively
666672
switch self.operator
667673
case 'union'
668-
throwAsCaller(MException('DataJoint:invalidOperator', ...
669-
'The union operator must be used in a restriction'))
674+
assert(...
675+
length(union(self.operands{1}.primaryKey, self.operands{2}.primaryKey)) == ...
676+
length(intersect(self.operands{1}.primaryKey, self.operands{2}.primaryKey)), ...
677+
'DataJoint:invalidUnion','Union operands must have the same primary key.');
678+
679+
assert(isempty(intersect(self.operands{1}.nonKeyFields, self.operands{2}.nonKeyFields)), ...
680+
'DataJoint:invalidUnion','Union operands may not have any common non-key attributes.');
681+
682+
% join the operands, which will form part of the query
683+
[header,sql1] = compile(self.operands{1} * self.operands{2}, 0);
670684

685+
if isempty(self.operands{1}.nonKeyFields) && isempty(self.operands{2}.nonKeyFields)
686+
% Only PKs: simple UNION
687+
% We just need to make sure the attributes are in
688+
% the same order (assumed by MySQL)
689+
[header2,sql2] = compile(self.operands{1}, 0);
690+
[~,fieldOrder] = ismember(header.names, header2.names);
691+
header2.reorderFields(fieldOrder);
692+
693+
[header3,sql3] = compile(self.operands{2}, 0);
694+
[~,fieldOrder] = ismember(header.names, header3.names);
695+
header3.reorderFields(fieldOrder);
696+
sql = sprintf('((SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',...
697+
header2.sql,sql2,header3.sql,sql3,aliasCount);
698+
else
699+
% With dependent fields, we first want to join any
700+
% matching PKs, then union with the antijoin,
701+
% (a-b)|(b-a). We will append NULL to the missing
702+
% columns in the antijoin to make the query valid.
703+
704+
fields = header.dependentFields;
705+
nk = ismember(fields, self.operands{2}.nonKeyFields);
706+
fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false);
707+
708+
[header2,sql2] = compile(proj(self.operands{1} - self.operands{2}, fields{:}), 0);
709+
[~,fieldOrder] = ismember(header.names, header2.names);
710+
header2.reorderFields(fieldOrder);
711+
712+
fields = header.dependentFields;
713+
nk = ismember(fields, self.operands{1}.nonKeyFields);
714+
fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false);
715+
716+
[header3,sql3] = compile(proj(self.operands{2} - self.operands{1}, fields{:}), 0);
717+
[~,fieldOrder] = ismember(header.names, header3.names);
718+
header3.reorderFields(fieldOrder);
719+
720+
sql = sprintf(...
721+
'((SELECT %s FROM %s) UNION (SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',...
722+
header.sql,sql1,header2.sql,sql2,header3.sql,sql3, aliasCount);
723+
724+
end
725+
671726
case 'not'
672727
throwAsCaller(MException('DataJoint:invalidOperator', ...
673728
'The NOT operator must be used in a restriction'))
@@ -699,7 +754,10 @@ function restrict(self, varargin)
699754
sql = sprintf('%s NATURAL JOIN %s', sql1, sql2);
700755
header = join(header1,header2);
701756
clear header1 header2 sql1 sql2
702-
757+
758+
case 'U'
759+
[header, sql] = compile(self.operands{1},2);
760+
header.promote(self.operands{3}, self.operands{2}.primaryKey{:});
703761
otherwise
704762
error 'unknown relational operator'
705763
end

+dj/+internal/Header.m

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
properties(SetAccess=private)
77
info % table info
88
attributes % array of attributes
9+
distinct=false% whether to select all the elements or only distinct ones
910
end
1011

1112
properties(Access=private,Constant)
@@ -204,7 +205,8 @@ function project(self, params)
204205
else
205206
% process a regular attribute
206207
ix = find(strcmp(params{iAttr},self.names));
207-
assert(~isempty(ix), 'Attribute `%s` does not exist', ...
208+
assert(~isempty(ix),'DataJoint:missingAttributes',...
209+
'Attribute `%s` does not exist', ...
208210
params{iAttr})
209211
end
210212
end
@@ -242,7 +244,8 @@ function project(self, params)
242244
function sql = sql(self)
243245
% make an SQL list of attributes for header
244246
sql = '';
245-
assert(~isempty(self.attributes))
247+
assert(~isempty(self.attributes),...
248+
'DataJoint:missingAttributes','Relation has no attributes');
246249
for i = 1:length(self.attributes)
247250
if isempty(self.attributes(i).alias)
248251
% if strcmp(self.attributes(i).type,'float')
@@ -266,7 +269,11 @@ function project(self, params)
266269
end
267270
end
268271
end
269-
sql = sql(2:end); % strip leading comma
272+
sql = sql(2:end); % strip leading comma
273+
274+
if self.distinct
275+
sql = sprintf('DISTINCT %s', sql);
276+
end
270277
end
271278

272279

@@ -276,4 +283,44 @@ function stripAliases(self)
276283
end
277284
end
278285
end
286+
287+
methods (Access = {?dj.internal.GeneralRelvar})
288+
function reorderFields(self, order)
289+
assert(length(order) == length(self.names));
290+
self.attributes = self.attributes(order);
291+
end
292+
293+
function promote(self, keep, varargin)
294+
if ~keep
295+
[self.attributes(:).iskey] = deal(false);
296+
self.distinct = true;
297+
self.project(varargin); % do the projection
298+
else
299+
self.project([varargin, '*']);
300+
end
301+
302+
% promote the keys
303+
for iAttr = 1:numel(varargin)
304+
%renamed attribute
305+
toks = regexp(varargin{iAttr}, ...
306+
'^([a-z]\w*)\s*->\s*(\w+)', 'tokens');
307+
if ~isempty(toks)
308+
name = toks{1}{2};
309+
else
310+
%computed attribute
311+
toks = regexp(varargin{iAttr}, '(.*\S)\s*->\s*(\w+)', 'tokens');
312+
if ~isempty(toks)
313+
name = toks{1}{2};
314+
else
315+
%regular attribute
316+
name = varargin{iAttr};
317+
end
318+
end
319+
ix = find(strcmp(name, self.names));
320+
assert(~isempty(ix), 'DataJoint:missingAttributes', 'Attribute `%s` does not exist', ...
321+
name)
322+
self.attributes(ix).iskey = true;
323+
end
324+
end
325+
end
279326
end

+dj/U.m

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
classdef U
2+
properties (SetAccess=private, Hidden)
3+
primaryKey
4+
end
5+
6+
methods
7+
function self = U(varargin)
8+
% UNIVERSAL SET - a set representing all possible values of the
9+
% supplied attributes
10+
% Can be queried in combination with other relations to alter their
11+
% primary key structure.
12+
13+
self.primaryKey = varargin;
14+
% self.init('U', {}); % general relvar node
15+
end
16+
17+
function ret = and(self, arg)
18+
ret = self.restrict(arg);
19+
end
20+
21+
function ret = restrict(self, arg)
22+
% RESTRICT - relational restriction
23+
% dj.U(varargin) & A returns the unique combinations of the keys in
24+
% varargin that appear in A.
25+
26+
% for dj.U(), only support restricting by a relvar
27+
assert(isa(arg, 'dj.internal.GeneralRelvar'),...
28+
'restriction requires a relvar as operand');
29+
30+
% self = init(dj.internal.GeneralRelvar, 'U', {self, arg, 0});
31+
ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 0});
32+
end
33+
34+
function ret = mtimes(self, arg)
35+
% MTIMES - relational natural join.
36+
% dj.U(varargin) * A promotes the keys in varargin to the primary
37+
% key of A and returns the resulting relation.
38+
39+
assert(isa(arg, 'dj.internal.GeneralRelvar'), ...
40+
'mtimes requires another relvar as operand')
41+
ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 1});
42+
end
43+
44+
function ret = aggr(self, other, varargin)
45+
% AGGR -- relational aggregation operator.
46+
% dj.U(varargin).aggr(A,...) Allows grouping by arbitrary
47+
% combinations of the keys in A.
48+
49+
assert(iscellstr(varargin), ...
50+
'proj() requires a list of strings as attribute args')
51+
ret = init(dj.internal.GeneralRelvar, 'aggregate', ...
52+
[{self & other, self * other}, varargin]);
53+
%Note: join is required here to make projection semantics work.
54+
end
55+
end
56+
57+
end

tests/Main.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
TestSchema & ...
1414
TestTls & ...
1515
TestUuid & ...
16-
TestBlob
16+
TestBlob & ...
17+
TestRelationalOperator
1718
end

tests/TestRelationalOperator.m

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
classdef TestRelationalOperator < Prep
2+
methods (TestClassSetup)
3+
function init(testCase)
4+
init@Prep(testCase);
5+
package = 'University';
6+
dj.createSchema(package,[testCase.test_root '/test_schemas'], ...
7+
[testCase.PREFIX '_university']);
8+
University.Student().insert(struct(...
9+
'student_id', {1,2,3,4},...
10+
'first_name', {'John','Paul','George','Ringo'},...
11+
'last_name',{'Lennon','McCartney','Harrison','Starr'},...
12+
'enrolled',{'1960-01-01','1960-01-01','1960-01-01','1960-01-01'}...
13+
));
14+
15+
University.A().insert({1, 'test', '1960-01-01 00:00:00','1960-01-01', 1.234, struct()});
16+
end
17+
18+
end
19+
methods (Test)
20+
function TestRelationalOperator_testUnion(testCase)
21+
st = dbstack;
22+
disp(['---------------' st(1).name '---------------']);
23+
24+
25+
% Unions may only share primary, but not secondary, keys
26+
testCase.verifyError(@() count(University.Student() | University.Student()), 'DataJoint:invalidUnion');
27+
% The primary key of each relation must be the same
28+
testCase.verifyError(@() count(University.A() | University.Student()), 'DataJoint:invalidUnion');
29+
30+
% A basic union
31+
testCase.verifyEqual(count(...
32+
proj(University.Student() & 'student_id<2') | proj(University.Student() & 'student_id>3')),...
33+
2);
34+
35+
% Unions with overlapping primary keys are merged
36+
testCase.verifyEqual(count(...
37+
proj(University.Student() & 'student_id<3') | proj(University.Student() & 'student_id>1 AND student_id<4')),...
38+
3);
39+
40+
% Unions with disjoint secondary keys are also merged and filled with NULL
41+
a = University.Student & 'student_id<4';
42+
b = proj(University.Student() & 'student_id>1','"test_val"->test_col');
43+
c = fetch(a | b, '*');
44+
testCase.verifyEqual(length(c), 4);
45+
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name})), 1);
46+
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).test_col})), 1);
47+
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name}) & cellfun(@isempty,{c(:).test_col})), 0);
48+
49+
end
50+
51+
function TestRelationalOperator_testUniversalSet(testCase)
52+
st = dbstack;
53+
disp(['---------------' st(1).name '---------------']);
54+
55+
% dj.U() & rel has no attributes
56+
a = dj.U() & University.Student();
57+
testCase.verifyError(@() a.header.sql, 'DataJoint:missingAttributes');
58+
59+
% dj.U(c) * rel is invalid if c is not an attribute of rel
60+
testCase.verifyError(@() count(dj.U('bad_attribute') * University.Student()), 'DataJoint:missingAttributes');
61+
62+
% rel = dj.U(c) * rel promotes c to a primary key of rel
63+
a = dj.U('first_name') * University.Student();
64+
testCase.verifyTrue(ismember('first_name', a.primaryKey));
65+
testCase.verifyEqual(length(a.primaryKey), 2);
66+
67+
% dj.U(c) & rel returns the unique combinations of c in rel
68+
a = dj.U('enrolled') & University.Student();
69+
testCase.verifyEqual(count(a), 1);
70+
testCase.verifyEqual(length(a.header.attributes), 1);
71+
a = dj.U('last_name','enrolled') & University.Student();
72+
testCase.verifyEqual(count(a), 4);
73+
testCase.verifyEqual(length(a.header.attributes), 2);
74+
75+
% dj.U(c).aggr(rel, ...) aggregates into the groupings in c that exist in rel
76+
a = dj.U('last_name').aggr(University.Student(), 'length(min(first_name))->n_chars');
77+
testCase.verifyEqual(length(a.primaryKey),1);
78+
testCase.verifyTrue(strcmp(a.primaryKey{1}, 'last_name'));
79+
testCase.verifyEqual(count(a), 4);
80+
testCase.verifyEqual(length(a.nonKeyFields), 1);
81+
testCase.verifyTrue(strcmp(a.nonKeyFields{1}, 'n_chars'));
82+
83+
% dj.U(c) supports projection semantics on c
84+
a = dj.U('left(first_name,1)->first_initial') & University.Student();
85+
testCase.verifyEqual(length(a.primaryKey), 1);
86+
testCase.verifyTrue(strcmp(a.primaryKey{1}, 'first_initial'));
87+
testCase.verifyEqual(count(a), 4);
88+
89+
90+
end
91+
end
92+
end

0 commit comments

Comments
 (0)