Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ services:
# Telemetry helps us prioritize feature development and understand how people are using Khoj
# Read more at https://docs.khoj.dev/miscellaneous/telemetry
# - KHOJ_TELEMETRY_DISABLE=True
#
# Uncomment the line below to add custom image domains to Content-Security-Policy.
# Useful when using custom image sources like CDN or external image hosts.
# Comma-separated list of domains (without https:// prefix).
# - KHOJ_CSP_IMG_DOMAINS=static.example.com,cdn.example.com
# Comment out this line when you're using the official ghcr.io/khoj-ai/khoj-cloud:latest prod image.
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode --non-interactive

Expand Down
14 changes: 11 additions & 3 deletions src/interface/web/app/common/layoutHelper.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
export function ContentSecurityPolicy() {
// Allow additional image domains via environment variable (comma-separated)
// e.g., KHOJ_CSP_IMG_DOMAINS=static.example.com,cdn.example.com
const additionalImgDomains = process.env.NEXT_PUBLIC_CSP_IMG_DOMAINS
? process.env.NEXT_PUBLIC_CSP_IMG_DOMAINS.split(',').map(d => `https://${d.trim()}`).join(' ')
: '';

const imgSrc = `'self' data: blob: https://*.khoj.dev https://accounts.google.com https://*.googleusercontent.com https://*.google.com/ https://*.gstatic.com ${additionalImgDomains}`;

return (
<meta
httpEquiv="Content-Security-Policy"
content="default-src 'self' https://assets.khoj.dev;
content={`default-src 'self' https://assets.khoj.dev;
media-src * blob:;
script-src 'self' https://assets.khoj.dev https://app.chatwoot.com https://accounts.google.com 'unsafe-inline' 'unsafe-eval';
connect-src 'self' blob: https://ipapi.co/json ws://localhost:42110 https://accounts.google.com;
style-src 'self' https://assets.khoj.dev 'unsafe-inline' https://fonts.googleapis.com https://accounts.google.com;
img-src 'self' data: blob: https://*.khoj.dev https://accounts.google.com https://*.googleusercontent.com https://*.google.com/ https://*.gstatic.com;
img-src ${imgSrc};
font-src 'self' https://assets.khoj.dev https://fonts.gstatic.com;
frame-src 'self' https://accounts.google.com https://app.chatwoot.com;
child-src 'self' https://app.chatwoot.com;
object-src 'none';"
object-src 'none';`}
></meta>
);
}
19 changes: 19 additions & 0 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,31 @@ async def dispatch(self, request: Request, call_next):
# Re-raise for API routes and non-5xx errors
raise e

class CSPHeadersMiddleware(BaseHTTPMiddleware):
"""Add Content-Security-Policy headers with configurable image domains."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)

# Only add CSP headers for HTML responses
content_type = response.headers.get("content-type", "")
if "text/html" in content_type:
# Get additional image domains from environment variable
additional_img_domains = os.environ.get("KHOJ_CSP_IMG_DOMAINS", "")
if additional_img_domains:
domains = " ".join([f"https://{d.strip()}" for d in additional_img_domains.split(",") if d.strip()])
# Note: This adds to existing CSP meta tag in HTML
# For full CSP control, consider using header-based CSP instead of meta tag
response.headers["X-Khoj-CSP-Img-Domains"] = domains

return response

if ssl_enabled:
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(SuppressClientDisconnectMiddleware)
app.add_middleware(AsyncCloseConnectionsMiddleware)
app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend())
app.add_middleware(ServerErrorMiddleware) # Add after AuthenticationMiddleware to catch its exceptions
app.add_middleware(CSPHeadersMiddleware) # Add CSP headers with configurable image domains
app.add_middleware(NextJsMiddleware)
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))

Expand Down
2 changes: 1 addition & 1 deletion src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class ModelType(models.TextChoices):
strengths = models.TextField(default=None, null=True, blank=True)

def __str__(self):
return self.friendly_name
return self.friendly_name or self.name


class VoiceModelOption(DbBaseModel):
Expand Down
12 changes: 10 additions & 2 deletions src/khoj/processor/conversation/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
# Default completion tokens
# Reduce premature termination, especially when streaming structured responses
MAX_COMPLETION_TOKENS = 16000
# Groq API has a lower max_completion_tokens limit
MAX_COMPLETION_TOKENS_GROQ = 8192


def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str:
Expand Down Expand Up @@ -115,7 +117,10 @@ def completion_with_backoff(

model_kwargs["temperature"] = temperature
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", MAX_COMPLETION_TOKENS)

# Set max_completion_tokens with Groq-specific limit
default_max_tokens = MAX_COMPLETION_TOKENS_GROQ if is_groq_api(api_base_url) else MAX_COMPLETION_TOKENS
model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", default_max_tokens)

formatted_messages = format_message_for_api(messages, model_name, api_base_url)

Expand Down Expand Up @@ -308,7 +313,10 @@ async def chat_completion_with_backoff(
model_kwargs.pop("stream_options", None)

model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", MAX_COMPLETION_TOKENS)

# Set max_completion_tokens with Groq-specific limit
default_max_tokens = MAX_COMPLETION_TOKENS_GROQ if is_groq_api(api_base_url) else MAX_COMPLETION_TOKENS
model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", default_max_tokens)

formatted_messages = format_message_for_api(messages, model_name, api_base_url)

Expand Down
22 changes: 18 additions & 4 deletions src/khoj/processor/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,21 @@ def __init__(
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
self.inference_endpoint_type = embeddings_inference_endpoint_type
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
# Only load model locally if no inference endpoint is configured
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL and not self.inference_endpoint:
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
else:
self.embeddings_model = None

def embed_query(self, query):
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
return self.embed_with_hf([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embed_with_openai([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL and self.inference_endpoint:
# Use OpenAI-compatible API for local inference endpoints (e.g., llama.cpp, vLLM)
return self.embed_with_openai([query])[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]

@retry(
Expand Down Expand Up @@ -87,13 +93,21 @@ def embed_with_hf(self, docs):
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_openai(self, docs):
client = get_openai_client(self.api_key, self.inference_endpoint)
# Use empty string for API key if not provided (local servers may not require auth)
client = get_openai_client(self.api_key or "", self.inference_endpoint)
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
return [item.embedding for item in response.data]

def embed_documents(self, docs):
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
# If inference endpoint is configured, use OpenAI-compatible API
if self.inference_endpoint:
embed_with_api = self.embed_with_openai
elif self.embeddings_model:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
else:
logger.warning("No local embedding model or endpoint configured")
return []
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
embed_with_api = self.embed_with_hf
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
Expand All @@ -102,7 +116,7 @@ def embed_documents(self, docs):
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
Expand Down