Skip to content

Commit 38ee9c7

Browse files
committed
change user in new context
1 parent f1c427d commit 38ee9c7

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

template/server/api/models/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class Context(BaseModel):
66
id: StrictStr = Field(description="Context ID")
77
language: StrictStr = Field(description="Language of the context")
88
cwd: StrictStr = Field(description="Current working directory of the context")
9+
user: StrictStr = Field(description="User of the context")
910

1011
def __hash__(self):
1112
return hash(self.id)

template/server/api/models/create_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55

66
class CreateContext(BaseModel):
7+
user: Optional[StrictStr] = Field(
8+
default="root",
9+
description="User to run the context",
10+
)
711
cwd: Optional[StrictStr] = Field(
812
default="/home/user",
913
description="Current working directory",

template/server/contexts.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,16 @@ def normalize_language(language: Optional[str]) -> str:
2424
return language
2525

2626

27-
async def create_context(client, websockets: dict, language: str, cwd: str) -> Context:
27+
def get_kernel_name(language: str, user: str) -> str:
28+
if user == "root":
29+
return language+"_root"
30+
return language
31+
32+
33+
async def create_context(client, websockets: dict, language: str, cwd: str, user: str) -> Context:
2834
data = {
2935
"path": str(uuid.uuid4()),
30-
"kernel": {"name": language},
36+
"kernel": {"name": get_kernel_name(language, user)}, # replace with root kernel when user is root
3137
"type": "notebook",
3238
"name": str(uuid.uuid4()),
3339
}
@@ -59,4 +65,4 @@ async def create_context(client, websockets: dict, language: str, cwd: str) -> C
5965
status_code=500,
6066
)
6167

62-
return Context(language=language, id=context_id, cwd=cwd)
68+
return Context(language=language, id=context_id, cwd=cwd, user=user)

template/server/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def post_execute(request: ExecutionRequest):
9191
if not context_id:
9292
try:
9393
context = await create_context(
94-
client, websockets, language, "/home/user"
94+
client, websockets, language, "/home/user", "root"
9595
)
9696
except Exception as e:
9797
return PlainTextResponse(str(e), status_code=500)
@@ -127,9 +127,10 @@ async def post_contexts(request: CreateContext) -> Context:
127127

128128
language = normalize_language(request.language)
129129
cwd = request.cwd or "/home/user"
130+
user = request.user or "root"
130131

131132
try:
132-
return await create_context(client, websockets, language, cwd)
133+
return await create_context(client, websockets, language, cwd, user)
133134
except Exception as e:
134135
return PlainTextResponse(str(e), status_code=500)
135136

0 commit comments

Comments
 (0)