diff --git a/cpp/include/segmentation.hpp b/cpp/include/segmentation.hpp new file mode 100644 index 0000000..00d901c --- /dev/null +++ b/cpp/include/segmentation.hpp @@ -0,0 +1,11 @@ +#pragma once +include "xtensor/" + +namespace libneutorch{ + +template +auto label2affinity(pytensor label){ + +} + +} // namespace libneutorch \ No newline at end of file diff --git a/cpp/main.cpp b/cpp/main.cpp new file mode 100644 index 0000000..02fbea7 --- /dev/null +++ b/cpp/main.cpp @@ -0,0 +1,20 @@ +#include + + +PYBIND11_MODULE(libneutorch, m) { + m.doc() = R"pbdoc( + libneutorch + ----------------------- + .. currentmodule:: libneutorch + .. autosummary:: + :toctree: _generate + warp3d + )pbdoc"; + + m.def("warp3d", &warp3d, R"pbdoc( + Warp 3d image + + used for patch augmentation. + )pbdoc"); + +} \ No newline at end of file diff --git a/neutorch/__version__.py b/neutorch/__version__.py new file mode 100644 index 0000000..3b93d0b --- /dev/null +++ b/neutorch/__version__.py @@ -0,0 +1 @@ +__version__ = "0.0.2" diff --git a/neutorch/dataset/patch.py b/neutorch/dataset/patch.py index 77e6e64..3a68e05 100644 --- a/neutorch/dataset/patch.py +++ b/neutorch/dataset/patch.py @@ -19,6 +19,7 @@ def __init__(self, image: np.ndarray, label: np.ndarray, assert image.shape == label.shape self.image = image self.label = label + self.target = None self.delayed_shrink_size = delayed_shrink_size def accumulate_delayed_shrink_size(self, shrink_size: tuple): @@ -61,4 +62,17 @@ def shape(self): @property @lru_cache def center(self): - return tuple(ps // 2 for ps in self.shape) \ No newline at end of file + return tuple(ps // 2 for ps in self.shape) + + @property + def target(self): + if self.target is None: + assert np.issubdtype(self.label.dtype, np.floating) + return self.label + else: + return self.target + + @property + @lru_cache + def affinity_map(self): + \ No newline at end of file diff --git a/neutorch/dataset/transform.py b/neutorch/dataset/transform.py index dec63c3..2dad369 100644 --- a/neutorch/dataset/transform.py +++ b/neutorch/dataset/transform.py @@ -536,3 +536,7 @@ def transform(self, patch: Patch): strength=random.randint(1, self.max_strength), radius = (patch.shape[-1] + patch.shape[-2]) // 4, ) + +class LabelAsTarget() + +class Label2Affinity \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 091e976..7e49152 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ torchvision toml h5py tensorboard +pybind11 +scikit-image opencv-python -scikit-image \ No newline at end of file diff --git a/setup.py b/setup.py index a2411ac..7cfd179 100755 --- a/setup.py +++ b/setup.py @@ -1,15 +1,56 @@ #!/usr/bin/env python +import os +import re from setuptools import setup, find_packages +from pybind11.setup_helpers import Pybind11Extension, build_ext + + +PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__)) + +with open(os.path.join(PACKAGE_DIR, 'requirements.txt')) as f: + requirements = f.read().splitlines() + requirements = [l for l in requirements if not l.startswith('#')] + +with open("README.md", "r") as fh: + long_description = fh.read() + +VERSIONFILE = os.path.join(PACKAGE_DIR, "neutorch/__version__.py") +verstrline = open(VERSIONFILE, "rt").read() +VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" +mo = re.search(VSRE, verstrline, re.M) +if mo: + version = mo.group(1) +else: + raise RuntimeError("Unable to find version string in %s." % + (VERSIONFILE, )) + + +ext_modules = [ + Pybind11Extension("libneutorch", + ["cpp/main.cpp"], + # Example: passing in the version to the compiled code + define_macros = [('VERSION_INFO', version)], + ), +] + setup( name='neutorch', - version='0.0.1', + version=version, description='Deep Learning for brain connectomics using PyTorch', + long_description=long_description, + long_description_content_type="text/markdown", author='Jingpeng Wu', author_email='jingpeng.wu@gmail.com', url='https://github.com/brain-connectome/neutorch', packages=find_packages(exclude=['bin']), + cmdclass={"build_ext": build_ext}, + ext_modules=ext_modules, + install_requires=requirements, + tests_require=[ + 'pytest', + ], entry_points=''' [console_scripts] neutrain-tbar=neutorch.cli.train_tbar:train @@ -25,4 +66,6 @@ "Operating System :: OS Independent", "Programming Language :: Python :: 3", ], + python_requires='>=3', + zip_safe=False, ) diff --git a/tests/train.py b/tests/train.py deleted file mode 100644 index 586a5fd..0000000 --- a/tests/train.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -