Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions comps/cores/mega/keycloak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
from typing import Dict, Optional

import jwt
import requests
from jwt import ExpiredSignatureError, InvalidTokenError


class Keycloak:
def __init__(self, realm_url: str = os.getenv("REALM_URL"), algorithm: str = "RS256"):
"""Initializes the Keycloak JWT Interface with the realm URL and algorithm.

:param realm_url: Keycloak realm URL to fetch public keys for token verification.
:param algorithm: Algorithm used for the token, usually 'RS256' for Keycloak.
"""
self.realm_url = realm_url
self.algorithm = algorithm
self.public_keys = self.fetch_public_keys()

Check warning on line 22 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L20-L22

Added lines #L20 - L22 were not covered by tests

def fetch_public_keys(self) -> Dict[str, str]:
"""Fetches and returns Keycloak public keys for token verification.

:return: Dictionary mapping key IDs to their corresponding public keys.
"""
try:
response = requests.get(f"{self.realm_url}/protocol/openid-connect/certs")
response.raise_for_status()
certs = response.json()
return {key["kid"]: jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) for key in certs["keys"]}
except requests.RequestException as e:
return {}

Check warning on line 35 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L29-L35

Added lines #L29 - L35 were not covered by tests

def decode_token(self, token: str) -> Optional[Dict]:
"""Decodes a Keycloak JWT token and verifies its signature and expiration.

:param token: JWT token as a string.
:return: Decoded payload as a dictionary if valid, None otherwise.
"""
try:
unverified_header = jwt.get_unverified_header(token)
key = self.public_keys.get(unverified_header.get("kid"))
if not key:
print("Invalid token header: key ID not found.")
return None
decoded = jwt.decode(token, key=key, algorithms=[self.algorithm])
return decoded
except ExpiredSignatureError:
raise ExpiredSignatureError
except InvalidTokenError:
raise InvalidTokenError

Check warning on line 54 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L43-L54

Added lines #L43 - L54 were not covered by tests

def verify_token(self, token: str) -> bool:
"""Verifies if the token is valid and not expired.

:param token: JWT token as a string.
:return: True if valid, False otherwise.
"""
decoded = self.decode_token(token)
return decoded is not None

Check warning on line 63 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L62-L63

Added lines #L62 - L63 were not covered by tests

def get_user_info(self, token: str) -> Optional[Dict]:
"""Extracts user information from the JWT token payload.

:param token: JWT token as a string.
:return: Dictionary of user information if available, None otherwise.
"""
decoded = self.decode_token(token)
if decoded:
return {

Check warning on line 73 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L71-L73

Added lines #L71 - L73 were not covered by tests
"username": decoded.get("preferred_username"),
"email": decoded.get("email"),
"roles": decoded.get("realm_access", {}).get("roles", []),
}
return None

Check warning on line 78 in comps/cores/mega/keycloak.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/keycloak.py#L78

Added line #L78 was not covered by tests
69 changes: 69 additions & 0 deletions comps/cores/mega/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from socket import AF_INET, SOCK_STREAM, socket
from typing import List, Optional, Union

import jwt
import requests
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from .keycloak import Keycloak
from .logger import CustomLogger


Expand Down Expand Up @@ -244,6 +248,71 @@
return ""


bearer_scheme = HTTPBearer(auto_error=False)


def token_validator(allowed_roles: Optional[List[str]] = None):
async def validate_token(

Check warning on line 255 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L255

Added line #L255 was not covered by tests
request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme)
):
"""Validates the token, checks for allowed roles, and sets user details in request.state.user if valid.

Raises HTTPException with appropriate status code and message if validation fails.
"""
# If token is not provided, skip validation
JWT_AUTH = os.getenv("JWT_AUTH", False)
if not JWT_AUTH:
request.state.user = None
return
if credentials is None:
raise HTTPException(

Check warning on line 268 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L263-L268

Added lines #L263 - L268 were not covered by tests
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme or Missing Token",
)
if credentials.scheme != "Bearer":
raise HTTPException(

Check warning on line 273 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L272-L273

Added lines #L272 - L273 were not covered by tests
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme",
)
try:
token = credentials.credentials
identity_provider = Keycloak()
decoded_token = identity_provider.decode_token(token)

Check warning on line 280 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L277-L280

Added lines #L277 - L280 were not covered by tests

if not decoded_token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token.")

Check warning on line 283 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L282-L283

Added lines #L282 - L283 were not covered by tests

# Extract roles from the token
user_roles = decoded_token.get("realm_access", {}).get("roles", [])

Check warning on line 286 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L286

Added line #L286 was not covered by tests

# Check if user has any of the allowed roles
if allowed_roles and not any(role in user_roles for role in allowed_roles):
raise HTTPException(

Check warning on line 290 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L289-L290

Added lines #L289 - L290 were not covered by tests
status_code=status.HTTP_403_FORBIDDEN, detail="User does not have required permissions."
)

# Set user details in request.state.user
request.state.user = {

Check warning on line 295 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L295

Added line #L295 was not covered by tests
"username": decoded_token.get("preferred_username"),
"email": decoded_token.get("email"),
"roles": user_roles,
}

except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired.")

Check warning on line 302 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L301-L302

Added lines #L301 - L302 were not covered by tests

except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token signature.")
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(

Check warning on line 309 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L304-L309

Added lines #L304 - L309 were not covered by tests
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token validation error: {str(e)}"
)

return validate_token

Check warning on line 313 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L313

Added line #L313 was not covered by tests


class SafeContextManager:
"""This context manager ensures that the `__exit__` method of the
sub context is called, even when there is an Exception in the
Expand Down
12 changes: 12 additions & 0 deletions comps/dataprep/redis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ export INDEX_NAME=${your_index_name}
export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token}
```

if Authorization is needed with keycloak

```bash
realm_name=productivitysuite
export JWT_AUTH=True
export REALM_URL="http://${your_ip}/realms/$realm_name"
export ADMIN_ROLE="admin"
export USER_ROLE="user"
```

If JWT_AUTH is enabled make sure to follow [keycloak setup guide](https://github.com/opea-project/GenAIExamples/blob/main/ProductivitySuite/docker_compose/intel/cpu/xeon/keycloak_setup_guide.md)

### 2.3 Build Docker Image

- Build docker image with langchain
Expand Down
3 changes: 3 additions & 0 deletions comps/dataprep/redis/langchain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ def format_redis_conn_from_env():
TIMEOUT_SECONDS = int(os.getenv("TIMEOUT_SECONDS", 600))

SEARCH_BATCH_SIZE = int(os.getenv("SEARCH_BATCH_SIZE", 10))

ADMIN_ROLE = os.getenv("ADMIN_ROLE_KEY", "admin")
USER_ROLE = os.getenv("USER_ROLE_KEY", "user")
12 changes: 8 additions & 4 deletions comps/dataprep/redis/langchain/prepare_doc_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

# from pyspark import SparkConf, SparkContext
import redis
from config import EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE
from fastapi import Body, File, Form, HTTPException, UploadFile
from config import ADMIN_ROLE, EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE, USER_ROLE
from fastapi import Body, Depends, File, Form, HTTPException, UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Redis
Expand All @@ -19,6 +19,7 @@
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

from comps import CustomLogger, DocPath, opea_microservices, register_microservice
from comps.cores.mega.utils import token_validator
from comps.dataprep.utils import (
create_upload_folder,
document_loader,
Expand Down Expand Up @@ -223,6 +224,7 @@ async def ingest_documents(
chunk_overlap: int = Form(100),
process_table: bool = Form(False),
table_strategy: str = Form("fast"),
_: Optional[str] = Depends(token_validator([ADMIN_ROLE])),
):
if logflag:
logger.info(f"[ upload ] files:{files}")
Expand Down Expand Up @@ -341,7 +343,7 @@ async def ingest_documents(
@register_microservice(
name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/get_file", host="0.0.0.0", port=6007
)
async def rag_get_file_structure():
async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USER_ROLE, ADMIN_ROLE]))):
if logflag:
logger.info("[ get ] start to get file structure")

Expand Down Expand Up @@ -375,7 +377,9 @@ async def rag_get_file_structure():
@register_microservice(
name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/delete_file", host="0.0.0.0", port=6007
)
async def delete_single_file(file_path: str = Body(..., embed=True)):
async def delete_single_file(
file_path: str = Body(..., embed=True), _: Optional[str] = Depends(token_validator([ADMIN_ROLE]))
):
"""Delete file according to `file_path`.

`file_path`:
Expand Down
1 change: 1 addition & 0 deletions comps/dataprep/redis/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ opentelemetry-sdk
pandas
Pillow
prometheus-fastapi-instrumentator
pyjwt
pymupdf
pyspark
pytesseract
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ opentelemetry-exporter-otlp
opentelemetry-sdk
Pillow
prometheus-fastapi-instrumentator
pyjwt
pypdf
python-multipart
pyyaml
Expand Down
Loading