diff --git a/docker-compose.yml b/docker-compose.yml index 3bdb88a86..e78d7d0f6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -121,6 +121,11 @@ services: # Telemetry helps us prioritize feature development and understand how people are using Khoj # Read more at https://docs.khoj.dev/miscellaneous/telemetry # - KHOJ_TELEMETRY_DISABLE=True + # + # Uncomment the line below to add custom image domains to Content-Security-Policy. + # Useful when using custom image sources like CDN or external image hosts. + # Comma-separated list of domains (without https:// prefix). + # - KHOJ_CSP_IMG_DOMAINS=static.example.com,cdn.example.com # Comment out this line when you're using the official ghcr.io/khoj-ai/khoj-cloud:latest prod image. command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode --non-interactive diff --git a/src/interface/web/app/common/layoutHelper.tsx b/src/interface/web/app/common/layoutHelper.tsx index 2d5159d4b..918933c36 100644 --- a/src/interface/web/app/common/layoutHelper.tsx +++ b/src/interface/web/app/common/layoutHelper.tsx @@ -1,17 +1,25 @@ export function ContentSecurityPolicy() { + // Allow additional image domains via environment variable (comma-separated) + // e.g., KHOJ_CSP_IMG_DOMAINS=static.example.com,cdn.example.com + const additionalImgDomains = process.env.NEXT_PUBLIC_CSP_IMG_DOMAINS + ? process.env.NEXT_PUBLIC_CSP_IMG_DOMAINS.split(',').map(d => `https://${d.trim()}`).join(' ') + : ''; + + const imgSrc = `'self' data: blob: https://*.khoj.dev https://accounts.google.com https://*.googleusercontent.com https://*.google.com/ https://*.gstatic.com ${additionalImgDomains}`; + return ( ); } diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 50b2dd313..46bf1270d 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -396,12 +396,31 @@ async def dispatch(self, request: Request, call_next): # Re-raise for API routes and non-5xx errors raise e + class CSPHeadersMiddleware(BaseHTTPMiddleware): + """Add Content-Security-Policy headers with configurable image domains.""" + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + + # Only add CSP headers for HTML responses + content_type = response.headers.get("content-type", "") + if "text/html" in content_type: + # Get additional image domains from environment variable + additional_img_domains = os.environ.get("KHOJ_CSP_IMG_DOMAINS", "") + if additional_img_domains: + domains = " ".join([f"https://{d.strip()}" for d in additional_img_domains.split(",") if d.strip()]) + # Note: This adds to existing CSP meta tag in HTML + # For full CSP control, consider using header-based CSP instead of meta tag + response.headers["X-Khoj-CSP-Img-Domains"] = domains + + return response + if ssl_enabled: app.add_middleware(HTTPSRedirectMiddleware) app.add_middleware(SuppressClientDisconnectMiddleware) app.add_middleware(AsyncCloseConnectionsMiddleware) app.add_middleware(AuthenticationMiddleware, backend=UserAuthenticationBackend()) app.add_middleware(ServerErrorMiddleware) # Add after AuthenticationMiddleware to catch its exceptions + app.add_middleware(CSPHeadersMiddleware) # Add CSP headers with configurable image domains app.add_middleware(NextJsMiddleware) app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret")) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 56516b2db..6a8c3fe32 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -236,7 +236,7 @@ class ModelType(models.TextChoices): strengths = models.TextField(default=None, null=True, blank=True) def __str__(self): - return self.friendly_name + return self.friendly_name or self.name class VoiceModelOption(DbBaseModel): diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 3ebb2e62b..0ec2d0d20 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -58,6 +58,8 @@ # Default completion tokens # Reduce premature termination, especially when streaming structured responses MAX_COMPLETION_TOKENS = 16000 +# Groq API has a lower max_completion_tokens limit +MAX_COMPLETION_TOKENS_GROQ = 8192 def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str: @@ -115,7 +117,10 @@ def completion_with_backoff( model_kwargs["temperature"] = temperature model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) - model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", MAX_COMPLETION_TOKENS) + + # Set max_completion_tokens with Groq-specific limit + default_max_tokens = MAX_COMPLETION_TOKENS_GROQ if is_groq_api(api_base_url) else MAX_COMPLETION_TOKENS + model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", default_max_tokens) formatted_messages = format_message_for_api(messages, model_name, api_base_url) @@ -308,7 +313,10 @@ async def chat_completion_with_backoff( model_kwargs.pop("stream_options", None) model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) - model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", MAX_COMPLETION_TOKENS) + + # Set max_completion_tokens with Groq-specific limit + default_max_tokens = MAX_COMPLETION_TOKENS_GROQ if is_groq_api(api_base_url) else MAX_COMPLETION_TOKENS + model_kwargs["max_completion_tokens"] = model_kwargs.get("max_completion_tokens", default_max_tokens) formatted_messages = format_message_for_api(messages, model_name, api_base_url) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index f9159474d..4aadcbf71 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -46,15 +46,21 @@ def __init__( self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key self.inference_endpoint_type = embeddings_inference_endpoint_type - if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: + # Only load model locally if no inference endpoint is configured + if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL and not self.inference_endpoint: with timer(f"Loaded embedding model {self.model_name}", logger): self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) + else: + self.embeddings_model = None def embed_query(self, query): if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: return self.embed_with_hf([query])[0] elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: return self.embed_with_openai([query])[0] + elif self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL and self.inference_endpoint: + # Use OpenAI-compatible API for local inference endpoints (e.g., llama.cpp, vLLM) + return self.embed_with_openai([query])[0] return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( @@ -87,13 +93,21 @@ def embed_with_hf(self, docs): before_sleep=before_sleep_log(logger, logging.DEBUG), ) def embed_with_openai(self, docs): - client = get_openai_client(self.api_key, self.inference_endpoint) + # Use empty string for API key if not provided (local servers may not require auth) + client = get_openai_client(self.api_key or "", self.inference_endpoint) response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float") return [item.embedding for item in response.data] def embed_documents(self, docs): if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL: - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + # If inference endpoint is configured, use OpenAI-compatible API + if self.inference_endpoint: + embed_with_api = self.embed_with_openai + elif self.embeddings_model: + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] + else: + logger.warning("No local embedding model or endpoint configured") + return [] elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE: embed_with_api = self.embed_with_hf elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI: @@ -102,7 +116,7 @@ def embed_documents(self, docs): logger.warning( f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead." ) - return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] # break up the docs payload in chunks of 1000 to avoid hitting rate limits embeddings = [] with tqdm.tqdm(total=len(docs)) as pbar: