Skip to content

Commit 4bbddbb

Browse files
authored
restructure for pip install
1 parent 2f019b7 commit 4bbddbb

22 files changed

Lines changed: 25 additions & 6 deletions

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recursive-include torchsparsegradutils/tests *.yaml

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Things that are missing may be listed as [issues](https://github.com/cai4cai/tor
2424
## Installation
2525
The provided package can be installed using:
2626

27-
`pip install torchsparsegradutils` (TODO)
27+
`pip install torchsparsegradutils`
2828

2929
or
3030

setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@ def readme():
88

99
setuptools.setup(
1010
name="torchsparsegradutils",
11-
version="0.0.3",
11+
version="0.1.0",
1212
description="A collection of utility functions to work with PyTorch sparse tensors",
1313
long_description=readme(),
1414
long_description_content_type="text/markdown",
1515
classifiers=[
1616
"Operating System :: OS Independent",
1717
"License :: OSI Approved :: Apache Software License",
18-
"Programming Language :: Python :: 3",
18+
"Programming Language :: Python :: 3.8",
19+
"Programming Language :: Python :: 3.9",
20+
"Programming Language :: Python :: 3.10",
1921
],
22+
python_requires=">=3.8, <3.11",
2023
keywords="sparse torch utility",
2124
url="https://github.com/cai4cai/torchsparsegradutils",
2225
author="CAI4CAI research group",
@@ -26,8 +29,12 @@ def readme():
2629
install_requires=[
2730
"torch>=1.13",
2831
],
32+
setup_requires=["pytest-runner"],
33+
tests_require=["pytest"],
34+
test_suite="tests",
2935
extras_require={
3036
"extras": ["jax", "cupy"],
3137
},
3238
zip_safe=False,
39+
include_package_data=True,
3340
)

tests/test_bicgstab.py renamed to torchsparsegradutils/tests/test_bicgstab.py

File renamed without changes.

tests/test_cupy_bindings.py renamed to torchsparsegradutils/tests/test_cupy_bindings.py

File renamed without changes.

tests/test_cupy_sparse_solve.py renamed to torchsparsegradutils/tests/test_cupy_sparse_solve.py

File renamed without changes.

tests/test_distributions.py renamed to torchsparsegradutils/tests/test_distributions.py

File renamed without changes.

tests/test_encoders.py renamed to torchsparsegradutils/tests/test_encoders.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import yaml
55
import os
6+
from pathlib import Path
67
from ast import literal_eval
78

89
from functools import reduce
@@ -152,7 +153,13 @@ def test_trim(tensor_nd, offsets, expected_output_slice):
152153

153154
# Test neighbourgood coordinate generation:
154155

155-
with open("tests/test_params/xyz_coords.yaml") as f: # load test cases from file
156+
# Get the absolute path to the directory of the current module:
157+
current_dir = Path(os.path.abspath(os.path.dirname(__file__)))
158+
159+
# Construct the path to the yaml file:
160+
yaml_file = current_dir / "test_params" / "xyz_coords.yaml"
161+
162+
with open(yaml_file) as f: # load test cases from file
156163
coord_test_cases = yaml.safe_load(f)
157164

158165
params = [tuple(tc.values())[1:] for tc in coord_test_cases] # Skip the first value, which is 'id'
@@ -172,7 +179,9 @@ def test_gen_coords(radius, expected_coords):
172179

173180
# Test neighbourgood offset generation:
174181

175-
with open("tests/test_params/czyx_shifts.yaml") as f: # load test cases from file
182+
yaml_file = current_dir / "test_params" / "czyx_shifts.yaml"
183+
184+
with open(yaml_file) as f: # load test cases from file
176185
shift_test_cases = yaml.safe_load(f)
177186

178187
params = [tuple(tc.values())[1:] for tc in shift_test_cases] # Skip the first value, which is 'id'
@@ -252,7 +261,9 @@ def test_pairwise_coo_indices_unique(radius, volume_shape, diag, upper, channel_
252261

253262
# Test the indices generated are as expected for a simple (3, 2, 2, 2) volume:
254263

255-
with open("tests/test_params/pairwise_coo_indices.yaml") as f: # load test cases from file to avoid a massive mess
264+
yaml_file = current_dir / "test_params" / "pairwise_coo_indices.yaml"
265+
266+
with open(yaml_file) as f: # load test cases from file to avoid a massive mess
256267
data = yaml.safe_load(f)
257268

258269
test_cases = data["test_cases"]

tests/test_jax_bindings.py renamed to torchsparsegradutils/tests/test_jax_bindings.py

File renamed without changes.

0 commit comments

Comments
 (0)