|  | 
| 17 | 17 |     typing.Tuple[sympy.core.symbol.Symbol, sympy.core.expr.Expr]]] | 
| 18 | 18 | 
 | 
| 19 | 19 | 
 | 
|  | 20 | +def _which_side(vs: SetOfPoints, p: PointType, q: PointType) -> typing.Optional[int]: | 
|  | 21 | +    """Check which side of a line or plane a set of points are. | 
|  | 22 | +
 | 
|  | 23 | +    Args: | 
|  | 24 | +        vs: The set of points | 
|  | 25 | +        p: A point on the line or plane | 
|  | 26 | +        q: Another point on the line (2D) or the normal to the plane (3D) | 
|  | 27 | +
 | 
|  | 28 | +    Returns: | 
|  | 29 | +        2 if the points are all to the left, 1 if the points are all to the left or on the line, | 
|  | 30 | +        0 if the points are all on the line, -1 if the points are all to the right or on the line, | 
|  | 31 | +        -1 if the points are all to the right, None if there are some points on either side. | 
|  | 32 | +    """ | 
|  | 33 | +    sides = [] | 
|  | 34 | +    for v in vs: | 
|  | 35 | +        if len(q) == 2: | 
|  | 36 | +            cross = (v[0] - p[0]) * (q[1] - p[1]) - (v[1] - p[1]) * (q[0] - p[0]) | 
|  | 37 | +        elif len(q) == 3: | 
|  | 38 | +            cross = (v[0] - p[0]) * q[0] + (v[1] - p[1]) * q[1] + (v[2] - p[2]) * q[2] | 
|  | 39 | +        else: | 
|  | 40 | +            return None | 
|  | 41 | +        if cross == 0: | 
|  | 42 | +            sides.append(0) | 
|  | 43 | +        elif cross > 0: | 
|  | 44 | +            sides.append(1) | 
|  | 45 | +        else: | 
|  | 46 | +            sides.append(-1) | 
|  | 47 | + | 
|  | 48 | +    if -1 in sides and 1 in sides: | 
|  | 49 | +        return None | 
|  | 50 | +    if 1 in sides: | 
|  | 51 | +        if 0 in sides: | 
|  | 52 | +            return 1 | 
|  | 53 | +        else: | 
|  | 54 | +            return 2 | 
|  | 55 | +    if -1 in sides: | 
|  | 56 | +        if 0 in sides: | 
|  | 57 | +            return -1 | 
|  | 58 | +        else: | 
|  | 59 | +            return -2 | 
|  | 60 | +    return 0 | 
|  | 61 | + | 
|  | 62 | + | 
| 20 | 63 | def _vsub(v: PointTypeInput, w: PointTypeInput) -> PointType: | 
| 21 | 64 |     """Subtract. | 
| 22 | 65 | 
 | 
| @@ -155,6 +198,45 @@ def clockwise_vertices(self) -> SetOfPoints: | 
| 155 | 198 |         """ | 
| 156 | 199 |         return self.vertices | 
| 157 | 200 | 
 | 
|  | 201 | +    def intersection(self, other: Reference) -> typing.Optional[Reference]: | 
|  | 202 | +        """Get the intersection of two references. | 
|  | 203 | +
 | 
|  | 204 | +        Returns: | 
|  | 205 | +            A reference element that is the intersection | 
|  | 206 | +        """ | 
|  | 207 | +        if self.gdim != other.gdim: | 
|  | 208 | +            raise ValueError("Incompatible cell dimensions") | 
|  | 209 | + | 
|  | 210 | +        for cell1, cell2 in [(self, other), (other, self)]: | 
|  | 211 | +            try: | 
|  | 212 | +                for v in cell1.vertices: | 
|  | 213 | +                    if not cell2.contains(v): | 
|  | 214 | +                        break | 
|  | 215 | +                else: | 
|  | 216 | +                    return cell1 | 
|  | 217 | +            except NotImplementedError: | 
|  | 218 | +                pass | 
|  | 219 | +        for cell1, cell2 in [(self, other), (other, self)]: | 
|  | 220 | +            if cell1.gdim == 2: | 
|  | 221 | +                for e in cell1.edges: | 
|  | 222 | +                    p = cell1.vertices[e[0]] | 
|  | 223 | +                    q = cell1.vertices[e[1]] | 
|  | 224 | +                    dir1 = _which_side(cell1.vertices, p, q) | 
|  | 225 | +                    dir2 = _which_side(cell2.vertices, p, q) | 
|  | 226 | +                    if dir1 is not None and dir2 is not None and dir1 * dir2 < 0: | 
|  | 227 | +                        return None | 
|  | 228 | +            if cell1.gdim == 3: | 
|  | 229 | +                for i in range(cell1.sub_entity_count(2)): | 
|  | 230 | +                    face = cell1.sub_entity(2, i) | 
|  | 231 | +                    p = face.midpoint() | 
|  | 232 | +                    n = face.normal() | 
|  | 233 | +                    dir1 = _which_side(cell1.vertices, p, n) | 
|  | 234 | +                    dir2 = _which_side(cell2.vertices, p, n) | 
|  | 235 | +                    if dir1 is not None and dir2 is not None and dir1 * dir2 < 0: | 
|  | 236 | +                        return None | 
|  | 237 | + | 
|  | 238 | +        raise NotImplementedError("Intersection of these elements is not yet supported") | 
|  | 239 | + | 
| 158 | 240 |     @abstractmethod | 
| 159 | 241 |     def default_reference(self) -> Reference: | 
| 160 | 242 |         """Get the default reference for this cell type. | 
| @@ -1028,9 +1110,17 @@ def contains(self, point: PointType) -> bool: | 
| 1028 | 1110 |         Returns: | 
| 1029 | 1111 |             Is the point contained in the reference? | 
| 1030 | 1112 |         """ | 
| 1031 |  | -        if self.vertices != self.reference_vertices: | 
| 1032 |  | -            raise NotImplementedError() | 
| 1033 |  | -        return 0 <= point[0] and 0 <= point[1] and sum(point) <= 1 | 
|  | 1113 | +        if self.vertices == self.reference_vertices: | 
|  | 1114 | +            return 0 <= point[0] and 0 <= point[1] and sum(point) <= 1 | 
|  | 1115 | +        elif self.gdim == 2: | 
|  | 1116 | +            po = _vsub(point, self.origin) | 
|  | 1117 | +            det = self.axes[0][0] * self.axes[1][1] - self.axes[0][1] * self.axes[1][0] | 
|  | 1118 | +            t0 = (self.axes[1][1] * po[0] - self.axes[1][0] * po[1]) / det | 
|  | 1119 | +            t1 = (self.axes[0][0] * po[1] - self.axes[0][1] * po[0]) / det | 
|  | 1120 | +            print(self.origin, self.axes, point) | 
|  | 1121 | +            print(t0, t1) | 
|  | 1122 | +            return 0 <= t0 and 0 <= t1 and t0 + t1 <= 1 | 
|  | 1123 | +        raise NotImplementedError() | 
| 1034 | 1124 | 
 | 
| 1035 | 1125 | 
 | 
| 1036 | 1126 | class Tetrahedron(Reference): | 
| @@ -1203,9 +1293,15 @@ def contains(self, point: PointType) -> bool: | 
| 1203 | 1293 |         Returns: | 
| 1204 | 1294 |             Is the point contained in the reference? | 
| 1205 | 1295 |         """ | 
| 1206 |  | -        if self.vertices != self.reference_vertices: | 
| 1207 |  | -            raise NotImplementedError() | 
| 1208 |  | -        return 0 <= point[0] and 0 <= point[1] and 0 <= point[2] and sum(point) <= 1 | 
|  | 1296 | +        if self.vertices == self.reference_vertices: | 
|  | 1297 | +            return 0 <= point[0] and 0 <= point[1] and 0 <= point[2] and sum(point) <= 1 | 
|  | 1298 | +        else: | 
|  | 1299 | +            po = _vsub(point, self.origin) | 
|  | 1300 | +            minv = sympy.Matrix([[a[i] for a in self.axes] for i in range(3)]).inv() | 
|  | 1301 | +            t0 = (minv[0, 0] * po[0] + minv[0, 1] * po[1] + minv[0, 2] * po[2]) | 
|  | 1302 | +            t1 = (minv[1, 0] * po[0] + minv[1, 1] * po[1] + minv[1, 2] * po[2]) | 
|  | 1303 | +            t2 = (minv[2, 0] * po[0] + minv[2, 1] * po[1] + minv[2, 2] * po[2]) | 
|  | 1304 | +            return 0 <= t0 and 0 <= t1 and 0 >= t2 and t0 + t1 + t2 <= 1 | 
| 1209 | 1305 | 
 | 
| 1210 | 1306 | 
 | 
| 1211 | 1307 | class Quadrilateral(Reference): | 
|  | 
0 commit comments