Skip to content

Commit

Permalink
Merge pull request #111 from tjmlabs/proxy
Browse files Browse the repository at this point in the history
Proxy support
  • Loading branch information
Jonathan-Adly authored Dec 6, 2024
2 parents 78123d6 + 28e2b1b commit 5a4566d
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 104 deletions.
39 changes: 20 additions & 19 deletions web/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import urllib.parse
from io import BytesIO
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import aiohttp
import magic
Expand Down Expand Up @@ -169,7 +169,7 @@ async def get_url(self) -> str:
return self.s3_file.url
return self.url

async def embed_document(self) -> None:
async def embed_document(self, use_proxy: Optional[bool] = False) -> None:
"""
Process a document by embedding its pages and storing the results.
Expand Down Expand Up @@ -249,7 +249,7 @@ async def send_batch(
)
return out["output"]["data"]

base64_images = await self._prep_document()
base64_images = await self._prep_document(use_proxy=use_proxy)
logger.info(f"Successfully prepped document {self.name}")
# Split the images into batches
batches = [
Expand Down Expand Up @@ -328,7 +328,9 @@ async def send_batch(

return

async def _prep_document(self, document_data=None) -> List[str]:
async def _prep_document(
self, document_data=None, use_proxy: Optional[bool] = False
) -> List[str]:
"""
The goal of this method is to take a document and convert it into a series of base64 images.
Steps:
Expand Down Expand Up @@ -494,7 +496,9 @@ async def _prep_document(self, document_data=None) -> List[str]:
logger.info(f"Document filename: {filename}")

elif self.url:
content_type, filename = await self._get_url_info()
content_type, filename, document_data = await self._fetch_document(
use_proxy
)
if "text/html" in content_type:
logger.info("Document is a webpage.")
# It's a webpage, convert to PDF
Expand All @@ -504,7 +508,6 @@ async def _prep_document(self, document_data=None) -> List[str]:
else:
# It's a regular file
logger.info(f"Fetching document from URL: {self.url}")
document_data = await self._fetch_document()
if "application/pdf" in content_type:
extension = "pdf"
else:
Expand Down Expand Up @@ -571,13 +574,20 @@ async def _prep_document(self, document_data=None) -> List[str]:
# Step 5: returning the base64 images
return base64_images

async def _get_url_info(self):
"""Get content type and filename from URL via HEAD request"""
async def _fetch_document(self, use_proxy: Optional[bool] = False):
proxy = None
if use_proxy:
proxy = settings.PROXY_URL
# replace https with http for the proxy
self.url = self.url.replace("https://", "http://")
logger.info("Using proxy to fetch document.")

MAX_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
async with aiohttp.ClientSession() as session:
async with session.head(self.url, allow_redirects=True) as response:
async with session.get(self.url, proxy=proxy) as response:
# handle when the response is not 200
if response.status != 200:
logger.info(f"response status: {response.status}")
raise ValidationError(
"Failed to fetch document info from URL. Some documents are protected by anti-scrapping measures. We recommend you download them and send us base64."
)
Expand All @@ -594,16 +604,7 @@ async def _get_url_info(self):
)
if not filename:
filename = "downloaded_file"
return content_type, filename

async def _fetch_document(self):
async with aiohttp.ClientSession() as session:
async with session.get(self.url) as response:
if response.status != 200:
raise ValidationError(
"Failed to fetch document info from URL. Some documents are protected by anti-scrapping measures. We recommend you download them and send us base64."
)
return await response.read()
return content_type, filename, await response.read()

@retry(
stop=stop_after_attempt(3),
Expand Down
Loading

0 comments on commit 5a4566d

Please sign in to comment.