Skip to content

fix: protect endpoints with auth API key #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 35 additions & 65 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Union, Generator, Iterator


from utils.pipelines.auth import bearer_security, get_current_user
from utils.pipelines.auth import get_current_user_or_abort
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should keep the get_current_user convention. Please refer to our main Open WebUI repo!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Timothy, I'm sorry but I don't understand what you mean. Are you talking about adding the HTTP exception in the get_current_user or implement something like the function get_current_user_by_api_key (https://github.com/open-webui/open-webui/blob/eff736acd2e0bbbdd0eeca4cc209b216a1f23b6a/backend/utils/utils.py#L116C5-L116C32)?

from utils.pipelines.main import get_last_user_message, stream_message_template
from utils.pipelines.misc import convert_to_raw_url

Expand All @@ -28,7 +28,7 @@
import sys


from config import API_KEY, PIPELINES_DIR
from config import PIPELINES_DIR

if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)
Expand Down Expand Up @@ -243,7 +243,7 @@ async def check_url(request: Request, call_next):

@app.get("/v1/models")
@app.get("/models")
async def get_models():
async def get_models(user: str = Depends(get_current_user_or_abort)):
"""
Returns the available pipelines
"""
Expand Down Expand Up @@ -288,32 +288,26 @@ async def get_status():

@app.get("/v1/pipelines")
@app.get("/pipelines")
async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
return {
"data": [
{
"id": pipeline_id,
"name": PIPELINE_NAMES[pipeline_id],
"type": (
PIPELINE_MODULES[pipeline_id].type
if hasattr(PIPELINE_MODULES[pipeline_id], "type")
else "pipe"
),
"valves": (
True
if hasattr(PIPELINE_MODULES[pipeline_id], "valves")
else False
),
}
for pipeline_id in list(PIPELINE_MODULES.keys())
]
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
async def list_pipelines(user: str = Depends(get_current_user_or_abort)):
return {
"data": [
{
"id": pipeline_id,
"name": PIPELINE_NAMES[pipeline_id],
"type": (
PIPELINE_MODULES[pipeline_id].type
if hasattr(PIPELINE_MODULES[pipeline_id], "type")
else "pipe"
),
"valves": (
True
if hasattr(PIPELINE_MODULES[pipeline_id], "valves")
else False
),
}
for pipeline_id in list(PIPELINE_MODULES.keys())
]
}


class AddPipelineForm(BaseModel):
Expand Down Expand Up @@ -346,14 +340,8 @@ async def download_file(url: str, dest_folder: str):
@app.post("/v1/pipelines/add")
@app.post("/pipelines/add")
async def add_pipeline(
form_data: AddPipelineForm, user: str = Depends(get_current_user)
form_data: AddPipelineForm, user: str = Depends(get_current_user_or_abort)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)

try:
url = convert_to_raw_url(form_data.url)

Expand All @@ -376,14 +364,8 @@ async def add_pipeline(
@app.post("/v1/pipelines/upload")
@app.post("/pipelines/upload")
async def upload_pipeline(
file: UploadFile = File(...), user: str = Depends(get_current_user)
file: UploadFile = File(...), user: str = Depends(get_current_user_or_abort)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)

file_ext = os.path.splitext(file.filename)[1]
if file_ext != ".py":
raise HTTPException(
Expand Down Expand Up @@ -425,14 +407,8 @@ class DeletePipelineForm(BaseModel):
@app.delete("/v1/pipelines/delete")
@app.delete("/pipelines/delete")
async def delete_pipeline(
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
form_data: DeletePipelineForm, user: str = Depends(get_current_user_or_abort)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)

pipeline_id = form_data.id
pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None)

Expand All @@ -457,20 +433,14 @@ async def delete_pipeline(

@app.post("/v1/pipelines/reload")
@app.post("/pipelines/reload")
async def reload_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
await reload()
return {"message": "Pipelines reloaded successfully."}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
async def reload_pipelines(user: str = Depends(get_current_user_or_abort)):
await reload()
return {"message": "Pipelines reloaded successfully."}


@app.get("/v1/{pipeline_id}/valves")
@app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str):
async def get_valves(pipeline_id: str, user: str = Depends(get_current_user_or_abort)):
if pipeline_id not in PIPELINE_MODULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -490,7 +460,7 @@ async def get_valves(pipeline_id: str):

@app.get("/v1/{pipeline_id}/valves/spec")
@app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str):
async def get_valves_spec(pipeline_id: str, user: str = Depends(get_current_user_or_abort)):
if pipeline_id not in PIPELINE_MODULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -510,7 +480,7 @@ async def get_valves_spec(pipeline_id: str):

@app.post("/v1/{pipeline_id}/valves/update")
@app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict):
async def update_valves(pipeline_id: str, form_data: dict, user: str = Depends(get_current_user_or_abort)):

if pipeline_id not in PIPELINE_MODULES:
raise HTTPException(
Expand Down Expand Up @@ -553,7 +523,7 @@ async def update_valves(pipeline_id: str, form_data: dict):

@app.post("/v1/{pipeline_id}/filter/inlet")
@app.post("/{pipeline_id}/filter/inlet")
async def filter_inlet(pipeline_id: str, form_data: FilterForm):
async def filter_inlet(pipeline_id: str, form_data: FilterForm, user: str = Depends(get_current_user_or_abort)):
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down Expand Up @@ -585,7 +555,7 @@ async def filter_inlet(pipeline_id: str, form_data: FilterForm):

@app.post("/v1/{pipeline_id}/filter/outlet")
@app.post("/{pipeline_id}/filter/outlet")
async def filter_outlet(pipeline_id: str, form_data: FilterForm):
async def filter_outlet(pipeline_id: str, form_data: FilterForm, user: str = Depends(get_current_user_or_abort)):
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down Expand Up @@ -617,7 +587,7 @@ async def filter_outlet(pipeline_id: str, form_data: FilterForm):

@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm, user: str = Depends(get_current_user_or_abort)):
messages = [message.model_dump() for message in form_data.messages]
user_message = get_last_user_message(messages)

Expand Down
10 changes: 9 additions & 1 deletion utils/pipelines/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel
from typing import Union, Optional


from config import API_KEY
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
Expand Down Expand Up @@ -63,3 +63,11 @@ def get_current_user(
) -> Optional[dict]:
token = credentials.credentials
return token

def get_current_user_or_abort(user:str = Depends(get_current_user)) -> str:
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return user