Skip to content

Commit

Permalink
doc v0.5 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 5, 2021
1 parent f8ae4ad commit 5f36307
Show file tree
Hide file tree
Showing 16 changed files with 333 additions and 219 deletions.
4 changes: 2 additions & 2 deletions MinkowskiEngine/MinkowskiCoordinateManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def set_gpu_allocator(backend: GPUMemoryAllocatorType):
than allocating GPU directly using raw CUDA calls.
By default, the Minkowski Engine uses
:attr:`ME.MemoryManagerBackend.PYTORCH` for memory management.
:attr:`ME.GPUMemoryAllocatorType.PYTORCH` for memory management.
Example::
Expand All @@ -84,7 +84,7 @@ def set_gpu_allocator(backend: GPUMemoryAllocatorType):
"""
assert isinstance(
backend, GPUMemoryAllocatorType
), f"Input must be an instance of MemoryManagerBackend not {backend}"
), f"Input must be an instance of GPUMemoryAllocatorType not {backend}"
global _allocator_type
_allocator_type = backend

Expand Down
11 changes: 10 additions & 1 deletion MinkowskiEngine/MinkowskiTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,26 @@ def set_sparse_tensor_operation_mode(operation_mode: SparseTensorOperationMode):
_sparse_tensor_operation_mode = operation_mode


def sparse_tensor_operation_mode():
def sparse_tensor_operation_mode() -> SparseTensorOperationMode:
r"""Return the current sparse tensor operation mode.
"""
global _sparse_tensor_operation_mode
return copy.deepcopy(_sparse_tensor_operation_mode)


def global_coordinate_manager():
r"""Return the current global coordinate manager
"""
global _global_coordinate_manager
return _global_coordinate_manager


def set_global_coordinate_manager(coordinate_manager):
r"""Set the global coordinate manager.
:attr:`MinkowskiEngine.CoordinateManager` The coordinate manager which will
be set to the global coordinate manager.
"""
global _global_coordinate_manager
_global_coordinate_manager = coordinate_manager

Expand Down
2 changes: 2 additions & 0 deletions MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
SparseTensorQuantizationMode,
set_sparse_tensor_operation_mode,
sparse_tensor_operation_mode,
global_coordinate_manager,
set_global_coordinate_manager,
clear_global_coordinate_manager,
)

Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
# of the code.
from .quantization import sparse_quantize, ravel_hash_vec, fnv_hash_vec, unique_coordinate_map
from .collation import SparseCollation, batched_coordinates, sparse_collate, batch_sparse_collate
from .coords import get_coords_map
# from .coords import get_coords_map
from .init import kaiming_normal_
2 changes: 1 addition & 1 deletion docs/benchmark.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Benchmark

We report the feed forward and backward pass time of a convolution layer, and a small U-network. Note that the kernel map can be reused for other layers with the same tensor-stride, stride, and kernel offsets, thus the time reported in this page can be amortized across all layers used in a large nueral network.
We report the feed forward and backward pass time of a convolution layer, and a small U-network for v0.4.3. Note that the kernel map can be reused for other layers with the same tensor-stride, stride, and kernel offsets, thus the time reported in this page can be amortized across all layers used in a large nueral network.

We use a Titan X for the experiments.

Expand Down
20 changes: 10 additions & 10 deletions docs/coords.rst
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
Coordinate Management
=====================

CoordsKey
---------
CoordinateMapKey
----------------

.. autoclass:: MinkowskiEngine.CoordsKey
.. autoclass:: MinkowskiEngine.CoordinateMapKey
:members:
:undoc-members:
:exclude-members: __repr__

.. automethod:: __init__


CoordsManager
-------------
CoordinateManager
-----------------

.. autoclass:: MinkowskiEngine.CoordsManager
.. autoclass:: MinkowskiEngine.CoordinateManager
:members:
:undoc-members:
:exclude-members: __repr__

.. automethod:: __init__


Coordinate GPU Memory Manager
-----------------------------
GPU Memory Allocator
--------------------

.. autoclass:: MinkowskiEngine.MemoryManagerBackend
.. autoclass:: MinkowskiEngine.GPUMemoryAllocatorType
:members:

.. autofunction:: MinkowskiEngine.MinkowskiCoords.set_memory_manager_backend
.. autofunction:: MinkowskiEngine.set_gpu_allocator
4 changes: 2 additions & 2 deletions docs/demo/interop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ a min-batch.
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
bias=False,
dimension=D), ME.MinkowskiBatchNorm(64), ME.MinkowskiReLU(),
ME.MinkowskiConvolution(
in_channels=64,
Expand Down Expand Up @@ -66,7 +66,7 @@ accessing the features of the sparse tensor
# Get new data
coords, feat, label = data_loader()
input = ME.SparseTensor(feat, coords=coords).to(device)
input = ME.SparseTensor(features=feat, coordinates=coords, device=device)
label = label.to(device)
# Forward
Expand Down
Loading

0 comments on commit 5f36307

Please sign in to comment.