Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions areal/infra/rpc/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@
import torch
from pydantic import BaseModel, Field

try:
from PIL import Image
from PIL.Image import Image as ImageObject
except ImportError: # pragma: no cover - optional dependency for non-VLM setups
Image = None
ImageObject = None

from areal.utils import logging

TOKENIZER_ARCHIVE_INLINE_THRESHOLD = 512 * 1024
TOKENIZER_ZSTD_THRESHOLD = 20 * 1024 * 1024
TokenizerCompression = Literal["zip", "zstd"]
ProcessorCompression = Literal["zip", "zstd"]

logger = logging.getLogger("RPCSerialization")

Expand Down Expand Up @@ -207,6 +215,37 @@ def to_array(self) -> np.ndarray:
return array.reshape(self.shape)


class SerializedPILImage(BaseModel):
"""Pydantic model for serialized PIL images."""

type: Literal["pil_image"] = Field(default="pil_image")
data: str
mode: str | None = None

@classmethod
def from_image(cls, image: "ImageObject") -> "SerializedPILImage":
with io.BytesIO() as buffer:
# Always use PNG to avoid format-specific save issues
image.save(buffer, format="PNG")
data_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return cls(data=data_b64, mode=image.mode)

def to_image(self) -> "ImageObject":
if Image is None: # pragma: no cover - depends on optional pillow dependency
raise RuntimeError(
"Pillow is required to deserialize PIL images but is not installed"
)

with io.BytesIO(base64.b64decode(self.data.encode("utf-8"))) as buffer:
image = Image.open(buffer)
image.load()

if self.mode is not None and image.mode != self.mode:
image = image.convert(self.mode)

return image


class SerializedDataclass(BaseModel):
"""Pydantic model for serialized dataclass with metadata.

Expand Down Expand Up @@ -380,6 +419,115 @@ def _maybe_decompress(self, blob: bytes) -> bytes:
raise ValueError(msg)


class SerializedProcessor(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the new SerializedProcessor class and the existing SerializedTokenizer class. The methods for archiving (_archive_processor), compression (_maybe_compress), and decompression (_maybe_decompress) are nearly identical.

This duplication makes the code harder to maintain, as any bug fix or improvement in this logic would need to be applied in two places.

To improve maintainability, I recommend refactoring this common logic. For example, you could extract the shared logic into standalone utility functions that both SerializedTokenizer and SerializedProcessor can use.

"""Pydantic model for serialized Hugging Face processors.

Attributes
----------
type : str
Type marker, always "processor"
name_or_path : str
Original ``name_or_path`` attribute captured from the processor
data : str
Base64-encoded ZIP (optionally Zstandard-compressed) archive of the processor files
compression : {"zip", "zstd"}
Compression algorithm applied to the archive payload
"""

type: Literal["processor"] = Field(default="processor")
name_or_path: str
data: str
compression: ProcessorCompression = Field(default="zip")

@classmethod
def from_processor(cls, processor: Any) -> "SerializedProcessor":
"""Create a serialized representation from a Hugging Face processor."""
name_or_path = getattr(processor, "name_or_path", None)
if name_or_path is None:
# Some processors store name_or_path on their inner tokenizer
tokenizer = getattr(processor, "tokenizer", None)
name_or_path = getattr(
tokenizer, "name_or_path", processor.__class__.__name__
)
blob = cls._archive_processor(processor)
blob, compression = cls._maybe_compress(blob)
data_b64 = base64.b64encode(blob).decode("utf-8")
return cls(name_or_path=name_or_path, data=data_b64, compression=compression)

def to_processor(self) -> Any:
"""Reconstruct a Hugging Face processor from serialized data."""
blob = base64.b64decode(self.data.encode("utf-8"))
blob = self._maybe_decompress(blob)
from transformers import AutoProcessor

zip_buffer = io.BytesIO(blob)
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(zip_buffer) as zf:
zf.extractall(tmpdir)
processor = AutoProcessor.from_pretrained(tmpdir)

if hasattr(processor, "name_or_path"):
processor.name_or_path = self.name_or_path
return processor

@staticmethod
def _is_processor(obj: Any) -> bool:
try:
from transformers import ProcessorMixin
except ImportError: # pragma: no cover - optional dependency
return False
return isinstance(obj, ProcessorMixin)

@staticmethod
def _archive_processor(processor: Any) -> bytes:
zip_buffer = io.BytesIO()
with tempfile.TemporaryDirectory() as tmpdir:
processor.save_pretrained(tmpdir)
total_size = sum(
os.path.getsize(os.path.join(root, file))
for root, _, files in os.walk(tmpdir)
for file in files
)
compression = (
zipfile.ZIP_STORED
if total_size < TOKENIZER_ARCHIVE_INLINE_THRESHOLD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since TOKENIZER_ARCHIVE_INLINE_THRESHOLD and TOKENIZER_ZSTD_THRESHOLD (used on line 512) are now also used for serializing processors, consider renaming them to be more generic, for example ARCHIVE_INLINE_THRESHOLD and ZSTD_COMPRESSION_THRESHOLD. This would improve code clarity.

else zipfile.ZIP_DEFLATED
)
compress_kwargs = (
{"compresslevel": 6} if compression == zipfile.ZIP_DEFLATED else {}
)
with zipfile.ZipFile(
zip_buffer, "w", compression=compression, **compress_kwargs
) as zf:
for root, _, files in os.walk(tmpdir):
for file in files:
full_path = os.path.join(root, file)
arcname = os.path.relpath(full_path, tmpdir)
zf.write(full_path, arcname=arcname)
return zip_buffer.getvalue()

@staticmethod
def _maybe_compress(blob: bytes) -> tuple[bytes, ProcessorCompression]:
if (
len(blob) > TOKENIZER_ZSTD_THRESHOLD
and importlib.util.find_spec("zstandard") is not None
):
import zstandard as zstd

return zstd.ZstdCompressor(level=3).compress(blob), "zstd"
return blob, "zip"

def _maybe_decompress(self, blob: bytes) -> bytes:
if self.compression == "zip":
return blob
if self.compression == "zstd":
import zstandard as zstd

return zstd.ZstdDecompressor().decompress(blob)
msg = f"Unsupported processor compression: {self.compression}"
raise ValueError(msg)


def serialize_value(value: Any) -> Any:
"""Recursively serialize a value, converting tensors and dataclasses to serialized dicts.

Expand All @@ -388,6 +536,7 @@ def serialize_value(value: Any) -> Any:
- numpy.ndarray -> SerializedNDArray dict
- dataclass instances -> SerializedDataclass dict (preserves type information)
- Hugging Face tokenizers -> SerializedTokenizer dict
- Hugging Face processors -> SerializedProcessor dict
- dict -> recursively serialize values
- list/tuple -> recursively serialize elements
- primitives (int, float, str, bool, None) -> unchanged
Expand All @@ -414,6 +563,10 @@ def serialize_value(value: Any) -> Any:
if isinstance(value, np.ndarray):
return SerializedNDArray.from_array(value).model_dump()

# Handle PIL image payloads for VLM tasks
if ImageObject is not None and isinstance(value, ImageObject):
return SerializedPILImage.from_image(value).model_dump()

# Handle dataclass instances (check before dict, as dataclasses can be dict-like)
# Note: is_dataclass returns True for both classes and instances, so check it's not a type
if is_dataclass(value) and not isinstance(value, type):
Expand All @@ -432,6 +585,11 @@ def serialize_value(value: Any) -> Any:
tokenizer_payload = SerializedTokenizer.from_tokenizer(value)
return tokenizer_payload.model_dump()

# Handle Hugging Face processors (e.g. Qwen2_5_VLProcessor)
if SerializedProcessor._is_processor(value):
processor_payload = SerializedProcessor.from_processor(value)
return processor_payload.model_dump()

# Handle dict - recursively serialize values
if isinstance(value, dict):
return {key: serialize_value(val) for key, val in value.items()}
Expand Down Expand Up @@ -460,6 +618,7 @@ def deserialize_value(value: Any) -> Any:
- SerializedNDArray dict -> numpy.ndarray
- SerializedDataclass dict -> dataclass instance (reconstructed with original type)
- SerializedTokenizer dict -> Hugging Face tokenizer
- SerializedProcessor dict -> Hugging Face processor
- dict -> recursively deserialize values
- list -> recursively deserialize elements
- primitives -> unchanged
Expand Down Expand Up @@ -507,6 +666,16 @@ def deserialize_value(value: Any) -> Any:
f"Failed to deserialize tokenizer, treating as regular dict: {e}"
)

# Check for SerializedProcessor marker
if value.get("type") == "processor":
try:
serialized_processor = SerializedProcessor.model_validate(value)
return serialized_processor.to_processor()
except Exception as e:
logger.warning(
f"Failed to deserialize processor, treating as regular dict: {e}"
)

# Check for SerializedNDArray marker
if value.get("type") == "ndarray":
try:
Expand All @@ -517,6 +686,16 @@ def deserialize_value(value: Any) -> Any:
f"Failed to deserialize ndarray, treating as regular dict: {e}"
)

# Check for SerializedPILImage marker
if value.get("type") == "pil_image":
try:
serialized_image = SerializedPILImage.model_validate(value)
return serialized_image.to_image()
except Exception as e:
logger.warning(
f"Failed to deserialize PIL image, treating as regular dict: {e}"
)

# Check for SerializedTensor marker
if value.get("type") == "tensor":
try:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for RPC serialization utilities."""

from dataclasses import dataclass
from io import BytesIO

import numpy as np
import pytest
import torch
from PIL import Image
from transformers import AutoTokenizer

from tests.utils import get_model_path
Expand Down Expand Up @@ -79,6 +81,24 @@ def test_numpy_object_array_rejected(self):
with pytest.raises(ValueError, match="Object or void dtype"):
serialize_value(array)

def test_pil_image_roundtrip(self):
"""Test PIL image serialization for VLM RPC payloads."""
original = Image.new("RGB", (8, 6), color=(12, 34, 56))

serialized = serialize_value(original)
assert serialized["type"] == "pil_image"

deserialized = deserialize_value(serialized)
assert isinstance(deserialized, Image.Image)
assert deserialized.size == original.size
assert deserialized.mode == original.mode

# Validate pixel content with a deterministic byte compare
with BytesIO() as o_buf, BytesIO() as d_buf:
original.save(o_buf, format="PNG")
deserialized.save(d_buf, format="PNG")
assert o_buf.getvalue() == d_buf.getvalue()

def test_dataclass(self):
"""Test dataclass serialization with nested tensors."""
original = SampleData(
Expand Down Expand Up @@ -112,6 +132,26 @@ def test_tokenizer(self):
assert deserialized.vocab_size == original.vocab_size
assert deserialized.encode("test") == original.encode("test")

def test_processor(self):
"""Test Hugging Face processor serialization."""
from transformers import AutoProcessor

original = AutoProcessor.from_pretrained(
get_model_path(
"/storage/openpsi/models/Qwen__Qwen3-VL-2B-Instruct",
"Qwen/Qwen3-VL-2B-Instruct",
)
)

serialized = serialize_value(original)
assert serialized["type"] == "processor"

deserialized = deserialize_value(serialized)
assert deserialized.tokenizer.vocab_size == original.tokenizer.vocab_size
assert deserialized.tokenizer.encode("test") == original.tokenizer.encode(
"test"
)

def test_nested_structure(self):
"""Test complex nested structure with multiple types."""
payload = {
Expand Down