diff --git a/airbyte/http_caching/proxy.py b/airbyte/http_caching/proxy.py index a2e9f1f6..fd457a22 100644 --- a/airbyte/http_caching/proxy.py +++ b/airbyte/http_caching/proxy.py @@ -107,7 +107,9 @@ def _get_cache_path(self, key: str, *, is_read: bool = False) -> Path: The path to the cache file. """ base_dir = self.read_dir if is_read else self.cache_dir + extension = ".json" if self.serialization_format == SerializationFormat.JSON else ".mitm" + return base_dir / f"{key}{extension}" def request(self, flow: HTTPFlow) -> None: @@ -127,14 +129,15 @@ def request(self, flow: HTTPFlow) -> None: if cache_path.exists(): try: cached_data: dict[str, Any] = self.serializer.deserialize(cache_path) + cached_flow = HTTPFlow.from_state(cached_data) + if hasattr(cached_flow, "response") and cached_flow.response: flow.response = cached_flow.response - logger.info(f"Serving {flow.request.url} from cache") + logger.info(f"Serving {flow.request.url} from cache") + return except Exception as e: - logger.warning(f"Failed to load cached response: {e}") - else: - return + logger.warning(f"Failed to load cached response: {e}", exc_info=True) if self.mode == HttpCacheMode.READ_ONLY_FAIL_ON_MISS: flow.response = Response.make( @@ -154,7 +157,9 @@ def response(self, flow: HTTPFlow) -> None: cache_path = self._get_cache_path(key, is_read=False) try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + self.serializer.serialize(flow.get_state(), cache_path) logger.info(f"Cached response for {flow.request.url}") except Exception as e: - logger.warning(f"Failed to cache response: {e}") + logger.warning(f"Failed to cache response: {e}", exc_info=True) diff --git a/airbyte/http_caching/serialization.py b/airbyte/http_caching/serialization.py index b3257d8e..d37df456 100644 --- a/airbyte/http_caching/serialization.py +++ b/airbyte/http_caching/serialization.py @@ -4,12 +4,16 @@ from __future__ import annotations import json +import logging from enum import Enum from typing import TYPE_CHECKING, Any, Protocol from mitmproxy.io import io +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from pathlib import Path @@ -88,12 +92,14 @@ def serialize(self, data: T_SerializedData, path: Path) -> None: """ path.parent.mkdir(parents=True, exist_ok=True) - if not str(path).endswith(".mitm"): + if path.suffix != ".mitm": path = path.with_suffix(".mitm") + flows = data.get("flows", []) + with path.open("wb") as f: fw = io.FlowWriter(f) - for flow in data.get("flows", []): + for flow in flows: fw.add(flow) def deserialize(self, path: Path) -> T_SerializedData: @@ -105,14 +111,18 @@ def deserialize(self, path: Path) -> T_SerializedData: Returns: The deserialized data. """ - if not str(path).endswith(".mitm"): + if path.suffix != ".mitm": path = path.with_suffix(".mitm") if not path.exists(): return {"flows": []} - with path.open("rb") as f: - fr = io.FlowReader(f) - flows = list(fr.stream()) - - return {"flows": flows} + try: + with path.open("rb") as f: + fr = io.FlowReader(f) + flows = list(fr.stream()) + except Exception as e: + logger.warning(f"Error reading flow file {path}: {e}") + return {"flows": []} + else: + return {"flows": flows}