diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml
index 90d35f6b..44f5227f 100644
--- a/.github/workflows/linters.yml
+++ b/.github/workflows/linters.yml
@@ -4,14 +4,15 @@ jobs:
run-black:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v1
- - name: Set up Python 3.9.23
- uses: actions/setup-python@v1
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v5
with:
- python-version: 3.9.23
+ python-version: '3.10'
- name: Debug Message - Check Github branch
run: echo "Current Git branch is ${GITHUB_REF##*/}"
- name: Install Black
- run: pip install black==23.1.0
+ run: pip install black==24.3.0
- name: Run black . to format code.
run: black .
+
diff --git a/backend/anudesh_backend/settings.py b/backend/anudesh_backend/settings.py
index 53528d1b..2840e148 100644
--- a/backend/anudesh_backend/settings.py
+++ b/backend/anudesh_backend/settings.py
@@ -12,18 +12,20 @@
import logging
import os
+import socket
from datetime import timedelta
from pathlib import Path
from dotenv import load_dotenv
-load_dotenv()
-
if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
from google.cloud import logging as gc_logging
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
+# Load environment variables from .env file
+load_dotenv(os.path.join(BASE_DIR, ".env"))
+
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/
@@ -32,7 +34,11 @@
SECRET_KEY = os.getenv("SECRET_KEY")
# SECURITY WARNING: don't run with debug turned on in production!
-DEBUG = os.getenv("ENV") == "dev"
+DEBUG = os.getenv("ENV") == "dev" or os.getenv("DEBUG", "False").lower() in (
+ "true",
+ "1",
+ "t",
+)
if DEBUG:
ALLOWED_HOSTS = ["127.0.0.1", "localhost", "0.0.0.0", "*"]
@@ -41,8 +47,16 @@
"dev.anudesh.ai4bharat.org",
"0.0.0.0",
"backend.dev.anudesh.ai4bharat.org",
+ "backend.anudesh.ai4bharat.org",
]
+CSRF_TRUSTED_ORIGINS = [
+ "https://backend.anudesh.ai4bharat.org",
+ "https://anudesh.ai4bharat.org",
+ "https://backend.dev.anudesh.ai4bharat.org",
+ "https://dev.anudesh.ai4bharat.org",
+]
+
# Application definition
INSTALLED_APPS = [
@@ -74,7 +88,7 @@
CSRF_COOKIE_SECURE = False
-CSRF_TRUSTED_ORIGINS=['https://*.anudesh.ai4bharat.org']
+CSRF_TRUSTED_ORIGINS = ["https://*.anudesh.ai4bharat.org"]
MIDDLEWARE = [
"corsheaders.middleware.CorsMiddleware",
@@ -88,7 +102,13 @@
"whitenoise.middleware.WhiteNoiseMiddleware",
]
-CORS_ORIGIN_ALLOW_ALL = True
+CORS_ALLOWED_ORIGINS = [
+ "https://anudesh.ai4bharat.org",
+ "https://dev.anudesh.ai4bharat.org",
+ "http://localhost:3000",
+ "http://127.0.0.1:3000",
+]
+
CORS_ALLOW_CREDENTIALS = True
@@ -185,6 +205,10 @@
"rest_framework.authentication.SessionAuthentication",
),
"DEFAULT_PAGINATION_CLASS": "anudesh_backend.pagination.CustomPagination",
+ "DEFAULT_THROTTLE_RATES": {
+ "anon": "10/minute",
+ "user": "60/minute",
+ },
}
@@ -298,7 +322,12 @@
# Celery settings
CELERY_TIMEZONE = "Asia/Kolkata"
-CELERY_BROKER_URL = "redis://redis:6379/0"
+try:
+ socket.gethostbyname("redis")
+ CELERY_BROKER_URL = "redis://redis:6379/0"
+except socket.gaierror:
+ CELERY_BROKER_URL = "redis://localhost:6379/0"
+
CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler"
diff --git a/backend/dataset/admin.py b/backend/dataset/admin.py
index e4f51ec2..fa4a7f67 100644
--- a/backend/dataset/admin.py
+++ b/backend/dataset/admin.py
@@ -1,4 +1,8 @@
-import resource
+try:
+ import resource
+except ImportError:
+ resource = None
+
from django.contrib import admin
from import_export.admin import ImportExportActionModelAdmin
from .resources import *
diff --git a/backend/deploy/requirements-mac.txt b/backend/deploy/requirements-mac.txt
index 2c9f7c4e..8d286eb8 100644
--- a/backend/deploy/requirements-mac.txt
+++ b/backend/deploy/requirements-mac.txt
@@ -22,7 +22,7 @@ botocore==1.40.11
cachetools==5.5.2
celery==5.3.4
certifi==2025.8.3
-cffi==1.17.1
+cffi==2.0.0
charset-normalizer==3.4.3
click==8.2.1
click-didyoumean==0.3.1
@@ -30,7 +30,7 @@ click-plugins==1.1.1.2
click-repl==0.3.0
colorama==0.4.6
cron-descriptor==1.4.5
-cryptography==45.0.6
+cryptography==46.0.7
datamodel-code-generator==0.26.1
defusedxml==0.8.0rc2
diff-match-patch==20241021
@@ -193,7 +193,7 @@ sacrebleu==2.3.1
schedule==1.2.0
semver==2.13.0
sentry-sdk==2.35.0
-setuptools==80.9.0
+setuptools<81
simplejson==3.19.1
six==1.16.0
smart_open==7.3.0.post1
diff --git a/backend/deploy/requirements.txt b/backend/deploy/requirements.txt
index 160f9eef..13603558 100644
--- a/backend/deploy/requirements.txt
+++ b/backend/deploy/requirements.txt
@@ -1,174 +1,172 @@
-aiohttp==3.8.6
-aiosignal==1.3.1
-alabaster==0.7.13
-amqp==5.1.1
+aiohappyeyeballs==2.6.1
+aiohttp==3.12.15
+aiosignal==1.4.0
+alabaster==0.7.16
+amqp==5.3.1
+annotated-types==0.7.0
appdirs==1.4.4
-asgiref==3.7.2
-async-timeout==4.0.3
-asynctest==0.13.0
+asgiref==3.9.1
attr==0.3.1
-attrs==23.1.0
-azure-core==1.29.2
-azure-storage-blob==12.17.0
-Babel==2.12.1
+attrs==25.3.0
+azure-core==1.35.0
+azure-storage-blob==12.26.0
+Babel==2.17.0
beautifulsoup4==4.12.2
-billiard==3.6.4.0
+billiard==4.2.1
boto==2.49.0
-boto3==1.16.63
-botocore==1.19.63
-boxing==0.1.4
-cached-property==1.5.2
-cachetools==5.3.1
-celery==5.2.2
-certifi==2023.7.22
-cffi==1.15.1
-charset-normalizer==2.0.12
-click==8.1.7
-click-didyoumean==0.3.0
-click-plugins==1.1.1
+boto3==1.40.11
+botocore==1.40.11
+cachetools==5.5.2
+celery==5.3.4
+certifi==2025.8.3
+cffi==2.0.0
+charset-normalizer==3.4.3
+click==8.2.1
+click-didyoumean==0.3.1
+click-plugins==1.1.1.2
click-repl==0.3.0
colorama==0.4.6
coreapi==2.3.3
coreschema==0.0.4
-cryptography==41.0.3
+cryptography==46.0.7
defusedxml==0.7.1
Deprecated==1.2.14
-diff-match-patch==20230430
-Django==3.2.14
+diff-match-patch==20241021
+Django==5.1.11
django-annoying==0.10.6
-django-celery-beat==2.2.1
-django-celery-results==2.2.0
-django-cors-headers==3.6.0
+django-celery-beat==2.8.1
+django-celery-results==2.6.0
+django-cors-headers==4.7.0
django-debug-toolbar==3.2.1
-django-extensions==3.1.0
-django-filter==2.4.0
-django-import-export==3.2.0
+django-extensions==3.2.3
+django-filter==24.3
+django-import-export==4.3.9
django-model-utils==4.1.1
django-ranged-fileresponse==0.1.2
-django-rest-swagger==2.2.0
-django-rq==2.5.1
+django-rq==2.10.3
django-smtp-ssl==1.0
django-templated-mail==1.1.1
-django-timezone-field==4.2.3
+django-timezone-field==7.1
django-user-agents==0.4.0
-djangorestframework==3.13.1
-djangorestframework-simplejwt==4.8.0
-djoser==2.1.0
+djangorestframework==3.15.2
+djangorestframework-simplejwt==5.5.1
+djoser==2.3.3
docopt==0.6.2
docutils==0.17.1
drf-dynamic-fields==0.3.0
drf-flex-fields==0.9.5
drf-generators==0.3.0
-drf-yasg==1.20.0
+drf-yasg==1.21.10
et-xmlfile==1.1.0
-expiringdict==1.1.4
+expiringdict==1.2.2
ffmpeg
-frozenlist==1.3.3
+flower==2.0.1
+frozenlist==1.7.0
gcloud==0.18.3
google==3.0.0
-google-api-core==2.10.0
-google-auth==2.11.0
-google-cloud-appengine-logging==1.1.0
-google-cloud-audit-log==0.2.0
-google-cloud-core==2.3.2
-google-cloud-logging==2.7.2
-google-cloud-storage==2.5.0
-google-cloud-translate==3.8.4
-google-cloud-vision==3.1.4
-google-crc32c==1.5.0
-google-resumable-media==2.3.3
-googleapis-common-protos==1.56.4
-grpc-google-iam-v1==0.12.3
-grpcio==1.57.0
-grpcio-status==1.48.2
+google-api-core==2.25.1
+google-auth==2.40.3
+google-cloud-appengine-logging==1.6.2
+google-cloud-audit-log==0.3.2
+google-cloud-core==2.4.3
+google-cloud-logging==3.12.1
+google-cloud-storage==2.19.0
+google-cloud-translate==3.21.1
+google-cloud-vision==3.10.2
+google-crc32c==1.7.1
+google-resumable-media==2.7.2
+googleapis-common-protos==1.70.0
+grpc-google-iam-v1==0.14.2
+grpcio==1.74.0
+grpcio-status==1.74.0
gunicorn==21.2.0
htmlmin==0.1.12
httplib2==0.22.0
-idna==3.4
+httpx<0.28.0
+idna==3.10
imagesize==1.4.1
-importlib-metadata==1.7.0
+importlib-metadata==8.7.0
indic-nlp-library==0.92
inflection==0.5.1
-isodate==0.6.1
+isodate==0.7.2
itypes==1.2.0
-Jinja2==3.1.2
-jiwer==3.0.2
-jmespath==0.10.0
-joblib==1.3.2
-jsonschema==3.2.0
-jwcrypto==1.5.1
-kombu==5.2.4
-label-studio==1.6.0
-label-studio-converter==0.0.44
-label-studio-tools==0.0.1
-launchdarkly-server-sdk==7.3.0
+Jinja2==3.1.6
+jiwer==4.0.0
+jmespath==1.0.1
+joblib==1.5.1
+jsonschema==4.25.0
+jwcrypto==1.5.6
+kombu==5.3.7
+label-studio==1.20.0
+label-studio-converter==0.0.59
+label-studio-tools==0.0.4
+launchdarkly-server-sdk==8.2.1
Levenshtein==0.21.1
lockfile==0.12.2
-lxml==4.9.3
+lxml==6.0.0
MarkupPy==1.14
-MarkupSafe==2.1.3
+MarkupSafe==3.0.2
Morfessor==2.0.6
mosestokenizer==1.2.1
-multidict==6.0.4
-nltk==3.6.7
-numpy==1.21.6
+multidict==6.6.4
+nltk==3.9.1
+numpy==1.26.4
oauth2client==4.1.3
-oauthlib==3.2.2
+oauthlib==3.3.1
odfpy==1.4.1
-openai==1.53.0
+openai==1.102.0
openapi-codec==1.3.2
openfile==0.0.7
openpyxl==3.1.2
ordered-set==4.0.2
-packaging==23.1
-pandas==1.3.5
-Pillow==9.0.1
-portalocker==2.7.0
+packaging==25.0
+pandas==2.3.1
+Pillow==11.3.0
+portalocker==3.2.0
pretty-html-table==0.9.16
-prompt-toolkit==3.0.39
-proto-plus==1.22.3
-protobuf==3.20.3
-psycopg2==2.9.7
-psycopg2-binary==2.9.1
-pyasn1==0.5.0
-pyasn1-modules==0.3.0
-pycparser==2.21
-pycryptodome==3.19.1
-pydantic==1.8.2
-Pygments==2.16.1
-PyJWT==2.8.0
-pyparsing==3.1.1
-Pyrebase4==4.7.1
-pyRFC3339==1.1
-pyrsistent==0.19.3
-python-crontab==3.0.0
-python-dateutil==2.8.1
-python-dotenv==0.21.1
+prompt-toolkit==3.0.51
+proto-plus==1.26.1
+protobuf==6.32.0
+psycopg2-binary==2.9.10
+pyasn1==0.6.1
+pyasn1-modules==0.4.2
+pycparser==2.22
+pycryptodome==3.23.0
+pydantic==2.11.7
+Pygments==2.19.2
+PyJWT==2.10.1
+pyparsing==3.2.3
+Pyrebase4==4.8.0
+pyRFC3339==2.0.1
+python-crontab==3.3.0
+python-dateutil==2.9.0.post0
+python-dotenv==1.1.1
python-json-logger==2.0.4
python-jwt==4.1.0
python3-openid==3.2.0
-pytz==2019.3
-PyYAML==6.0.1
-rapidfuzz==2.13.7
-redis==5.0.0
-regex==2023.8.8
-requests==2.27.1
-requests-oauthlib==1.3.1
+pytz==2022.7.1
+PyYAML==6.0.2
+RapidFuzz==3.13.0
+redis==5.2.1
+regex==2025.7.34
+requests==2.32.4
+requests-oauthlib==2.0.0
requests-toolbelt==0.10.1
-rq==1.10.1
-rsa==4.9
+rq==1.16.2
+rsa==4.9.1
ruamel.yaml==0.17.32
ruamel.yaml.clib==0.2.7
-rules==2.2
-s3transfer==0.3.7
+rules==3.4
+s3transfer==0.13.1
sacrebleu==2.3.1
schedule==1.2.0
semver==2.13.0
-sentry-sdk==1.29.2
+sentry-sdk==2.35.0
+setuptools<81
simplejson==3.19.1
six==1.16.0
snowballstemmer==2.2.0
-social-auth-app-django==4.0.0
+social-auth-app-django==5.5.1
social-auth-core==4.4.2
soupsieve==2.4.1
Sphinx==4.3.2
@@ -182,16 +180,16 @@ sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlparse==0.4.4
-tablib==3.4.0
+tablib==3.8.0
tabulate==0.9.0
toolwrapper==2.1.0
tqdm==4.66.1
-typing_extensions==4.7.1
+typing_extensions==4.14.1
ua-parser==0.18.0
uctools==1.3.0
-ujson==5.7.0
+ujson==5.10.0
uritemplate==4.1.1
-urllib3==1.26.16
+urllib3==1.26.20
user-agents==2.2.0
vine==5.0.0
wcwidth==0.2.6
@@ -199,7 +197,6 @@ whitenoise==6.5.0
wrapt==1.16.0
xlrd==2.0.1
xlwt==1.3.0
-xmljson==0.2.0
-yarl==1.9.4
-zipp==3.15.0
-flower==2.0.1
\ No newline at end of file
+xmljson==0.2.1
+yarl==1.20.1
+zipp==3.23.0
\ No newline at end of file
diff --git a/backend/functions/views.py b/backend/functions/views.py
index b8ab7cc5..882a2e2c 100644
--- a/backend/functions/views.py
+++ b/backend/functions/views.py
@@ -2,7 +2,9 @@
import json
import os
import uuid
-from azure.storage.blob import BlobServiceClient
+import io
+from azure.storage.blob import BlobServiceClient, ContentSettings
+from PIL import Image
from anudesh_backend.locks import Lock
from urllib import request
@@ -12,7 +14,13 @@
from drf_yasg.utils import swagger_auto_schema
from projects.models import *
from rest_framework import status
-from rest_framework.decorators import api_view, permission_classes, action
+from rest_framework.decorators import (
+ api_view,
+ permission_classes,
+ action,
+ throttle_classes,
+)
+from rest_framework.throttling import UserRateThrottle
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from users.utils import (
@@ -35,6 +43,7 @@
)
from datetime import timezone
+
@api_view(["GET"])
def get_indic_trans_supported_langs_model_codes(request):
"""Function to get the supported languages and the translations supported by the indic-trans models"""
@@ -307,22 +316,44 @@ def download_all_projects(request):
)
-@permission_classes([AllowAny])
+@permission_classes([IsAuthenticated])
+@throttle_classes([UserRateThrottle])
@api_view(["POST"])
def chat_log(request):
try:
interaction_json = request.data.get("interaction_json")
user_data = request.data.get("user_data", {})
- x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
+
+ if interaction_json is None:
+ return Response(
+ {"message": "Missing interaction_json"},
+ status=status.HTTP_400_BAD_REQUEST,
+ )
+
+ if not isinstance(interaction_json, dict) or not isinstance(user_data, dict):
+ return Response(
+ {"message": "Invalid payload format"},
+ status=status.HTTP_400_BAD_REQUEST,
+ )
+
+ if len(json.dumps(interaction_json)) > 1024 * 1024:
+ return Response(
+ {"message": "Payload too large"}, status=status.HTTP_400_BAD_REQUEST
+ )
+ x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
user = request.user
- session_id = request.data.get("session_id", datetime.now(timezone.utc).strftime("%H:%M:%S %d-%m-%Y"))
+ session_id = request.data.get(
+ "session_id", datetime.now(timezone.utc).strftime("%H:%M:%S %d-%m-%Y")
+ )
if x_forwarded_for:
- ip = x_forwarded_for.split(',')[0].strip()
+ ip = x_forwarded_for.split(",")[0].strip()
else:
- ip = request.META.get('REMOTE_ADDR')
+ ip = request.META.get("REMOTE_ADDR")
user_data["ip_address"] = ip
user_data["user"] = user.email
- interaction_json["timestamp"] = datetime.now(timezone.utc).strftime("%H:%M:%S %d-%m-%Y")
+ interaction_json["timestamp"] = datetime.now(timezone.utc).strftime(
+ "%H:%M:%S %d-%m-%Y"
+ )
log_entry_string = json.dumps(interaction_json) + "\n"
log_entry_bytes = log_entry_string.encode("utf-8")
connection_string = os.getenv("CONNECTION_STRING_CHAT_LOG")
@@ -340,15 +371,19 @@ def chat_log(request):
container_client = blob_service_client.get_container_client(container_name)
name = f"{datetime.now(timezone.utc).strftime('%Y-%m-%d')}/{user.email}-{session_id}.jsonl"
blob_client = container_client.get_blob_client(name)
- blob_service_client = BlobServiceClient.from_connection_string(connection_string)
- blob_client = blob_service_client.get_blob_client(container=container_name, blob=name)
+ blob_service_client = BlobServiceClient.from_connection_string(
+ connection_string
+ )
+ blob_client = blob_service_client.get_blob_client(
+ container=container_name, blob=name
+ )
if not blob_client.exists():
blob_client.create_append_blob()
- user_entry_string = json.dumps({'user_data': user_data}) + "\n"
+ user_entry_string = json.dumps({"user_data": user_data}) + "\n"
user_entry_bytes = user_entry_string.encode("utf-8")
blob_client.append_block(user_entry_bytes)
-
+
blob_client.append_block(log_entry_bytes)
return Response(
@@ -362,7 +397,8 @@ def chat_log(request):
)
-@permission_classes([AllowAny])
+@permission_classes([IsAuthenticated])
+@throttle_classes([UserRateThrottle])
@api_view(["POST"])
def chat_output(request):
prompt = request.data.get("message")
@@ -382,27 +418,58 @@ def chat_output(request):
status=status.HTTP_200_OK,
)
+
@permission_classes([IsAuthenticated])
@api_view(["POST"])
def upload_chat_image(request):
- image_file = request.FILES.get('image')
+ image_file = request.FILES.get("image")
user = request.user
if image_file:
+ try:
+ # Verify and re-encode with Pillow to strip EXIF
+ image = Image.open(image_file)
+ image.verify()
+
+ image_file.seek(0)
+ image = Image.open(image_file)
+
+ output_io = io.BytesIO()
+ img_format = image.format if image.format else "JPEG"
+ image.save(output_io, format=img_format)
+ image_bytes = output_io.getvalue()
+ content_type = f"image/{img_format.lower()}"
+ except Exception as e:
+ return Response(
+ {"error": f"Invalid image file: {str(e)}"},
+ status=status.HTTP_400_BAD_REQUEST,
+ )
+
account_url = os.getenv("AZURE_ACCOUNT_URL_CHAT_IMAGES")
container_name = os.getenv("AZURE_CONTAINER_NAME_CHAT_IMAGES")
sas_token = os.getenv("AZURE_SAS_TOKEN_CHAT_IMAGES")
file_extension = os.path.splitext(image_file.name)[1]
blob_name = f"image-{user.email}{uuid.uuid4()}{file_extension}"
try:
- blob_service_client = BlobServiceClient(account_url=account_url, credential=sas_token)
- blob_client = blob_service_client.get_blob_client(container=container_name, blob=blob_name)
- blob_client.upload_blob(image_file.read(), blob_type="BlockBlob")
+ blob_service_client = BlobServiceClient(
+ account_url=account_url, credential=sas_token
+ )
+ blob_client = blob_service_client.get_blob_client(
+ container=container_name, blob=blob_name
+ )
+ blob_client.upload_blob(
+ image_bytes,
+ blob_type="BlockBlob",
+ content_settings=ContentSettings(
+ content_type=content_type, content_disposition="attachment"
+ ),
+ )
image_url = blob_client.url
return Response(
{"image_url": image_url},
- status=status.HTTP_201_CREATED,)
+ status=status.HTTP_201_CREATED,
+ )
except Exception as e:
return Response(
{"error": f"Failed to upload image: {str(e)}"},
- status=status.HTTP_500_INTERNAL_SERVER_ERROR
- )
\ No newline at end of file
+ status=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
diff --git a/backend/tasks/views.py b/backend/tasks/views.py
index f71c99ae..1f139e74 100644
--- a/backend/tasks/views.py
+++ b/backend/tasks/views.py
@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from locale import normalize
-from urllib.parse import unquote
+from urllib.parse import unquote, quote
import ast
from django.http import JsonResponse
from rest_framework import viewsets
@@ -8,7 +8,7 @@
from rest_framework import status
from rest_framework.response import Response
from rest_framework.decorators import action, api_view
-from rest_framework.permissions import IsAuthenticated, AllowAny
+from rest_framework.permissions import IsAuthenticated
from django.core.paginator import Paginator
import requests
@@ -19,7 +19,11 @@
PredictionSerializer,
TaskAnnotationSerializer,
)
-from tasks.utils import compute_meta_stats_for_instruction_driven_chat, compute_meta_stats_for_multiple_llm_idc, query_flower
+from tasks.utils import (
+ compute_meta_stats_for_instruction_driven_chat,
+ compute_meta_stats_for_multiple_llm_idc,
+ query_flower,
+)
from tasks.utils import Queued_Task_name, convert_audio_base64_to_mp3
from utils.pagination import paginate_queryset
from notifications.views import createNotification
@@ -81,18 +85,17 @@ class TaskViewSet(viewsets.ModelViewSet, mixins.ListModelMixin):
# 🔹 Swagger docs for the unassigned review summary endpoint
@swagger_auto_schema(
- method='get',
+ method="get",
operation_summary="Get Unassigned Review Summary",
operation_description="""
Returns the number of **unassigned review tasks** grouped by annotator
for a specific project.
-
Use this to display a popup summary of pending review assignments
per annotator in the Review Tasks tab.
""",
manual_parameters=[
openapi.Parameter(
- 'project_id',
+ "project_id",
openapi.IN_QUERY,
description="The ID of the project for which to fetch unassigned review task summary",
type=openapi.TYPE_INTEGER,
@@ -107,11 +110,10 @@ class TaskViewSet(viewsets.ModelViewSet, mixins.ListModelMixin):
},
)
@action(
- detail=False,
- methods=["get"],
- url_path="unassigned-review-summary",
- url_name="unassigned_review_summary",
- permission_classes=[AllowAny],
+ detail=False,
+ methods=["get"],
+ url_path="unassigned-review-summary",
+ url_name="unassigned_review_summary",
)
def unassigned_review_summary(self, request):
"""
@@ -146,8 +148,10 @@ def unassigned_review_summary(self, request):
.distinct()
)
+ from tasks.models import Annotation as Annotation_model, ANNOTATOR_ANNOTATION
+
annotator_task_counts = (
- Annotation.objects.filter(
+ Annotation_model.objects.filter(
task__in=unassigned_tasks,
annotation_type=ANNOTATOR_ANNOTATION,
completed_by__in=all_annotators,
@@ -168,20 +172,23 @@ def unassigned_review_summary(self, request):
result = []
for annotator in all_annotators:
- counts = count_lookup.get(annotator.id, {"unassigned_count": 0, "task_ids": []})
- result.append({
- "annotator_id": annotator.id,
- "annotator_email": annotator.email,
- "annotator_username": annotator.username,
- "unassigned_count": counts["unassigned_count"],
- "task_ids": counts["task_ids"],
- })
+ counts = count_lookup.get(
+ annotator.id, {"unassigned_count": 0, "task_ids": []}
+ )
+ result.append(
+ {
+ "annotator_id": annotator.id,
+ "annotator_email": annotator.email,
+ "annotator_username": annotator.username,
+ "unassigned_count": counts["unassigned_count"],
+ "task_ids": counts["task_ids"],
+ }
+ )
result.sort(key=lambda x: x["unassigned_count"], reverse=True)
return Response(result, status=status.HTTP_200_OK)
-
@action(
detail=False,
methods=["get"],
@@ -407,7 +414,7 @@ def list(self, request, *args, **kwargs):
task_obj["id"] = an.task_id
task_obj["annotation_status"] = an.annotation_status
task_obj["user_mail"] = an.completed_by.email
- if("unlabeled" not in ann_status):
+ if "unlabeled" not in ann_status:
task_obj["updated_at"] = utc_to_ist(an.updated_at)
task_objs.append(task_obj)
task_objs.sort(key=lambda x: x["id"])
@@ -418,7 +425,7 @@ def list(self, request, *args, **kwargs):
tas = tas.values()[0]
tas["annotation_status"] = task_obj["annotation_status"]
tas["user_mail"] = task_obj["user_mail"]
- if("unlabeled" not in ann_status):
+ if "unlabeled" not in ann_status:
tas["updated_at"] = task_obj["updated_at"]
ordered_tasks.append(tas)
if page_number is not None:
@@ -483,7 +490,7 @@ def list(self, request, *args, **kwargs):
task_obj["annotation_status"] = an.annotation_status
task_obj["user_mail"] = an.completed_by.email
task_obj["annotation_result_json"] = an.result
- if("unlabeled" not in ann_status):
+ if "unlabeled" not in ann_status:
task_obj["updated_at"] = utc_to_ist(an.updated_at)
task_objs.append(task_obj)
task_objs.sort(key=lambda x: x["id"])
@@ -495,7 +502,7 @@ def list(self, request, *args, **kwargs):
tas = tas.values()[0]
tas["annotation_status"] = task_obj["annotation_status"]
tas["user_mail"] = task_obj["user_mail"]
- if("unlabeled" not in ann_status):
+ if "unlabeled" not in ann_status:
tas["updated_at"] = task_obj["updated_at"]
if (ann_status[0] in ["labeled", "draft", "to_be_revised"]) and (
proj_type == "ContextualTranslationEditing"
@@ -899,11 +906,11 @@ def list(self, request, *args, **kwargs):
tas = Task.objects.filter(id=task_obj["id"])
tas = tas.values()[0]
tas["supercheck_status"] = task_obj["annotation_status"]
- if "UNVALIDATED" not in supercheck_status:
- tas["updated_at"] = task_obj["updated_at"]
tas["user_mail"] = task_obj["user_mail"]
tas["reviewer_mail"] = task_obj["reviewer_mail"]
tas["annotator_mail"] = task_obj["annotator_mail"]
+ if "UNVALIDATED" not in supercheck_status:
+ tas["updated_at"] = task_obj["updated_at"]
if proj_type == "ContextualTranslationEditing":
if supercheck_status[0] in [
"draft",
@@ -960,7 +967,9 @@ def list(self, request, *args, **kwargs):
task_status__in=tas_status,
)
if start_date and end_date:
- tasks = tasks.filter(annotations__updated_at__range=[start_date, end_date]).distinct()
+ tasks = tasks.filter(
+ annotations__updated_at__range=[start_date, end_date]
+ ).distinct()
# Handle search query (if any)
if len(tasks):
@@ -1001,7 +1010,9 @@ def list(self, request, *args, **kwargs):
annotation_users=user_id,
)
if start_date and end_date:
- tasks = tasks.filter(annotations__updated_at__range=[start_date, end_date]).distinct()
+ tasks = tasks.filter(
+ annotations__updated_at__range=[start_date, end_date]
+ ).distinct()
# Handle search query (if any)
if len(tasks):
@@ -1039,7 +1050,9 @@ def list(self, request, *args, **kwargs):
review_user_id=user_id,
)
if start_date and end_date:
- tasks = tasks.filter(annotations__updated_at__range=[start_date, end_date]).distinct()
+ tasks = tasks.filter(
+ annotations__updated_at__range=[start_date, end_date]
+ ).distinct()
# Handle search query (if any)
if len(tasks):
@@ -1077,7 +1090,9 @@ def list(self, request, *args, **kwargs):
super_checker_user_id=user_id,
)
if start_date and end_date:
- tasks = tasks.filter(annotations__updated_at__range=[start_date, end_date]).distinct()
+ tasks = tasks.filter(
+ annotations__updated_at__range=[start_date, end_date]
+ ).distinct()
tasks = tasks.order_by("id")
# Handle search query (if any)
@@ -1771,7 +1786,7 @@ def partial_update(self, request, pk=None):
== "MultipleLLMInstructionDrivenChat"
):
if isinstance(request.data["result"], str):
- if request.data["result"]=="":
+ if request.data["result"] == "":
# preferred_model = request.data.get("preferred_response")
# preferred_id = request.data.get("prompt_output_pair_id")
# for model_entry in annotation_obj.result:
@@ -1782,29 +1797,38 @@ def partial_update(self, request, pk=None):
eval_form_vals = request.data.get("model_responses_json")
preferred_id = request.data.get("prompt_output_pair_id")
eval_form_entry = next(
- (entry for entry in annotation_obj.result if "eval_form" in entry),
- None
+ (
+ entry
+ for entry in annotation_obj.result
+ if "eval_form" in entry
+ ),
+ None,
)
if eval_form_entry is None:
eval_form_entry = {"eval_form": []}
annotation_obj.result.insert(0, eval_form_entry)
existing_entry = next(
- (item for item in eval_form_entry["eval_form"] if item.get("prompt_output_pair_id") == preferred_id),
- None
+ (
+ item
+ for item in eval_form_entry["eval_form"]
+ if item.get("prompt_output_pair_id") == preferred_id
+ ),
+ None,
)
if existing_entry:
existing_entry["model_responses_json"] = eval_form_vals
else:
- eval_form_entry["eval_form"].append({
- "prompt_output_pair_id": preferred_id,
- "model_responses_json": eval_form_vals
- })
+ eval_form_entry["eval_form"].append(
+ {
+ "prompt_output_pair_id": preferred_id,
+ "model_responses_json": eval_form_vals,
+ }
+ )
else:
if not annotation_obj.result:
- annotation_obj.result.append({
- "eval_form": [],
- "model_interactions": []
- })
+ annotation_obj.result.append(
+ {"eval_form": [], "model_interactions": []}
+ )
result_entry = annotation_obj.result[0]
if "model_interactions" not in result_entry:
result_entry["model_interactions"] = []
@@ -1814,7 +1838,7 @@ def partial_update(self, request, pk=None):
annotation_obj.task,
annotation_obj,
annotation_obj.task.project_id.metadata_json,
- task.data["model"]
+ task.data["model"],
)
if output_result == -1:
ret_dict = {
@@ -1824,6 +1848,9 @@ def partial_update(self, request, pk=None):
return Response(ret_dict, status=ret_status)
elif isinstance(output_result, Response):
return output_result
+ for model_out in output_result.values():
+ if isinstance(model_out, Response):
+ return model_out
# store the result of all checks as well
prompt_text = request.data["result"]
for model_name, model_output in output_result.items():
@@ -1831,23 +1858,28 @@ def partial_update(self, request, pk=None):
"prompt": prompt_text,
"output": model_output,
"preferred_response": False,
- "prompt_output_pair_id": request.data['prompt_output_pair_id']
+ "prompt_output_pair_id": request.data[
+ "prompt_output_pair_id"
+ ],
}
model_found = False
for model_entry in result_entry["model_interactions"]:
if model_entry.get("model_name") == model_name:
- model_entry["interaction_json"].append(new_interaction)
+ model_entry["interaction_json"].append(
+ new_interaction
+ )
model_found = True
break
# If model not found, create a new one
if not model_found:
- result_entry["model_interactions"].append({
- "model_name": model_name,
- "interaction_json": [new_interaction]
-
- })
+ result_entry["model_interactions"].append(
+ {
+ "model_name": model_name,
+ "interaction_json": [new_interaction],
+ }
+ )
else:
annotation_obj.result = request.data["result"]
annotation_obj.meta_stats = (
@@ -1903,9 +1935,16 @@ def partial_update(self, request, pk=None):
)
if is_IDC:
annotation_response.data["output"] = output_result
- if (annotation_obj.task.project_id.project_type == "MultipleLLMInstructionDrivenChat"):
+ if (
+ annotation_obj.task.project_id.project_type
+ == "MultipleLLMInstructionDrivenChat"
+ ):
metadata = annotation_obj.task.project_id.metadata_json
- annotation_response.data["enable_preferrence_selection"] = metadata.get("enable_preferrence_selection", False) if isinstance(metadata, dict) else False
+ annotation_response.data["enable_preferrence_selection"] = (
+ metadata.get("enable_preferrence_selection", False)
+ if isinstance(metadata, dict)
+ else False
+ )
response_message = "Success"
else:
if "annotation_status" in dict(request.data) and request.data[
@@ -1981,16 +2020,16 @@ def partial_update(self, request, pk=None):
and len(annotation_obj.result) > len(request.data["result"])
):
request.data["result"] = annotation_obj.result
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- annotation_obj.result
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ annotation_obj.result
+ )
)
else:
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- request.data["result"]
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ request.data["result"]
+ )
)
annotation_response = super().partial_update(request)
if is_IDC:
@@ -2040,33 +2079,42 @@ def partial_update(self, request, pk=None):
== "MultipleLLMInstructionDrivenChat"
):
if isinstance(request.data["result"], str):
- if(request.data["result"]==""):
+ if request.data["result"] == "":
eval_form_vals = request.data.get("model_responses_json")
preferred_id = request.data.get("prompt_output_pair_id")
eval_form_entry = next(
- (entry for entry in annotation_obj.result if "eval_form" in entry),
- None
+ (
+ entry
+ for entry in annotation_obj.result
+ if "eval_form" in entry
+ ),
+ None,
)
if eval_form_entry is None:
eval_form_entry = {"eval_form": []}
annotation_obj.result.insert(0, eval_form_entry)
existing_entry = next(
- (item for item in eval_form_entry["eval_form"] if item.get("prompt_output_pair_id") == preferred_id),
- None
+ (
+ item
+ for item in eval_form_entry["eval_form"]
+ if item.get("prompt_output_pair_id") == preferred_id
+ ),
+ None,
)
if existing_entry:
existing_entry["model_responses_json"] = eval_form_vals
else:
- eval_form_entry["eval_form"].append({
- "prompt_output_pair_id": preferred_id,
- "model_responses_json": eval_form_vals
- })
+ eval_form_entry["eval_form"].append(
+ {
+ "prompt_output_pair_id": preferred_id,
+ "model_responses_json": eval_form_vals,
+ }
+ )
else:
if not annotation_obj.result:
- annotation_obj.result.append({
- "eval_form": [],
- "model_interactions": []
- })
+ annotation_obj.result.append(
+ {"eval_form": [], "model_interactions": []}
+ )
result_entry = annotation_obj.result[0]
if "model_interactions" not in result_entry:
result_entry["model_interactions"] = []
@@ -2076,7 +2124,7 @@ def partial_update(self, request, pk=None):
annotation_obj.task,
annotation_obj,
annotation_obj.task.project_id.metadata_json,
- task.data["model"]
+ task.data["model"],
)
if output_result == -1:
ret_dict = {
@@ -2086,6 +2134,9 @@ def partial_update(self, request, pk=None):
return Response(ret_dict, status=ret_status)
elif isinstance(output_result, Response):
return output_result
+ for model_out in output_result.values():
+ if isinstance(model_out, Response):
+ return model_out
# store the result of all checks as well
prompt_text = request.data["result"]
for model_name, model_output in output_result.items():
@@ -2093,23 +2144,28 @@ def partial_update(self, request, pk=None):
"prompt": prompt_text,
"output": model_output,
"preferred_response": False,
- "prompt_output_pair_id": request.data['prompt_output_pair_id'],
+ "prompt_output_pair_id": request.data[
+ "prompt_output_pair_id"
+ ],
}
model_found = False
for model_entry in result_entry["model_interactions"]:
if model_entry.get("model_name") == model_name:
- model_entry["interaction_json"].append(new_interaction)
+ model_entry["interaction_json"].append(
+ new_interaction
+ )
model_found = True
break
# If model not found, create a new one
if not model_found:
- result_entry["model_interactions"].append({
- "model_name": model_name,
- "interaction_json": [new_interaction]
-
- })
+ result_entry["model_interactions"].append(
+ {
+ "model_name": model_name,
+ "interaction_json": [new_interaction],
+ }
+ )
else:
annotation_obj.result = request.data["result"]
annotation_obj.meta_stats = (
@@ -2165,9 +2221,16 @@ def partial_update(self, request, pk=None):
)
if is_IDC:
annotation_response.data["output"] = output_result
- if (annotation_obj.task.project_id.project_type == "MultipleLLMInstructionDrivenChat"):
+ if (
+ annotation_obj.task.project_id.project_type
+ == "MultipleLLMInstructionDrivenChat"
+ ):
metadata = annotation_obj.task.project_id.metadata_json
- annotation_response.data["enable_preferrence_selection"] = metadata.get("enable_preferrence_selection", False) if isinstance(metadata, dict) else False
+ annotation_response.data["enable_preferrence_selection"] = (
+ metadata.get("enable_preferrence_selection", False)
+ if isinstance(metadata, dict)
+ else False
+ )
response_message = "Success"
else:
@@ -2282,16 +2345,16 @@ def partial_update(self, request, pk=None):
and len(annotation_obj.result) > len(request.data["result"])
):
request.data["result"] = annotation_obj.result
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- annotation_obj.result
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ annotation_obj.result
+ )
)
else:
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- request.data["result"]
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ request.data["result"]
+ )
)
annotation_response = super().partial_update(request)
if is_IDC:
@@ -2371,33 +2434,42 @@ def partial_update(self, request, pk=None):
== "MultipleLLMInstructionDrivenChat"
):
if isinstance(request.data["result"], str):
- if(request.data["result"]==""):
+ if request.data["result"] == "":
eval_form_vals = request.data.get("model_responses_json")
preferred_id = request.data.get("prompt_output_pair_id")
eval_form_entry = next(
- (entry for entry in annotation_obj.result if "eval_form" in entry),
- None
+ (
+ entry
+ for entry in annotation_obj.result
+ if "eval_form" in entry
+ ),
+ None,
)
if eval_form_entry is None:
eval_form_entry = {"eval_form": []}
annotation_obj.result.insert(0, eval_form_entry)
existing_entry = next(
- (item for item in eval_form_entry["eval_form"] if item.get("prompt_output_pair_id") == preferred_id),
- None
+ (
+ item
+ for item in eval_form_entry["eval_form"]
+ if item.get("prompt_output_pair_id") == preferred_id
+ ),
+ None,
)
if existing_entry:
existing_entry["model_responses_json"] = eval_form_vals
else:
- eval_form_entry["eval_form"].append({
- "prompt_output_pair_id": preferred_id,
- "model_responses_json": eval_form_vals
- })
+ eval_form_entry["eval_form"].append(
+ {
+ "prompt_output_pair_id": preferred_id,
+ "model_responses_json": eval_form_vals,
+ }
+ )
else:
if not annotation_obj.result:
- annotation_obj.result.append({
- "eval_form": [],
- "model_interactions": []
- })
+ annotation_obj.result.append(
+ {"eval_form": [], "model_interactions": []}
+ )
result_entry = annotation_obj.result[0]
if "model_interactions" not in result_entry:
result_entry["model_interactions"] = []
@@ -2407,7 +2479,7 @@ def partial_update(self, request, pk=None):
annotation_obj.task,
annotation_obj,
annotation_obj.task.project_id.metadata_json,
- task.data["model"]
+ task.data["model"],
)
if output_result == -1:
ret_dict = {
@@ -2424,23 +2496,28 @@ def partial_update(self, request, pk=None):
"prompt": prompt_text,
"output": model_output,
"preferred_response": False,
- "prompt_output_pair_id": request.data['prompt_output_pair_id'],
+ "prompt_output_pair_id": request.data[
+ "prompt_output_pair_id"
+ ],
}
model_found = False
for model_entry in result_entry["model_interactions"]:
if model_entry.get("model_name") == model_name:
- model_entry["interaction_json"].append(new_interaction)
+ model_entry["interaction_json"].append(
+ new_interaction
+ )
model_found = True
break
# If model not found, create a new one
if not model_found:
- result_entry["model_interactions"].append({
- "model_name": model_name,
- "interaction_json": [new_interaction]
-
- })
+ result_entry["model_interactions"].append(
+ {
+ "model_name": model_name,
+ "interaction_json": [new_interaction],
+ }
+ )
else:
annotation_obj.result = request.data["result"]
annotation_obj.meta_stats = (
@@ -2496,9 +2573,16 @@ def partial_update(self, request, pk=None):
)
if is_IDC:
annotation_response.data["output"] = output_result
- if (annotation_obj.task.project_id.project_type == "MultipleLLMInstructionDrivenChat"):
+ if (
+ annotation_obj.task.project_id.project_type
+ == "MultipleLLMInstructionDrivenChat"
+ ):
metadata = annotation_obj.task.project_id.metadata_json
- annotation_response.data["enable_preferrence_selection"] = metadata.get("enable_preferrence_selection", False) if isinstance(metadata, dict) else False
+ annotation_response.data["enable_preferrence_selection"] = (
+ metadata.get("enable_preferrence_selection", False)
+ if isinstance(metadata, dict)
+ else False
+ )
response_message = "Success"
else:
@@ -2604,23 +2688,30 @@ def partial_update(self, request, pk=None):
and len(annotation_obj.result) > len(request.data["result"])
):
request.data["result"] = annotation_obj.result
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- annotation_obj.result
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ annotation_obj.result
+ )
)
else:
- request.data[
- "meta_stats"
- ] = compute_meta_stats_for_multiple_llm_idc(
- request.data["result"]
+ request.data["meta_stats"] = (
+ compute_meta_stats_for_multiple_llm_idc(
+ request.data["result"]
+ )
)
annotation_response = super().partial_update(request)
if is_IDC:
annotation_response.data["output"] = output_result
- if (annotation_obj.task.project_id.project_type == "MultipleLLMInstructionDrivenChat"):
+ if (
+ annotation_obj.task.project_id.project_type
+ == "MultipleLLMInstructionDrivenChat"
+ ):
metadata = annotation_obj.task.project_id.metadata_json
- annotation_response.data["enable_preferrence_selection"] = metadata.get("enable_preferrence_selection", False) if isinstance(metadata, dict) else False
+ annotation_response.data["enable_preferrence_selection"] = (
+ metadata.get("enable_preferrence_selection", False)
+ if isinstance(metadata, dict)
+ else False
+ )
annotation_id = annotation_response.data["id"]
annotation = Annotation.objects.get(pk=annotation_id)
@@ -2950,7 +3041,10 @@ def get_llm_output(prompt, task, annotation, project_metadata_json):
if isinstance(project_metadata_json, str)
else project_metadata_json
)
- if isinstance(project_metadata, dict) and project_metadata.get("blank_response") == True:
+ if (
+ isinstance(project_metadata, dict)
+ and project_metadata.get("blank_response") == True
+ ):
return ""
if prompt in [None, "Null", 0, "None", "", " "]:
return -1
@@ -2996,6 +3090,7 @@ def get_llm_output(prompt, task, annotation, project_metadata_json):
return -1
return res
+
def get_all_llm_output(prompt, task, annotation, project_metadata_json, models_to_run):
# CHECKS
intent = task.data["meta_info_intent"]
@@ -3043,17 +3138,17 @@ def get_all_llm_output(prompt, task, annotation, project_metadata_json, models_t
# GET MODEL OUTPUT
history = ann_result[0]
-
model_output = get_all_model_output(
"We will be rendering your response on a frontend. so please add spaces or indentation or nextline chars or "
"bullet or numberings etc. suitably for code or the text. wherever required.",
prompt,
history,
- models_to_run
+ models_to_run,
)
return model_output
+
def format_model_output(model_output):
result = ""
if isinstance(model_output, list):
@@ -3122,13 +3217,17 @@ class TransliterationAPIView(APIView):
def get(self, request, target_language, data, *args, **kwargs):
response_transliteration = requests.get(
- os.getenv("TRANSLITERATION_URL") + target_language + "/" + data,
+ os.getenv("TRANSLITERATION_URL")
+ + quote(target_language)
+ + "/"
+ + quote(data),
headers={"Authorization": "Bearer " + os.getenv("TRANSLITERATION_KEY")},
)
transliteration_output = response_transliteration.json()
return Response(transliteration_output, status=status.HTTP_200_OK)
+
class TranscribeAPIView(APIView):
permission_classes = [IsAuthenticated]
@@ -3140,23 +3239,28 @@ def post(self, request, *args, **kwargs):
chunk_data = {
"config": {
- "serviceId": os.getenv("DHRUVA_SERVICE_ID") if lang != "en" else os.getenv("DHRUVA_SERVICE_ID_EN"),
+ "serviceId": (
+ os.getenv("DHRUVA_SERVICE_ID")
+ if lang != "en"
+ else os.getenv("DHRUVA_SERVICE_ID_EN")
+ ),
"language": {"sourceLanguage": lang},
- "transcriptionFormat": {"value": "transcript"}
- },
- "audio": [
- {
- "audioContent":mp3_base64
- }
- ]
- }
+ "transcriptionFormat": {"value": "transcript"},
+ },
+ "audio": [{"audioContent": mp3_base64}],
+ }
try:
- response = requests.post(os.getenv("DHRUVA_API_URL"),
- headers={"authorization": os.getenv("DHRUVA_KEY")},
- json=chunk_data,
+ response = requests.post(
+ os.getenv("DHRUVA_API_URL"),
+ headers={"authorization": os.getenv("DHRUVA_KEY")},
+ json=chunk_data,
)
transcript = response.json()["output"][0]["source"]
- return Response({"transcript": transcript+" " or ""}, status=status.HTTP_200_OK)
+ return Response(
+ {"transcript": transcript + " " or ""}, status=status.HTTP_200_OK
+ )
except Exception as e:
print("Error:", e)
- return Response({"message": "Failed to transcribe"}, status=status.HTTP_400_BAD_REQUEST)
+ return Response(
+ {"message": "Failed to transcribe"}, status=status.HTTP_400_BAD_REQUEST
+ )
diff --git a/backend/utils/llm_interactions.py b/backend/utils/llm_interactions.py
index eee3ad7f..9d7b91d4 100644
--- a/backend/utils/llm_interactions.py
+++ b/backend/utils/llm_interactions.py
@@ -41,7 +41,7 @@
import re
-from openai import OpenAI
+from openai import OpenAI, AzureOpenAI
import requests
from rest_framework import status
from rest_framework.response import Response
@@ -57,6 +57,7 @@ def process_history(history):
messages.append(system_side)
return messages
+
def get_gpt4_output(system_prompt, user_prompt, history, model):
if model == "GPT4":
deployment = os.getenv("LLM_INTERACTIONS_OPENAI_ENGINE_GPT_4")
@@ -66,10 +67,11 @@ def get_gpt4_output(system_prompt, user_prompt, history, model):
deployment = os.getenv("LLM_INTERACTIONS_OPENAI_ENGINE_GPT_4O_MINI")
else:
deployment = model
-
- client = OpenAI(
+
+ client = AzureOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
- base_url=f"{os.getenv('LLM_INTERACTIONS_OPENAI_API_BASE')}openai/deployments/{deployment}"
+ api_version=os.getenv("LLM_INTERACTIONS_OPENAI_API_VERSION"),
+ azure_endpoint=os.getenv("LLM_INTERACTIONS_OPENAI_API_BASE"),
)
history_messages = process_history(history)
@@ -86,7 +88,6 @@ def get_gpt4_output(system_prompt, user_prompt, history, model):
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
- extra_query={"api-version": os.getenv("LLM_INTERACTIONS_OPENAI_API_VERSION")},
)
return response.choices[0].message.content.strip()
@@ -104,12 +105,14 @@ def get_gpt4_output(system_prompt, user_prompt, history, model):
st = status.HTTP_500_INTERNAL_SERVER_ERROR
return Response({"message": message}, status=st)
+
def get_gpt3_output(system_prompt, user_prompt, history):
model = os.getenv("LLM_INTERACTIONS_OPENAI_ENGINE_GPT35")
- client = OpenAI(
+ client = AzureOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
- base_url=f"{os.getenv('LLM_INTERACTIONS_OPENAI_API_BASE')}openai/deployments/{model}"
+ api_version=os.getenv("LLM_INTERACTIONS_OPENAI_API_VERSION"),
+ azure_endpoint=os.getenv("LLM_INTERACTIONS_OPENAI_API_BASE"),
)
history_messages = process_history(history)
@@ -126,7 +129,6 @@ def get_gpt3_output(system_prompt, user_prompt, history):
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
- extra_query={"api-version": os.getenv("LLM_INTERACTIONS_OPENAI_API_VERSION")},
)
return response.choices[0].message.content.strip()
@@ -144,6 +146,7 @@ def get_gpt3_output(system_prompt, user_prompt, history):
st = status.HTTP_500_INTERNAL_SERVER_ERROR
return Response({"message": message}, status=st)
+
def get_llama2_output(system_prompt, conv_history, user_prompt):
api_base = os.getenv("LLM_INTERACTION_LLAMA2_API_BASE")
token = os.getenv("LLM_INTERACTION_LLAMA2_API_TOKEN")
@@ -165,21 +168,19 @@ def get_llama2_output(system_prompt, conv_history, user_prompt):
result = s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body)
return result.json()["choices"][0]["message"]["content"].strip()
+
def get_sarvam_m_output(system_prompt, conv_history, user_prompt):
api_base = os.getenv("SARVAM_M_API_BASE")
- api_key = os.getenv("SARVAM_M_API_KEY")
+ api_key = os.getenv("SARVAM_M_API_KEY")
url = f"{api_base}/chat/completions"
- headers = {
- "api-subscription-key": api_key,
- "Content-Type": "application/json"
- }
+ headers = {"api-subscription-key": api_key, "Content-Type": "application/json"}
history = process_history(conv_history)
messages = [{"role": "system", "content": system_prompt}]
messages.extend(history)
if type(user_prompt) == list:
- messages.append({"role": "user", "content": user_prompt[0]['text']})
+ messages.append({"role": "user", "content": user_prompt[0]["text"]})
else:
messages.append({"role": "user", "content": user_prompt})
@@ -189,13 +190,12 @@ def get_sarvam_m_output(system_prompt, conv_history, user_prompt):
"temperature": 0.2,
"max_tokens": 2048,
"top_p": 1,
- "reasoning_effort": None,
}
-
+
try:
s = requests.Session()
response = s.post(url, headers=headers, json=body)
- response.raise_for_status()
+ response.raise_for_status()
response_data = response.json()
return response_data["choices"][0]["message"]["content"].strip()
except requests.exceptions.RequestException as e:
@@ -206,11 +206,12 @@ def get_sarvam_m_output(system_prompt, conv_history, user_prompt):
print(f"Full response data: {response_data}")
raise
+
def get_deepinfra_output(system_prompt, user_prompt, history, model):
try:
client = OpenAI(
api_key=os.getenv("DEEPINFRA_API_KEY"),
- base_url=os.getenv("DEEPINFRA_BASE_URL")
+ base_url=os.getenv("DEEPINFRA_BASE_URL"),
)
history_messages = process_history(history)
@@ -226,7 +227,7 @@ def get_deepinfra_output(system_prompt, user_prompt, history, model):
)
output = response.choices[0].message.content.strip()
- cleaned_response = re.sub(r'.*?\s*', '', output, flags=re.DOTALL)
+ cleaned_response = re.sub(r".*?\s*", "", output, flags=re.DOTALL)
return cleaned_response
except Exception as e:
@@ -241,7 +242,8 @@ def get_deepinfra_output(system_prompt, user_prompt, history, model):
message = f"An error occurred while interacting with LLM: {err_msg}"
st = status.HTTP_500_INTERNAL_SERVER_ERROR
return Response({"message": message}, status=st)
-
+
+
def get_model_output(system_prompt, user_prompt, history, model=GPT4OMini):
# Assume that translation happens outside (and the prompt is already translated)
out = ""
@@ -257,6 +259,7 @@ def get_model_output(system_prompt, user_prompt, history, model=GPT4OMini):
out = get_deepinfra_output(system_prompt, user_prompt, history, model)
return out
+
def get_all_model_output(system_prompt, user_prompt, history, models_to_run):
results = {}
@@ -272,17 +275,25 @@ def get_all_model_output(system_prompt, user_prompt, history, models_to_run):
for interaction in history.get("model_interactions", [])
if interaction.get("model_name") == model
),
- []
+ [],
)
if model == GPT35:
results[model] = get_gpt3_output(system_prompt, user_prompt, model_history)
elif model in [GPT4, GPT4O, GPT4OMini]:
- results[model] = get_gpt4_output(system_prompt, user_prompt, model_history, model)
+ results[model] = get_gpt4_output(
+ system_prompt, user_prompt, model_history, model
+ )
elif model == LLAMA2:
- results[model] = get_llama2_output(system_prompt, model_history, user_prompt)
+ results[model] = get_llama2_output(
+ system_prompt, model_history, user_prompt
+ )
elif model == SARVAM_M:
- results[model] = get_sarvam_m_output(system_prompt, model_history, user_prompt)
+ results[model] = get_sarvam_m_output(
+ system_prompt, model_history, user_prompt
+ )
else:
- results[model] = get_deepinfra_output(system_prompt, user_prompt, model_history, model)
+ results[model] = get_deepinfra_output(
+ system_prompt, user_prompt, model_history, model
+ )
return results
diff --git a/backend/workspaces/views.py b/backend/workspaces/views.py
index 4775abee..baf6f595 100644
--- a/backend/workspaces/views.py
+++ b/backend/workspaces/views.py
@@ -66,7 +66,6 @@
)
from projects.registry_helper import ProjectRegistry
-
# Create your views here.
EMAIL_VALIDATION_REGEX = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
@@ -140,9 +139,11 @@ def list(self, request, *args, **kwargs):
)
serializer = WorkspaceSerializer(data, many=True)
return Response(serializer.data)
- elif (int(request.user.role) == User.ORGANIZATION_OWNER) or (
- request.user.is_superuser
- ) or (int(request.user.role) == User.ADMIN):
+ elif (
+ (int(request.user.role) == User.ORGANIZATION_OWNER)
+ or (request.user.is_superuser)
+ or (int(request.user.role) == User.ADMIN)
+ ):
data = self.queryset.filter(organization=request.user.organization)
serializer = WorkspaceSerializer(data, many=True)
return Response(serializer.data)
@@ -488,8 +489,7 @@ def archive(self, request, pk=None, *args, **kwargs):
methods=["POST"],
name="Bulk add Members to Projects",
url_name="bulk_add_members_to_projects",
- )
-
+ )
@is_particular_organization_owner
def bulk_add_members_to_projects(self, request, pk=None, *args, **kwargs):
"""
@@ -505,7 +505,9 @@ def bulk_add_members_to_projects(self, request, pk=None, *args, **kwargs):
)
if role not in ["annotator", "reviewer", "super_checker"]:
return Response(
- {"message": "Invalid role. Must be annotator or reviewer or super_checker."},
+ {
+ "message": "Invalid role. Must be annotator or reviewer or super_checker."
+ },
status=status.HTTP_400_BAD_REQUEST,
)
valid_users = []
@@ -521,7 +523,7 @@ def bulk_add_members_to_projects(self, request, pk=None, *args, **kwargs):
excepted_additions = []
for pid in project_ids:
try:
- project = Project.objects.get(pk=pid)
+ project = Project.objects.get(pk=pid, workspace_id=pk)
valid_projects.append(project)
except Project.DoesNotExist:
invalid_project_ids.append(pid)
@@ -548,13 +550,14 @@ def bulk_add_members_to_projects(self, request, pk=None, *args, **kwargs):
project.save()
message = "Users added to projects successfully."
if excepted_additions != []:
- message += f'Following users were not yet added: {excepted_additions}'
+ message += f"Following users were not yet added: {excepted_additions}"
return Response(
{
"message": message,
},
status=status.HTTP_200_OK,
)
+
@action(
detail=True, methods=["POST"], name="Assign Manager", url_name="assign_manager"
)
@@ -1751,7 +1754,7 @@ def cumulative_tasks_count_all(self, request, pk=None):
"InstructionDrivenChat",
"ModelInteractionEvaluation",
"ModelOutputEvaluation",
- "MultipleLLMInstructionDrivenChat"
+ "MultipleLLMInstructionDrivenChat",
]
if "project_type" in dict(request.query_params):
project_type = request.query_params["project_type"]