Skip to content

Commit fc978ad

Browse files
Bug fix: use bkd.stack instead of bkd.concat in convert_to_array function (#1101)
1 parent 9b8ddaf commit fc978ad

File tree

5 files changed

+29
-1
lines changed

5 files changed

+29
-1
lines changed

deepxde/backend/backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,18 @@ def concat(values, axis):
204204
"""
205205

206206

207+
def stack(values, axis):
208+
"""Returns the stack of the input tensors along the given dim.
209+
210+
Args:
211+
values (list or tuple of Tensor). The input tensors in list or tuple.
212+
axis (int). The stacking dim.
213+
214+
Returns:
215+
Tensor: Stacked tensor.
216+
"""
217+
218+
207219
def expand_dims(tensor, axis):
208220
"""Expand dim for tensor along given axis.
209221

deepxde/backend/paddle/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def concat(values, axis):
9090
return paddle.concat(values, axis=axis)
9191

9292

93+
def stack(values, axis):
94+
return paddle.stack(values, axis=axis)
95+
96+
9397
def expand_dims(tensor, axis):
9498
return paddle.unsqueeze(tensor, axis=axis)
9599

deepxde/backend/pytorch/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def concat(values, axis):
9898
return torch.cat(values, axis)
9999

100100

101+
def stack(values, axis):
102+
return torch.stack(values, axis)
103+
104+
101105
def expand_dims(tensor, axis):
102106
return torch.unsqueeze(tensor, axis)
103107

deepxde/backend/tensorflow_compat_v1/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def concat(values, axis):
108108
return tf.concat(values, axis)
109109

110110

111+
def stack(values, axis):
112+
return tf.stack(values, axis)
113+
114+
111115
def expand_dims(tensor, axis):
112116
return tf.expand_dims(tensor, axis)
113117

deepxde/utils/array_ops_compat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ def istensorlist(values):
1414
def convert_to_array(value):
1515
"""Convert a list of numpy arrays or tensors to a numpy array or a tensor."""
1616
if istensorlist(value):
17-
return bkd.concat(value, axis=0)
17+
# TODO: use concat instead of stack as paddle now use shape [1,]
18+
# for 0-D tensor, it will be solved soon.
19+
if bkd.backend_name == "paddle":
20+
return bkd.concat(value, axis=0)
21+
return bkd.stack(value, axis=0)
1822
value = np.array(value)
1923
if value.dtype != config.real(np):
2024
return value.astype(config.real(np))

0 commit comments

Comments
 (0)