Skip to content
Merged
59 changes: 32 additions & 27 deletions src/ota_proxy/cache_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pathlib import Path
from typing import NamedTuple

from multidict import CIMultiDict
from multidict import CIMultiDict, CIMultiDictProxy
from simple_sqlite3_orm import ORMBase, gen_sql_stmt
from simple_sqlite3_orm.utils import enable_wal_mode

Expand All @@ -51,20 +51,27 @@

class CacheIndexEntry(NamedTuple):
cache_size: int
file_compression_alg: str | None
content_encoding: str | None

def export_headers(self, file_sha256: str) -> CIMultiDict[str]:
res: CIMultiDict[str] = CIMultiDict()
if self.content_encoding:
res[HEADER_CONTENT_ENCODING] = self.content_encoding

if file_sha256 and not file_sha256.startswith(cfg.URL_BASED_HASH_PREFIX):
res[HEADER_OTA_FILE_CACHE_CONTROL] = export_kwargs_as_header_string(
file_sha256=file_sha256,
file_compression_alg=self.file_compression_alg or "",
)
return res
headers: CIMultiDictProxy[str]
"""Pre-computed response headers for this cache entry."""


def _build_index_entry_headers(
file_sha256: str, # NOTE: already interned by the caller
*,
content_encoding: str | None,
file_compression_alg: str | None,
) -> CIMultiDictProxy[str]:
"""Build the response headers CIMultiDictProxy for a cache index entry."""
res: CIMultiDict[str] = CIMultiDict()
if content_encoding:
res[HEADER_CONTENT_ENCODING] = sys.intern(content_encoding)

if file_sha256 and not file_sha256.startswith(cfg.URL_BASED_HASH_PREFIX):
res[HEADER_OTA_FILE_CACHE_CONTROL] = export_kwargs_as_header_string(
file_sha256=file_sha256,
file_compression_alg=file_compression_alg if file_compression_alg else "",
)
return CIMultiDictProxy(res)


_STOP_SENTINEL = typing.cast("CacheMeta", object())
Expand Down Expand Up @@ -279,12 +286,11 @@ def _preload_from_db(self) -> tuple[dict[str, CacheIndexEntry], list[str]]:

_res[_key] = CacheIndexEntry(
cache_size=row.cache_size,
file_compression_alg=sys.intern(row.file_compression_alg)
if row.file_compression_alg
else None,
content_encoding=sys.intern(row.content_encoding)
if row.content_encoding
else None,
headers=_build_index_entry_headers(
_key,
content_encoding=row.content_encoding,
file_compression_alg=row.file_compression_alg,
),
)
_count += 1
else:
Expand Down Expand Up @@ -319,12 +325,11 @@ def commit_entry(self, entry: CacheMeta) -> bool:
key = sys.intern(entry.file_sha256)
index_entry = CacheIndexEntry(
cache_size=entry.cache_size,
file_compression_alg=sys.intern(entry.file_compression_alg)
if entry.file_compression_alg
else None,
content_encoding=sys.intern(entry.content_encoding)
if entry.content_encoding
else None,
headers=_build_index_entry_headers(
key,
content_encoding=entry.content_encoding,
file_compression_alg=entry.file_compression_alg,
),
)

if len(self._index) >= cfg.MAX_INDEX_ENTRIES:
Expand Down
6 changes: 3 additions & 3 deletions src/ota_proxy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path
from typing import Optional

from multidict import CIMultiDict
from multidict import CIMultiDict, CIMultiDictProxy
from simple_sqlite3_orm import (
ConstrainRepr,
CreateTableParams,
Expand Down Expand Up @@ -69,7 +69,7 @@ def __hash__(self) -> int:
tuple(getattr(self, attrn) for attrn in self.__class__.model_fields)
)

def export_headers_to_client(self) -> CIMultiDict[str]:
def export_headers_to_client(self) -> CIMultiDictProxy[str]:
"""Export required headers for client.

Currently includes content-type, content-encoding and ota-file-cache-control headers.
Expand All @@ -86,7 +86,7 @@ def export_headers_to_client(self) -> CIMultiDict[str]:
file_sha256=self.file_sha256,
file_compression_alg=self.file_compression_alg or "",
)
return res
return CIMultiDictProxy(res)


class CacheMetaORM(ORMBase[CacheMeta]):
Expand Down
15 changes: 5 additions & 10 deletions src/ota_proxy/ota_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ async def _retrieve_file_by_downloading(

async def _retrieve_file_by_cache_lookup(
self, *, raw_url: str, cache_policy: OTAFileCacheControl
) -> tuple[AsyncGenerator[bytes] | bytes, CIMultiDict[str]] | None:
) -> tuple[AsyncGenerator[bytes] | bytes, CIMultiDictProxy[str]] | None:
"""
Raises:
ReaderPoolBusy if exceeding max pending read tasks.
Expand Down Expand Up @@ -410,7 +410,7 @@ async def _retrieve_file_by_cache_lookup(
# NOTE: handle empty file entry, for empty file entry, we will not actually
# create empty file in cache folder.
if index_entry.cache_size == 0:
return b"", index_entry.export_headers(cache_identifier)
return b"", index_entry.headers

# NOTE: db_entry.file_sha256 can be either
# 1. valid sha256 value for corresponding plain uncompressed OTA file
Expand Down Expand Up @@ -440,21 +440,19 @@ async def _retrieve_file_by_cache_lookup(
# NOTE(20260403): not do the cleanup at here, let the new cache handler do it.
return

_headers = index_entry.export_headers(cache_identifier)

# fast path for small file, read one and directly return bytes
if index_entry.cache_size <= self._chunk_size:
return (
await self._read_pool.read_file_once(cache_file),
_headers,
index_entry.headers,
)

local_fd = await self._read_pool.stream_read_file(cache_file)
# NOTE: we don't verify the cache here even cache is old, but let otaclient's hash verification
# do the job. If cache is invalid, otaclient will use CacheControlHeader's retry_cache
# directory to indicate invalid cache.

return local_fd, _headers
return local_fd, index_entry.headers

async def _retrieve_file_by_external_cache(
self,
Expand Down Expand Up @@ -503,10 +501,7 @@ async def _retrieve_file_by_new_caching(
raw_url: str,
cache_policy: OTAFileCacheControl,
headers_from_client: CIMultiDict[str],
) -> (
tuple[AsyncGenerator[bytes] | bytes, CIMultiDictProxy[str] | CIMultiDict[str]]
| None
):
) -> tuple[AsyncGenerator[bytes] | bytes, CIMultiDictProxy[str]] | None:
"""
Raises:
ReaderPoolBusy if exceeding max pending read tasks.
Expand Down
14 changes: 9 additions & 5 deletions src/ota_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ async def read_file(


def read_file_once(fpath: StrOrPath | anyio.Path) -> bytes:
"""Read the whole file with once call.

This function is to serve small files read.

NOTE(20260420): for small files read, it increases the kernel
page cache pages with much slower speed and much
small amount, so let kernel handles the cache pages.
"""
with open(fpath, "rb") as f:
fd = f.fileno()
os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_SEQUENTIAL)
data = f.read()
os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_DONTNEED)
return data
return f.read()


def url_based_hash(raw_url: str) -> str:
Expand Down
90 changes: 89 additions & 1 deletion test/unit/test_otaproxy/test_cache_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
import time

import pytest
from multidict import CIMultiDictProxy

from ota_proxy.cache_index import CacheDBWriter
from ota_proxy._consts import HEADER_CONTENT_ENCODING, HEADER_OTA_FILE_CACHE_CONTROL
from ota_proxy.cache_index import CacheDBWriter, _build_index_entry_headers
from ota_proxy.config import config as cfg
from ota_proxy.db import CacheMeta, init_db
from ota_proxy.utils import url_based_hash

_REAL_SHA256 = "a" * 64
_URL_HASH = url_based_hash("http://example.com/f")

BUCKET_SIZE_LIST = list(cfg.BUCKET_FILE_SIZE_DICT)


Expand Down Expand Up @@ -93,3 +98,86 @@ def test_bucket_idx_matches_old_lru(

assert entry.bucket_idx == expected_idx
assert entry.bucket_idx == _expected_bucket_idx(cache_size)


#
# ------------ Tests: _build_index_entry_headers ------------ #
#


class TestBuildIndexEntryHeaders:
"""Pre-computed CacheIndexEntry.headers must match the legacy export_headers
output: content-encoding only when set, ota-file-cache-control only for
real sha256 hashes (not URL-based), and the result must be a read-only
CIMultiDictProxy so it can be safely shared across requests."""

def test_returns_cimultidictproxy(self):
res = _build_index_entry_headers(
_REAL_SHA256, content_encoding=None, file_compression_alg=None
)
assert isinstance(res, CIMultiDictProxy)

def test_real_sha256_without_compression_sets_cache_control_only(self):
res = _build_index_entry_headers(
_REAL_SHA256, content_encoding=None, file_compression_alg=None
)
assert HEADER_CONTENT_ENCODING not in res
cache_ctrl = res[HEADER_OTA_FILE_CACHE_CONTROL]
assert f"file_sha256={_REAL_SHA256}" in cache_ctrl

def test_real_sha256_with_compression_embeds_alg_in_cache_control(self):
res = _build_index_entry_headers(
_REAL_SHA256, content_encoding="zstd", file_compression_alg="zst"
)
assert res[HEADER_CONTENT_ENCODING] == "zstd"
cache_ctrl = res[HEADER_OTA_FILE_CACHE_CONTROL]
assert f"file_sha256={_REAL_SHA256}" in cache_ctrl
assert "file_compression_alg=zst" in cache_ctrl

def test_url_based_hash_omits_cache_control(self):
"""URL-based hash entries must not expose ota-file-cache-control."""
res = _build_index_entry_headers(
_URL_HASH, content_encoding="gzip", file_compression_alg=None
)
assert HEADER_OTA_FILE_CACHE_CONTROL not in res
assert res[HEADER_CONTENT_ENCODING] == "gzip"

def test_empty_file_sha256_omits_cache_control(self):
res = _build_index_entry_headers(
"", content_encoding=None, file_compression_alg=None
)
assert HEADER_OTA_FILE_CACHE_CONTROL not in res
assert HEADER_CONTENT_ENCODING not in res


#
# ------------ Tests: CacheMeta.export_headers_to_client ------------ #
#


class TestCacheMetaExportHeaders:
"""CacheMeta.export_headers_to_client must mirror _build_index_entry_headers
and must return a read-only CIMultiDictProxy to match the ota_cache.py
return-type contract."""

def test_returns_cimultidictproxy(self):
meta = CacheMeta(file_sha256=_REAL_SHA256, url="http://example.com/f")
assert isinstance(meta.export_headers_to_client(), CIMultiDictProxy)

def test_url_based_hash_has_no_cache_control(self):
meta = CacheMeta(
file_sha256=_URL_HASH, url="http://example.com/f", content_encoding="gzip"
)
res = meta.export_headers_to_client()
assert HEADER_OTA_FILE_CACHE_CONTROL not in res
assert res[HEADER_CONTENT_ENCODING] == "gzip"

def test_real_sha256_emits_cache_control_with_compression(self):
meta = CacheMeta(
file_sha256=_REAL_SHA256,
url="http://example.com/f",
file_compression_alg="zst",
)
cache_ctrl = meta.export_headers_to_client()[HEADER_OTA_FILE_CACHE_CONTROL]
assert f"file_sha256={_REAL_SHA256}" in cache_ctrl
assert "file_compression_alg=zst" in cache_ctrl
57 changes: 57 additions & 0 deletions test/unit/test_otaproxy/test_ota_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from multidict import CIMultiDict, CIMultiDictProxy

from ota_proxy import config as cfg
from ota_proxy._consts import HEADER_CONTENT_ENCODING, HEADER_OTA_FILE_CACHE_CONTROL
from ota_proxy.cache_index import CacheIndex
from ota_proxy.db import CacheMeta
from ota_proxy.ota_cache import OTACache
Expand Down Expand Up @@ -119,6 +120,62 @@ def test_remove_entry(
assert idx.lookup_entry(entry.file_sha256) is None


class TestCacheIndexEntryHeaders:
"""CacheIndex.commit_entry pre-computes CacheIndexEntry.headers so that
ota_cache.py can return them directly without rebuilding per request."""

def test_committed_entry_exposes_precomputed_headers(
self, tmp_path_factory: pytest.TempPathFactory
):
base_dir = tmp_path_factory.mktemp("cache_index_headers") / "ota-cache"
base_dir.mkdir()
db_f = base_dir / "db_f"
idx = CacheIndex(db_f, base_dir, force_init_db=True)
try:
real_sha = "b" * 64
meta = CacheMeta(
file_sha256=real_sha,
url="http://example.com/compressed",
cache_size=1024,
file_compression_alg="zst",
content_encoding="zstd",
)
(base_dir / real_sha).touch()
assert idx.commit_entry(meta)

looked_up = idx.lookup_entry(real_sha)
assert looked_up is not None
assert looked_up.cache_size == 1024
assert looked_up.headers[HEADER_CONTENT_ENCODING] == "zstd"
cache_ctrl = looked_up.headers[HEADER_OTA_FILE_CACHE_CONTROL]
assert f"file_sha256={real_sha}" in cache_ctrl
assert "file_compression_alg=zst" in cache_ctrl
finally:
idx.close()

def test_url_based_hash_entry_headers_omit_cache_control(
self, tmp_path_factory: pytest.TempPathFactory
):
"""URL-based hash entries do not have a known file sha, so the
ota-file-cache-control header must not be emitted."""
base_dir = tmp_path_factory.mktemp("cache_index_url_headers") / "ota-cache"
base_dir.mkdir()
db_f = base_dir / "db_f"
idx = CacheIndex(db_f, base_dir, force_init_db=True)
try:
url = "http://example.com/opaque"
key = url_based_hash(url)
meta = CacheMeta(file_sha256=key, url=url, cache_size=32)
(base_dir / key).touch()
assert idx.commit_entry(meta)

looked_up = idx.lookup_entry(key)
assert looked_up is not None
assert HEADER_OTA_FILE_CACHE_CONTROL not in looked_up.headers
finally:
idx.close()


#
# ------------ Tests: OTACache CDN cache-hit detection ------------ #
#
Expand Down
Loading