File tree Expand file tree Collapse file tree 5 files changed +29
-1
lines changed Expand file tree Collapse file tree 5 files changed +29
-1
lines changed Original file line number Diff line number Diff line change @@ -204,6 +204,18 @@ def concat(values, axis):
204204 """
205205
206206
207+ def stack (values , axis ):
208+ """Returns the stack of the input tensors along the given dim.
209+
210+ Args:
211+ values (list or tuple of Tensor). The input tensors in list or tuple.
212+ axis (int). The stacking dim.
213+
214+ Returns:
215+ Tensor: Stacked tensor.
216+ """
217+
218+
207219def expand_dims (tensor , axis ):
208220 """Expand dim for tensor along given axis.
209221
Original file line number Diff line number Diff line change @@ -90,6 +90,10 @@ def concat(values, axis):
9090 return paddle .concat (values , axis = axis )
9191
9292
93+ def stack (values , axis ):
94+ return paddle .stack (values , axis = axis )
95+
96+
9397def expand_dims (tensor , axis ):
9498 return paddle .unsqueeze (tensor , axis = axis )
9599
Original file line number Diff line number Diff line change @@ -98,6 +98,10 @@ def concat(values, axis):
9898 return torch .cat (values , axis )
9999
100100
101+ def stack (values , axis ):
102+ return torch .stack (values , axis )
103+
104+
101105def expand_dims (tensor , axis ):
102106 return torch .unsqueeze (tensor , axis )
103107
Original file line number Diff line number Diff line change @@ -108,6 +108,10 @@ def concat(values, axis):
108108 return tf .concat (values , axis )
109109
110110
111+ def stack (values , axis ):
112+ return tf .stack (values , axis )
113+
114+
111115def expand_dims (tensor , axis ):
112116 return tf .expand_dims (tensor , axis )
113117
Original file line number Diff line number Diff line change @@ -14,7 +14,11 @@ def istensorlist(values):
1414def convert_to_array (value ):
1515 """Convert a list of numpy arrays or tensors to a numpy array or a tensor."""
1616 if istensorlist (value ):
17- return bkd .concat (value , axis = 0 )
17+ # TODO: use concat instead of stack as paddle now use shape [1,]
18+ # for 0-D tensor, it will be solved soon.
19+ if bkd .backend_name == "paddle" :
20+ return bkd .concat (value , axis = 0 )
21+ return bkd .stack (value , axis = 0 )
1822 value = np .array (value )
1923 if value .dtype != config .real (np ):
2024 return value .astype (config .real (np ))
You can’t perform that action at this time.
0 commit comments