|
17 | 17 | import numpy as np
|
18 | 18 | import torch
|
19 | 19 |
|
| 20 | +from simulai import ARRAY_DTYPE |
20 | 21 | from simulai.regression import ConvolutionalNetwork, DenseNetwork, Linear
|
21 | 22 | from simulai.templates import (
|
22 | 23 | NetworkTemplate,
|
@@ -508,7 +509,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
|
508 | 509 | """
|
509 | 510 |
|
510 | 511 | if isinstance(input_data, np.ndarray):
|
511 |
| - input_data = torch.from_numpy(input_data.astype("float32")) |
| 512 | + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) |
512 | 513 |
|
513 | 514 | input_data = input_data.to(self.device)
|
514 | 515 |
|
@@ -995,7 +996,7 @@ def predict(
|
995 | 996 |
|
996 | 997 | """
|
997 | 998 | if isinstance(input_data, np.ndarray):
|
998 |
| - input_data = torch.from_numpy(input_data.astype("float32")) |
| 999 | + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) |
999 | 1000 |
|
1000 | 1001 | predictions = list()
|
1001 | 1002 | latent = self.projection(input_data=input_data)
|
@@ -1694,7 +1695,7 @@ def project(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndar
|
1694 | 1695 | >>> projected_data = autoencoder.project(input_data=input_data)
|
1695 | 1696 | """
|
1696 | 1697 | if isinstance(input_data, np.ndarray):
|
1697 |
| - input_data = torch.from_numpy(input_data.astype("float32")) |
| 1698 | + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) |
1698 | 1699 |
|
1699 | 1700 | input_data = input_data.to(self.device)
|
1700 | 1701 |
|
@@ -1725,7 +1726,7 @@ def reconstruct(
|
1725 | 1726 | >>> reconstructed_data = autoencoder.reconstruct(input_data=input_data)
|
1726 | 1727 | """
|
1727 | 1728 | if isinstance(input_data, np.ndarray):
|
1728 |
| - input_data = torch.from_numpy(input_data.astype("float32")) |
| 1729 | + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) |
1729 | 1730 |
|
1730 | 1731 | input_data = input_data.to(self.device)
|
1731 | 1732 |
|
@@ -1754,7 +1755,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
|
1754 | 1755 | >>> reconstructed_data = autoencoder.eval(input_data=input_data)
|
1755 | 1756 | """
|
1756 | 1757 | if isinstance(input_data, np.ndarray):
|
1757 |
| - input_data = torch.from_numpy(input_data.astype("float32")) |
| 1758 | + input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE)) |
1758 | 1759 |
|
1759 | 1760 | input_data = input_data.to(self.device)
|
1760 | 1761 |
|
|
0 commit comments