@@ -171,8 +171,12 @@ function clip(self)
171171
172172 function n = count(self )
173173 % COUNT - the number of tuples in the relation.
174- [~ , sql_ ] = self .compile(3 );
175- n = self .conn .query(sprintf(' SELECT count(*) as n FROM %s ' , sql_ ));
174+ [header_ , sql_ ] = self .compile(3 );
175+ if header_ .distinct
176+ n = self .conn .query(sprintf(' SELECT count(*) as n FROM (SELECT %s FROM %s ) as `counted`' , header_ .sql , sql_ ));
177+ else
178+ n = self .conn .query(sprintf(' SELECT count(*) as n FROM %s ' , sql_ ));
179+ end
176180 n = double(n .n );
177181 end
178182
@@ -632,6 +636,8 @@ function restrict(self, varargin)
632636 % get reference to the connection object from the first table
633637 if strcmp(self .operator , ' table' )
634638 conn = self.operands{1 }.schema.conn;
639+ elseif strcmp(self .operator , ' U' ) && isempty(self .operands )
640+ conn = dj .conn ;
635641 else
636642 conn = self.operands{1 }.getConn;
637643 end
@@ -665,9 +671,58 @@ function restrict(self, varargin)
665671 % apply relational operators recursively
666672 switch self .operator
667673 case ' union'
668- throwAsCaller(MException(' DataJoint:invalidOperator' , ...
669- ' The union operator must be used in a restriction' ))
674+ assert(...
675+ length(union(self.operands{1 }.primaryKey, self.operands{2 }.primaryKey)) == ...
676+ length(intersect(self.operands{1 }.primaryKey, self.operands{2 }.primaryKey)), ...
677+ ' DataJoint:invalidUnion' ,' Union operands must have the same primary key.' );
678+
679+ assert(isempty(intersect(self.operands{1 }.nonKeyFields, self.operands{2 }.nonKeyFields)), ...
680+ ' DataJoint:invalidUnion' ,' Union operands may not have any common non-key attributes.' );
681+
682+ % join the operands, which will form part of the query
683+ [header ,sql1 ] = compile(self.operands{1 } * self.operands{2 }, 0 );
670684
685+ if isempty(self.operands{1 }.nonKeyFields) && isempty(self.operands{2 }.nonKeyFields)
686+ % Only PKs: simple UNION
687+ % We just need to make sure the attributes are in
688+ % the same order (assumed by MySQL)
689+ [header2 ,sql2 ] = compile(self.operands{1 }, 0 );
690+ [~ ,fieldOrder ] = ismember(header .names , header2 .names );
691+ header2 .reorderFields(fieldOrder );
692+
693+ [header3 ,sql3 ] = compile(self.operands{2 }, 0 );
694+ [~ ,fieldOrder ] = ismember(header .names , header3 .names );
695+ header3 .reorderFields(fieldOrder );
696+ sql = sprintf(' ((SELECT %s FROM %s ) UNION (SELECT %s FROM %s )) as `$s%x `' ,...
697+ header2 .sql ,sql2 ,header3 .sql ,sql3 ,aliasCount );
698+ else
699+ % With dependent fields, we first want to join any
700+ % matching PKs, then union with the antijoin,
701+ % (a-b)|(b-a). We will append NULL to the missing
702+ % columns in the antijoin to make the query valid.
703+
704+ fields = header .dependentFields ;
705+ nk = ismember(fields , self.operands{2 }.nonKeyFields);
706+ fields(nk ) = cellfun(@(s ) sprintf(' NULL -> %s ' ,s ), fields(nk ), ' UniformOutput' , false );
707+
708+ [header2 ,sql2 ] = compile(proj(self.operands{1 } - self.operands{2 }, fields{: }), 0 );
709+ [~ ,fieldOrder ] = ismember(header .names , header2 .names );
710+ header2 .reorderFields(fieldOrder );
711+
712+ fields = header .dependentFields ;
713+ nk = ismember(fields , self.operands{1 }.nonKeyFields);
714+ fields(nk ) = cellfun(@(s ) sprintf(' NULL -> %s ' ,s ), fields(nk ), ' UniformOutput' , false );
715+
716+ [header3 ,sql3 ] = compile(proj(self.operands{2 } - self.operands{1 }, fields{: }), 0 );
717+ [~ ,fieldOrder ] = ismember(header .names , header3 .names );
718+ header3 .reorderFields(fieldOrder );
719+
720+ sql = sprintf(...
721+ ' ((SELECT %s FROM %s ) UNION (SELECT %s FROM %s ) UNION (SELECT %s FROM %s )) as `$s%x `' ,...
722+ header .sql ,sql1 ,header2 .sql ,sql2 ,header3 .sql ,sql3 , aliasCount );
723+
724+ end
725+
671726 case ' not'
672727 throwAsCaller(MException(' DataJoint:invalidOperator' , ...
673728 ' The NOT operator must be used in a restriction' ))
@@ -699,7 +754,10 @@ function restrict(self, varargin)
699754 sql = sprintf(' %s NATURAL JOIN %s ' , sql1 , sql2 );
700755 header = join(header1 ,header2 );
701756 clear header1 header2 sql1 sql2
702-
757+
758+ case ' U'
759+ [header , sql ] = compile(self.operands{1 },2 );
760+ header .promote(self.operands{3 }, self.operands{2 }.primaryKey{: });
703761 otherwise
704762 error ' unknown relational operator'
705763 end
0 commit comments