@@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes):
276276 return image [0 ], zs [0 ]
277277
278278
279+ def render_attribute_normal_many (self , poses , vertices , faces , ranges , attributes ):
280+ """
281+ Render many scenes to an image by rasterizing and then interpolating attributes.
282+
283+ Parameters:
284+ poses: float array, shape (num_scenes, num_objectsß, 4, 4)
285+ Object pose matrix.
286+ vertices: float array, shape (num_vertices, 3)
287+ Vertex position matrix.
288+ faces: int array, shape (num_triangles, 3)
289+ Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix.
290+ ranges: int array, shape (num_objects, 2)
291+ Ranges matrix with the 2 elements specify start indices and counts into faces.
292+ attributes: float array, shape (num_vertices, num_attributes)
293+ Attributes corresponding to the vertices
294+
295+ Outputs:
296+ image: float array, shape (num_scenes, height, width, num_attributes)
297+ At each pixel the value is the barycentric interpolation of the attributes corresponding to the
298+ 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
299+ any triangle the value at that pixel will be 0s.
300+ zs: float array, shape (num_scenes, height, width)
301+ Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
302+ norm_im: approximate surface normal image (num_scenes, height, width, 3)
303+ """
304+ uvs , object_ids , triangle_ids , zs = self .rasterize_many (
305+ poses , vertices , faces , ranges
306+ )
307+ mask = object_ids > 0
308+
309+ interpolated_values = self .interpolate_many (
310+ attributes , uvs , triangle_ids , faces
311+ )
312+ image = interpolated_values * mask [..., None ]
313+
314+ def apply_pose (pose , points ):
315+ return pose .apply (points )
316+
317+ pose_apply_map = jax .vmap (apply_pose , (0 ,None ))
318+ new_vertices = pose_apply_map (poses , vertices [faces ])
319+
320+ def normal_vec (x ,y ,z ):
321+ vec = jnp .cross (y - x , z - x )
322+ norm_vec = vec / jnp .linalg .norm (vec )
323+ return norm_vec
324+
325+ normal_vec_vmap = jax .vmap (jax .vmap (normal_vec , (0 ,0 ,0 )))
326+ nvecs = normal_vec_vmap (new_vertices [...,0 ,:], new_vertices [...,1 ,:], new_vertices [...,2 ,:])
327+ norm_vecs = jnp .concatenate ((jnp .zeros ((len (nvecs ),1 ,3 )), nvecs ),axis = 1 )
328+
329+ def indexer (transformed_normals , triangle_ids ):
330+ return transformed_normals [triangle_ids ]
331+
332+ index_map = jax .vmap (indexer , (0 ,0 ))
333+ norm_im = index_map (norm_vecs , triangle_ids )
334+
335+ return image , zs , norm_im
336+
337+ def render_attribute_normal (self , pose , vertices , faces , ranges , attributes ):
338+ """
339+ Render a single scenes to an image by rasterizing and then interpolating attributes.
340+
341+ Parameters:
342+ poses: float array, shape (num_objects, 4, 4)
343+ Object pose matrix.
344+ vertices: float array, shape (num_vertices, 3)
345+ Vertex position matrix.
346+ faces: int array, shape (num_triangles, 3)
347+ Faces Triangle matrix. The integers correspond to rows in the vertices matrix.
348+ ranges: int array, shape (num_objects, 2)
349+ Ranges matrix with the 2 elements specify start indices and counts into faces.
350+ attributes: float array, shape (num_vertices, num_attributes)
351+ Attributes corresponding to the vertices
352+
353+ Outputs:
354+ image: float array, shape (height, width, num_attributes)
355+ At each pixel the value is the barycentric interpolation of the attributes corresponding to the
356+ 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
357+ any triangle the value at that pixel will be 0s.
358+ zs: float array, shape (height, width)
359+ Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
360+ norm_im: approximate surface normal image (height, width, 3)
361+ """
362+ image , zs , norm_im = self .render_attribute_normal_many (
363+ pose [None , ...], vertices , faces , ranges , attributes
364+ )
365+ return image [0 ], zs [0 ], norm_im [0 ]
366+
367+
279368# XLA array layout in memory
280369def default_layouts (* shapes ):
281370 return [range (len (shape ) - 1 , - 1 , - 1 ) for shape in shapes ]
0 commit comments