Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/include/segmentation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
include "xtensor/"

namespace libneutorch{

template <class T, xt::layout_type L>
auto label2affinity(pytensor<T, 3, L> label){

}

} // namespace libneutorch
20 changes: 20 additions & 0 deletions cpp/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <pybind11/pybind11.h>


PYBIND11_MODULE(libneutorch, m) {
m.doc() = R"pbdoc(
libneutorch
-----------------------
.. currentmodule:: libneutorch
.. autosummary::
:toctree: _generate
warp3d
)pbdoc";

m.def("warp3d", &warp3d, R"pbdoc(
Warp 3d image

used for patch augmentation.
)pbdoc");

}
1 change: 1 addition & 0 deletions neutorch/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.2"
16 changes: 15 additions & 1 deletion neutorch/dataset/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, image: np.ndarray, label: np.ndarray,
assert image.shape == label.shape
self.image = image
self.label = label
self.target = None
self.delayed_shrink_size = delayed_shrink_size

def accumulate_delayed_shrink_size(self, shrink_size: tuple):
Expand Down Expand Up @@ -61,4 +62,17 @@ def shape(self):
@property
@lru_cache
def center(self):
return tuple(ps // 2 for ps in self.shape)
return tuple(ps // 2 for ps in self.shape)

@property
def target(self):
if self.target is None:
assert np.issubdtype(self.label.dtype, np.floating)
return self.label
else:
return self.target

@property
@lru_cache
def affinity_map(self):

4 changes: 4 additions & 0 deletions neutorch/dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,7 @@ def transform(self, patch: Patch):
strength=random.randint(1, self.max_strength),
radius = (patch.shape[-1] + patch.shape[-2]) // 4,
)

class LabelAsTarget()

class Label2Affinity
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ torchvision
toml
h5py
tensorboard
pybind11
scikit-image
opencv-python
scikit-image
45 changes: 44 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,56 @@
#!/usr/bin/env python
import os
import re

from setuptools import setup, find_packages

from pybind11.setup_helpers import Pybind11Extension, build_ext


PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__))

with open(os.path.join(PACKAGE_DIR, 'requirements.txt')) as f:
requirements = f.read().splitlines()
requirements = [l for l in requirements if not l.startswith('#')]

with open("README.md", "r") as fh:
long_description = fh.read()

VERSIONFILE = os.path.join(PACKAGE_DIR, "neutorch/__version__.py")
verstrline = open(VERSIONFILE, "rt").read()
VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]"
mo = re.search(VSRE, verstrline, re.M)
if mo:
version = mo.group(1)
else:
raise RuntimeError("Unable to find version string in %s." %
(VERSIONFILE, ))


ext_modules = [
Pybind11Extension("libneutorch",
["cpp/main.cpp"],
# Example: passing in the version to the compiled code
define_macros = [('VERSION_INFO', version)],
),
]

setup(
name='neutorch',
version='0.0.1',
version=version,
description='Deep Learning for brain connectomics using PyTorch',
long_description=long_description,
long_description_content_type="text/markdown",
author='Jingpeng Wu',
author_email='[email protected]',
url='https://github.com/brain-connectome/neutorch',
packages=find_packages(exclude=['bin']),
cmdclass={"build_ext": build_ext},
ext_modules=ext_modules,
install_requires=requirements,
tests_require=[
'pytest',
],
entry_points='''
[console_scripts]
neutrain-tbar=neutorch.cli.train_tbar:train
Expand All @@ -25,4 +66,6 @@
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
],
python_requires='>=3',
zip_safe=False,
)
4 changes: 0 additions & 4 deletions tests/train.py

This file was deleted.