@@ -581,39 +581,55 @@ def ti_to_python(
581581 # Get metadata
582582 ti_data_meta = _get_ti_metadata (value )
583583
584- # Extract value as a whole.
585- # Note that this is usually much faster than using a custom kernel to extract a slice.
586- # The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
587- is_metal = gs .device .type == "mps"
588- out_dtype = _to_torch_type_fast (ti_data_meta .dtype ) if to_torch else _to_numpy_type_fast (ti_data_meta .dtype )
589- data_type = type (value )
590- if issubclass (data_type , (ti .ScalarField , ti .ScalarNdarray )):
591- if to_torch :
592- out = torch .zeros (ti_data_meta .shape , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
593- else :
594- out = np .zeros (ti_data_meta .shape , dtype = out_dtype )
595- TO_EXT_ARR_FAST_MAP [data_type ](value , out )
596- elif issubclass (data_type , ti .MatrixField ):
597- as_vector = value .m == 1
598- shape_ext = (value .n ,) if as_vector else (value .n , value .m )
599- if to_torch :
600- out = torch .empty (ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
601- else :
602- out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
603- TO_EXT_ARR_FAST_MAP [data_type ](value , out , as_vector )
604- elif issubclass (data_type , (ti .VectorNdarray , ti .MatrixNdarray )):
605- layout_is_aos = 1
606- as_vector = issubclass (data_type , ti .VectorNdarray )
607- shape_ext = (value .n ,) if as_vector else (value .n , value .m )
608- if to_torch :
609- out = torch .empty (ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
584+ use_zerocopy = gs .use_zerocopy
585+ if gs .use_zerocopy :
586+ # Leverage zero-copy if enabled
587+ try :
588+ out = value ._tc
589+ if not to_torch :
590+ out = tensor_to_array (out )
591+ except AttributeError :
592+ gs .logger .debug ("Zezo-copy memory sharing not available for this tensor. Falling back to copy mode." )
593+ use_zerocopy = False
594+
595+ if not use_zerocopy :
596+ # Extract value as a whole.
597+ # Note that this is usually much faster than using a custom kernel to extract a slice.
598+ # The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
599+ is_metal = gs .device .type == "mps"
600+ out_dtype = _to_torch_type_fast (ti_data_meta .dtype ) if to_torch else _to_numpy_type_fast (ti_data_meta .dtype )
601+ data_type = type (value )
602+ if issubclass (data_type , (ti .ScalarField , ti .ScalarNdarray )):
603+ if to_torch :
604+ out = torch .zeros (ti_data_meta .shape , dtype = out_dtype , device = "cpu" if is_metal else gs .device )
605+ else :
606+ out = np .zeros (ti_data_meta .shape , dtype = out_dtype )
607+ TO_EXT_ARR_FAST_MAP [data_type ](value , out )
608+ elif issubclass (data_type , ti .MatrixField ):
609+ as_vector = value .m == 1
610+ shape_ext = (value .n ,) if as_vector else (value .n , value .m )
611+ if to_torch :
612+ out = torch .empty (
613+ ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device
614+ )
615+ else :
616+ out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
617+ TO_EXT_ARR_FAST_MAP [data_type ](value , out , as_vector )
618+ elif issubclass (data_type , (ti .VectorNdarray , ti .MatrixNdarray )):
619+ layout_is_aos = 1
620+ as_vector = issubclass (data_type , ti .VectorNdarray )
621+ shape_ext = (value .n ,) if as_vector else (value .n , value .m )
622+ if to_torch :
623+ out = torch .empty (
624+ ti_data_meta .shape + shape_ext , dtype = out_dtype , device = "cpu" if is_metal else gs .device
625+ )
626+ else :
627+ out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
628+ TO_EXT_ARR_FAST_MAP [ti .MatrixNdarray ](value , out , layout_is_aos , as_vector )
610629 else :
611- out = np .zeros (ti_data_meta .shape + shape_ext , dtype = out_dtype )
612- TO_EXT_ARR_FAST_MAP [ti .MatrixNdarray ](value , out , layout_is_aos , as_vector )
613- else :
614- gs .raise_exception (f"Unsupported type '{ type (value )} '." )
615- if to_torch and is_metal :
616- out = out .to (gs .device )
630+ gs .raise_exception (f"Unsupported type '{ type (value )} '." )
631+ if to_torch and is_metal :
632+ out = out .to (gs .device )
617633
618634 # Transpose if necessary and requested.
619635 # Note that it is worth transposing here before slicing, as it preserve row-major memory alignment in case of
@@ -645,7 +661,7 @@ def extract_slice(
645661 """
646662 # Make sure that the user-arguments are valid if requested
647663 if not unsafe :
648- if value .ndim == 1 and col_mask is not None :
664+ if col_mask is not None and value .ndim == 1 :
649665 gs .raise_exception ("Cannot specify column mask for 1D tensor." )
650666 for i , mask in enumerate ((row_mask , col_mask )):
651667 if mask is None or isinstance (mask , slice ):
@@ -739,6 +755,8 @@ def ti_to_torch(
739755 unsafe (bool, optional): Whether to skip validity check of the masks.
740756 """
741757 tensor = ti_to_python (value , transpose , to_torch = True )
758+ if row_mask is None and col_mask is None :
759+ return tensor
742760
743761 ti_data_meta = _get_ti_metadata (value )
744762 if len (ti_data_meta .shape ) < 2 :
@@ -771,6 +789,8 @@ def ti_to_numpy(
771789 unsafe (bool, optional): Whether to skip validity check of the masks.
772790 """
773791 tensor = ti_to_python (value , transpose , to_torch = False )
792+ if row_mask is None and col_mask is None :
793+ return tensor
774794
775795 ti_data_meta = _get_ti_metadata (value )
776796 if len (ti_data_meta .shape ) < 2 :
0 commit comments