Skip to content

Commit afbcc64

Browse files
authored
Merge pull request #69 from NatLabRockies/bnb/obs_2d_fix
fix: shape mismatch between mask and hr data in sup3r obs model layer…
2 parents 66d756c + d70b04b commit afbcc64

4 files changed

Lines changed: 11832 additions & 1 deletion

File tree

phygnn/utilities/tf_utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def idw_fill(x, low_mem=True):
132132
Mask of the input tensor where 1 is not NaN and 0 is NaN.
133133
"""
134134
rank = len(x.shape)
135-
assert rank in [4, 5], 'Input tensor must be 4D or 5D'
135+
assert rank in {4, 5}, 'Input tensor must be 4D or 5D'
136136
x = tf.expand_dims(x, axis=-2) if rank == 4 else x
137137
mask = tf.math.logical_not(tf.math.is_nan(x))
138138
B, H, W, D, C = x.shape
@@ -185,4 +185,5 @@ def idw_fill(x, low_mem=True):
185185
filled = tf.stack(filled, axis=0)
186186
filled = tf.reshape(filled, [B, H, W, D, C])
187187
filled = tf.squeeze(filled, axis=-2) if rank == 4 else filled
188+
mask = tf.squeeze(mask, axis=-2) if rank == 4 else mask
188189
return tf.cast(filled, x.dtype), tf.cast(mask, x.dtype)

0 commit comments

Comments
 (0)