Skip to content

Commit f52a5ba

Browse files
authored
Merge pull request #15 from eyeseast/label-results
Label results with geocoder used
2 parents 9f819bb + 4ad98f0 commit f52a5ba

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

Diff for: geocode_sqlite/cli.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def geocode(ctx, geocoder):
8585
click.echo(f"Adding column: {longitude}")
8686
table.add_column(longitude, float)
8787

88+
if "geocoder" not in table.columns_dict:
89+
click.echo("Adding geocoder column")
90+
table.add_column("geocoder", str)
91+
8892
# always use a rate limiter, even if delay is zero
8993
geocode = RateLimiter(geocoder.geocode, min_delay_seconds=delay)
9094

@@ -93,13 +97,12 @@ def geocode(ctx, geocoder):
9397
)
9498

9599
done = 0
100+
errors = []
96101

97102
gen = geocode_list(
98103
rows, geocode, location, latitude_column=latitude, longitude_column=longitude
99104
)
100105

101-
errors = []
102-
103106
with click.progressbar(gen, length=count, label=f"{count} rows") as bar:
104107
for row, success in bar:
105108
pks = [row[pk] for pk in table.pks]

Diff for: geocode_sqlite/utils.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
This is the Python interface
33
"""
44
import logging
5-
from geopy import geocoders
65
from geopy.extra.rate_limiter import RateLimiter
76
from sqlite_utils import Database
87

@@ -39,31 +38,41 @@ def geocode_table(
3938
table = db[table_name]
4039

4140
if latitude_column not in table.columns_dict:
41+
log.info(f"Adding latitude column: {latitude_column}")
4242
table.add_column(latitude_column, float)
4343

4444
if longitude_column not in table.columns_dict:
45+
log.info(f"Adding longitude column: {longitude_column}")
4546
table.add_column(longitude_column, float)
4647

47-
if force:
48-
rows = table.rows
49-
else:
50-
rows = table.rows_where(
51-
f"{latitude_column} IS NULL OR {longitude_column} IS NULL"
52-
)
48+
if "geocoder" not in table.columns_dict:
49+
log.info("Adding geocoder column")
50+
table.add_column("geocoder", str)
51+
52+
rows, todo = select_ungeocoded(
53+
db,
54+
table,
55+
latitude_column=latitude_column,
56+
longitude_column=longitude_column,
57+
force=force,
58+
)
5359

54-
if delay:
55-
geocode = RateLimiter(geocoder.geocode, min_delay_seconds=delay)
56-
else:
57-
geocode = geocoder.geocode
60+
# always use a rate limiter, even with no delay
61+
geocode = RateLimiter(geocoder.geocode, min_delay_seconds=delay)
5862

5963
count = 0
64+
log.info(f"Geocoding {todo} rows from {table.name}")
6065
for row in rows:
6166
result = geocode_row(geocode, query_template, row)
6267
if result:
6368
pks = [row[pk] for pk in table.pks]
6469
table.update(
6570
pks,
66-
{latitude_column: result.latitude, longitude_column: result.longitude},
71+
{
72+
latitude_column: result.latitude,
73+
longitude_column: result.longitude,
74+
"geocoder": geocoder.__class__.__name__,
75+
},
6776
)
6877
count += 1
6978

@@ -95,6 +104,7 @@ def geocode_list(
95104
if result:
96105
row[longitude_column] = result.longitude
97106
row[latitude_column] = result.latitude
107+
row["geocoder"] = get_geocoder_class(geocode)
98108

99109
yield row, bool(result)
100110

@@ -126,3 +136,13 @@ def select_ungeocoded(
126136
rows = table.rows_where(f"{latitude_column} IS NULL OR {longitude_column} IS NULL")
127137

128138
return rows, count
139+
140+
141+
def get_geocoder_class(geocode):
142+
"Walk back up to the original geocoder class"
143+
144+
if isinstance(geocode, RateLimiter):
145+
return geocode.func.__self__.__class__.__name__
146+
147+
# unwrapped function
148+
return geocode.__self__.__class__.__name__

Diff for: tests/test_geocode_sqlite.py

+11
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,14 @@ def test_geocode_list(db, geocoder):
231231
for row, success in done:
232232
assert success
233233
assert row == table.get(row["id"])
234+
235+
236+
def test_label_results(db, geocoder):
237+
table = db[TABLE_NAME]
238+
239+
# geocode it once
240+
geocode_table(db, TABLE_NAME, geocoder, "{id}")
241+
242+
for row in table.rows:
243+
assert "geocoder" in row
244+
assert row["geocoder"] == geocoder.__class__.__name__

0 commit comments

Comments
 (0)