Skip to content

Commit 3d624e8

Browse files
authored
Support Expressions in HAVING and ORDER BY (#525)
* Sort and Having allow unnamed expressions * Recursive replace expressions * order by and having support expressions * Fix for Spark CI * Release 0.2.9 Co-authored-by: Joshua <joshua-oss@users.noreply.github.com> Co-authored-by: = <=>
1 parent 7e9f093 commit 3d624e8

10 files changed

Lines changed: 282 additions & 23 deletions

File tree

sql/HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# SmartNoise SQL v0.2.9 Release Notes
2+
3+
* MySql and SQLite readers
4+
* HAVING and ORDER BY allow expresssions in addition to columns
5+
16
# SmartNoise SQL v0.2.8 Release Notes
27

38
* Fix bug where integer sums can overflow i32. All engines default to 64-bit integers now.

sql/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.2.8
1+
0.2.9

sql/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "smartnoise-sql"
3-
version = "0.2.8"
3+
version = "0.2.9"
44
description = "Differentially Private SQL Queries"
55
authors = ["SmartNoise Team <smartnoise@opendp.org>"]
66
license = "MIT"

sql/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
setup_kwargs = {
2727
'name': 'smartnoise-sql',
28-
'version': '0.2.8',
28+
'version': '0.2.9',
2929
'description': 'Differentially Private SQL Queries',
3030
'long_description': '[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Python](https://img.shields.io/badge/python-3.7%20%7C%203.8-blue)](https://www.python.org/)\n\n<a href="https://smartnoise.org"><img src="https://github.com/opendp/smartnoise-sdk/raw/main/images/SmartNoise/SVG/Logo%20Mark_grey.svg" align="left" height="65" vspace="8" hspace="18"></a>\n\n## SmartNoise SQL\n\nDifferentially private SQL queries. Tested with:\n* PostgreSQL\n* SQL Server\n* Spark\n* Pandas (SQLite)\n* PrestoDB\n* BigQuery\n\nSmartNoise is intended for scenarios where the analyst is trusted by the data owner. SmartNoise uses the [OpenDP](https://github.com/opendp/opendp) library of differential privacy algorithms.\n\n## Installation\n\n```\npip install smartnoise-sql\n```\n\n## Querying a Pandas DataFrame\n\nUse the `from_df` method to create a private reader that can issue queries against a pandas dataframe.\n\n```python\nimport snsql\nfrom snsql import Privacy\nimport pandas as pd\nprivacy = Privacy(epsilon=1.0, delta=0.01)\n\ncsv_path = \'PUMS.csv\'\nmeta_path = \'PUMS.yaml\'\n\npums = pd.read_csv(csv_path)\nreader = snsql.from_df(pums, privacy=privacy, metadata=meta_path)\n\nresult = reader.execute(\'SELECT sex, AVG(age) AS age FROM PUMS.PUMS GROUP BY sex\')\n```\n\n## Querying a SQL Database\n\nUse `from_connection` to wrap an existing database connection.\n\n```python\nimport snsql\nfrom snsql import Privacy\nimport psycopg2\n\nprivacy = Privacy(epsilon=1.0, delta=0.01)\nmeta_path = \'PUMS.yaml\'\n\npumsdb = psycopg2.connect(user=\'postgres\', host=\'localhost\', database=\'PUMS\')\nreader = snsql.from_connection(pumsdb, privacy=privacy, metadata=meta_path)\n\nresult = reader.execute(\'SELECT sex, AVG(age) AS age FROM PUMS.PUMS GROUP BY sex\')\n```\n\n## Querying a Spark DataFrame\n\nUse `from_connection` to wrap a spark session.\n\n```python\nimport pyspark\nfrom pyspark.sql import SparkSession\nspark = SparkSession.builder.getOrCreate()\nfrom snsql import *\n\npums = spark.read.load(...) # load a Spark DataFrame\npums.createOrReplaceTempView("PUMS_large")\n\nmetadata = \'PUMS_large.yaml\'\n\nprivate_reader = from_connection(\n spark, \n metadata=metadata, \n privacy=Privacy(epsilon=3.0, delta=1/1_000_000)\n)\nprivate_reader.reader.compare.search_path = ["PUMS"]\n\n\nres = private_reader.execute(\'SELECT COUNT(*) FROM PUMS_large\')\nres.show()\n```\n\n## Privacy Cost\n\nThe privacy parameters epsilon and delta are passed in to the private connection at instantiation time, and apply to each computed column during the life of the session. Privacy cost accrues indefinitely as new queries are executed, with the total accumulated privacy cost being available via the `spent` property of the connection\'s `odometer`:\n\n```python\nprivacy = Privacy(epsilon=0.1, delta=10e-7)\n\nreader = from_connection(conn, metadata=metadata, privacy=privacy)\nprint(reader.odometer.spent) # (0.0, 0.0)\n\nresult = reader.execute(\'SELECT COUNT(*) FROM PUMS.PUMS\')\nprint(reader.odometer.spent) # approximately (0.1, 10e-7)\n```\n\nThe privacy cost increases with the number of columns:\n\n```python\nreader = from_connection(conn, metadata=metadata, privacy=privacy)\nprint(reader.odometer.spent) # (0.0, 0.0)\n\nresult = reader.execute(\'SELECT AVG(age), AVG(income) FROM PUMS.PUMS\')\nprint(reader.odometer.spent) # approximately (0.4, 10e-6)\n```\n\nThe odometer is advanced immediately before the differentially private query result is returned to the caller. If the caller wishes to estimate the privacy cost of a query without running it, `get_privacy_cost` can be used:\n\n```python\nreader = from_connection(conn, metadata=metadata, privacy=privacy)\nprint(reader.odometer.spent) # (0.0, 0.0)\n\ncost = reader.get_privacy_cost(\'SELECT AVG(age), AVG(income) FROM PUMS.PUMS\')\nprint(cost) # approximately (0.4, 10e-6)\n\nprint(reader.odometer.spent) # (0.0, 0.0)\n```\n\nNote that the total privacy cost of a session accrues at a slower rate than the sum of the individual query costs obtained by `get_privacy_cost`. The odometer accrues all invocations of mechanisms for the life of a session, and uses them to compute total spend.\n\n```python\nreader = from_connection(conn, metadata=metadata, privacy=privacy)\nquery = \'SELECT COUNT(*) FROM PUMS.PUMS\'\nepsilon_single, _ = reader.get_privacy_cost(query)\nprint(epsilon_single) # 0.1\n\n# no queries executed yet\nprint(reader.odometer.spent) # (0.0, 0.0)\n\nfor _ in range(100):\n reader.execute(query)\n\nepsilon_many, _ = reader.odometer.spent\nprint(f\'{epsilon_many} < {epsilon_single * 100}\')\n```\n\n## Histograms\n\nSQL `group by` queries represent histograms binned by grouping key. Queries over a grouping key with unbounded or non-public dimensions expose privacy risk. For example:\n\n```sql\nSELECT last_name, COUNT(*) FROM Sales GROUP BY last_name\n```\n\nIn the above query, if someone with a distinctive last name is included in the database, that person\'s record might accidentally be revealed, even if the noisy count returns 0 or negative. To prevent this from happening, the system will automatically censor dimensions which would violate differential privacy.\n\n## Private Synopsis\n\nA private synopsis is a pre-computed set of differentially private aggregates that can be filtered and aggregated in various ways to produce new reports. Because the private synopsis is differentially private, reports generated from the synopsis do not need to have additional privacy applied, and the synopsis can be distributed without risk of additional privacy loss. Reports over the synopsis can be generated with non-private SQL, within an Excel Pivot Table, or through other common reporting tools.\n\nYou can see a sample [notebook for creating private synopsis](samples/Synopsis.ipynb) suitable for consumption in Excel or SQL.\n\n## Limitations\n\nYou can think of the data access layer as simple middleware that allows composition of `opendp` computations using the SQL language. The SQL language provides a limited subset of what can be expressed through the full `opendp` library. For example, the SQL language does not provide a way to set per-field privacy budget.\n\nBecause we delegate the computation of exact aggregates to the underlying database engines, execution through the SQL layer can be considerably faster, particularly with database engines optimized for precomputed aggregates. However, this design choice means that analysis graphs composed with SQL language do not access data in the engine on a per-row basis. Therefore, SQL queries do not currently support algorithms that require per-row access, such as quantile algorithms that use underlying values. This is a limitation that future releases will relax for database engines that support row-based access, such as Spark.\n\nThe SQL processing layer has limited support for bounding contributions when individuals can appear more than once in the data. This includes ability to perform reservoir sampling to bound contributions of an individual, and to scale the sensitivity parameter. These parameters are important when querying reporting tables that might be produced from subqueries and joins, but require caution to use safely.\n\nFor this release, we recommend using the SQL functionality while bounding user contribution to 1 row. The platform defaults to this option by setting `max_contrib` to 1, and should only be overridden if you know what you are doing. Future releases will focus on making these options easier for non-experts to use safely.\n\n\n## Communication\n\n- You are encouraged to join us on [GitHub Discussions](https://github.com/opendp/opendp/discussions/categories/smartnoise)\n- Please use [GitHub Issues](https://github.com/opendp/smartnoise-sdk/issues) for bug reports and feature requests.\n- For other requests, including security issues, please contact us at [smartnoise@opendp.org](mailto:smartnoise@opendp.org).\n\n## Releases and Contributing\n\nPlease let us know if you encounter a bug by [creating an issue](https://github.com/opendp/smartnoise-sdk/issues).\n\nWe appreciate all contributions. Please review the [contributors guide](../contributing.rst). We welcome pull requests with bug-fixes without prior discussion.\n\nIf you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.\n',
3131
'author': 'SmartNoise Team',

sql/snsql/_ast/tokens.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,29 @@ def symbol(self, relations):
356356
)
357357
return self
358358

359+
"""
360+
Replace all instances of an expression in the tree with another expression.
361+
362+
:param old: the old expression
363+
:param new: the new expression
364+
:param lock: if True, then the new expression will be locked
365+
such that it cannot be replaced again
366+
:return: the updated expression
367+
"""
368+
def replaced(self, old, new, lock=False):
369+
if hasattr(self, "_locked") and self._locked:
370+
return self
371+
if self == old:
372+
if lock:
373+
new._locked = True
374+
return new
375+
else:
376+
props = self.__dict__
377+
for k, v in props.items():
378+
if isinstance(v, SqlExpr) and str(v) != '*':
379+
props[k] = v.replaced(old, new, lock)
380+
return self
381+
359382
@property
360383
def is_key_count(self):
361384
return False

sql/snsql/sql/private_reader.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import List, Union
23
import warnings
34
import numpy as np
@@ -10,7 +11,7 @@
1011
from .private_rewriter import Rewriter
1112
from .parse import QueryParser
1213
from .reader import PandasReader
13-
from .reader.base import SortKey
14+
from .reader.base import SortKeyExpressions
1415

1516
from snsql._ast.ast import Query, Top
1617
from snsql._ast.expressions import sql as ast
@@ -20,6 +21,7 @@
2021
from ._mechanisms import *
2122

2223
import itertools
24+
import string
2325

2426
class PrivateReader(Reader):
2527
"""Executes SQL queries against tabular data sources and returns differentially private results.
@@ -501,6 +503,7 @@ def _execute_ast(self, query, *ignore, accuracy:bool=False, pre_aggregated=None,
501503
if isinstance(query, str):
502504
raise ValueError("Please pass AST to _execute_ast.")
503505

506+
_orig_query = query
504507
subquery, query = self._rewrite_ast(query)
505508

506509
if pre_aggregated is not None:
@@ -591,6 +594,8 @@ def process_clamp_counts(row_in):
591594
out_syms = query._select_symbols
592595
out_types = [s.expression.type() for s in out_syms]
593596
out_col_names = [s.name for s in out_syms]
597+
bind_prefix = ''.join(np.random.choice(list(string.ascii_lowercase), 5))
598+
binding_col_names = [name if name != "???" else f"col_{bind_prefix}_{i}" for i, name in enumerate(out_col_names)]
594599

595600
def convert(val, type):
596601
if val is None:
@@ -641,12 +646,15 @@ def process_out_row(row):
641646
out = map(process_out_row, out)
642647

643648
def filter_aggregate(row, condition):
644-
bindings = dict((name.lower(), val) for name, val in zip(out_col_names, row[0]))
649+
bindings = dict((name.lower(), val) for name, val in zip(binding_col_names, row[0]))
645650
keep = condition.evaluate(bindings)
646651
return keep
647652

648653
if query.having is not None:
649-
condition = query.having.condition
654+
condition = deepcopy(query.having.condition)
655+
for i, ne in enumerate(_orig_query.select.namedExpressions):
656+
source_col = binding_col_names[i]
657+
condition = condition.replaced(ne.expression, ast.Column(source_col), lock=True)
650658
if hasattr(out, "filter"):
651659
# it's an RDD
652660
out = out.filter(lambda row: filter_aggregate(row, condition))
@@ -655,29 +663,23 @@ def filter_aggregate(row, condition):
655663

656664
# sort it if necessary
657665
if query.order is not None:
658-
sort_fields = []
666+
sort_expressions = []
659667
for si in query.order.sortItems:
660-
if type(si.expression) is not ast.Column:
661-
raise ValueError("We only know how to sort by column names right now")
662-
colname = si.expression.name.lower()
663-
if colname not in out_col_names:
664-
raise ValueError(
665-
"Can't sort by {0}, because it's not in output columns: {1}".format(
666-
colname, out_col_names
667-
)
668-
)
669-
colidx = out_col_names.index(colname)
670668
desc = False
671669
if si.order is not None and si.order.lower() == "desc":
672670
desc = True
673-
if desc and not (out_types[colidx] in ["int", "float", "boolean", "datetime"]):
674-
raise ValueError("We don't know how to sort descending by " + out_types[colidx])
675-
sf = (desc, colidx)
676-
sort_fields.append(sf)
671+
if type(si.expression) is ast.Column and si.expression.name.lower() in out_col_names:
672+
sort_expressions.append((desc, si.expression))
673+
else:
674+
expr = deepcopy(si.expression)
675+
for i, ne in enumerate(_orig_query.select.namedExpressions):
676+
source_col = binding_col_names[i]
677+
expr = expr.replaced(ne.expression, ast.Column(source_col), lock=True)
678+
sort_expressions.append((desc, expr))
677679

678680
def sort_func(row):
679681
# use index 0, since index 1 is accuracy
680-
return SortKey(row[0], sort_fields)
682+
return SortKeyExpressions(row[0], sort_expressions, binding_col_names)
681683

682684
if hasattr(out, "sortBy"):
683685
out = out.sortBy(sort_func)

sql/snsql/sql/reader/base.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ def serialize(self, query):
147147
return str(query)
148148

149149
class SortKey:
150+
"""
151+
Handles comparison operators for sorting
152+
153+
:param obj: The object to be sorted (a row)
154+
:param sort_fields: A list of tuples, where each tuple is a pair of (bool, int)
155+
The bool indicates whether the sort is descending (True) or ascending (False)
156+
The int indicates the column index to sort on
157+
"""
150158
def __init__(self, obj, sort_fields, *args):
151159
self.obj = obj
152160
self.sort_fields = sort_fields
@@ -181,3 +189,55 @@ def __ge__(self, other):
181189

182190
def __ne__(self, other):
183191
return self.mycmp(self.obj, other.obj, self.sort_fields) != 0
192+
193+
class SortKeyExpressions:
194+
"""
195+
Handles comparison operators for sorting
196+
197+
:param obj: The object to be sorted (a row)
198+
:param sort_expressions: A list of tuples of SqlExpression objects to be used for comparison
199+
each tuple is a boolean indicating whether the sort is descending (True) or ascending (False)
200+
followed by the SqlExpression object to be used for comparison.
201+
:param binding_col_names: A list of column names to be used for binding the sort expression
202+
"""
203+
def __init__(self, obj, sort_expressions, binding_col_names, *args):
204+
self.sort_expressions = sort_expressions
205+
self.bindings = dict((name.lower(), val) for name, val in zip(binding_col_names, obj))
206+
def mycmp(self, bindings_a, bindings_b, sort_expressions):
207+
for desc, expr in sort_expressions:
208+
try:
209+
v_a = expr.evaluate(bindings_a)
210+
v_b = expr.evaluate(bindings_b)
211+
if desc:
212+
if v_a < v_b:
213+
return 1
214+
elif v_a > v_b:
215+
return -1
216+
else:
217+
if v_a < v_b:
218+
return -1
219+
elif v_a > v_b:
220+
return 1
221+
except Exception as e:
222+
message = f"Error evaluating sort expression {expr}"
223+
message += "\nWe can only sort using expressions that can be evaluated on output columns."
224+
raise ValueError(message) from e
225+
return 0
226+
227+
def __lt__(self, other):
228+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) < 0
229+
230+
def __gt__(self, other):
231+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) > 0
232+
233+
def __eq__(self, other):
234+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) == 0
235+
236+
def __le__(self, other):
237+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) <= 0
238+
239+
def __ge__(self, other):
240+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) >= 0
241+
242+
def __ne__(self, other):
243+
return self.mycmp(self.bindings, other.bindings, self.sort_expressions) != 0

sql/tests/mechanism/test_approx_bounds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_bounds_zero_negative(self):
3434
assert (max == 0.0)
3535
def test_bounds_increment(self):
3636
powers = np.arange(10) * 4
37-
vals = [2**p for p in powers] * 100
37+
vals = [2.0**p for p in powers] * 100
3838
min, max = approx_bounds(vals, 10.0)
3939
assert (min == 1.0)
4040
assert (max >= 2**35 and max <= 2**37)

0 commit comments

Comments
 (0)