Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions +dj/+internal/GeneralRelvar.m
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ function clip(self)

function n = count(self)
% COUNT - the number of tuples in the relation.
[~, sql_] = self.compile(3);
n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_));
[header_, sql_] = self.compile(3);
if header_.distinct
n = self.conn.query(sprintf('SELECT count(*) as n FROM (SELECT %s FROM %s) as `counted`', header_.sql, sql_));
else
n = self.conn.query(sprintf('SELECT count(*) as n FROM %s', sql_));
end
n = double(n.n);
end

Expand Down Expand Up @@ -632,6 +636,8 @@ function restrict(self, varargin)
% get reference to the connection object from the first table
if strcmp(self.operator, 'table')
conn = self.operands{1}.schema.conn;
elseif strcmp(self.operator, 'U') && isempty(self.operands)
conn = dj.conn;
else
conn = self.operands{1}.getConn;
end
Expand Down Expand Up @@ -665,9 +671,58 @@ function restrict(self, varargin)
% apply relational operators recursively
switch self.operator
case 'union'
throwAsCaller(MException('DataJoint:invalidOperator', ...
'The union operator must be used in a restriction'))
assert(...
length(union(self.operands{1}.primaryKey, self.operands{2}.primaryKey)) == ...
length(intersect(self.operands{1}.primaryKey, self.operands{2}.primaryKey)), ...
'DataJoint:invalidUnion','Union operands must have the same primary key.');

assert(isempty(intersect(self.operands{1}.nonKeyFields, self.operands{2}.nonKeyFields)), ...
'DataJoint:invalidUnion','Union operands may not have any common non-key attributes.');

% join the operands, which will form part of the query
[header,sql1] = compile(self.operands{1} * self.operands{2}, 0);

if isempty(self.operands{1}.nonKeyFields) && isempty(self.operands{2}.nonKeyFields)
% Only PKs: simple UNION
% We just need to make sure the attributes are in
% the same order (assumed by MySQL)
[header2,sql2] = compile(self.operands{1}, 0);
[~,fieldOrder] = ismember(header.names, header2.names);
header2.reorderFields(fieldOrder);

[header3,sql3] = compile(self.operands{2}, 0);
[~,fieldOrder] = ismember(header.names, header3.names);
header3.reorderFields(fieldOrder);
sql = sprintf('((SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',...
header2.sql,sql2,header3.sql,sql3,aliasCount);
else
% With dependent fields, we first want to join any
% matching PKs, then union with the antijoin,
% (a-b)|(b-a). We will append NULL to the missing
% columns in the antijoin to make the query valid.

fields = header.dependentFields;
nk = ismember(fields, self.operands{2}.nonKeyFields);
fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false);

[header2,sql2] = compile(proj(self.operands{1} - self.operands{2}, fields{:}), 0);
[~,fieldOrder] = ismember(header.names, header2.names);
header2.reorderFields(fieldOrder);

fields = header.dependentFields;
nk = ismember(fields, self.operands{1}.nonKeyFields);
fields(nk) = cellfun(@(s) sprintf('NULL -> %s',s), fields(nk), 'UniformOutput', false);

[header3,sql3] = compile(proj(self.operands{2} - self.operands{1}, fields{:}), 0);
[~,fieldOrder] = ismember(header.names, header3.names);
header3.reorderFields(fieldOrder);

sql = sprintf(...
'((SELECT %s FROM %s) UNION (SELECT %s FROM %s) UNION (SELECT %s FROM %s)) as `$s%x`',...
header.sql,sql1,header2.sql,sql2,header3.sql,sql3, aliasCount);

end

case 'not'
throwAsCaller(MException('DataJoint:invalidOperator', ...
'The NOT operator must be used in a restriction'))
Expand Down Expand Up @@ -699,7 +754,10 @@ function restrict(self, varargin)
sql = sprintf('%s NATURAL JOIN %s', sql1, sql2);
header = join(header1,header2);
clear header1 header2 sql1 sql2


case 'U'
[header, sql] = compile(self.operands{1},2);
header.promote(self.operands{3}, self.operands{2}.primaryKey{:});
otherwise
error 'unknown relational operator'
end
Expand Down
53 changes: 50 additions & 3 deletions +dj/+internal/Header.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
properties(SetAccess=private)
info % table info
attributes % array of attributes
distinct=false% whether to select all the elements or only distinct ones
Comment on lines 7 to +9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
info % table info
attributes % array of attributes
distinct=false% whether to select all the elements or only distinct ones
info % table info
attributes % array of attributes
distinct=false % whether to select all the elements or only distinct ones

end

properties(Access=private,Constant)
Expand Down Expand Up @@ -204,7 +205,8 @@ function project(self, params)
else
% process a regular attribute
ix = find(strcmp(params{iAttr},self.names));
assert(~isempty(ix), 'Attribute `%s` does not exist', ...
assert(~isempty(ix),'DataJoint:missingAttributes',...
'Attribute `%s` does not exist', ...
params{iAttr})
end
end
Expand Down Expand Up @@ -242,7 +244,8 @@ function project(self, params)
function sql = sql(self)
% make an SQL list of attributes for header
sql = '';
assert(~isempty(self.attributes))
assert(~isempty(self.attributes),...
'DataJoint:missingAttributes','Relation has no attributes');
for i = 1:length(self.attributes)
if isempty(self.attributes(i).alias)
% if strcmp(self.attributes(i).type,'float')
Expand All @@ -266,7 +269,11 @@ function project(self, params)
end
end
end
sql = sql(2:end); % strip leading comma
sql = sql(2:end); % strip leading comma

if self.distinct
sql = sprintf('DISTINCT %s', sql);
end
end


Expand All @@ -276,4 +283,44 @@ function stripAliases(self)
end
end
end

methods (Access = {?dj.internal.GeneralRelvar})
function reorderFields(self, order)
assert(length(order) == length(self.names));
self.attributes = self.attributes(order);
end

function promote(self, keep, varargin)
if ~keep
[self.attributes(:).iskey] = deal(false);
self.distinct = true;
self.project(varargin); % do the projection
else
self.project([varargin, '*']);
end

% promote the keys
for iAttr = 1:numel(varargin)
%renamed attribute
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%renamed attribute
% renamed attribute

toks = regexp(varargin{iAttr}, ...
'^([a-z]\w*)\s*->\s*(\w+)', 'tokens');
if ~isempty(toks)
name = toks{1}{2};
else
%computed attribute
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%computed attribute
% computed attribute

toks = regexp(varargin{iAttr}, '(.*\S)\s*->\s*(\w+)', 'tokens');
if ~isempty(toks)
name = toks{1}{2};
else
%regular attribute
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%regular attribute
% regular attribute

name = varargin{iAttr};
end
end
ix = find(strcmp(name, self.names));
assert(~isempty(ix), 'DataJoint:missingAttributes', 'Attribute `%s` does not exist', ...
name)
self.attributes(ix).iskey = true;
end
end
end
end
57 changes: 57 additions & 0 deletions +dj/U.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
classdef U
properties (SetAccess=private, Hidden)
primaryKey
end

methods
function self = U(varargin)
% UNIVERSAL SET - a set representing all possible values of the
% supplied attributes
% Can be queried in combination with other relations to alter their
% primary key structure.

self.primaryKey = varargin;
% self.init('U', {}); % general relvar node
end

function ret = and(self, arg)
ret = self.restrict(arg);
end

function ret = restrict(self, arg)
% RESTRICT - relational restriction
% dj.U(varargin) & A returns the unique combinations of the keys in
% varargin that appear in A.

% for dj.U(), only support restricting by a relvar
assert(isa(arg, 'dj.internal.GeneralRelvar'),...
'restriction requires a relvar as operand');

% self = init(dj.internal.GeneralRelvar, 'U', {self, arg, 0});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all instances commented code

Suggested change
% self = init(dj.internal.GeneralRelvar, 'U', {self, arg, 0});

ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 0});
end

function ret = mtimes(self, arg)
% MTIMES - relational natural join.
% dj.U(varargin) * A promotes the keys in varargin to the primary
% key of A and returns the resulting relation.

assert(isa(arg, 'dj.internal.GeneralRelvar'), ...
'mtimes requires another relvar as operand')
ret = init(dj.internal.GeneralRelvar, 'U', {arg, self, 1});
end

function ret = aggr(self, other, varargin)
% AGGR -- relational aggregation operator.
% dj.U(varargin).aggr(A,...) Allows grouping by arbitrary
% combinations of the keys in A.

assert(iscellstr(varargin), ...
'proj() requires a list of strings as attribute args')
ret = init(dj.internal.GeneralRelvar, 'aggregate', ...
[{self & other, self * other}, varargin]);
%Note: join is required here to make projection semantics work.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%Note: join is required here to make projection semantics work.
% Note: join is required here to make projection semantics work.

end
end

end
3 changes: 2 additions & 1 deletion tests/Main.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
TestSchema & ...
TestTls & ...
TestUuid & ...
TestBlob
TestBlob & ...
TestRelationalOperator
end
92 changes: 92 additions & 0 deletions tests/TestRelationalOperator.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
classdef TestRelationalOperator < Prep
methods (TestClassSetup)
function init(testCase)
init@Prep(testCase);
package = 'University';
dj.createSchema(package,[testCase.test_root '/test_schemas'], ...
[testCase.PREFIX '_university']);
University.Student().insert(struct(...
'student_id', {1,2,3,4},...
'first_name', {'John','Paul','George','Ringo'},...
'last_name',{'Lennon','McCartney','Harrison','Starr'},...
'enrolled',{'1960-01-01','1960-01-01','1960-01-01','1960-01-01'}...
));

University.A().insert({1, 'test', '1960-01-01 00:00:00','1960-01-01', 1.234, struct()});
end

end
methods (Test)
function TestRelationalOperator_testUnion(testCase)
st = dbstack;
disp(['---------------' st(1).name '---------------']);


% Unions may only share primary, but not secondary, keys
testCase.verifyError(@() count(University.Student() | University.Student()), 'DataJoint:invalidUnion');
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
testCase.verifyError(@() count(University.Student() | University.Student()), 'DataJoint:invalidUnion');
testCase.verifyError(@() count(University.Student() + University.Student()), 'DataJoint:invalidUnion');

Switch | to +. We are aware that using the + operator raises a warning suggesting to use | instead. This is was due to a convention that no longer applies. We are working on officially switching to +.

% The primary key of each relation must be the same
testCase.verifyError(@() count(University.A() | University.Student()), 'DataJoint:invalidUnion');
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
testCase.verifyError(@() count(University.A() | University.Student()), 'DataJoint:invalidUnion');
testCase.verifyError(@() count(University.A() + University.Student()), 'DataJoint:invalidUnion');


% A basic union
testCase.verifyEqual(count(...
proj(University.Student() & 'student_id<2') | proj(University.Student() & 'student_id>3')),...
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
proj(University.Student() & 'student_id<2') | proj(University.Student() & 'student_id>3')),...
proj(University.Student() & 'student_id<2') + proj(University.Student() & 'student_id>3')),...

2);

% Unions with overlapping primary keys are merged
testCase.verifyEqual(count(...
proj(University.Student() & 'student_id<3') | proj(University.Student() & 'student_id>1 AND student_id<4')),...
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
proj(University.Student() & 'student_id<3') | proj(University.Student() & 'student_id>1 AND student_id<4')),...
proj(University.Student() & 'student_id<3') + proj(University.Student() & 'student_id>1 AND student_id<4')),...

3);

% Unions with disjoint secondary keys are also merged and filled with NULL
a = University.Student & 'student_id<4';
b = proj(University.Student() & 'student_id>1','"test_val"->test_col');
c = fetch(a | b, '*');
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
c = fetch(a | b, '*');
c = fetch(a + b, '*');

testCase.verifyEqual(length(c), 4);
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name})), 1);
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).test_col})), 1);
testCase.verifyEqual(nnz(cellfun(@isempty,{c(:).first_name}) & cellfun(@isempty,{c(:).test_col})), 0);

end

function TestRelationalOperator_testUniversalSet(testCase)
st = dbstack;
disp(['---------------' st(1).name '---------------']);

% dj.U() & rel has no attributes
a = dj.U() & University.Student();
testCase.verifyError(@() a.header.sql, 'DataJoint:missingAttributes');

% dj.U(c) * rel is invalid if c is not an attribute of rel
testCase.verifyError(@() count(dj.U('bad_attribute') * University.Student()), 'DataJoint:missingAttributes');

% rel = dj.U(c) * rel promotes c to a primary key of rel
a = dj.U('first_name') * University.Student();
testCase.verifyTrue(ismember('first_name', a.primaryKey));
testCase.verifyEqual(length(a.primaryKey), 2);

% dj.U(c) & rel returns the unique combinations of c in rel
a = dj.U('enrolled') & University.Student();
testCase.verifyEqual(count(a), 1);
testCase.verifyEqual(length(a.header.attributes), 1);
a = dj.U('last_name','enrolled') & University.Student();
testCase.verifyEqual(count(a), 4);
testCase.verifyEqual(length(a.header.attributes), 2);

% dj.U(c).aggr(rel, ...) aggregates into the groupings in c that exist in rel
a = dj.U('last_name').aggr(University.Student(), 'length(min(first_name))->n_chars');
testCase.verifyEqual(length(a.primaryKey),1);
testCase.verifyTrue(strcmp(a.primaryKey{1}, 'last_name'));
testCase.verifyEqual(count(a), 4);
testCase.verifyEqual(length(a.nonKeyFields), 1);
testCase.verifyTrue(strcmp(a.nonKeyFields{1}, 'n_chars'));

% dj.U(c) supports projection semantics on c
a = dj.U('left(first_name,1)->first_initial') & University.Student();
testCase.verifyEqual(length(a.primaryKey), 1);
testCase.verifyTrue(strcmp(a.primaryKey{1}, 'first_initial'));
testCase.verifyEqual(count(a), 4);


end
end
end