Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/karapace/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Config(BaseSettings):
registry_authfile: str | None = None
rest_authorization: bool = False
rest_base_uri: str | None = None
rest_avro_extended_json_parser: bool = False
log_handler: str | None = "stdout"
log_level: str = "DEBUG"
log_format: str = "%(name)-20s\t%(threadName)s\t%(levelname)-8s\t%(message)s"
Expand Down
167 changes: 164 additions & 3 deletions src/karapace/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,24 @@
import asyncio
import avro
import avro.schema
import base64
import datetime
import decimal
import io
import re
import struct

START_BYTE = 0x0
HEADER_FORMAT = ">bI"
HEADER_SIZE = 5

_EPOCH_DATETIME = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
_EPOCH_DATE = datetime.date(1970, 1, 1)
_MILLIS_PER_DAY = 86_400_000
_MICROS_PER_DAY = 86_400_000_000
_DECIMAL_TEN = decimal.Decimal(10)
_DECIMAL_STRING_RE = re.compile(r"-?\d+(\.\d+)?([Ee][+-]?\d+)?")


class DeserializationError(Exception):
pass
Expand Down Expand Up @@ -450,6 +461,150 @@ def get_name(obj) -> str:
return value


def convert_logical_types(schema: avro.schema.Schema, value: Any, extended_json_parser: bool = False) -> Any:
"""Recursively coerce JSON-friendly Avro values to logical Python types.
https://avro.apache.org/docs/++version++/specification/#logical-types

The function traverses records, arrays, maps, and unions, converting values
for known logical types:

- timestamp-millis / timestamp-micros:
int (ms/µs since epoch) -> timezone-aware UTC datetime.datetime.
str ISO 8601 (extended_json_parser only) -> UTC datetime.datetime;
timezone-aware strings are shifted to UTC, naive strings are assumed UTC.
- date:
int (days since epoch) -> datetime.date.
str ISO 8601 (extended_json_parser only) -> datetime.date.
- time-millis / time-micros:
int (ms/µs of day) -> datetime.time.
str ISO 8601 (extended_json_parser only) -> datetime.time.
- decimal:
extended_json_parser=True: int or numeric string ("123.45", "-7")
-> decimal.Decimal quantized to schema scale (ROUND_HALF_UP).
extended_json_parser=False (default): Confluent-compatible Base64-encoded
two's complement unscaled bytes (e.g. "BZw=" for 14.36 at scale=2)
-> decimal.Decimal. Integer inputs are also accepted in both modes.
float inputs are intentionally not accepted to avoid silent precision loss.

Args:
schema: The Avro schema for the current node.
value: The JSON-decoded value to coerce.
extended_json_parser: When True, temporal fields additionally accept ISO 8601
strings and decimal fields accept numeric strings. Defaults to False
(Confluent-compatible behaviour).

For unions, each branch is tried in order; the first branch that validates after
conversion is returned. If conversion is not applicable or fails, the original
value is returned unchanged.
"""
if isinstance(schema, avro.schema.RecordSchema) and isinstance(value, dict):
result: dict[Any, Any] = dict(value)
for field in schema.fields:
if field.name in value:
result[field.name] = convert_logical_types(field.type, value[field.name], extended_json_parser)
return result

if isinstance(schema, avro.schema.UnionSchema):
# Try to find a branch schema that validates after conversion.
for branch in schema.schemas:
converted = convert_logical_types(branch, value, extended_json_parser)
if avro.io.validate(branch, converted):
return converted
return value

if isinstance(schema, avro.schema.ArraySchema) and isinstance(value, list):
return [convert_logical_types(schema.items, v, extended_json_parser) for v in value]

if isinstance(schema, avro.schema.MapSchema) and isinstance(value, dict):
return {k: convert_logical_types(schema.values, v, extended_json_parser) for (k, v) in value.items()}

if isinstance(schema, avro.schema.LogicalSchema):
logical_type = getattr(schema, "logical_type", None)

# Timestamps
if logical_type == "timestamp-millis":
if isinstance(value, int):
return _EPOCH_DATETIME + datetime.timedelta(milliseconds=value)
if extended_json_parser and isinstance(value, str):
try:
parsed = datetime.datetime.fromisoformat(value)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=datetime.timezone.utc)
return parsed.astimezone(datetime.timezone.utc)
except ValueError:
return value

if logical_type == "timestamp-micros":
if isinstance(value, int):
return _EPOCH_DATETIME + datetime.timedelta(microseconds=value)
if extended_json_parser and isinstance(value, str):
try:
parsed = datetime.datetime.fromisoformat(value)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=datetime.timezone.utc)
return parsed.astimezone(datetime.timezone.utc)
except ValueError:
return value

# Date
if logical_type == "date":
if isinstance(value, int):
return _EPOCH_DATE + datetime.timedelta(days=value)
if extended_json_parser and isinstance(value, str):
try:
return datetime.date.fromisoformat(value)
except ValueError:
return value

# Time
if logical_type == "time-millis":
if isinstance(value, int):
value = value % _MILLIS_PER_DAY
seconds, millis = divmod(value, 1000)
hours, rem = divmod(seconds, 3600)
minutes, seconds = divmod(rem, 60)
return datetime.time(hour=hours, minute=minutes, second=seconds, microsecond=millis * 1000)
if extended_json_parser and isinstance(value, str):
try:
return datetime.time.fromisoformat(value)
except ValueError:
return value

if logical_type == "time-micros":
if isinstance(value, int):
value = value % _MICROS_PER_DAY
seconds, micros = divmod(value, 1_000_000)
hours, rem = divmod(seconds, 3600)
minutes, seconds = divmod(rem, 60)
return datetime.time(hour=hours, minute=minutes, second=seconds, microsecond=micros)
if extended_json_parser and isinstance(value, str):
try:
return datetime.time.fromisoformat(value)
except ValueError:
return value

# Decimal: accept numeric values (int/float/str/Decimal) or Confluent-style
# base64-encoded two's complement unscaled bytes (e.g. "BYw=" for 14.36 scale=2).
if logical_type == "decimal" and isinstance(value, (int, str)):
scale: int = getattr(schema, "scale", 0)
# Numeric path: int or decimal-looking string ("123.45", "-7")
if extended_json_parser:
if isinstance(value, int) or (isinstance(value, str) and _DECIMAL_STRING_RE.fullmatch(value)):
try:
return decimal.Decimal(str(value)).quantize(_DECIMAL_TEN**-scale, rounding=decimal.ROUND_HALF_UP)
except (decimal.InvalidOperation, ValueError):
return value
# Confluent base64 bytes path: base64 string or raw bytes
try:
raw: bytes = base64.b64decode(value, validate=True) if isinstance(value, str) else value
unscaled = int.from_bytes(raw, byteorder="big", signed=True)
return decimal.Decimal(unscaled).scaleb(-scale)
except Exception:
return value

return value


def read_value(config: Config, schema: TypedSchema, bio: io.BytesIO):
if schema.schema_type is SchemaType.AVRO:
reader = DatumReader(writers_schema=schema.schema)
Expand All @@ -475,10 +630,16 @@ def read_value(config: Config, schema: TypedSchema, bio: io.BytesIO):
def write_value(config: Config, schema: TypedSchema, bio: io.BytesIO, value: dict) -> None:
if schema.schema_type is SchemaType.AVRO:
# Backwards compatibility: Support JSON encoded data without the tags for unions.
if avro.io.validate(schema.schema, value):
data = value
# First, try to convert logical types on the original value. If the resulting
# value validates against the schema, use it as-is to preserve backwards
# compatibility with existing union encodings. Otherwise, fall back to
# flattening unions and then converting logical types.
converted = convert_logical_types(schema.schema, value, config.rest_avro_extended_json_parser)
if avro.io.validate(schema.schema, converted):
Comment thread
e11it marked this conversation as resolved.
data = converted
else:
data = flatten_unions(schema.schema, value)
flattened = flatten_unions(schema.schema, value)
data = convert_logical_types(schema.schema, flattened, config.rest_avro_extended_json_parser)

writer = DatumWriter(writers_schema=schema.schema)
writer.write(data, BinaryEncoder(bio))
Expand Down
64 changes: 59 additions & 5 deletions src/karapace/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from aiohttp.web_request import BaseRequest
from aiohttp.web_response import StreamResponse
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import datetime as _dt_module
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from pathlib import Path
from types import MappingProxyType
Expand Down Expand Up @@ -42,7 +43,7 @@ def loads(s: bytes | str):
@staticmethod
def dumps(obj, *, default=None, indent=None, sort_keys=False, separators=None, **kwargs):
"""Dump object to JSON string (returns str for compatibility)."""
options = 0
options = orjson.OPT_PASSTHROUGH_DATETIME
if sort_keys:
options |= orjson.OPT_SORT_KEYS
if indent is not None:
Expand All @@ -63,7 +64,7 @@ def load(fp):
@staticmethod
def dump(obj, fp, *, default=None, indent=None, sort_keys=False, separators=None, **kwargs):
"""Dump object to JSON file."""
options = 0
options = orjson.OPT_PASSTHROUGH_DATETIME
if sort_keys:
options |= orjson.OPT_SORT_KEYS
if indent is not None:
Expand All @@ -76,7 +77,48 @@ def dump(obj, fp, *, default=None, indent=None, sort_keys=False, separators=None
elif importlib.util.find_spec("ujson"):
from ujson import JSONDecodeError # noqa: F401

import ujson as json
import ujson as _ujson

class _JsonModule: # type: ignore[no-redef]
"""Wrapper around ujson that routes datetime/date/time through default_json_serialization.

ujson serialises datetime objects natively as "+00:00" offset strings, bypassing
the ``default`` callback (which is only invoked for *unknown* types). To keep
the output format consistent with the stdlib-json backend (which calls
``_isoformat`` and produces "Z"), we recursively replace datetime/date/time
objects before handing the value to ujson.
"""

@staticmethod
def _preprocess(obj):
"""Recursively convert datetime/date/time so ujson never sees them natively."""
if isinstance(obj, datetime):
return _isoformat(obj)
if isinstance(obj, (date, _dt_module.time)):
return obj.isoformat()
if isinstance(obj, dict):
return {k: _JsonModule._preprocess(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_JsonModule._preprocess(v) for v in obj]
return obj

@staticmethod
def loads(s):
return _ujson.loads(s)

@staticmethod
def dumps(obj, *, default=None, indent=None, sort_keys=False, separators=None, **kwargs):
return _ujson.dumps(_JsonModule._preprocess(obj), default=default, indent=indent or 0, sort_keys=sort_keys)

@staticmethod
def load(fp):
return _ujson.load(fp)

@staticmethod
def dump(obj, fp, *, default=None, indent=None, sort_keys=False, separators=None, **kwargs):
return _ujson.dump(_JsonModule._preprocess(obj), fp, default=default, indent=indent or 0, sort_keys=sort_keys)

json = _JsonModule()
else:
from json import JSONDecodeError # noqa: F401

Expand Down Expand Up @@ -109,19 +151,31 @@ def default_json_serialization(obj: timedelta) -> float: ...
def default_json_serialization(obj: Decimal) -> str: ...


@overload
def default_json_serialization(obj: date) -> str: ...


@overload
def default_json_serialization(obj: _dt_module.time) -> str: ...


@overload
def default_json_serialization(obj: MappingProxyType) -> dict: ...


def default_json_serialization(
obj: datetime | timedelta | Decimal | MappingProxyType,
obj: datetime | timedelta | Decimal | date | _dt_module.time | MappingProxyType,
) -> str | float | dict:
if isinstance(obj, datetime):
return _isoformat(obj)
if isinstance(obj, timedelta):
return obj.total_seconds()
if isinstance(obj, Decimal):
return str(obj)
if isinstance(obj, date):
return obj.isoformat()
if isinstance(obj, _dt_module.time):
return obj.isoformat()
if isinstance(obj, MappingProxyType):
return dict(obj)

Expand Down
67 changes: 67 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,73 @@ async def get_client(**kwargs) -> TestClient:
await client.close()


@pytest.fixture(scope="function", name="rest_async_extended_parser")
async def fixture_rest_async_extended_parser(
request: SubRequest,
loop: asyncio.AbstractEventLoop,
kafka_servers: KafkaServers,
registry_async_client: Client,
) -> AsyncIterator[KafkaRest | None]:
# Do not start a REST api when the user provided an external service. Doing
# so would cause this node to join the existing group and participate in
# the election process. Without proper configuration for the listeners that
# won't work and will cause test failures.
rest_url = request.config.getoption("rest_url")
if rest_url:
yield None
return

config = Config()
config.admin_metadata_max_age = 2
config.bootstrap_uri = kafka_servers.bootstrap_servers[0]
config.producer_max_request_size = REST_PRODUCER_MAX_REQUEST_BYTES
config.rest_avro_extended_json_parser = True
config.waiting_time_before_acting_as_master_ms = 500
rest = KafkaRest(config=config)

assert rest.serializer.registry_client
rest.serializer.registry_client.client = registry_async_client
try:
yield rest
finally:
await rest.close()


@pytest.fixture(scope="function", name="rest_async_extended_parser_client")
async def fixture_rest_async_extended_parser_client(
request: SubRequest,
loop: asyncio.AbstractEventLoop,
rest_async_extended_parser: KafkaRest,
aiohttp_client: AiohttpClient,
) -> AsyncIterator[Client]:
rest_url = request.config.getoption("rest_url")

# client and server_uri are incompatible settings.
if rest_url:
client = Client(server_uri=rest_url)
else:

async def get_client(**kwargs) -> TestClient:
return await aiohttp_client(rest_async_extended_parser.app)

client = Client(client_factory=get_client)

try:
# wait until the server is listening, otherwise the tests may fail
await repeat_until_successful_request(
client.get,
"brokers",
json_data=None,
headers=None,
error_msg="REST API is unreachable",
timeout=10,
sleep=0.3,
)
yield client
finally:
await client.close()


@pytest.fixture(scope="function", name="rest_async_registry_auth")
async def fixture_rest_async_registry_auth(
request: SubRequest,
Expand Down
Loading
Loading