diff --git a/pycti/api/opencti_api_client.py b/pycti/api/opencti_api_client.py index bbde7d68..9b9afbbb 100644 --- a/pycti/api/opencti_api_client.py +++ b/pycti/api/opencti_api_client.py @@ -264,13 +264,15 @@ def set_retry_number(self, retry_number): "" if retry_number is None else str(retry_number) ) - def query(self, query, variables=None): + def query(self, query, variables=None, disable_impersonate=False): """submit a query to the OpenCTI GraphQL API :param query: GraphQL query string :type query: str :param variables: GraphQL query variables, defaults to {} :type variables: dict, optional + :param disable_impersonate: removes impersonate header if set to True, defaults to False + :type disable_impersonate: bool, optional :return: returns the response json content :rtype: Any """ @@ -295,6 +297,9 @@ def query(self, query, variables=None): else: query_var[key] = val + query_headers = self.request_headers.copy() + if disable_impersonate and "opencti-applicant-id" in query_headers: + del query_headers["opencti-applicant-id"] # If yes, transform variable (file to null) and create multipart query if len(files_vars) > 0: multipart_data = { @@ -361,7 +366,7 @@ def query(self, query, variables=None): self.api_url, data=multipart_data, files=multipart_files, - headers=self.request_headers, + headers=query_headers, verify=self.ssl_verify, cert=self.cert, proxies=self.proxies, @@ -372,7 +377,7 @@ def query(self, query, variables=None): r = self.session.post( self.api_url, json={"query": query, "variables": variables}, - headers=self.request_headers, + headers=query_headers, verify=self.ssl_verify, cert=self.cert, proxies=self.proxies, diff --git a/pycti/api/opencti_api_work.py b/pycti/api/opencti_api_work.py index 4ef5dde8..5d11a0fb 100644 --- a/pycti/api/opencti_api_work.py +++ b/pycti/api/opencti_api_work.py @@ -20,7 +20,7 @@ def to_received(self, work_id: str, message: str): } } """ - self.api.query(query, {"id": work_id, "message": message}) + self.api.query(query, {"id": work_id, "message": message}, True) def to_processed(self, work_id: str, message: str, in_error: bool = False): if self.api.bundle_send_to_queue: @@ -35,7 +35,7 @@ def to_processed(self, work_id: str, message: str, in_error: bool = False): } """ self.api.query( - query, {"id": work_id, "message": message, "inError": in_error} + query, {"id": work_id, "message": message, "inError": in_error}, True ) def ping(self, work_id: str): @@ -60,7 +60,7 @@ def report_expectation(self, work_id: str, error): } """ try: - self.api.query(query, {"id": work_id, "error": error}) + self.api.query(query, {"id": work_id, "error": error}, True) except: self.api.app_logger.error("Cannot report expectation") @@ -78,7 +78,9 @@ def add_expectations(self, work_id: str, expectations: int): } """ try: - self.api.query(query, {"id": work_id, "expectations": expectations}) + self.api.query( + query, {"id": work_id, "expectations": expectations}, True + ) except: self.api.app_logger.error("Cannot report expectation") @@ -96,7 +98,9 @@ def add_draft_context(self, work_id: str, draft_context: str): } """ try: - self.api.query(query, {"id": work_id, "draftContext": draft_context}) + self.api.query( + query, {"id": work_id, "draftContext": draft_context}, True + ) except: self.api.app_logger.error("Cannot report draft context") @@ -111,7 +115,9 @@ def initiate_work(self, connector_id: str, friendly_name: str) -> str: } """ work = self.api.query( - query, {"connectorId": connector_id, "friendlyName": friendly_name} + query, + {"connectorId": connector_id, "friendlyName": friendly_name}, + True, ) return work["data"]["workAdd"]["id"] @@ -122,10 +128,7 @@ def delete_work(self, work_id: str): delete } }""" - work = self.api.query( - query, - {"workId": work_id}, - ) + work = self.api.query(query, {"workId": work_id}, True) return work["data"] def wait_for_work_to_finish(self, work_id: str): @@ -179,10 +182,7 @@ def get_work(self, work_id: str) -> Dict: } } """ - result = self.api.query( - query, - {"id": work_id}, - ) + result = self.api.query(query, {"id": work_id}, True) return result["data"]["work"] def get_connector_works(self, connector_id: str) -> List[Dict]: @@ -243,6 +243,7 @@ def get_connector_works(self, connector_id: str) -> List[Dict]: "filterGroups": [], }, }, + True, ) result = result["data"]["works"]["edges"] return_value = []