Skip to content

Commit 2f648b4

Browse files
authored
Merge pull request #9 from The-Strategy-Unit/add-get_inequalities-methods
2 parents f7e0715 + 9fa0aea commit 2f648b4

File tree

3 files changed

+42
-13
lines changed

3 files changed

+42
-13
lines changed

src/nhp/databricks/icb.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,16 @@ def get_hsa_gams(self):
185185
"""Get the health status adjustment gams."""
186186
# this is not supported in our data bricks environment currently
187187
raise NotImplementedError
188+
189+
def get_inequalities(self) -> pd.DataFrame:
190+
"""Get the inequalities dataframe.
191+
192+
Returns:
193+
The inequalities dataframe.
194+
"""
195+
return (
196+
self._spark.read.parquet(f"{self._data_path}/inequalities")
197+
.filter(F.col("icb") == self._icb)
198+
.filter(F.col("fyear") == self._year)
199+
.toPandas()
200+
)

src/nhp/databricks/national.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def create(
5555
:return: a function to initialise the object
5656
:rtype: Callable[[str, str], Databricks]
5757
"""
58-
return lambda fyear, _: DatabricksNational(
59-
spark, data_path, fyear, sample_rate, seed
60-
)
58+
return lambda fyear, _: DatabricksNational(spark, data_path, fyear, sample_rate, seed)
6159

6260
def get_ip(self) -> pd.DataFrame:
6361
"""Get the inpatients dataframe.
@@ -98,14 +96,10 @@ def get_op(self) -> pd.DataFrame:
9896
# TODO: temporary fix, see #353
9997
.withColumn("sushrg_trimmed", F.lit("HRG"))
10098
.withColumn("imd_quintile", F.lit(0))
101-
.groupBy(
102-
op.drop("index", "fyear", "attendances", "tele_attendances").columns
103-
)
99+
.groupBy(op.drop("index", "fyear", "attendances", "tele_attendances").columns)
104100
.agg(
105101
(F.sum("attendances") * self._sample_rate).alias("attendances"),
106-
(F.sum("tele_attendances") * self._sample_rate).alias(
107-
"tele_attendances"
108-
),
102+
(F.sum("tele_attendances") * self._sample_rate).alias("tele_attendances"),
109103
)
110104
# TODO: how do we make this stable? at the moment we can't use full model results with
111105
# national
@@ -209,4 +203,15 @@ def get_hsa_gams(self):
209203
"""Get the health status adjustment gams."""
210204
# this is not supported in our data bricks environment currently
211205
raise NotImplementedError
212-
raise NotImplementedError
206+
207+
def get_inequalities(self) -> pd.DataFrame:
208+
"""Get the inequalities dataframe.
209+
210+
Returns:
211+
The inequalities dataframe.
212+
"""
213+
return (
214+
self._spark.read.parquet(f"{self._data_path}/inequalities")
215+
.filter(F.col("fyear") == self._year)
216+
.toPandas()
217+
)

src/nhp/databricks/provider.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def create(spark: SparkSession, data_path: str) -> Callable[[int, str], Any]:
3434
:return: a function to initialise the object
3535
:rtype: Callable[[str, str], Databricks]
3636
"""
37-
return lambda fyear, dataset: DatabricksProvider(
38-
spark, data_path, fyear, dataset
39-
)
37+
return lambda fyear, dataset: DatabricksProvider(spark, data_path, fyear, dataset)
4038

4139
@property
4240
def _apc(self):
@@ -144,3 +142,16 @@ def get_hsa_gams(self):
144142
"""Get the health status adjustment gams."""
145143
# this is not supported in our data bricks environment currently
146144
raise NotImplementedError
145+
146+
def get_inequalities(self) -> pd.DataFrame:
147+
"""Get the inequalities dataframe.
148+
149+
Returns:
150+
The inequalities dataframe.
151+
"""
152+
return (
153+
self._spark.read.parquet(f"{self._data_path}/inequalities")
154+
.filter(F.col("dataset") == self._dataset)
155+
.filter(F.col("fyear") == self._year)
156+
.toPandas()
157+
)

0 commit comments

Comments
 (0)