Skip to content

Commit 897e9bf

Browse files
authored
Merge pull request #111 from predicthq/BU-24-Update-SDK-for-Beam-Saved-Location-support
Bu 24 update sdk for beam saved location support
2 parents 5c74557 + 66c7527 commit 897e9bf

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

predicthq/endpoints/v1/beam/endpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
CreateAnalysisGroupResponse,
99
AnalysisGroup,
1010
AnalysisGroupResultSet,
11+
GeoJson,
12+
Place
1113
)
1214
from predicthq.endpoints.decorators import accepts, returns
1315
from typing import overload, List, Optional, TextIO, Union
@@ -29,6 +31,7 @@ def create(
2931
location__unit: Optional[str] = None,
3032
location__google_place_id: Optional[str] = None,
3133
location__geoscope_paths: Optional[List[str]] = None,
34+
location__saved_location_id: Optional[str] = None,
3235
rank__type: Optional[str] = None,
3336
rank__levels__phq: Optional[dict] = None,
3437
rank__levels__local: Optional[dict] = None,
@@ -70,6 +73,7 @@ def search(
7073
limit: Optional[int] = None,
7174
external_id: Optional[List[str]] = None,
7275
label: Optional[List[str]] = None,
76+
location__saved_location_id: Optional[List[str]] = None,
7377
**params,
7478
): ...
7579
@accepts()
@@ -102,6 +106,7 @@ def update(
102106
location__unit: Optional[str] = None,
103107
location__google_place_id: Optional[str] = None,
104108
location__geoscope_paths: Optional[List[str]] = None,
109+
location__saved_location_id: Optional[str] = None,
105110
rank__type: Optional[str] = None,
106111
rank__levels__phq: Optional[dict] = None,
107112
rank__levels__local: Optional[dict] = None,

predicthq/endpoints/v1/beam/schemas.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,28 @@
33
from predicthq.endpoints.schemas import ArgKwargResultSet
44
from typing import Optional, List
55

6+
# Python < 3.11 does not have StrEnum in the enum module
7+
import sys
8+
if sys.version_info < (3, 11):
9+
import enum
10+
11+
class StrEnum(str, enum.Enum):
12+
pass
13+
else:
14+
from enum import StrEnum
15+
16+
# Python < 3.9 does not have Annotated
17+
if sys.version_info < (3, 9):
18+
from typing_extensions import Annotated
19+
else:
20+
from typing import Annotated
21+
22+
# Python < 3.8 does not have Literal
23+
if sys.version_info < (3, 8):
24+
from typing_extensions import Literal
25+
else:
26+
from typing import Literal
27+
628

729
class BeamPaginationResultSet(ArgKwargResultSet):
830
def has_next(self):
@@ -113,6 +135,58 @@ class DemandType(DemandTypeGroup):
113135
currency_code: str
114136

115137

138+
class RadiusUnit(StrEnum):
139+
m = "m"
140+
km = "km"
141+
mi = "mi"
142+
ft = "ft"
143+
144+
145+
class GeoJsonGeometryType(StrEnum):
146+
POINT = "Point"
147+
POLYGON = "Polygon"
148+
MULTI_POLYGON = "MultiPolygon"
149+
LINE_STRING = "LineString"
150+
MULTI_LINE_STRING = "MultiLineString"
151+
152+
153+
class GeoJsonProperties(BaseModel):
154+
radius: Annotated[float, Field(gt=0)]
155+
radius_unit: RadiusUnit
156+
157+
158+
class GeoJsonGeometry(BaseModel):
159+
type: GeoJsonGeometryType
160+
coordinates: Annotated[list, Field(min_length=1)]
161+
162+
163+
class GeoJson(BaseModel):
164+
type: Literal["Feature"]
165+
properties: Optional[GeoJsonProperties] = None
166+
geometry: GeoJsonGeometry
167+
168+
169+
class Place(BaseModel):
170+
place_id: int
171+
type: str
172+
name: str
173+
county: Optional[str] = None
174+
region: Optional[str] = None
175+
country: Optional[str] = None
176+
geojson: GeoJson
177+
178+
179+
class SavedLocation(BaseModel):
180+
name: Optional[str] = None
181+
formatted_address: Optional[str] = None
182+
geojson: Optional[GeoJson] = None
183+
h3: Optional[List[str]] = None
184+
place_ids: Optional[List[int]] = None
185+
place_hierarchies: Optional[List[str]] = None
186+
places: Optional[List[Place]] = None
187+
location_id: str
188+
189+
116190
class Analysis(BaseModel):
117191
model_config: ConfigDict = ConfigDict(extra="allow")
118192

@@ -134,6 +208,7 @@ class Analysis(BaseModel):
134208
processed_dt: Optional[datetime] = None
135209
external_id: Optional[str] = None
136210
label: Optional[List[str]] = None
211+
saved_location: Optional[SavedLocation] = None
137212

138213

139214
class AnalysisResultSet(BeamPaginationResultSet):

tests/endpoints/v1/test_beam.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_search_params_underscores(self, client):
5050
limit=10,
5151
external_id="external_id",
5252
label=["label1", "label2"],
53+
location__saved_location_id="saved_location_id",
5354
)
5455

5556
client.request.assert_called_once_with(
@@ -72,6 +73,7 @@ def test_search_params_underscores(self, client):
7273
"limit": 10,
7374
"external_id": "external_id",
7475
"label": "label1,label2",
76+
"location.saved_location_id": "saved_location_id",
7577
},
7678
verify=True,
7779
)
@@ -98,6 +100,7 @@ def test_search_params_dicts(self, client):
98100
limit=10,
99101
external_id="external_id",
100102
label=["label1", "label2"],
103+
location__saved_location_id="saved_location_id",
101104
)
102105

103106
client.request.assert_called_once_with(
@@ -120,6 +123,7 @@ def test_search_params_dicts(self, client):
120123
"limit": 10,
121124
"external_id": "external_id",
122125
"label": "label1,label2",
126+
"location.saved_location_id": "saved_location_id",
123127
},
124128
verify=True,
125129
)
@@ -135,6 +139,7 @@ def test_create_params_underscores(self, client):
135139
location__unit="km",
136140
location__google_place_id="google_place_id",
137141
location__geoscope_paths=["geoscope_path1", "geoscope_path2"],
142+
location__saved_location_id="saved_location_id",
138143
rank__type="type",
139144
rank__levels__phq={"min": 1.0, "max": 2.0},
140145
rank__levels__local={"min": 3.0, "max": 4.0},
@@ -158,6 +163,7 @@ def test_create_params_underscores(self, client):
158163
"unit": "km",
159164
"google_place_id": "google_place_id",
160165
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
166+
"saved_location_id": "saved_location_id",
161167
},
162168
"rank": {
163169
"type": "type",
@@ -192,6 +198,7 @@ def test_create_params_dicts(self, client):
192198
"unit": "km",
193199
"google_place_id": "google_place_id",
194200
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
201+
"saved_location_id": "saved_location_id",
195202
},
196203
rank={
197204
"type": "type",
@@ -222,6 +229,7 @@ def test_create_params_dicts(self, client):
222229
"unit": "km",
223230
"google_place_id": "google_place_id",
224231
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
232+
"saved_location_id": "saved_location_id",
225233
},
226234
"rank": {
227235
"type": "type",
@@ -255,6 +263,7 @@ def test_update_params_underscores(self, client):
255263
location__unit="km",
256264
location__google_place_id="google_place_id",
257265
location__geoscope_paths=["geoscope_path1", "geoscope_path2"],
266+
location__saved_location_id="saved_location_id",
258267
rank__type="type",
259268
rank__levels__phq={"min": 1.0, "max": 2.0},
260269
rank__levels__local={"min": 3.0, "max": 4.0},
@@ -278,6 +287,7 @@ def test_update_params_underscores(self, client):
278287
"unit": "km",
279288
"google_place_id": "google_place_id",
280289
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
290+
"saved_location_id": "saved_location_id",
281291
},
282292
"rank": {
283293
"type": "type",
@@ -310,6 +320,7 @@ def test_update_params_dicts(self, client):
310320
"unit": "km",
311321
"google_place_id": "google_place_id",
312322
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
323+
"saved_location_id": "saved_location_id",
313324
},
314325
rank={
315326
"type": "type",
@@ -340,6 +351,7 @@ def test_update_params_dicts(self, client):
340351
"unit": "km",
341352
"google_place_id": "google_place_id",
342353
"geoscope_paths": ["geoscope_path1", "geoscope_path2"],
354+
"saved_location_id": "saved_location_id",
343355
},
344356
"rank": {
345357
"type": "type",

0 commit comments

Comments
 (0)