Skip to content

Commit 4542cbf

Browse files
authored
Merge pull request #27 from eyeseast/rowid-bug
Fix error with rowid tables
2 parents 5bd0096 + 5d3f970 commit 4542cbf

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

Diff for: geocode_sqlite/cli.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@ def geocode(ctx, geocoder):
159159
)
160160

161161
with click.progressbar(gen, length=count, label=f"{count} rows") as bar:
162-
for row, success in bar:
163-
pks = [row[pk] for pk in table.pks]
162+
for pk, row, success in bar:
163+
# pks = [row[pk] for pk in table.pks]
164164
if success:
165-
table.update(pks, row)
165+
table.update(pk, row)
166166
done += 1
167167
else:
168-
errors.append(pks)
168+
errors.append(pk)
169169

170170
click.echo("Geocoded {} rows".format(done))
171171
if errors:

Diff for: geocode_sqlite/utils.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def geocode_table(
6363

6464
count = 0
6565
log.info(f"Geocoding {todo} rows from {table.name}")
66-
for row in rows:
66+
for pk, row in rows:
6767
result = geocode_row(geocode, query_template, row, **kwargs)
6868
if result:
69-
pks = [row[pk] for pk in table.pks]
69+
# pks = [row[pk] for pk in table.pks]
7070
table.update(
71-
pks,
71+
pk,
7272
{
7373
latitude_column: result.latitude,
7474
longitude_column: result.longitude,
@@ -95,20 +95,21 @@ def geocode_list(
9595
"""
9696
Geocode an arbitrary list of rows, returning a generator.
9797
This does not query or save geocoded results into a table.
98-
If geocoding succeeds, it will yield a two-tuple:
98+
If geocoding succeeds, it will yield a three-tuple:
99+
- the primary key of the row (rowid or actual PK)
99100
- the row with latitude and longitude columns set
100101
- and True
101102
102103
If geocoding fails, it will yield the original row and False.
103104
"""
104-
for row in rows:
105+
for pk, row in rows:
105106
result = geocode_row(geocode, query_template, row, **kwargs)
106107
if result:
107108
row[longitude_column] = result.longitude
108109
row[latitude_column] = result.latitude
109110
row["geocoder"] = get_geocoder_class(geocode)
110111

111-
yield row, bool(result)
112+
yield pk, row, bool(result)
112113

113114

114115
def geocode_row(geocode, query_template, row, **kwargs):
@@ -135,7 +136,9 @@ def select_ungeocoded(
135136
if count:
136137
count = count[0]
137138

138-
rows = table.rows_where(f"{latitude_column} IS NULL OR {longitude_column} IS NULL")
139+
rows = table.pks_and_rows_where(
140+
f"{latitude_column} IS NULL OR {longitude_column} IS NULL"
141+
)
139142

140143
return rows, count
141144

Diff for: tests/test_geocode_sqlite.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424

2525

2626
@pytest.fixture
27-
def db():
27+
def db(request):
2828
db = Database(DB_PATH)
2929
table = db[TABLE_NAME]
3030

31+
pk = getattr(request, "param", "id")
32+
3133
# load csv data, which will be geocoded
32-
table.insert_all(csv.DictReader(open(CSV_DATA)), alter=True, pk="id")
34+
table.insert_all(csv.DictReader(open(CSV_DATA)), alter=True, pk=pk)
3335

3436
# load our geojson data, for our fake geocoder
3537
fc = json.load(open(GEOJSON_DATA))
@@ -56,6 +58,7 @@ def test_version():
5658
assert result.output.startswith("cli, version ")
5759

5860

61+
@pytest.mark.parametrize("db", ["id", None], indirect=True)
5962
def test_cli_geocode_table(db, geocoder):
6063
runner = CliRunner()
6164
table = db[TABLE_NAME]
@@ -243,7 +246,7 @@ def test_resume_table(db, geocoder):
243246
def test_geocode_list(db, geocoder):
244247
table = db[TABLE_NAME]
245248

246-
utah = list(table.rows_where('"state" = "UT"'))
249+
utah = list(table.pks_and_rows_where('"state" = "UT"'))
247250
assert len(utah) == 10
248251

249252
gen = geocode_list(utah, geocoder.geocode, "{id}")
@@ -255,9 +258,9 @@ def test_geocode_list(db, geocoder):
255258
# geocode the whole table, to cheeck results
256259
geocode_table(db, TABLE_NAME, geocoder, "{id}")
257260

258-
for row, success in done:
261+
for pk, row, success in done:
259262
assert success
260-
assert row == table.get(row["id"])
263+
assert row == table.get(pk)
261264

262265

263266
def test_label_results(db, geocoder):

0 commit comments

Comments
 (0)