diff --git a/deepxde/backend/backend.py b/deepxde/backend/backend.py index b744f9174..883967d11 100644 --- a/deepxde/backend/backend.py +++ b/deepxde/backend/backend.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from numbers import Number +from typing import Sequence, overload +from ..types import Tensor, dtype, SparseTensor, TensorOrTensors +from numpy.typing import NDArray, ArrayLike """This file defines the unified tensor framework interface required by DeepXDE. The principles of this interface: @@ -26,7 +32,7 @@ # Tensor, data type and context interfaces -def data_type_dict(): +def data_type_dict() -> dict[str, object]: """Returns a dictionary from data type string to the data type. The dictionary should include at least: @@ -58,7 +64,7 @@ def data_type_dict(): """ -def is_gpu_available(): +def is_gpu_available() -> bool: """Returns a bool indicating if GPU is currently available. Returns: @@ -66,11 +72,11 @@ def is_gpu_available(): """ -def is_tensor(obj): +def is_tensor(obj: object) -> bool: """Returns True if `obj` is a backend-native type tensor.""" -def shape(input_tensor): +def shape(input_tensor: Tensor) -> Sequence[int]: """Return the shape of the tensor. Args: @@ -81,7 +87,7 @@ def shape(input_tensor): """ -def size(input_tensor): +def size(input_tensor: Tensor) -> int: """Return the total number of elements in the input tensor. Args: @@ -92,7 +98,7 @@ def size(input_tensor): """ -def ndim(input_tensor): +def ndim(input_tensor: Tensor) -> int: """Returns the number of dimensions of the tensor. Args: @@ -103,7 +109,7 @@ def ndim(input_tensor): """ -def transpose(tensor, axes=None): +def transpose(tensor: Tensor, axes: Sequence[int] | int | None = None) -> Tensor: """Reverse or permute the axes of a tensor; returns the modified array. For a tensor with two axes, transpose gives the matrix transpose. @@ -117,7 +123,7 @@ def transpose(tensor, axes=None): """ -def reshape(tensor, shape): +def reshape(tensor: Tensor, shape: Sequence[int]) -> Tensor: """Gives a new shape to a tensor without changing its data. Args: @@ -130,7 +136,7 @@ def reshape(tensor, shape): """ -def Variable(initial_value, dtype=None): +def Variable(initial_value: Number, dtype: dtype = None) -> Tensor: """Return a trainable variable. Args: @@ -140,7 +146,7 @@ def Variable(initial_value, dtype=None): """ -def as_tensor(data, dtype=None): +def as_tensor(data: ArrayLike, dtype: dtype = None) -> Tensor: """Convert the data to a Tensor. If the data is already a tensor and has the same dtype, directly return. @@ -155,7 +161,7 @@ def as_tensor(data, dtype=None): """ -def sparse_tensor(indices, values, shape): +def sparse_tensor(indices: Sequence[Sequence[Number, Number]], values: Tensor, shape: Sequence[int]) -> SparseTensor: """Construct a sparse tensor based on given indices, values and shape. Args: @@ -170,7 +176,7 @@ def sparse_tensor(indices, values, shape): """ -def from_numpy(np_array): +def from_numpy(np_array: NDArray) -> Tensor: """Create a tensor that shares the underlying numpy array memory, if possible. Args: @@ -181,7 +187,7 @@ def from_numpy(np_array): """ -def to_numpy(input_tensor): +def to_numpy(input_tensor: Tensor) -> NDArray: """Create a numpy ndarray that shares the same underlying storage, if possible. Args: @@ -192,7 +198,7 @@ def to_numpy(input_tensor): """ -def concat(values, axis): +def concat(values: TensorOrTensors, axis: int) -> Tensor: """Returns the concatenation of the input tensors along the given dim. Args: @@ -204,7 +210,7 @@ def concat(values, axis): """ -def stack(values, axis): +def stack(values: TensorOrTensors, axis: int) -> Tensor: """Returns the stack of the input tensors along the given dim. Args: @@ -216,7 +222,7 @@ def stack(values, axis): """ -def expand_dims(tensor, axis): +def expand_dims(tensor: Tensor, axis: int) -> Tensor: """Expand dim for tensor along given axis. Args: @@ -228,7 +234,7 @@ def expand_dims(tensor, axis): """ -def reverse(tensor, axis): +def reverse(tensor: Tensor, axis: int) -> Tensor: """Reverse the order of elements along the given axis. Args: @@ -240,7 +246,7 @@ def reverse(tensor, axis): """ -def roll(tensor, shift, axis): +def roll(tensor: Tensor, shift: int | Sequence[int], axis: int | Sequence[int]) -> Tensor: """Roll the tensor along the given axis (axes). Args: @@ -261,67 +267,67 @@ def roll(tensor, shift, axis): # implementation in each framework. -def lgamma(x): +def lgamma(x: Tensor) -> Tensor: """Computes the natural logarithm of the absolute value of the gamma function of x element-wise. """ -def elu(x): +def elu(x: Tensor) -> Tensor: """Computes the exponential linear function.""" -def relu(x): +def relu(x: Tensor) -> Tensor: """Applies the rectified linear unit activation function.""" -def gelu(x): +def gelu(x: Tensor) -> Tensor: """Computes Gaussian Error Linear Unit function.""" -def selu(x): +def selu(x: Tensor) -> Tensor: """Computes scaled exponential linear.""" -def sigmoid(x): +def sigmoid(x: Tensor) -> Tensor: """Computes sigmoid of x element-wise.""" -def silu(x): +def silu(x: Tensor) -> Tensor: """Sigmoid Linear Unit (SiLU) function, also known as the swish function. silu(x) = x * sigmoid(x). """ -def sin(x): +def sin(x: Tensor) -> Tensor: """Computes sine of x element-wise.""" -def cos(x): +def cos(x: Tensor) -> Tensor: """Computes cosine of x element-wise.""" -def exp(x): +def exp(x: Tensor) -> Tensor: """Computes exponential of x element-wise.""" -def square(x): +def square(x: Tensor) -> Tensor: """Returns the square of the elements of input.""" -def abs(x): +def abs(x: Tensor) -> Tensor: """Computes the absolute value element-wise.""" -def minimum(x, y): +def minimum(x: Tensor, y: Tensor) -> Tensor: """Returns the minimum of x and y (i.e. x < y ? x : y) element-wise.""" -def tanh(x): +def tanh(x: Tensor) -> Tensor: """Computes hyperbolic tangent of x element-wise.""" -def pow(x, y): +def pow(x: Tensor, y: Number | Tensor) -> Tensor: """Computes the power of one value to another: x ^ y.""" @@ -332,15 +338,15 @@ def pow(x, y): # implementation in each framework. -def mean(input_tensor, dim, keepdims=False): +def mean(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor: """Returns the mean value of the input tensor in the given dimension dim.""" -def reduce_mean(input_tensor): +def reduce_mean(input_tensor: Tensor) -> Tensor: """Returns the mean value of all elements in the input tensor.""" -def sum(input_tensor, dim, keepdims=False): +def sum(input_tensor: Tensor, dim: int | Sequence[int], keepdims: Tensor = False): """Returns the sum of the input tensor along the given dim. Args: @@ -353,7 +359,7 @@ def sum(input_tensor, dim, keepdims=False): """ -def reduce_sum(input_tensor): +def reduce_sum(input_tensor: Tensor) -> Tensor: """Returns the sum of all elements in the input tensor. Args: @@ -364,7 +370,7 @@ def reduce_sum(input_tensor): """ -def prod(input_tensor, dim, keepdims=False): +def prod(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor: """Returns the product of the input tensor along the given dim. Args: @@ -377,7 +383,7 @@ def prod(input_tensor, dim, keepdims=False): """ -def reduce_prod(input_tensor): +def reduce_prod(input_tensor: Tensor) -> Tensor: """Returns the product of all elements in the input tensor. Args: @@ -388,7 +394,7 @@ def reduce_prod(input_tensor): """ -def min(input_tensor, dim, keepdims=False): +def min(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor: """Returns the minimum of the input tensor along the given dim. Args: @@ -401,7 +407,7 @@ def min(input_tensor, dim, keepdims=False): """ -def reduce_min(input_tensor): +def reduce_min(input_tensor: Tensor) -> Tensor: """Returns the minimum of all elements in the input tensor. Args: @@ -412,7 +418,7 @@ def reduce_min(input_tensor): """ -def max(input_tensor, dim, keepdims=False): +def max(input_tensor: Tensor, dim: int | Sequence[int], keepdims: bool = False) -> Tensor: """Returns the maximum of the input tensor along the given dim. Args: @@ -425,7 +431,7 @@ def max(input_tensor, dim, keepdims=False): """ -def reduce_max(input_tensor): +def reduce_max(input_tensor: Tensor) -> Tensor: """Returns the maximum of all elements in the input tensor. Args: @@ -436,7 +442,7 @@ def reduce_max(input_tensor): """ -def norm(tensor, ord=None, axis=None, keepdims=False): +def norm(tensor: Tensor, ord: Number | None = None, axis: int | None = None, keepdims: bool = False) -> Tensor: """Computes a vector norm. Due to the incompatibility of different backends, only some vector norms are @@ -457,7 +463,7 @@ def norm(tensor, ord=None, axis=None, keepdims=False): """ -def zeros(shape, dtype): +def zeros(shape: Sequence[int], dtype: dtype) -> Tensor: """Creates a tensor with all elements set to zero. Args: @@ -469,7 +475,7 @@ def zeros(shape, dtype): """ -def zeros_like(input_tensor): +def zeros_like(input_tensor: Tensor) -> Tensor: """Create a zero tensor with the same shape, dtype and context of the given tensor. Args: @@ -480,7 +486,7 @@ def zeros_like(input_tensor): """ -def matmul(x, y): +def matmul(x: Tensor, y: Tensor) -> Tensor: """Compute matrix multiplication for two matrices x and y. Args: @@ -492,6 +498,14 @@ def matmul(x, y): """ +@overload +def sparse_dense_matmul(x: SparseTensor, y: Tensor) -> Tensor: ... + + +@overload +def sparse_dense_matmul(x: SparseTensor, y: SparseTensor) -> SparseTensor: ... + + def sparse_dense_matmul(x, y): """Compute matrix multiplication of a sparse matrix x and a sparse/dense matrix y. diff --git a/deepxde/geometry/csg.py b/deepxde/geometry/csg.py index 7bbcf7b44..e580c6ee4 100644 --- a/deepxde/geometry/csg.py +++ b/deepxde/geometry/csg.py @@ -1,13 +1,12 @@ import numpy as np -from . import geometry +from .geometry import Geometry from .. import config - -class CSGUnion(geometry.Geometry): +class CSGUnion(Geometry): """Construct an object by CSG Union.""" - def __init__(self, geom1, geom2): + def __init__(self, geom1: Geometry, geom2: Geometry): if geom1.dim != geom2.dim: raise ValueError( "{} | {} failed (dimensions do not match).".format( @@ -101,10 +100,10 @@ def periodic_point(self, x, component): return x -class CSGDifference(geometry.Geometry): +class CSGDifference(Geometry): """Construct an object by CSG Difference.""" - def __init__(self, geom1, geom2): + def __init__(self, geom1: Geometry, geom2: Geometry): if geom1.dim != geom2.dim: raise ValueError( "{} - {} failed (dimensions do not match).".format( @@ -183,10 +182,10 @@ def periodic_point(self, x, component): return x -class CSGIntersection(geometry.Geometry): +class CSGIntersection(Geometry): """Construct an object by CSG Intersection.""" - def __init__(self, geom1, geom2): + def __init__(self, geom1: Geometry, geom2: Geometry): if geom1.dim != geom2.dim: raise ValueError( "{} & {} failed (dimensions do not match).".format( diff --git a/deepxde/geometry/geometry.py b/deepxde/geometry/geometry.py index 7564921bd..03f0374cf 100644 --- a/deepxde/geometry/geometry.py +++ b/deepxde/geometry/geometry.py @@ -1,37 +1,42 @@ +from __future__ import annotations + import abc -from typing import Literal +from typing import Callable, Literal +from numbers import Number import numpy as np +from numpy.typing import NDArray +from ..types import Tensor class Geometry(abc.ABC): - def __init__(self, dim, bbox, diam): + def __init__(self, dim: int, bbox: NDArray[np.float_], diam: Number): self.dim = dim self.bbox = bbox self.diam = min(diam, np.linalg.norm(bbox[1] - bbox[0])) self.idstr = type(self).__name__ @abc.abstractmethod - def inside(self, x): + def inside(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: """Check if x is inside the geometry (including the boundary).""" @abc.abstractmethod - def on_boundary(self, x): + def on_boundary(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: """Check if x is on the geometry boundary.""" - def distance2boundary(self, x, dirn): + def distance2boundary(self, x: NDArray[np.float_], dirn: Number) -> NDArray[np.float_]: raise NotImplementedError( "{}.distance2boundary to be implemented".format(self.idstr) ) - def mindist2boundary(self, x): + def mindist2boundary(self, x: NDArray[np.float_]) -> NDArray[np.float_]: raise NotImplementedError( "{}.mindist2boundary to be implemented".format(self.idstr) ) def boundary_constraint_factor( - self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" - ): + self, x: NDArray[np.float_], smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" + ) -> Tensor: """Compute the hard constraint factor at x for the boundary. This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). @@ -74,13 +79,13 @@ def boundary_constraint_factor( "{}.boundary_constraint_factor to be implemented".format(self.idstr) ) - def boundary_normal(self, x): + def boundary_normal(self, x: NDArray[np.float_]) -> NDArray[np.float_]: """Compute the unit normal at x for Neumann or Robin boundary conditions.""" raise NotImplementedError( "{}.boundary_normal to be implemented".format(self.idstr) ) - def uniform_points(self, n, boundary=True): + def uniform_points(self, n: int, boundary: bool = True)-> NDArray[np.float_]: """Compute the equispaced point locations in the geometry.""" print( "Warning: {}.uniform_points not implemented. Use random_points instead.".format( @@ -90,10 +95,10 @@ def uniform_points(self, n, boundary=True): return self.random_points(n) @abc.abstractmethod - def random_points(self, n, random="pseudo"): + def random_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: """Compute the random point locations in the geometry.""" - def uniform_boundary_points(self, n): + def uniform_boundary_points(self, n: int) -> NDArray[np.float_]: """Compute the equispaced point locations on the boundary.""" print( "Warning: {}.uniform_boundary_points not implemented. Use random_boundary_points instead.".format( @@ -103,51 +108,51 @@ def uniform_boundary_points(self, n): return self.random_boundary_points(n) @abc.abstractmethod - def random_boundary_points(self, n, random="pseudo"): + def random_boundary_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: """Compute the random point locations on the boundary.""" - def periodic_point(self, x, component): + def periodic_point(self, x: NDArray[np.float_], component: int | list[int]) -> NDArray[np.float_]: """Compute the periodic image of x for periodic boundary condition.""" raise NotImplementedError( "{}.periodic_point to be implemented".format(self.idstr) ) - def background_points(self, x, dirn, dist2npt, shift): + def background_points(self, x: NDArray[np.float_], dirn: Number, dist2npt: Callable[[NDArray[np.float_]], int], shift: int) -> NDArray[np.float_]: raise NotImplementedError( "{}.background_points to be implemented".format(self.idstr) ) - def union(self, other): + def union(self, other: Geometry) -> Geometry: """CSG Union.""" from . import csg return csg.CSGUnion(self, other) - def __or__(self, other): + def __or__(self, other: Geometry) -> Geometry: """CSG Union.""" from . import csg return csg.CSGUnion(self, other) - def difference(self, other): + def difference(self, other: Geometry) -> Geometry: """CSG Difference.""" from . import csg return csg.CSGDifference(self, other) - def __sub__(self, other): + def __sub__(self, other: Geometry) -> Geometry: """CSG Difference.""" from . import csg return csg.CSGDifference(self, other) - def intersection(self, other): + def intersection(self, other: Geometry) -> Geometry: """CSG Intersection.""" from . import csg return csg.CSGIntersection(self, other) - def __and__(self, other): + def __and__(self, other: Geometry) -> Geometry: """CSG Intersection.""" from . import csg diff --git a/deepxde/geometry/geometry_1d.py b/deepxde/geometry/geometry_1d.py index 98e0eccd1..ac57f39f8 100644 --- a/deepxde/geometry/geometry_1d.py +++ b/deepxde/geometry/geometry_1d.py @@ -1,4 +1,5 @@ -from typing import Literal, Union +from numbers import Number +from typing import Literal import numpy as np @@ -10,7 +11,7 @@ class Interval(Geometry): - def __init__(self, l, r): + def __init__(self, l: Number, r: Number): super().__init__(1, (np.array([l]), np.array([r])), r - l) self.l, self.r = l, r @@ -30,7 +31,7 @@ def boundary_constraint_factor( self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", - where: Union[None, Literal["left", "right"]] = None, + where: Literal["left", "right"] | None = None, ): """Compute the hard constraint factor at x for the boundary. @@ -115,7 +116,7 @@ def uniform_points(self, n, boundary=True): self.l, self.r, num=n + 1, endpoint=False, dtype=config.real(np) )[1:, None] - def log_uniform_points(self, n, boundary=True): + def log_uniform_points(self, n: int, boundary: bool = True): eps = 0 if self.l > 0 else np.finfo(config.real(np)).eps l = np.log(self.l + eps) r = np.log(self.r + eps) diff --git a/deepxde/geometry/geometry_2d.py b/deepxde/geometry/geometry_2d.py index 2a36e7fe8..1db2bd33f 100644 --- a/deepxde/geometry/geometry_2d.py +++ b/deepxde/geometry/geometry_2d.py @@ -1,8 +1,11 @@ -__all__ = ["Disk", "Ellipse", "Polygon", "Rectangle", "StarShaped", "Triangle"] +from __future__ import annotations -from typing import Union, Literal +__all__ = ["Disk", "Ellipse", "Polygon", "Rectangle", "StarShaped", "Triangle"] +from numbers import Number +from typing import Any, Literal import numpy as np +from numpy.typing import ArrayLike, NDArray from scipy import spatial from .geometry import Geometry @@ -10,10 +13,13 @@ from .sampler import sample from .. import backend as bkd from .. import config +from ..types import Tensor from ..utils import isclose, vectorize - class Disk(Hypersphere): + def __init__(self, center, radius): + super().__init__(center, radius) + def inside(self, x): return np.linalg.norm(x - self.center, axis=-1) <= self.radius @@ -84,7 +90,7 @@ class Ellipse(Geometry): counterclockwise about the center. """ - def __init__(self, center, semimajor, semiminor, angle=0): + def __init__(self, center: ArrayLike, semimajor: Number, semiminor: Number, angle: Number = 0): self.center = np.array(center, dtype=config.real(np)) self.semimajor = semimajor self.semiminor = semiminor @@ -272,10 +278,10 @@ def random_boundary_points(self, n, random="pseudo"): def _boundary_constraint_factor_inside( self, - x, - where: Union[None, Literal["left", "right", "bottom", "top"]] = None, + x: NDArray[np.float_], + where: Literal["left", "right", "bottom", "top"] | None = None, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", - ): + ) -> Tensor: """(Internal use only) Compute the hard constraint factor at `x` for the boundary. The points in `x` are assumed to live inside the geometry. @@ -316,9 +322,9 @@ def boundary_constraint_factor( self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", - where: Union[None, Literal["left", "right", "bottom", "top"]] = None, + where: Literal["left", "right", "bottom", "top"] | None = None, inside: bool = True, - ): + ) -> Tensor: """Compute the hard constraint factor at x for the boundary. This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). @@ -380,7 +386,7 @@ def boundary_constraint_factor( self.x12_tensor = bkd.as_tensor([self.xmin[0], self.xmax[1]]) self.x21_tensor = bkd.as_tensor([self.xmax[0], self.xmin[1]]) - dist_left = dist_right = dist_bottom = dist_top = None + dist_left = dist_right = dist_bottom = dist_top = 0.0 if where is None or where == "left": dist_left = bkd.abs( bkd.norm(x - self.x11_tensor, axis=-1, keepdims=True) @@ -421,7 +427,7 @@ def boundary_constraint_factor( return dist_left * dist_right * dist_bottom * dist_top @staticmethod - def is_valid(vertices): + def is_valid(vertices: NDArray[np.float_]) -> bool: """Check if the geometry is a Rectangle.""" return ( len(vertices) == 4 @@ -451,7 +457,7 @@ class StarShaped(Geometry): coeffs_sin: i-th order coefficients for the i-th sin term (b_i). """ - def __init__(self, center, radius, coeffs_cos, coeffs_sin): + def __init__(self, center: NDArray[np.float_], radius: Number, coeffs_cos: NDArray[np.float_], coeffs_sin: NDArray[np.float_]): self.center = np.array(center, dtype=config.real(np)) self.radius = radius self.coeffs_cos = coeffs_cos @@ -463,7 +469,7 @@ def __init__(self, center, radius, coeffs_cos, coeffs_sin): 2 * max_radius, ) - def _r_theta(self, theta): + def _r_theta(self, theta: Number) -> NDArray[np.float_]: """Define the parametrization r(theta) at angles theta.""" result = self.radius * np.ones(theta.shape) for i, (coeff_cos, coeff_sin) in enumerate( @@ -472,7 +478,7 @@ def _r_theta(self, theta): result += coeff_cos * np.cos(i * theta) + coeff_sin * np.sin(i * theta) return result - def _dr_theta(self, theta): + def _dr_theta(self, theta: Number) -> NDArray[np.float_]: """Evalutate the polar derivative r'(theta) at angles theta""" result = np.zeros(theta.shape) for i, (coeff_cos, coeff_sin) in enumerate( @@ -483,7 +489,7 @@ def _dr_theta(self, theta): ) return result - def inside(self, x): + def inside(self, x) -> NDArray[np.bool_]: r, theta = polar(x - self.center) r_theta = self._r_theta(theta) return r_theta >= r @@ -537,7 +543,7 @@ class Triangle(Geometry): vertices will be re-ordered in counterclockwise (right hand rule). """ - def __init__(self, x1, x2, x3): + def __init__(self, x1: Number, x2: Number, x3: Number): self.area = polygon_signed_area([x1, x2, x3]) # Clockwise if self.area < 0: @@ -684,7 +690,7 @@ def boundary_constraint_factor( self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", - where: Union[None, Literal["x1-x2", "x1-x3", "x2-x3"]] = None, + where: Literal["x1-x2", "x1-x3", "x2-x3"] | None = None, ): """Compute the hard constraint factor at x for the boundary. @@ -739,7 +745,7 @@ def boundary_constraint_factor( self.x2_tensor = bkd.as_tensor(self.x2) self.x3_tensor = bkd.as_tensor(self.x3) - diff_x1_x2 = diff_x1_x3 = diff_x2_x3 = None + # diff_x1_x2 = diff_x1_x3 = diff_x2_x3 = 0.0 if where not in ["x1-x3", "x2-x3"]: diff_x1_x2 = ( bkd.norm(x - self.x1_tensor, axis=-1, keepdims=True) @@ -779,7 +785,7 @@ class Polygon(Geometry): rule). """ - def __init__(self, vertices): + def __init__(self, vertices: ArrayLike): self.vertices = np.array(vertices, dtype=config.real(np)) if len(vertices) == 3: raise ValueError("The polygon is a triangle. Use Triangle instead.") @@ -813,7 +819,7 @@ def __init__(self, vertices): self.normal = clockwise_rotation_90(self.segments.T).T self.normal = self.normal / np.linalg.norm(self.normal, axis=1).reshape(-1, 1) - def inside(self, x): + def inside(self, x) -> NDArray[np.bool_]: def wn_PnPoly(P, V): """Winding number algorithm. @@ -927,7 +933,7 @@ def random_boundary_points(self, n, random="pseudo"): return np.vstack(x) -def polygon_signed_area(vertices): +def polygon_signed_area(vertices: ArrayLike) -> Number: """The (signed) area of a simple polygon. If the vertices are in the counterclockwise direction, then the area is positive; if @@ -941,12 +947,12 @@ def polygon_signed_area(vertices): return 0.5 * (np.sum(x[:-1] * y[1:]) - np.sum(x[1:] * y[:-1])) -def clockwise_rotation_90(v): +def clockwise_rotation_90(v: NDArray[np.float_]) -> NDArray[np.float_]: """Rotate a vector of 90 degrees clockwise about the origin.""" return np.array([v[1], -v[0]]) -def is_left(P0, P1, P2): +def is_left(P0: ArrayLike, P1: ArrayLike, P2: ArrayLike) -> NDArray[Any]: """Test if a point is Left|On|Right of an infinite line. See: the January 2001 Algorithm "Area of 2D and 3D Triangles and Polygons". @@ -963,7 +969,7 @@ def is_left(P0, P1, P2): return np.cross(P1 - P0, P2 - P0, axis=-1).reshape((-1, 1)) -def is_rectangle(vertices): +def is_rectangle(vertices: ArrayLike) -> bool: """Check if the geometry is a rectangle. https://stackoverflow.com/questions/2303278/find-if-4-points-on-a-plane-form-a-rectangle/2304031 @@ -979,7 +985,7 @@ def is_rectangle(vertices): return np.allclose(d, np.full(4, d[0])) -def is_on_line_segment(P0, P1, P2): +def is_on_line_segment(P0: ArrayLike, P1: ArrayLike, P2: ArrayLike) -> bool: """Test if a point is between two other points on a line segment. Args: @@ -1005,7 +1011,7 @@ def is_on_line_segment(P0, P1, P2): # or isclose(np.linalg.norm(v12), 0) # check whether P2 is close to P1 -def polar(x): +def polar(x: NDArray[np.float_]) -> tuple[NDArray[np.float_], NDArray[np.float_]]: """Get the polar coordinated for a 2d vector in cartesian coordinates.""" r = np.sqrt(x[:, 0] ** 2 + x[:, 1] ** 2) theta = np.arctan2(x[:, 1], x[:, 0]) diff --git a/deepxde/geometry/geometry_3d.py b/deepxde/geometry/geometry_3d.py index ae08a239d..b52071a52 100644 --- a/deepxde/geometry/geometry_3d.py +++ b/deepxde/geometry/geometry_3d.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import itertools -from typing import Union, Literal +from typing import Literal import numpy as np @@ -72,9 +74,7 @@ def boundary_constraint_factor( self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", - where: Union[ - None, Literal["back", "front", "left", "right", "bottom", "top"] - ] = None, + where: Literal["back", "front", "left", "right", "bottom", "top"] | None = None, inside: bool = True, ): """Compute the hard constraint factor at x for the boundary. @@ -138,7 +138,6 @@ def boundary_constraint_factor( self.xmin_tensor = bkd.as_tensor(self.xmin) self.xmax_tensor = bkd.as_tensor(self.xmax) - dist_l = dist_r = None if where not in ["front", "right", "top"]: dist_l = bkd.abs( (x - self.xmin_tensor) / (self.xmax_tensor - self.xmin_tensor) * 2 diff --git a/deepxde/geometry/geometry_nd.py b/deepxde/geometry/geometry_nd.py index a011cd417..db17009d9 100644 --- a/deepxde/geometry/geometry_nd.py +++ b/deepxde/geometry/geometry_nd.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import itertools -from typing import Literal +from numbers import Number +from typing import Callable, Literal import numpy as np +from numpy.typing import NDArray + from scipy import stats from sklearn import preprocessing @@ -9,11 +14,12 @@ from .sampler import sample from .. import backend as bkd from .. import config +from ..types import Tensor from ..utils import isclose class Hypercube(Geometry): - def __init__(self, xmin, xmax): + def __init__(self, xmin: NDArray[np.float_], xmax: NDArray[np.float_]): if len(xmin) != len(xmax): raise ValueError("Dimensions of xmin and xmax do not match.") @@ -28,19 +34,19 @@ def __init__(self, xmin, xmax): ) self.volume = np.prod(self.side_length) - def inside(self, x): + def inside(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: return np.logical_and( np.all(x >= self.xmin, axis=-1), np.all(x <= self.xmax, axis=-1) ) - def on_boundary(self, x): + def on_boundary(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: _on_boundary = np.logical_or( np.any(isclose(x, self.xmin), axis=-1), np.any(isclose(x, self.xmax), axis=-1), ) return np.logical_and(self.inside(x), _on_boundary) - def boundary_normal(self, x): + def boundary_normal(self, x: NDArray[np.float_]) -> NDArray[np.float_]: _n = -isclose(x, self.xmin).astype(config.real(np)) + isclose(x, self.xmax) # For vertices, the normal is averaged for all directions idx = np.count_nonzero(_n, axis=-1) > 1 @@ -53,7 +59,7 @@ def boundary_normal(self, x): _n[idx] /= l return _n - def uniform_points(self, n, boundary=True): + def uniform_points(self, n: int, boundary: bool = True) -> NDArray[np.float_]: dx = (self.volume / n) ** (1 / self.dim) xi = [] for i in range(self.dim): @@ -81,11 +87,11 @@ def uniform_points(self, n, boundary=True): ) return x - def random_points(self, n, random="pseudo"): + def random_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: x = sample(n, self.dim, random) return (self.xmax - self.xmin) * x + self.xmin - def random_boundary_points(self, n, random="pseudo"): + def random_boundary_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: x = sample(n, self.dim, random) # Randomly pick a dimension rand_dim = np.random.randint(self.dim, size=n) @@ -93,7 +99,7 @@ def random_boundary_points(self, n, random="pseudo"): x[np.arange(n), rand_dim] = np.round(x[np.arange(n), rand_dim]) return (self.xmax - self.xmin) * x + self.xmin - def periodic_point(self, x, component): + def periodic_point(self, x: NDArray[np.float_], component: int | list[int]) -> NDArray[np.float_]: y = np.copy(x) _on_xmin = isclose(y[:, component], self.xmin[component]) _on_xmax = isclose(y[:, component], self.xmax[component]) @@ -103,11 +109,11 @@ def periodic_point(self, x, component): def boundary_constraint_factor( self, - x, + x: NDArray[np.float_], smoothness: Literal["C0", "C0+", "Cinf"] = "C0", - where: None = None, + where: str | None = None, inside: bool = True, - ): + ) -> Tensor: """Compute the hard constraint factor at x for the boundary. This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). @@ -184,7 +190,7 @@ def boundary_constraint_factor( class Hypersphere(Geometry): - def __init__(self, center, radius): + def __init__(self, center: Number, radius: Number): self.center = np.array(center, dtype=config.real(np)) self.radius = radius super().__init__( @@ -193,13 +199,13 @@ def __init__(self, center, radius): self._r2 = radius**2 - def inside(self, x): + def inside(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: return np.linalg.norm(x - self.center, axis=-1) <= self.radius - def on_boundary(self, x): + def on_boundary(self, x: NDArray[np.float_]) -> NDArray[np.bool_]: return isclose(np.linalg.norm(x - self.center, axis=-1), self.radius) - def distance2boundary_unitdirn(self, x, dirn): + def distance2boundary_unitdirn(self, x: NDArray[np.float_], dirn: Number) -> NDArray[np.float_]: # https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection xc = x - self.center ad = np.dot(xc, dirn) @@ -207,15 +213,15 @@ def distance2boundary_unitdirn(self, x, dirn): config.real(np) ) - def distance2boundary(self, x, dirn): + def distance2boundary(self, x: NDArray[np.float_], dirn: Number) -> NDArray[np.float_]: return self.distance2boundary_unitdirn(x, dirn / np.linalg.norm(dirn)) - def mindist2boundary(self, x): + def mindist2boundary(self, x: NDArray[np.float_]) -> NDArray[np.float_]: return np.amin(self.radius - np.linalg.norm(x - self.center, axis=-1)) def boundary_constraint_factor( - self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" - ): + self, x: NDArray[np.float_], smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" + ) -> Tensor: if smoothness not in ["C0", "C0+", "Cinf"]: raise ValueError("smoothness must be one of C0, C0+, Cinf") @@ -230,13 +236,13 @@ def boundary_constraint_factor( dist = bkd.abs(dist) return dist - def boundary_normal(self, x): + def boundary_normal(self, x: NDArray[np.float_]) -> NDArray[np.float_]: _n = x - self.center l = np.linalg.norm(_n, axis=-1, keepdims=True) _n = _n / l * isclose(l, self.radius) return _n - def random_points(self, n, random="pseudo"): + def random_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: # https://math.stackexchange.com/questions/87230/picking-random-points-in-the-volume-of-sphere-with-uniform-probability if random == "pseudo": U = np.random.rand(n, 1).astype(config.real(np)) @@ -249,7 +255,7 @@ def random_points(self, n, random="pseudo"): X = U ** (1 / self.dim) * X return self.radius * X + self.center - def random_boundary_points(self, n, random="pseudo"): + def random_boundary_points(self, n: int, random: str = "pseudo") -> NDArray[np.float_]: # http://mathworld.wolfram.com/HyperspherePointPicking.html if random == "pseudo": X = np.random.normal(size=(n, self.dim)).astype(config.real(np)) @@ -259,7 +265,7 @@ def random_boundary_points(self, n, random="pseudo"): X = preprocessing.normalize(X) return self.radius * X + self.center - def background_points(self, x, dirn, dist2npt, shift): + def background_points(self, x: NDArray[np.float_], dirn: Number, dist2npt: Callable[[NDArray[np.float_]], int], shift: int) -> NDArray[np.float_]: dirn = dirn / np.linalg.norm(dirn) dx = self.distance2boundary_unitdirn(x, -dirn) n = max(dist2npt(dx), 1) diff --git a/deepxde/geometry/pointcloud.py b/deepxde/geometry/pointcloud.py index 5644a79ae..d9883cfd9 100644 --- a/deepxde/geometry/pointcloud.py +++ b/deepxde/geometry/pointcloud.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import numpy as np +from numpy.typing import NDArray from .geometry import Geometry from .. import config -from ..data import BatchSampler +from ..data.sampler import BatchSampler from ..utils import isclose @@ -17,7 +20,7 @@ class PointCloud(Geometry): boundary_normals: A 2-D NumPy array. """ - def __init__(self, points, boundary_points=None, boundary_normals=None): + def __init__(self, points: NDArray[np.float_], boundary_points: NDArray[np.float_] | None = None, boundary_normals: NDArray[np.float_] | None = None): self.points = np.asarray(points, dtype=config.real(np)) self.num_points = len(points) self.boundary_points = None diff --git a/deepxde/geometry/sampler.py b/deepxde/geometry/sampler.py index 7b384eddc..19da42341 100644 --- a/deepxde/geometry/sampler.py +++ b/deepxde/geometry/sampler.py @@ -1,12 +1,14 @@ __all__ = ["sample"] +from typing import Literal import numpy as np +from numpy.typing import NDArray import skopt from .. import config -def sample(n_samples, dimension, sampler="pseudo"): +def sample(n_samples: int, dimension: int, sampler: Literal["pseudo", "LHS", "Halton", "Hammersley", "Sobol"] = "pseudo"): """Generate pseudorandom or quasirandom samples in [0, 1]^dimension. Args: @@ -23,7 +25,7 @@ def sample(n_samples, dimension, sampler="pseudo"): raise ValueError("f{sampler} sampling is not available.") -def pseudorandom(n_samples, dimension): +def pseudorandom(n_samples: int, dimension: int) -> NDArray[np.float_]: """Pseudo random.""" # If random seed is set, then the rng based code always returns the same random # number, which may not be what we expect. @@ -32,7 +34,7 @@ def pseudorandom(n_samples, dimension): return np.random.random(size=(n_samples, dimension)).astype(config.real(np)) -def quasirandom(n_samples, dimension, sampler): +def quasirandom(n_samples: int, dimension: int, sampler: Literal["LHS", "Halton", "Hammersley", "Sobol"]) -> NDArray[np.float_]: # Certain points should be removed: # - Boundary points such as [..., 0, ...] # - Special points [0, 0, 0, ...] and [0.5, 0.5, 0.5, ...], which cause error in diff --git a/deepxde/geometry/timedomain.py b/deepxde/geometry/timedomain.py index 0ee871b6c..483cee95f 100644 --- a/deepxde/geometry/timedomain.py +++ b/deepxde/geometry/timedomain.py @@ -1,7 +1,10 @@ import itertools +from numbers import Number import numpy as np +from numpy.typing import NDArray +from .geometry import Geometry from .geometry_1d import Interval from .geometry_2d import Rectangle from .geometry_3d import Cuboid @@ -11,17 +14,17 @@ class TimeDomain(Interval): - def __init__(self, t0, t1): + def __init__(self, t0: Number, t1: Number): super().__init__(t0, t1) self.t0 = t0 self.t1 = t1 - def on_initial(self, t): + def on_initial(self, t: NDArray[np.float_]) -> NDArray[np.bool_]: return isclose(t, self.t0).flatten() -class GeometryXTime: - def __init__(self, geometry, timedomain): +class GeometryXTime(): + def __init__(self, geometry: Geometry, timedomain: TimeDomain): self.geometry = geometry self.timedomain = timedomain self.dim = geometry.dim + timedomain.dim diff --git a/deepxde/icbc/boundary_conditions.py b/deepxde/icbc/boundary_conditions.py index e1f863a08..01bd7a26d 100644 --- a/deepxde/icbc/boundary_conditions.py +++ b/deepxde/icbc/boundary_conditions.py @@ -1,3 +1,4 @@ +from __future__ import annotations """Boundary conditions.""" __all__ = [ @@ -14,8 +15,10 @@ import numbers from abc import ABC, abstractmethod from functools import wraps +from typing import Any, Callable, overload import numpy as np +from numpy.typing import NDArray, ArrayLike from .. import backend as bkd from .. import config @@ -23,6 +26,8 @@ from .. import gradients as grad from .. import utils from ..backend import backend_name +from ..geometry import Geometry +from ..types import Tensor, TensorOrTensors class BC(ABC): @@ -30,11 +35,17 @@ class BC(ABC): Args: geom: A ``deepxde.geometry.Geometry`` instance. - on_boundary: A function: (x, Geometry.on_boundary(x)) -> True/False. - component: The output component satisfying this BC. + on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False. + component: The output component satisfying this BC, should be provided + if ``BC.error`` involves derivatives and the output has multiple components. """ - def __init__(self, geom, on_boundary, component): + def __init__( + self, + geom: Geometry, + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + component: list[int] | int, + ): self.geom = geom self.on_boundary = lambda x, on: np.array( [on_boundary(x[i], on[i]) for i in range(len(x))] @@ -45,28 +56,57 @@ def __init__(self, geom, on_boundary, component): utils.return_tensor(self.geom.boundary_normal) ) - def filter(self, X): + def filter(self, X: NDArray[Any]) -> NDArray[np.bool_]: return X[self.on_boundary(X, self.geom.on_boundary(X))] - def collocation_points(self, X): + def collocation_points(self, X: NDArray[Any]) -> NDArray[Any]: return self.filter(X) - def normal_derivative(self, X, inputs, outputs, beg, end): + def normal_derivative( + self, + X: NDArray[Any], + inputs: TensorOrTensors, + outputs: Tensor, + beg: int, + end: int, + ) -> Tensor: dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] n = self.boundary_normal(X, beg, end, None) return bkd.sum(dydx * n, 1, keepdims=True) @abstractmethod - def error(self, X, inputs, outputs, beg, end, aux_var=None): + def error( + self, + X: NDArray[Any], + inputs: TensorOrTensors, + outputs: Tensor, + beg: int, + end: int, + aux_var: NDArray[np.float_] | None = None, + ) -> Tensor: """Returns the loss.""" # aux_var is used in PI-DeepONet, where aux_var is the input function evaluated # at x. class DirichletBC(BC): - """Dirichlet boundary conditions: y(x) = func(x).""" + """Dirichlet boundary conditions: `y(x) = func(x)`. + + Args: + geom: A ``deepxde.geometry.Geometry`` instance. + func: A function: `x` -> `y`. + on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False. + component: The output component satisfying this BC, should be provided + if ``BC.error`` involves derivatives and the output has multiple components. + """ - def __init__(self, geom, func, on_boundary, component=0): + def __init__( + self, + geom: Geometry, + func: Callable[[NDArray[np.float_]], NDArray[np.float_]], + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + component: list[int] | int = 0, + ): super().__init__(geom, on_boundary, component) self.func = npfunc_range_autocache(utils.return_tensor(func)) @@ -81,9 +121,23 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): class NeumannBC(BC): - """Neumann boundary conditions: dy/dn(x) = func(x).""" + """Neumann boundary conditions: `dy/dn(x) = func(x)`. - def __init__(self, geom, func, on_boundary, component=0): + Args: + geom: A ``deepxde.geometry.Geometry`` instance. + func: A function: `x` -> `dy/dn`. + on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False. + component: The output component satisfying this BC, should be provided + if ``BC.error`` involves derivatives and the output has multiple components. + """ + + def __init__( + self, + geom: Geometry, + func: Callable[[NDArray[np.float_]], NDArray[np.float_]], + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + component: list[int] | int = 0, + ): super().__init__(geom, on_boundary, component) self.func = npfunc_range_autocache(utils.return_tensor(func)) @@ -93,9 +147,23 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): class RobinBC(BC): - """Robin boundary conditions: dy/dn(x) = func(x, y).""" + """Robin boundary conditions: `dy/dn(x) = func(x, y)`. + + Args: + geom: A ``deepxde.geometry.Geometry`` instance. + func: A function: `(x, y)` -> `dy/dn`. + on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False. + component: The output component satisfying this BC, should be provided + if ``BC.error`` involves derivatives and the output has multiple components. + """ - def __init__(self, geom, func, on_boundary, component=0): + def __init__( + self, + geom: Geometry, + func: Callable[[NDArray[np.float_]], NDArray[np.float_]], + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + component: list[int] | int = 0, + ): super().__init__(geom, on_boundary, component) self.func = func @@ -106,9 +174,25 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): class PeriodicBC(BC): - """Periodic boundary conditions on component_x.""" + """Periodic boundary conditions on component_x. + + Args: + geom: A ``deepxde.geometry.Geometry`` instance. + component_x: The component of the input satisfying this BC. + on_boundary: A function: `(x, Geometry.on_boundary(x))` -> True/False. + derivative_order: The derivative order of the output satisfying this BC. + component: The output component satisfying this BC, should be provided + if ``BC.error`` involves derivatives and the output has multiple components. + """ - def __init__(self, geom, component_x, on_boundary, derivative_order=0, component=0): + def __init__( + self, + geom: Geometry, + component_x: int, + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + derivative_order: int = 0, + component: list[int] | int = 0, + ): super().__init__(geom, on_boundary, component) self.component_x = component_x self.derivative_order = derivative_order @@ -135,11 +219,11 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): class OperatorBC(BC): - """General operator boundary conditions: func(inputs, outputs, X) = 0. + """General operator boundary conditions: `func(inputs, outputs, X) = 0`. Args: - geom: ``Geometry``. - func: A function takes arguments (`inputs`, `outputs`, `X`) + geom: A ``deepxde.geometry.Geometry`` instance. + func: A function takes arguments `(inputs, outputs, X)` and outputs a tensor of size `N x 1`, where `N` is the length of `inputs`. `inputs` and `outputs` are the network input and output tensors, respectively; `X` are the NumPy array of the `inputs`. @@ -153,7 +237,12 @@ class OperatorBC(BC): which cannot be fixed in an easy way for all backends. """ - def __init__(self, geom, func, on_boundary): + def __init__( + self, + geom: Geometry, + func: Callable[[TensorOrTensors, Tensor, NDArray[np.float_]], Tensor], + on_boundary: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + ): super().__init__(geom, on_boundary, 0) self.func = func @@ -161,7 +250,7 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): return self.func(inputs, outputs, X)[beg:end] -class PointSetBC: +class PointSetBC(BC): """Dirichlet boundary condition for a set of points. Compare the output (that associates with `points`) with `values` (target data). @@ -172,7 +261,7 @@ class PointSetBC: points: An array of points where the corresponding target values are known and used for training. values: A scalar or a 2D-array of values that gives the exact solution of the problem. - component: Integer or a list of integers. The output components satisfying this BC. + omponent: Integer or a list of integers. The output components satisfying this BC. List of integers only supported for the backend PyTorch. batch_size: The number of points per minibatch, or `None` to return all points. This is only supported for the backend PyTorch and PaddlePaddle. @@ -181,7 +270,14 @@ class PointSetBC: shuffle: Randomize the order on each pass through the data when batching. """ - def __init__(self, points, values, component=0, batch_size=None, shuffle=True): + def __init__( + self, + points: ArrayLike, + values: ArrayLike, + component: list[int] | int = 0, + batch_size: int | None = None, + shuffle: bool = True, + ): self.points = np.array(points, dtype=config.real(np)) self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib)) self.component = component @@ -233,7 +329,7 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): return outputs[beg:end, self.component] - self.values -class PointSetOperatorBC: +class PointSetOperatorBC(BC): """General operator boundary conditions for a set of points. Compare the function output, func, (that associates with `points`) @@ -249,7 +345,12 @@ class PointSetOperatorBC: tensors, respectively; `X` are the NumPy array of the `inputs`. """ - def __init__(self, points, values, func): + def __init__( + self, + points: ArrayLike, + values: ArrayLike, + func: Callable[[TensorOrTensors, Tensor, NDArray[np.float_]], Tensor], + ): self.points = np.array(points, dtype=config.real(np)) if not isinstance(values, numbers.Number) and values.shape[1] != 1: raise RuntimeError("PointSetOperatorBC should output 1D values") @@ -263,6 +364,20 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): return self.func(inputs, outputs, X)[beg:end] - self.values +@overload +def npfunc_range_autocache( + func: Callable[[NDArray[np.float_]], NDArray[np.float_]] +) -> NDArray[np.float_]: + ... + + +@overload +def npfunc_range_autocache( + func: Callable[[NDArray[np.float_], NDArray[np.float_]], NDArray[np.float_]] +) -> NDArray[np.float_]: + ... + + def npfunc_range_autocache(func): """Call a NumPy function on a range of the input ndarray. @@ -291,22 +406,30 @@ def npfunc_range_autocache(func): cache = {} @wraps(func) - def wrapper_nocache(X, beg, end, _): + def wrapper_nocache( + X: NDArray[np.float_], beg: int, end: int, _ + ) -> NDArray[np.float_]: return func(X[beg:end]) @wraps(func) - def wrapper_nocache_auxiliary(X, beg, end, aux_var): + def wrapper_nocache_auxiliary( + X: NDArray[np.float_], beg: int, end: int, aux_var: NDArray[np.float_] + ) -> NDArray[np.float_]: return func(X[beg:end], aux_var[beg:end]) @wraps(func) - def wrapper_cache(X, beg, end, _): + def wrapper_cache( + X: NDArray[np.float_], beg: int, end: int, _ + ) -> NDArray[np.float_]: key = (id(X), beg, end) if key not in cache: cache[key] = func(X[beg:end]) return cache[key] @wraps(func) - def wrapper_cache_auxiliary(X, beg, end, aux_var): + def wrapper_cache_auxiliary( + X: NDArray[np.float_], beg: int, end: int, aux_var: NDArray[np.float_] + ) -> NDArray[np.float_]: # Even if X is the same one, aux_var could be different key = (id(X), beg, end) if key not in cache: diff --git a/deepxde/icbc/initial_conditions.py b/deepxde/icbc/initial_conditions.py index a7ca57f2f..cc4e0bd6e 100644 --- a/deepxde/icbc/initial_conditions.py +++ b/deepxde/icbc/initial_conditions.py @@ -1,18 +1,30 @@ +from __future__ import annotations """Initial conditions.""" __all__ = ["IC"] +from typing import Any, Callable + import numpy as np +from numpy.typing import NDArray, ArrayLike from .boundary_conditions import npfunc_range_autocache from .. import backend as bkd from .. import utils +from ..geometry import Geometry +from ..types import Tensor, TensorOrTensors class IC: """Initial conditions: y([x, t0]) = func([x, t0]).""" - def __init__(self, geom, func, on_initial, component=0): + def __init__( + self, + geom: Geometry, + func: Callable[[NDArray[np.float_]], NDArray[np.float_]], + on_initial: Callable[[NDArray[Any], NDArray[Any]], NDArray[np.bool_]], + component: list[int] | int = 0, + ): self.geom = geom self.func = npfunc_range_autocache(utils.return_tensor(func)) self.on_initial = lambda x, on: np.array( @@ -20,13 +32,21 @@ def __init__(self, geom, func, on_initial, component=0): ) self.component = component - def filter(self, X): + def filter(self, X: NDArray[np.float_]) -> NDArray[np.bool_]: return X[self.on_initial(X, self.geom.on_initial(X))] - def collocation_points(self, X): + def collocation_points(self, X: NDArray[np.float_]) -> NDArray[np.float_]: return self.filter(X) - def error(self, X, inputs, outputs, beg, end, aux_var=None): + def error( + self, + X: NDArray[np.float_], + inputs: TensorOrTensors, + outputs: Tensor, + beg: int, + end: int, + aux_var: NDArray[np.float_] | None = None, + ) -> Tensor: values = self.func(X, beg, end, aux_var) if bkd.ndim(values) == 2 and bkd.shape(values)[1] != 1: raise RuntimeError( diff --git a/deepxde/types.py b/deepxde/types.py new file mode 100644 index 000000000..159cc4e5a --- /dev/null +++ b/deepxde/types.py @@ -0,0 +1,15 @@ +from __future__ import annotations +from typing import Sequence, TypeVar, Union + +# dtype from any backend +dtype = TypeVar("dtype") + +# NN from any backend (Using the `NN` from deepxde is recommended.) +NN = TypeVar("NN") + +# SparseTensor from any backend +SparseTensor = TypeVar("SparseTensor") + +# Tensor from any backend +Tensor = TypeVar("Tensor") +TensorOrTensors = Union[Tensor, Sequence[Tensor]] diff --git a/deepxde/utils/external.py b/deepxde/utils/external.py index bca1cc1d1..caa9bc0c0 100644 --- a/deepxde/utils/external.py +++ b/deepxde/utils/external.py @@ -1,3 +1,4 @@ +from __future__ import annotations """External utilities.""" import csv @@ -6,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np +from numpy.typing import NDArray import scipy.spatial.distance from mpl_toolkits.mplot3d import Axes3D from sklearn import preprocessing @@ -376,7 +378,7 @@ def dat_to_csv(dat_file_path, csv_file_path, columns): csv_writer.writerow(row) -def isclose(a, b): +def isclose(a: NDArray[np.float_], b: NDArray[np.float_]) -> NDArray[np.bool_]: """A modified version of `np.isclose` for DeepXDE. This function changes the value of `atol` due to the dtype of `a` and `b`. diff --git a/pyproject.toml b/pyproject.toml index a1f7c55bd..f455cf01f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,8 @@ exclude = ["docker", "docs*", "examples*"] [tool.setuptools_scm] write_to = "deepxde/_version.py" + +[mypy] +python_version = 3.8 +strict = false +ignore_missing_imports = true \ No newline at end of file