-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclient.py
136 lines (112 loc) · 3.69 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from copy import copy
import requests
from common import PullRequestInfo, RepositoryInfo
class ApiTokenClient:
"""
Base class for API clients
"""
token_prefix = "Bearer"
common_url_path = None
def __init__(
self,
instance_url: str,
instance_token: str,
):
self.instance_url = instance_url
self.instance_token = instance_token
self.headers = {
"Authorization": f"{self.token_prefix} {self.instance_token}",
}
def get(self, path: str) -> requests.Response:
return requests.get(self.get_url(path), headers=self.headers)
def post(self, path: str, json) -> requests.Response:
return requests.post(
self.get_url(path),
json=json,
headers={**self.headers, "Content-Type": "application/json"},
)
def put(self, path: str, data: dict) -> requests.Response:
return requests.put(
self.get_url(path),
files=data,
headers=self.headers,
)
def delete(self, path: str) -> requests.Response:
return requests.delete(self.get_url(path), headers=self.headers)
def get_url(self, path: str) -> str:
if self.common_url_path:
return f"{self.instance_url}/{self.common_url_path}/{path}"
else:
return f"{self.instance_url}/{path}"
class VCSClient(ApiTokenClient):
"""
Base class for VCS clients
"""
def validate_credentials(self) -> None:
"""
Validate the credentials and connection to the VCS with the credentials
:return:
"""
raise NotImplementedError
def get_repository_info(self, repository_name: str) -> RepositoryInfo:
"""
Get the info about the repository
:param repository_name:
:return: info about the repository
"""
raise NotImplementedError
class VCSRepoClient(VCSClient):
def __init__(self, repo_info: RepositoryInfo, *args, **kwargs):
self.repo_info = repo_info
super().__init__(*args, **kwargs)
def create_new_branch(
self,
branch_name: str,
target_branch: str,
) -> None:
"""
Create a new branch on a repository from the target branch.
"""
raise NotImplementedError
def delete_branch(self, branch_name: str) -> None:
"""
Delete a branch
"""
raise NotImplementedError
def create_new_commit(
self, pr_info: PullRequestInfo, repo_info: RepositoryInfo
) -> None:
"""
Create a new commit based on the info in PullRequestInfo
"""
raise NotImplementedError
def create_pull_request(
self, branch_name: str, target_branch: str, title: str
) -> str:
"""
Create a pull request on a VCS repository
"""
raise NotImplementedError
# def get_repo_url(self, url: str) -> str:
# """
# Returns the URL prefix for the specific repository
# :param url:
# :return:
# """
# raise NotImplementedError
def disseminate_in_pull_request(self, pr_info: PullRequestInfo) -> str:
"""
Create a pull request on a VCS repository
"""
self.create_new_branch(pr_info.branch, self.repo_info.default_branch)
try:
self.create_new_commit(pr_info, self.repo_info)
merge_request_url = self.create_pull_request(
pr_info.branch,
self.repo_info.default_branch,
pr_info.commit_message,
)
except Exception as e:
self.delete_branch(branch_name=pr_info.branch)
raise e
return merge_request_url