Skip to content

Commit a5e5ce8

Browse files
committed
Initial commit
1 parent 0776551 commit a5e5ce8

File tree

4 files changed

+301
-0
lines changed

4 files changed

+301
-0
lines changed

.gitignore

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
.pytest_cache/
49+
50+
# Translations
51+
*.mo
52+
*.pot
53+
54+
# Django stuff:
55+
*.log
56+
local_settings.py
57+
db.sqlite3
58+
59+
# Flask stuff:
60+
instance/
61+
.webassets-cache
62+
63+
# Scrapy stuff:
64+
.scrapy
65+
66+
# Sphinx documentation
67+
docs/_build/
68+
69+
# PyBuilder
70+
target/
71+
72+
# Jupyter Notebook
73+
.ipynb_checkpoints
74+
75+
# pyenv
76+
.python-version
77+
78+
.DS_Store
79+
80+
# celery beat schedule file
81+
celerybeat-schedule
82+
83+
# SageMath parsed files
84+
*.sage.py
85+
86+
# Environments
87+
.env
88+
.venv
89+
env/
90+
venv/
91+
ENV/
92+
env.bak/
93+
venv.bak/
94+
95+
# Spyder project settings
96+
.spyderproject
97+
.spyproject
98+
99+
# Rope project settings
100+
.ropeproject
101+
102+
# mkdocs documentation
103+
/site
104+
105+
# mypy
106+
.mypy_cache/
107+
108+
.idea
109+
110+
_build/
111+
_templates/
112+
113+
# BUILD FILES
114+
*.zip

LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright 2019 Justin Shenk
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

closest_pairs/main.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import math
2+
3+
import numpy as np
4+
from sklearn.decomposition import PCA
5+
6+
7+
def solution(array: np.ndarray, n=1):
8+
"""Solve the closest pairs problem.
9+
Note: Code borrowed from Andriy Lazorenko's Medium post.
10+
"""
11+
if array.shape[1] > 2:
12+
array = reduce_dims(array)
13+
14+
pairs = []
15+
mis = []
16+
17+
x, y = np.split(array, 2, axis=1)
18+
for i in range(n):
19+
a = list(zip(x, y)) # This produces list of tuples
20+
ax = sorted(a, key=lambda x: x[0]) # Presorting x-wise
21+
ay = sorted(a, key=lambda x: x[1]) # Presorting y-wise
22+
p1, p2, mi = closest_pair(ax, ay) # Recursive D&C function
23+
24+
p1_arg = np.where(array[:, 0] == p1[0])[0]
25+
p2_arg = np.where(array[:, 0] == p2[0])[0]
26+
27+
pairs.append((p1_arg, p2_arg))
28+
mis.append(mi)
29+
30+
# Remove point from array
31+
x = np.delete(x, p1_arg)
32+
y = np.delete(y, p1_arg)
33+
34+
return np.array(pairs), np.array(mis)
35+
36+
37+
def reduce_dims(array: np.ndarray):
38+
pca = PCA(n_components=2)
39+
array_2d = pca.fit(array).transform(array)
40+
41+
return array_2d
42+
43+
44+
def closest_pair(ax: np.ndarray, ay: np.ndarray):
45+
"""Find the closest pair.
46+
Note: Code borrowed from Andriy Lazorenko.
47+
"""
48+
ln_ax = len(ax) # It's quicker to assign variable
49+
if ln_ax <= 3:
50+
return brute(ax) # A call to bruteforce comparison
51+
mid = ln_ax // 2 # Division without remainder, need int
52+
Qx = ax[:mid] # Two-part split
53+
Rx = ax[mid:]
54+
# Determine midpoint on x-axis
55+
midpoint = ax[mid][0]
56+
Qy = list()
57+
Ry = list()
58+
for x in ay: # split ay into 2 arrays using midpoint
59+
if x[0] <= midpoint:
60+
Qy.append(x)
61+
else:
62+
Ry.append(x)
63+
# Call recursively both arrays after split
64+
(p1, q1, mi1) = closest_pair(Qx, Qy)
65+
(p2, q2, mi2) = closest_pair(Rx, Ry)
66+
# Determine smaller distance between points of 2 arrays
67+
if mi1 <= mi2:
68+
d = mi1
69+
mn = (p1, q1)
70+
else:
71+
d = mi2
72+
mn = (p2, q2)
73+
# Call function to account for points on the boundary
74+
(p3, q3, mi3) = closest_split_pair(ax, ay, d, mn)
75+
# Determine smallest distance for the array
76+
if d <= mi3:
77+
return mn[0], mn[1], d
78+
else:
79+
return p3, q3, mi3
80+
81+
82+
def brute(ax: np.ndarray):
83+
mi = dist(ax[0], ax[1])
84+
p1 = ax[0]
85+
p2 = ax[1]
86+
ln_ax = len(ax)
87+
if ln_ax == 2:
88+
return p1, p2, mi
89+
for i in range(ln_ax - 1):
90+
for j in range(i + 1, ln_ax):
91+
if i != 0 and j != 1:
92+
d = dist(ax[i], ax[j])
93+
if d < mi: # Update min_dist and points
94+
mi = d
95+
p1, p2 = ax[i], ax[j]
96+
return p1, p2, mi
97+
98+
99+
def dist(p1: np.ndarray, p2: np.ndarray):
100+
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
101+
102+
103+
def closest_split_pair(p_x: list, p_y: np.ndarray, delta: float, best_pair: tuple):
104+
"""Find the closest_split_pair.
105+
Note: Code modified from Andriy Lazorenko.
106+
"""
107+
ln_x = len(p_x) # store length - quicker
108+
mx_x = p_x[ln_x // 2][0] # select midpoint on x-sorted array
109+
# Create a subarray of points not further than delta from
110+
# midpoint on x-sorted array
111+
s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
112+
best = delta # assign best value to delta
113+
ln_y = len(s_y) # store length of subarray for quickness
114+
for i in range(ln_y - 1):
115+
for j in range(i + 1, min(i + 7, ln_y)):
116+
p, q = s_y[i], s_y[j]
117+
dst = dist(p, q)
118+
if dst < best:
119+
best_pair = p, q
120+
best = dst
121+
return best_pair[0], best_pair[1], best

setup.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
from setuptools import setup, find_packages
5+
6+
import os
7+
import re
8+
9+
here = os.path.abspath(os.path.dirname(__file__))
10+
11+
12+
def read(*parts):
13+
with open(os.path.join(here, *parts), "r", encoding="utf8") as fp:
14+
return fp.read()
15+
16+
17+
# Get package version
18+
def find_version(*file_paths):
19+
version_file = read(*file_paths)
20+
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
21+
if version_match:
22+
return version_match.group(1)
23+
raise RuntimeError("Unable to find version string.")
24+
25+
26+
requirements = ["numpy", "scikit-learn"]
27+
28+
this_dir = os.path.abspath(os.path.dirname(__file__))
29+
with open(os.path.join(this_dir, "README.md"), encoding="utf-8") as f:
30+
long_description = f.read()
31+
32+
setup(
33+
name="closest_pairs",
34+
version=find_version("closest_pairs", "__init__.py"),
35+
description="closest_pairs finds the closest pairs of points in a dataset",
36+
url="https://github.com/justinshenk/closest_pairs",
37+
author="Justin Shenk",
38+
author_email="shenkjustin@gmail.com",
39+
long_description=long_description,
40+
long_description_content_type="text/markdown",
41+
install_requires=requirements,
42+
classifiers=[
43+
"Intended Audience :: Developers",
44+
"License :: OSI Approved :: MIT License",
45+
"Intended Audience :: Education",
46+
"Intended Audience :: Science/Research",
47+
"Programming Language :: Python :: 3",
48+
"Programming Language :: Python :: 3.6",
49+
"Topic :: Scientific/Engineering :: Mathematics",
50+
"Topic :: Software Development :: Libraries :: Python Modules",
51+
"Topic :: Software Development :: Libraries",
52+
],
53+
python_requires=">= 3.6",
54+
packages=find_packages(),
55+
include_package_data=True,
56+
license="MIT",
57+
keywords="mathematics geometry",
58+
zip_safe=False,
59+
)

0 commit comments

Comments
 (0)