Skip to content

Commit cef6d6f

Browse files
authored
Merge pull request #49 from Accenture/version/1.31.1
Update to version 1.31.1
2 parents 3be186d + f5a1161 commit cef6d6f

10 files changed

Lines changed: 784 additions & 28 deletions

File tree

air/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
try:
2727
__version__: str = _metadata.version(__package__ or "airefinery-sdk")
2828
except _metadata.PackageNotFoundError: # pragma: no cover
29-
__version__ = "1.30.0"
29+
__version__ = "1.31.1"
3030

3131
# Decide the default base url
3232
# - Default: api.airefinery.accenture.com (production K8s cluster)
@@ -49,6 +49,10 @@
4949
from air.client import AIRefinery, AsyncAIRefinery # noqa: E402
5050
from air.distiller.client import AsyncDistillerClient # noqa: E402
5151
from air.distiller.realtime_client import AsyncRealtimeDistillerClient # noqa: E402
52+
from air.document_analysis import ( # noqa: E402
53+
AsyncDocumentAnalysisClient,
54+
DocumentAnalysisClient,
55+
)
5256
from air.governance import AsyncGovernanceClient, GovernanceClient # noqa: E402
5357

5458
# Backwards-compatibility alias
@@ -61,9 +65,11 @@
6165
"AsyncAIRefinery",
6266
"AsyncDistillerClient",
6367
"AsyncRealtimeDistillerClient",
68+
"AsyncDocumentAnalysisClient",
6469
"AsyncGovernanceClient",
65-
"GovernanceClient",
70+
"DocumentAnalysisClient",
6671
"DistillerClient",
72+
"GovernanceClient",
6773
# Constants
6874
"BASE_URL",
6975
"CACHE_DIR",

air/client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from air.audio import AsyncAudio, Audio
99
from air.auth import TokenProvider
1010
from air.chat import AsyncChatClient, ChatClient
11+
from air.compression import AsyncCompressionClient, CompressionClient
1112
from air.distiller import AsyncDistillerClient, AsyncRealtimeDistillerClient
1213
from air.embeddings import (
1314
AsyncEmbeddingsClient,
1415
EmbeddingsClient,
1516
)
16-
from air.governance import AsyncGovernanceClient, GovernanceClient
1717
from air.fine_tuning import AsyncFineTuningClient, FineTuningClient
18+
from air.governance import AsyncGovernanceClient, GovernanceClient
1819
from air.images import (
1920
AsyncImagesClient,
2021
ImagesClient,
@@ -161,6 +162,12 @@ def __init__(
161162
api_key=self.api_key,
162163
default_headers=self.default_headers,
163164
)
165+
# Provides async compression functionalities
166+
self.compression = AsyncCompressionClient(
167+
base_url=self.base_url,
168+
api_key=self.api_key,
169+
default_headers=self.default_headers,
170+
)
164171

165172
# Provides async knowledge functionalities
166173
self.knowledge = AsyncKnowledgeClient(
@@ -315,6 +322,12 @@ def __init__(
315322
api_key=self.api_key,
316323
default_headers=self.default_headers,
317324
)
325+
# Provides sync compression functionalities
326+
self.compression = CompressionClient(
327+
base_url=self.base_url,
328+
api_key=self.api_key,
329+
default_headers=self.default_headers,
330+
)
318331

319332
# Provides sync knowledge functionalities
320333
self.knowledge = KnowledgeClient(

air/compression/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from air.compression.client import AsyncCompressionClient, CompressionClient

air/compression/client.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
Module providing clients for prompt compression operations.
3+
All responses are validated using Pydantic models.
4+
5+
This module includes:
6+
- `CompressionClient` for synchronous calls.
7+
- `AsyncCompressionClient` for asynchronous calls.
8+
9+
Both clients call the `/compress` endpoint.
10+
All responses are validated using Pydantic models (`CompressionResponse`).
11+
"""
12+
13+
from typing import List, Optional, Union
14+
15+
import aiohttp
16+
import requests
17+
18+
from air import BASE_URL
19+
from air.auth.token_provider import TokenProvider
20+
from air.types.compression import CompressedPrompt, CompressionResponse
21+
from air.types.constants import DEFAULT_TIMEOUT
22+
from air.utils import get_base_headers, get_base_headers_async
23+
24+
ENDPOINT_COMPRESS = "{base_url}/v1/compress"
25+
26+
27+
class CompressionClient:
28+
"""
29+
A synchronous client for the prompt compression endpoint.
30+
31+
This class handles sending requests to the compression endpoint
32+
and converts the responses into Pydantic models for type safety.
33+
"""
34+
35+
def __init__(
36+
self,
37+
api_key: str | TokenProvider,
38+
*,
39+
base_url: str = BASE_URL,
40+
default_headers: dict[str, str] | None = None,
41+
):
42+
self.base_url = base_url
43+
self.api_key = api_key
44+
self.default_headers = default_headers or {}
45+
46+
def compress(
47+
self,
48+
*,
49+
context: Union[str, List[str]],
50+
model: str,
51+
rate: float = 0.5,
52+
target_token: int = -1,
53+
instruction: Optional[str] = None,
54+
question: Optional[str] = None,
55+
force_tokens: Optional[List[str]] = None,
56+
timeout: float | None = None,
57+
extra_headers: dict[str, str] | None = None,
58+
**kwargs,
59+
) -> CompressionResponse:
60+
"""
61+
Compresses a prompt synchronously.
62+
63+
Args:
64+
context (str | List[str]): Text or list of texts to compress
65+
model (str): The compression model name
66+
rate (float): Target compression rate (0.0 to 1.0). Default 0.5
67+
target_token (int): Explicit target token count (-1 for rate-based). Default -1
68+
instruction (str | None): Optional instruction for compression context
69+
question (str | None): Optional question for compression context
70+
force_tokens (List[str] | None): Tokens to preserve in compressed output
71+
timeout (float | None): Max time (in seconds) to wait for a response
72+
extra_headers (dict[str, str] | None): Request-specific headers
73+
**kwargs: Additional compression parameters
74+
75+
Returns:
76+
CompressionResponse: The parsed response containing compressed prompts
77+
"""
78+
effective_timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
79+
80+
endpoint = ENDPOINT_COMPRESS.format(base_url=self.base_url)
81+
82+
payload: dict = {
83+
"model": model,
84+
"context": context,
85+
"rate": rate,
86+
"target_token": target_token,
87+
**kwargs,
88+
}
89+
if instruction is not None:
90+
payload["instruction"] = instruction
91+
if question is not None:
92+
payload["question"] = question
93+
if force_tokens is not None:
94+
payload["force_tokens"] = force_tokens
95+
96+
headers = get_base_headers(self.api_key)
97+
headers.update(self.default_headers)
98+
if extra_headers:
99+
headers.update(extra_headers)
100+
101+
response = requests.post(
102+
endpoint, json=payload, headers=headers, timeout=effective_timeout
103+
)
104+
response.raise_for_status()
105+
106+
results = response.json()
107+
# Platform returns a single object; raw server returns a list
108+
if isinstance(results, dict):
109+
results = [results]
110+
return CompressionResponse(
111+
data=[CompressedPrompt.model_validate(r) for r in results]
112+
)
113+
114+
115+
class AsyncCompressionClient:
116+
"""
117+
An asynchronous client for the prompt compression endpoint.
118+
119+
This class handles sending requests to the compression endpoint
120+
and converts the responses into Pydantic models for type safety.
121+
"""
122+
123+
def __init__(
124+
self,
125+
api_key: str | TokenProvider,
126+
*,
127+
base_url: str,
128+
default_headers: dict[str, str] | None = None,
129+
):
130+
self.base_url = base_url
131+
self.api_key = api_key
132+
self.default_headers = default_headers or {}
133+
134+
async def compress(
135+
self,
136+
*,
137+
context: Union[str, List[str]],
138+
model: str,
139+
rate: float = 0.5,
140+
target_token: int = -1,
141+
instruction: Optional[str] = None,
142+
question: Optional[str] = None,
143+
force_tokens: Optional[List[str]] = None,
144+
timeout: float | None = None,
145+
extra_headers: dict[str, str] | None = None,
146+
**kwargs,
147+
) -> CompressionResponse:
148+
"""
149+
Compresses a prompt asynchronously.
150+
151+
Args:
152+
context (str | List[str]): Text or list of texts to compress
153+
model (str): The compression model name
154+
rate (float): Target compression rate (0.0 to 1.0). Default 0.5
155+
target_token (int): Explicit target token count (-1 for rate-based). Default -1
156+
instruction (str | None): Optional instruction for compression context
157+
question (str | None): Optional question for compression context
158+
force_tokens (List[str] | None): Tokens to preserve in compressed output
159+
timeout (float | None): Max time (in seconds) to wait for a response
160+
extra_headers (dict[str, str] | None): Request-specific headers
161+
**kwargs: Additional compression parameters
162+
163+
Returns:
164+
CompressionResponse: The parsed response containing compressed prompts
165+
"""
166+
effective_timeout = DEFAULT_TIMEOUT if timeout is None else timeout
167+
168+
endpoint = ENDPOINT_COMPRESS.format(base_url=self.base_url)
169+
170+
payload: dict = {
171+
"model": model,
172+
"context": context,
173+
"rate": rate,
174+
"target_token": target_token,
175+
**kwargs,
176+
}
177+
if instruction is not None:
178+
payload["instruction"] = instruction
179+
if question is not None:
180+
payload["question"] = question
181+
if force_tokens is not None:
182+
payload["force_tokens"] = force_tokens
183+
184+
headers = await get_base_headers_async(self.api_key)
185+
headers.update(self.default_headers)
186+
if extra_headers:
187+
headers.update(extra_headers)
188+
189+
client_timeout = aiohttp.ClientTimeout(total=effective_timeout)
190+
async with aiohttp.ClientSession(timeout=client_timeout) as session:
191+
async with session.post(endpoint, json=payload, headers=headers) as resp:
192+
resp.raise_for_status()
193+
results = await resp.json()
194+
195+
# Platform returns a single object; raw server returns a list
196+
if isinstance(results, dict):
197+
results = [results]
198+
return CompressionResponse(
199+
data=[CompressedPrompt.model_validate(r) for r in results]
200+
)

air/document_analysis/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Document analysis client for PaddleX-based OCR, layout detection, and text detection."""
2+
3+
from air.document_analysis.client import (
4+
AsyncDocumentAnalysisClient,
5+
DocumentAnalysisClient,
6+
)
7+
8+
__all__ = ["DocumentAnalysisClient", "AsyncDocumentAnalysisClient"]

0 commit comments

Comments
 (0)