Skip to content

Commit e045a0a

Browse files
author
Paul Duncan
committed
various fixes:
formatting, search params, test function name, unneeded code removed, imports cleaned up, endpoint extra slash removed
1 parent baedf12 commit e045a0a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+910
-589
lines changed

predicthq/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def get_headers(self, headers):
5555
@stamina.retry(on=RetriableError, attempts=3)
5656
def request(self, method, path, **kwargs):
5757
headers = self.get_headers(kwargs.pop("headers", {}))
58-
response = requests.request(method, self.build_url(path), headers=headers, **kwargs)
58+
response = requests.request(
59+
method, self.build_url(path), headers=headers, **kwargs
60+
)
5961
self.logger.debug(response.request.url)
6062
try:
6163
response.raise_for_status()

predicthq/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
class Config(object):
15-
1615
_config_sections = (
1716
"endpoint",
1817
"oauth2",

predicthq/endpoints/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ def __new__(mcs, name, bases, data):
33
if "Meta" not in data:
44

55
class Meta:
6-
""" Used by decorators when overriding schema classes """
6+
"""Used by decorators when overriding schema classes"""
77

88
pass
99

@@ -35,6 +35,8 @@ def for_account(self, account_id):
3535

3636
def build_url(self, prefix, suffix):
3737
if self.account_id is not None:
38-
return f"/{prefix.strip('/')}/accounts/{self.account_id}/{suffix.strip('/')}/"
38+
return (
39+
f"/{prefix.strip('/')}/accounts/{self.account_id}/{suffix.strip('/')}/"
40+
)
3941
else:
4042
return super(UserBaseEndpoint, self).build_url(prefix, suffix)

predicthq/endpoints/oauth2/endpoint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77

88
class OAuth2Endpoint(BaseEndpoint):
9-
109
@deprecated(
1110
reason=(
12-
"OAuth2 endpoints in the SDK are deprecated and will be removed in future releases. "
13-
"Use TokenAuth (API Access Token) with Client(..., access_token=...)."
11+
"OAuth2 endpoints in the SDK are deprecated and will be removed in future releases. "
12+
"Use TokenAuth (API Access Token) with Client(..., access_token=...)."
1413
),
1514
category=FutureWarning,
1615
)
@@ -31,11 +30,10 @@ def get_token(self, client_id, client_secret, scope, grant_type, **kwargs):
3130
verify=verify_ssl,
3231
)
3332

34-
3533
@deprecated(
3634
reason=(
37-
"OAuth2 endpoints in the SDK are deprecated and will be removed in future releases. "
38-
"Use TokenAuth (API Access Token) with Client(..., access_token=...)."
35+
"OAuth2 endpoints in the SDK are deprecated and will be removed in future releases. "
36+
"Use TokenAuth (API Access Token) with Client(..., access_token=...)."
3937
),
4038
category=FutureWarning,
4139
)

predicthq/endpoints/v1/accounts/endpoint.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66

77
class AccountsEndpoint(BaseEndpoint):
8-
98
@deprecated(
109
reason=(
11-
"The Accounts endpoint in the SDK is deprecated and will be removed in future releases. "
12-
"Account information can be managed via the PredictHQ dashboard."
10+
"The Accounts endpoint in the SDK is deprecated and will be removed in future releases. "
11+
"Account information can be managed via the PredictHQ dashboard."
1312
),
1413
category=FutureWarning,
1514
)

predicthq/endpoints/v1/beam/endpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
AnalysisGroup,
1010
AnalysisGroupResultSet,
1111
GeoJson,
12-
Place
12+
Place,
1313
)
1414
from predicthq.endpoints.decorators import accepts, returns
1515
from typing import overload, List, Optional, TextIO, Union
@@ -185,8 +185,8 @@ def upload_demand(
185185
self,
186186
analysis_id: str,
187187
json: Optional[Union[str, TextIO]] = None,
188-
ndjson: Optional[Union[str, TextIO]] = None,
189-
csv: Optional[Union[str, TextIO]] = None,
188+
ndjson: Optional[Union[str, TextIO]] = None,
189+
csv: Optional[Union[str, TextIO]] = None,
190190
**params,
191191
): ...
192192
def upload_demand(self, analysis_id: str, **params):

predicthq/endpoints/v1/beam/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# Python < 3.11 does not have StrEnum in the enum module
77
import sys
8+
89
if sys.version_info < (3, 11):
910
import enum
1011

@@ -38,7 +39,6 @@ def get_next(self):
3839
return self._more(**self._kwargs)
3940

4041

41-
4242
class CreateAnalysisResponse(BaseModel):
4343
model_config: ConfigDict = ConfigDict(extra="allow")
4444

predicthq/endpoints/v1/events/schemas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ class Event(BaseModel):
114114
private: Optional[bool] = None # Loop add-on
115115
rank: Optional[int] = None # PHQ Rank add-on
116116
predicted_event_spend: Optional[int] = None # Predicted Event Spend add-on
117-
predicted_event_spend_industries: Optional[PredictedEventSpendIndustries] = None # Predicted Event Spend add-on
117+
predicted_event_spend_industries: Optional[PredictedEventSpendIndustries] = (
118+
None # Predicted Event Spend add-on
119+
)
118120

119121

120122
class EventResultSet(ResultSet):

predicthq/endpoints/v1/features/schemas.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def __flatten_json(d: dict, pk: str = "") -> dict:
2020
flat_json.update({f"{pk}{separator}{k}" if pk else k: v})
2121
return flat_json
2222

23-
return [__flatten_json(d.model_dump(exclude_none=True)) for d in self.iter_all()]
23+
return [
24+
__flatten_json(d.model_dump(exclude_none=True)) for d in self.iter_all()
25+
]
2426

2527
def to_csv(self, file: str, mode: str = "w+", separator: str = "_") -> None:
2628
header = None
@@ -65,4 +67,6 @@ def get_next(self):
6567
if not self.has_next() or not hasattr(self, "_more"):
6668
return
6769
params = self._parse_params(self.next)
68-
return self._more(_params=params, _json=self._kwargs.get("_json", {}) or self._kwargs)
70+
return self._more(
71+
_params=params, _json=self._kwargs.get("_json", {}) or self._kwargs
72+
)
Lines changed: 46 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,54 @@
1-
from predicthq.endpoints.base import UserBaseEndpoint
2-
from predicthq.endpoints.base import BaseEndpoint
3-
from predicthq.endpoints.decorators import accepts, returns
4-
from typing import overload, List, Optional, TextIO, Union
5-
from pydantic import BaseModel
1+
from predicthq.endpoints.base import UserBaseEndpoint, BaseEndpoint
62
from predicthq.endpoints.decorators import accepts, returns
3+
from typing import overload, List, Optional
74
from .schemas import (
85
SavedLocation,
96
SavedLocationResultSet,
107
CreateSavedLocationResponse,
118
PostSharingEnableResponse,
12-
SuggestedRadiusResponse,Location
9+
SuggestedRadiusResponse,
10+
Location,
1311
)
1412
from ..events.schemas import EventResultSet
1513

1614

17-
from typing import Optional, List
18-
19-
# Python < 3.11 does not have StrEnum in the enum module
20-
import sys
21-
if sys.version_info < (3, 11):
22-
import enum
23-
24-
class StrEnum(str, enum.Enum):
25-
pass
26-
else:
27-
from enum import StrEnum
28-
29-
# Python < 3.9 does not have Annotated
30-
if sys.version_info < (3, 9):
31-
from typing_extensions import Annotated
32-
else:
33-
from typing import Annotated
34-
35-
# Python < 3.8 does not have Literal
36-
if sys.version_info < (3, 8):
37-
from typing_extensions import Literal
38-
else:
39-
from typing import Literal
40-
41-
4215
class SavedLocationsEndpoint(UserBaseEndpoint):
43-
4416
@overload
4517
def search(
46-
self,
47-
location_id: Optional[List[str]] = None,
48-
location_code: Optional[List[str]] = None,
49-
labels: Optional[List[str]] = None,
50-
user_id: Optional[List[str]] = None,
51-
subscription_valid_types: Optional[List[str]] = None,
52-
q: Optional[str] = None,
53-
sort: Optional[List[str]] = None,
54-
offset: Optional[int] = None,
55-
limit: Optional[int] = None,
56-
**params,
18+
self,
19+
location_id: Optional[List[str]] = None,
20+
location_code: Optional[List[str]] = None,
21+
labels: Optional[List[str]] = None,
22+
user_id: Optional[List[str]] = None,
23+
subscription_valid_types: Optional[List[str]] = None,
24+
q: Optional[str] = None,
25+
sort: Optional[List[str]] = None,
26+
offset: Optional[int] = None,
27+
limit: Optional[int] = None,
28+
**params,
5729
): ...
5830
@accepts()
5931
@returns(SavedLocationResultSet)
6032
def search(self, **params):
6133
verify_ssl = params.pop("config.verify_ssl", True)
62-
result = self.client.get(
34+
result = self.client.get(
6335
self.build_url("v1", "saved-locations"),
6436
params=params,
6537
verify=verify_ssl,
6638
)
6739
return result
6840

69-
7041
@overload
7142
def create(
72-
self,
73-
name: str,
74-
geojson: dict,
75-
labels: Optional[List[str]] = None,
76-
location_code: Optional[str] = None,
77-
formatted_address: Optional[str] = None,
78-
description: Optional[str] = None,
79-
place_ids: Optional[List[str]] = None,
80-
**params,
43+
self,
44+
name: str,
45+
geojson: dict,
46+
labels: Optional[List[str]] = None,
47+
location_code: Optional[str] = None,
48+
formatted_address: Optional[str] = None,
49+
description: Optional[str] = None,
50+
place_ids: Optional[List[str]] = None,
51+
**params,
8152
): ...
8253
@accepts(query_string=False)
8354
@returns(CreateSavedLocationResponse)
@@ -91,7 +62,6 @@ def create(self, **params):
9162
)
9263
return result
9364

94-
9565
@accepts()
9666
@returns(SavedLocation)
9767
def get(self, location_id, **params):
@@ -102,30 +72,16 @@ def get(self, location_id, **params):
10272
verify=verify_ssl,
10373
)
10474

105-
10675
@overload
10776
def search_event_result_set(
108-
location_id : str,
109-
date_range_type : Optional[str] = None,
110-
offset : Optional[int] = None,
111-
limit : Optional[int] = None,
77+
location_id: str,
78+
date_range_type: Optional[str] = None,
79+
offset: Optional[int] = None,
80+
limit: Optional[int] = None,
11281
**params,
113-
):...
82+
): ...
11483
@returns(EventResultSet)
11584
def search_event_result_set(self, location_id, **params):
116-
"""
117-
Search for events for a saved location.
118-
119-
Args:
120-
location_id (str): The ID of the location.
121-
date_range_type (str, optional): The date range type filter.
122-
offset (int, optional): Pagination offset.
123-
limit (int, optional): Pagination limit.
124-
... (other query params)
125-
126-
Returns:
127-
EventResultSet: The result set of events.
128-
"""
12985
verify_ssl = params.pop("config.verify_ssl", True)
13086
url = f"{self.build_url('v1', 'saved-locations')}{location_id}/insights/events"
13187
response = self.client.get(
@@ -135,31 +91,29 @@ def search_event_result_set(self, location_id, **params):
13591
)
13692
return response
13793

138-
13994
@accepts()
14095
def refresh_location_insights(self, location_id: str, **params):
14196
verify_ssl = params.pop("config.verify_ssl", True)
14297
return self.client.post(
143-
f"{self.build_url('v1', 'saved-locations')}{location_id}/insights/refresh/",
98+
f"{self.build_url('v1', 'saved-locations')}{location_id}/insights/refresh",
14499
params=params,
145100
verify=verify_ssl,
146101
)
147102

148-
149103
@overload
150104
def replace_location_data(
151-
self,
152-
location_id: str,
153-
name: str,
154-
geojson: dict,
155-
labels: Optional[List[str]] = None,
156-
location_code: Optional[str] = None,
157-
formatted_address: Optional[str] = None,
158-
description: Optional[str] = None,
159-
place_ids: Optional[List[str]] = None,
160-
external_id: Optional[str] = None,
161-
**params,
162-
):...
105+
self,
106+
location_id: str,
107+
name: str,
108+
geojson: dict,
109+
labels: Optional[List[str]] = None,
110+
location_code: Optional[str] = None,
111+
formatted_address: Optional[str] = None,
112+
description: Optional[str] = None,
113+
place_ids: Optional[List[str]] = None,
114+
external_id: Optional[str] = None,
115+
**params,
116+
): ...
163117
@accepts(query_string=False)
164118
def replace_location_data(self, location_id: str, **params):
165119
verify_ssl = params.pop("config.verify_ssl", True)
@@ -171,7 +125,6 @@ def replace_location_data(self, location_id: str, **params):
171125
)
172126
return response
173127

174-
175128
@accepts()
176129
def delete_location(self, location_id: str, **params):
177130
verify_ssl = params.pop("config.verify_ssl", True)
@@ -180,8 +133,7 @@ def delete_location(self, location_id: str, **params):
180133
params=params,
181134
verify=verify_ssl,
182135
)
183-
#
184-
#
136+
185137
@accepts()
186138
@returns(PostSharingEnableResponse)
187139
def sharing_enable(self, location_id: str, **params):
@@ -192,15 +144,14 @@ def sharing_enable(self, location_id: str, **params):
192144
verify=verify_ssl,
193145
)
194146

195-
196147
@overload
197148
def suggested_radius(
198149
self,
199150
location_origin: Location,
200151
radius_unit: str,
201152
industry: str,
202153
**params,
203-
):...
154+
): ...
204155
@accepts(query_string=False)
205156
@returns(SuggestedRadiusResponse)
206157
def suggested_radius(self, **params):
@@ -214,4 +165,4 @@ def suggested_radius(self, **params):
214165
f"{self.build_url('v1', 'suggested-radius')}",
215166
params=params,
216167
verify=verify_ssl,
217-
)
168+
)

0 commit comments

Comments
 (0)