diff --git a/+dj/+internal/GeneralRelvar.m b/+dj/+internal/GeneralRelvar.m index 24be944b..bea4d4c4 100644 --- a/+dj/+internal/GeneralRelvar.m +++ b/+dj/+internal/GeneralRelvar.m @@ -171,8 +171,12 @@ function clip(self) function n = count(self) % COUNT - the number of tuples in the relation. - [~, sql_] = self.compile(3); - n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_)); + [header_, sql_] = self.compile(3); + if header_.distinct + n = self.conn.query(sprintf('SELECT count(*) as n FROM (SELECT %s FROM %s) as `counted`', header_.sql, sql_)); + else + n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_)); + end n = double(n.n); end @@ -632,6 +636,8 @@ function restrict(self, varargin) % get reference to the connection object from the first table if strcmp(self.operator, 'table') conn = self.operands{1}.schema.conn; + elseif strcmp(self.operator, 'U') && isempty(self.operands) + conn = dj.conn; else conn = self.operands{1}.getConn; end @@ -665,9 +671,58 @@ function restrict(self, varargin) % apply relational operators recursively switch self.operator case 'union' - throwAsCaller(MException('DataJoint:invalidOperator', ... - 'The union operator must be used in a restriction')) + assert(... + length(union(self.operands{1}.primaryKey, self.operands{2}.primaryKey)) == ... + length(intersect(self.operands{1}.primaryKey, self.operands{2}.primaryKey)), ... + 'DataJoint:invalidUnion','Union operands must have the same primary key.'); + + assert(isempty(intersect(self.operands{1}.nonKeyFields, self.operands{2}.nonKeyFields)), ... + 'DataJoint:invalidUnion','Union operands may not have any common non-key attributes.'); + + % join the operands, which will form part of the query + [header,sql1] = compile(self.operands{1} * self.operands{2}, 0); + if isempty(self.operands{1}.nonKeyFields) && isempty(self.operands{2}.nonKeyFields) + % Only PKs: simple UNION + % We just need to make sure the attributes are in + % the same order (assumed by MySQL) + [header2,sql2] = compile(self.operands{1}, 0); + [~,fieldOrder] = ismember(header.names, header2.names); + header2.reorderFields(fieldOrder); + + [header3,sql3] = compile(self.operands{2}, 0); + [~,fieldOrder] = ismember(header.names, header3.names); + header3.reorderFields(fieldOrder); + sql = sprintf('((SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',... + header2.sql,sql2,header3.sql,sql3,aliasCount); + else + % With dependent fields, we first want to join any + % matching PKs, then union with the antijoin, + % (a-b)|(b-a). We will append NULL to the missing + % columns in the antijoin to make the query valid. + + fields = header.dependentFields; + nk = ismember(fields, self.operands{2}.nonKeyFields); + fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false); + + [header2,sql2] = compile(proj(self.operands{1} - self.operands{2}, fields{:}), 0); + [~,fieldOrder] = ismember(header.names, header2.names); + header2.reorderFields(fieldOrder); + + fields = header.dependentFields; + nk = ismember(fields, self.operands{1}.nonKeyFields); + fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false); + + [header3,sql3] = compile(proj(self.operands{2} - self.operands{1}, fields{:}), 0); + [~,fieldOrder] = ismember(header.names, header3.names); + header3.reorderFields(fieldOrder); + + sql = sprintf(... + '((SELECT %s FROM %s) UNION (SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',... + header.sql,sql1,header2.sql,sql2,header3.sql,sql3, aliasCount); + + end + case 'not' throwAsCaller(MException('DataJoint:invalidOperator', ... 'The NOT operator must be used in a restriction')) @@ -699,7 +754,10 @@ function restrict(self, varargin) sql = sprintf('%s NATURAL JOIN %s', sql1, sql2); header = join(header1,header2); clear header1 header2 sql1 sql2 - + + case 'U' + [header, sql] = compile(self.operands{1},2); + header.promote(self.operands{3}, self.operands{2}.primaryKey{:}); otherwise error 'unknown relational operator' end diff --git a/+dj/+internal/Header.m b/+dj/+internal/Header.m index e41a4999..a910e359 100644 --- a/+dj/+internal/Header.m +++ b/+dj/+internal/Header.m @@ -6,6 +6,7 @@ properties(SetAccess=private) info % table info attributes % array of attributes + distinct=false% whether to select all the elements or only distinct ones end properties(Access=private,Constant) @@ -204,7 +205,8 @@ function project(self, params) else % process a regular attribute ix = find(strcmp(params{iAttr},self.names)); - assert(~isempty(ix), 'Attribute `%s` does not exist', ... + assert(~isempty(ix),'DataJoint:missingAttributes',... + 'Attribute `%s` does not exist', ... params{iAttr}) end end @@ -242,7 +244,8 @@ function project(self, params) function sql = sql(self) % make an SQL list of attributes for header sql = ''; - assert(~isempty(self.attributes)) + assert(~isempty(self.attributes),... + 'DataJoint:missingAttributes','Relation has no attributes'); for i = 1:length(self.attributes) if isempty(self.attributes(i).alias) % if strcmp(self.attributes(i).type,'float') @@ -266,7 +269,11 @@ function project(self, params) end end end - sql = sql(2:end); % strip leading comma + sql = sql(2:end); % strip leading comma + + if self.distinct + sql = sprintf('DISTINCT %s', sql); + end end @@ -276,4 +283,44 @@ function stripAliases(self) end end end + + methods (Access = {?dj.internal.GeneralRelvar}) + function reorderFields(self, order) + assert(length(order) == length(self.names)); + self.attributes = self.attributes(order); + end + + function promote(self, keep, varargin) + if ~keep + [self.attributes(:).iskey] = deal(false); + self.distinct = true; + self.project(varargin); % do the projection + else + self.project([varargin, '*']); + end + + % promote the keys + for iAttr = 1:numel(varargin) + %renamed attribute + toks = regexp(varargin{iAttr}, ... + '^([a-z]\w*)\s*->\s*(\w+)', 'tokens'); + if ~isempty(toks) + name = toks{1}{2}; + else + %computed attribute + toks = regexp(varargin{iAttr}, '(.*\S)\s*->\s*(\w+)', 'tokens'); + if ~isempty(toks) + name = toks{1}{2}; + else + %regular attribute + name = varargin{iAttr}; + end + end + ix = find(strcmp(name, self.names)); + assert(~isempty(ix), 'DataJoint:missingAttributes', 'Attribute `%s` does not exist', ... + name) + self.attributes(ix).iskey = true; + end + end + end end diff --git a/+dj/U.m b/+dj/U.m new file mode 100644 index 00000000..d6c9212e --- /dev/null +++ b/+dj/U.m @@ -0,0 +1,57 @@ +classdef U +properties (SetAccess=private, Hidden) + primaryKey +end + +methods + function self = U(varargin) + % UNIVERSAL SET - a set representing all possible values of the + % supplied attributes + % Can be queried in combination with other relations to alter their + % primary key structure. + + self.primaryKey = varargin; + % self.init('U', {}); % general relvar node + end + + function ret = and(self, arg) + ret = self.restrict(arg); + end + + function ret = restrict(self, arg) + % RESTRICT - relational restriction + % dj.U(varargin) & A returns the unique combinations of the keys in + % varargin that appear in A. + + % for dj.U(), only support restricting by a relvar + assert(isa(arg, 'dj.internal.GeneralRelvar'),... + 'restriction requires a relvar as operand'); + +% self = init(dj.internal.GeneralRelvar, 'U', {self, arg, 0}); + ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 0}); + end + + function ret = mtimes(self, arg) + % MTIMES - relational natural join. + % dj.U(varargin) * A promotes the keys in varargin to the primary + % key of A and returns the resulting relation. + + assert(isa(arg, 'dj.internal.GeneralRelvar'), ... + 'mtimes requires another relvar as operand') + ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 1}); + end + + function ret = aggr(self, other, varargin) + % AGGR -- relational aggregation operator. + % dj.U(varargin).aggr(A,...) Allows grouping by arbitrary + % combinations of the keys in A. + + assert(iscellstr(varargin), ... + 'proj() requires a list of strings as attribute args') + ret = init(dj.internal.GeneralRelvar, 'aggregate', ... + [{self & other, self * other}, varargin]); + %Note: join is required here to make projection semantics work. + end +end + +end \ No newline at end of file diff --git a/tests/Main.m b/tests/Main.m index cda5ad9f..7264b209 100644 --- a/tests/Main.m +++ b/tests/Main.m @@ -13,5 +13,6 @@ TestSchema & ... TestTls & ... TestUuid & ... - TestBlob + TestBlob & ... + TestRelationalOperator end \ No newline at end of file diff --git a/tests/TestRelationalOperator.m b/tests/TestRelationalOperator.m new file mode 100644 index 00000000..8f002f40 --- /dev/null +++ b/tests/TestRelationalOperator.m @@ -0,0 +1,92 @@ +classdef TestRelationalOperator < Prep + methods (TestClassSetup) + function init(testCase) + init@Prep(testCase); + package = 'University'; + dj.createSchema(package,[testCase.test_root '/test_schemas'], ... + [testCase.PREFIX '_university']); + University.Student().insert(struct(... + 'student_id', {1,2,3,4},... + 'first_name', {'John','Paul','George','Ringo'},... + 'last_name',{'Lennon','McCartney','Harrison','Starr'},... + 'enrolled',{'1960-01-01','1960-01-01','1960-01-01','1960-01-01'}... + )); + + University.A().insert({1, 'test', '1960-01-01 00:00:00','1960-01-01', 1.234, struct()}); + end + + end + methods (Test) + function TestRelationalOperator_testUnion(testCase) + st = dbstack; + disp(['---------------' st(1).name '---------------']); + + + % Unions may only share primary, but not secondary, keys + testCase.verifyError(@() count(University.Student() | University.Student()), 'DataJoint:invalidUnion'); + % The primary key of each relation must be the same + testCase.verifyError(@() count(University.A() | University.Student()), 'DataJoint:invalidUnion'); + + % A basic union + testCase.verifyEqual(count(... + proj(University.Student() & 'student_id<2') | proj(University.Student() & 'student_id>3')),... + 2); + + % Unions with overlapping primary keys are merged + testCase.verifyEqual(count(... + proj(University.Student() & 'student_id<3') | proj(University.Student() & 'student_id>1 AND student_id<4')),... + 3); + + % Unions with disjoint secondary keys are also merged and filled with NULL + a = University.Student & 'student_id<4'; + b = proj(University.Student() & 'student_id>1','"test_val"->test_col'); + c = fetch(a | b, '*'); + testCase.verifyEqual(length(c), 4); + testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name})), 1); + testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).test_col})), 1); + testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name}) & cellfun(@isempty,{c(:).test_col})), 0); + + end + + function TestRelationalOperator_testUniversalSet(testCase) + st = dbstack; + disp(['---------------' st(1).name '---------------']); + + % dj.U() & rel has no attributes + a = dj.U() & University.Student(); + testCase.verifyError(@() a.header.sql, 'DataJoint:missingAttributes'); + + % dj.U(c) * rel is invalid if c is not an attribute of rel + testCase.verifyError(@() count(dj.U('bad_attribute') * University.Student()), 'DataJoint:missingAttributes'); + + % rel = dj.U(c) * rel promotes c to a primary key of rel + a = dj.U('first_name') * University.Student(); + testCase.verifyTrue(ismember('first_name', a.primaryKey)); + testCase.verifyEqual(length(a.primaryKey), 2); + + % dj.U(c) & rel returns the unique combinations of c in rel + a = dj.U('enrolled') & University.Student(); + testCase.verifyEqual(count(a), 1); + testCase.verifyEqual(length(a.header.attributes), 1); + a = dj.U('last_name','enrolled') & University.Student(); + testCase.verifyEqual(count(a), 4); + testCase.verifyEqual(length(a.header.attributes), 2); + + % dj.U(c).aggr(rel, ...) aggregates into the groupings in c that exist in rel + a = dj.U('last_name').aggr(University.Student(), 'length(min(first_name))->n_chars'); + testCase.verifyEqual(length(a.primaryKey),1); + testCase.verifyTrue(strcmp(a.primaryKey{1}, 'last_name')); + testCase.verifyEqual(count(a), 4); + testCase.verifyEqual(length(a.nonKeyFields), 1); + testCase.verifyTrue(strcmp(a.nonKeyFields{1}, 'n_chars')); + + % dj.U(c) supports projection semantics on c + a = dj.U('left(first_name,1)->first_initial') & University.Student(); + testCase.verifyEqual(length(a.primaryKey), 1); + testCase.verifyTrue(strcmp(a.primaryKey{1}, 'first_initial')); + testCase.verifyEqual(count(a), 4); + + + end + end +end \ No newline at end of file