Skip to content

Commit dfb8e8b

Browse files
Merge pull request #137 from IBM/fix/global_dtype
Fix/global dtype
2 parents b5c5608 + f276fc3 commit dfb8e8b

16 files changed

+96
-25
lines changed

simulai/models/_pytorch_models/_autoencoder.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import torch
1919

20+
from simulai import ARRAY_DTYPE
2021
from simulai.regression import ConvolutionalNetwork, DenseNetwork, Linear
2122
from simulai.templates import (
2223
NetworkTemplate,
@@ -508,7 +509,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
508509
"""
509510

510511
if isinstance(input_data, np.ndarray):
511-
input_data = torch.from_numpy(input_data.astype("float32"))
512+
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))
512513

513514
input_data = input_data.to(self.device)
514515

@@ -995,7 +996,7 @@ def predict(
995996
996997
"""
997998
if isinstance(input_data, np.ndarray):
998-
input_data = torch.from_numpy(input_data.astype("float32"))
999+
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))
9991000

10001001
predictions = list()
10011002
latent = self.projection(input_data=input_data)
@@ -1694,7 +1695,7 @@ def project(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndar
16941695
>>> projected_data = autoencoder.project(input_data=input_data)
16951696
"""
16961697
if isinstance(input_data, np.ndarray):
1697-
input_data = torch.from_numpy(input_data.astype("float32"))
1698+
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))
16981699

16991700
input_data = input_data.to(self.device)
17001701

@@ -1725,7 +1726,7 @@ def reconstruct(
17251726
>>> reconstructed_data = autoencoder.reconstruct(input_data=input_data)
17261727
"""
17271728
if isinstance(input_data, np.ndarray):
1728-
input_data = torch.from_numpy(input_data.astype("float32"))
1729+
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))
17291730

17301731
input_data = input_data.to(self.device)
17311732

@@ -1754,7 +1755,7 @@ def eval(self, input_data: Union[np.ndarray, torch.Tensor] = None) -> np.ndarray
17541755
>>> reconstructed_data = autoencoder.eval(input_data=input_data)
17551756
"""
17561757
if isinstance(input_data, np.ndarray):
1757-
input_data = torch.from_numpy(input_data.astype("float32"))
1758+
input_data = torch.from_numpy(input_data.astype(ARRAY_DTYPE))
17581759

17591760
input_data = input_data.to(self.device)
17601761

tests/PINN/test_deep_operator_pinn.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import numpy as np
1818

19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
21+
1922
from simulai.optimization import Optimizer
2023
from simulai.residuals import SymbolicOperator
2124

tests/PINN/test_vanilla_pinn.py

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import matplotlib.pyplot as plt
1818
import numpy as np
1919

20+
from tests.config import configure_dtype
21+
torch = configure_dtype()
22+
2023
from simulai.optimization import Optimizer
2124
from simulai.residuals import SymbolicOperator
2225

tests/config.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# (C) Copyright IBM Corporation 2017, 2018, 2019
16+
# U.S. Government Users Restricted Rights: Use, duplication or disclosure restricted
17+
# by GSA ADP Schedule Contract with IBM Corp.
18+
#
19+
# Author: Joao Lucas S. Almeida <[email protected]>
20+
21+
import os
22+
import torch
23+
24+
def configure_dtype():
25+
26+
test_dtype_var = os.environ.get("TEST_DTYPE")
27+
28+
if test_dtype_var is not None:
29+
test_dtype = getattr(torch, test_dtype_var)
30+
else:
31+
test_dtype = torch.float32
32+
33+
torch.set_default_dtype(test_dtype)
34+
35+
print(f"Using dtype {test_dtype} in tests.")
36+
37+
return torch
38+
39+
40+

tests/metrics/test_mahalanobis.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from unittest import TestCase
2-
3-
import torch
2+
from tests.config import configure_dtype
3+
torch = configure_dtype()
44

55
from simulai.metrics import MahalanobisDistance
66

tests/metrics/test_pointwise.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from unittest import TestCase
1616

1717
import numpy as np
18-
import torch
18+
19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
1921

2022
from simulai.metrics import PointwiseError
2123

tests/network/test_conv_1d.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from unittest import TestCase
1717

1818
import numpy as np
19-
import torch
19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
21+
2022
from utils import configure_device
2123

24+
from simulai import ARRAY_DTYPE
2225
from simulai.file import SPFile
2326
from simulai.optimization import Optimizer
2427

@@ -34,8 +37,8 @@ def generate_data(
3437
input_data = np.random.rand(n_samples, n_inputs, vector_size)
3538
output_data = np.random.rand(n_samples, n_outputs)
3639

37-
return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
38-
output_data.astype("float32")
40+
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
41+
output_data.astype(ARRAY_DTYPE)
3942
)
4043

4144

tests/network/test_conv_2d.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from unittest import TestCase
1717

1818
import numpy as np
19-
import torch
19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
21+
2022
from utils import configure_device
2123

24+
from simulai import ARRAY_DTYPE
2225
from simulai.file import SPFile
2326
from simulai.optimization import Optimizer
2427

@@ -34,8 +37,8 @@ def generate_data(
3437
input_data = np.random.rand(n_samples, n_inputs, *image_size)
3538
output_data = np.random.rand(n_samples, n_outputs)
3639

37-
return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
38-
output_data.astype("float32")
40+
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
41+
output_data.astype(ARRAY_DTYPE)
3942
)
4043

4144

tests/network/test_deeponet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from unittest import TestCase
1616

1717
import numpy as np
18-
import torch
18+
from tests.config import configure_dtype
19+
torch = configure_dtype()
1920
from utils import configure_device
2021

2122
DEVICE = configure_device()

tests/network/test_flexible_deeponet.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from unittest import TestCase
1616

1717
import numpy as np
18-
import torch
18+
from tests.config import configure_dtype
19+
torch = configure_dtype()
20+
1921
from utils import configure_device
2022

2123
DEVICE = configure_device()

tests/network/test_improved_deeponet.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from unittest import TestCase
1616

1717
import numpy as np
18-
import torch
18+
from tests.config import configure_dtype
19+
torch = configure_dtype()
20+
1921
from utils import configure_device
2022

2123
DEVICE = configure_device()

tests/network/test_residual_cnn.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
import matplotlib.pyplot as plt
2121
import numpy as np
22-
import torch
22+
23+
from tests.config import configure_dtype
24+
torch = configure_dtype()
2325

2426
torch.autograd.set_detect_anomaly(True)
2527

tests/network/test_template_gen.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from unittest import TestCase
1717

1818
import numpy as np
19-
import torch
19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
21+
2022
from utils import configure_device
2123

2224
DEVICE = configure_device()
2325

26+
from simulai import ARRAY_DTYPE
2427

2528
def generate_data_2d(
2629
n_samples: int = None,
@@ -31,8 +34,8 @@ def generate_data_2d(
3134
input_data = np.random.rand(n_samples, n_inputs, *image_size)
3235
output_data = np.random.rand(n_samples, n_outputs)
3336

34-
return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
35-
output_data.astype("float32")
37+
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
38+
output_data.astype(ARRAY_DTYPE)
3639
)
3740

3841

@@ -45,8 +48,8 @@ def generate_data_1d(
4548
input_data = np.random.rand(n_samples, n_inputs, vector_size)
4649
output_data = np.random.rand(n_samples, n_outputs)
4750

48-
return torch.from_numpy(input_data.astype("float32")), torch.from_numpy(
49-
output_data.astype("float32")
51+
return torch.from_numpy(input_data.astype(ARRAY_DTYPE)), torch.from_numpy(
52+
output_data.astype(ARRAY_DTYPE)
5053
)
5154

5255

tests/network/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ def configure_device():
77
if not simulai_network_gpu:
88
device = "cpu"
99
else:
10-
import torch
10+
from tests.config import configure_dtype
11+
torch = configure_dtype()
1112

1213
if not torch.cuda.is_available():
1314
raise Exception("There is no gpu available to execute the tests.")

tests/residuals/test_symbolicoperator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from unittest import TestCase
1717

1818
import numpy as np
19-
import torch
19+
from tests.config import configure_dtype
20+
torch = configure_dtype()
21+
2022

2123
from simulai.residuals import SymbolicOperator
2224

tests/rom/test_cnn_autoencoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import numpy as np
55

6+
from tests.config import configure_dtype
7+
torch = configure_dtype()
8+
69
from simulai.file import SPFile
710
from simulai.optimization import Optimizer
811

0 commit comments

Comments
 (0)