@@ -1159,20 +1159,31 @@ def _np_z_up_to_R(z, up=None, out=None):
11591159 if z_norm_i > gs .EPS :
11601160 z /= z_norm_i
11611161 else :
1162- z [:] = 0.0 , 1.0 , 0.0
1162+ if up is None or abs (up [i ][1 ]) < 0.5 :
1163+ z [:] = 0.0 , 1.0 , 0.0
1164+ else :
1165+ z [:] = 0.0 , 0.0 , 1.0
11631166
11641167 if up is not None :
11651168 x [:] = np .cross (up [i ], z )
11661169 else :
1167- x [0 ] = z [1 ]
1168- x [1 ] = - z [0 ]
1169- x [2 ] = 0.0
1170+ if abs (z [2 ]) < 1.0 - gs .EPS :
1171+ # up = (0.0, 0.0, 1.0)
1172+ x [0 ] = z [1 ]
1173+ x [1 ] = - z [0 ]
1174+ x [2 ] = 0.0
1175+ else :
1176+ # up = (0.0, 1.0, 0.0)
1177+ x [0 ] = z [2 ]
1178+ x [1 ] = 0.0
1179+ x [2 ] = - z [0 ]
11701180
11711181 x_norm = np .linalg .norm (x )
11721182 if x_norm > gs .EPS :
11731183 x /= x_norm
11741184 y [:] = np .cross (z , x )
11751185 else :
1186+ # This would only occurs if the user specified non-zero colinear z and up
11761187 R [:] = np .eye (3 , dtype = R .dtype )
11771188
11781189 return out_
@@ -1199,16 +1210,22 @@ def _tc_z_up_to_R(z, up=None, out=None):
11991210 # Handle zero norm cases
12001211 zero_mask = z_norm [..., 0 ] < gs .EPS
12011212 if zero_mask .any ():
1202- z [zero_mask ] = torch .tensor ((0.0 , 1.0 , 0.0 ), device = z .device , dtype = z .dtype )
1213+ if up is None :
1214+ z [zero_mask ] = torch .tensor ((0.0 , 1.0 , 0.0 ), device = z .device , dtype = z .dtype )
1215+ else :
1216+ up_mask = up [..., 1 ].abs () < 0.5
1217+ z [zero_mask & up_mask ] = torch .tensor ((0.0 , 1.0 , 0.0 ), device = z .device , dtype = z .dtype )
1218+ z [zero_mask & ~ up_mask ] = torch .tensor ((0.0 , 0.0 , 1.0 ), device = z .device , dtype = z .dtype )
12031219
12041220 # Compute x vectors (first column)
12051221 if up is not None :
12061222 x [:] = torch .cross (up , z , dim = - 1 )
12071223 else :
1208- # Default up vector case
1209- x [..., 0 ] = z [..., 1 ]
1210- x [..., 1 ] = - z [..., 0 ]
1211- x [..., 2 ] = 0.0
1224+ up_mask = z [..., 2 ].abs () < 1.0 - gs .EPS
1225+ _zero = torch .tensor (0.0 , device = z .device , dtype = z .dtype )
1226+ torch .where (up_mask , z [..., 1 ], z [..., 2 ], out = x [..., 0 ])
1227+ torch .where (up_mask , - z [..., 0 ], _zero , out = x [..., 1 ])
1228+ torch .where (up_mask , _zero , - z [..., 0 ], out = x [..., 2 ])
12121229
12131230 # Normalize x vectors
12141231 x_norm = torch .norm (x , dim = - 1 , keepdim = True )
0 commit comments