Skip to content

Commit 8429769

Browse files
authored
Merge pull request #232 from labmlai/archive
Archive
2 parents 3d48565 + e3184b7 commit 8429769

File tree

15 files changed

+385
-51
lines changed

15 files changed

+385
-51
lines changed

app/server/labml_app/db/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pymongo.errors import ConnectionFailure
1212

1313
from .. import settings
14-
from . import project
14+
from . import project, folder
1515
from . import user
1616
from . import status
1717
from . import app_token
@@ -58,6 +58,7 @@ def msave_dict(self, keys: List[str], data: List[ModelDict]):
5858

5959
models = [user.User,
6060
project.Project,
61+
folder.Folder,
6162
status.Status,
6263
status.RunStatus,
6364
app_token.AppToken,

app/server/labml_app/db/folder.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import time
2+
from enum import Enum
3+
from typing import List, Set
4+
5+
from labml_db import Model
6+
7+
from labml_app.db import run
8+
9+
10+
class DefaultFolders(Enum):
11+
ARCHIVE = 'archive'
12+
DEFAULT = 'default'
13+
14+
15+
class Folder(Model['Folder']):
16+
name: str
17+
created_at: float
18+
run_uuids: Set[str]
19+
20+
@classmethod
21+
def defaults(cls):
22+
return dict(name='',
23+
created_at=time.time(),
24+
run_uuids=set(),
25+
)
26+
27+
def get_runs(self) -> List['run.Run']:
28+
res = []
29+
for run_uuid in self.run_uuids:
30+
r_key = run.RunIndex.get(run_uuid)
31+
if r_key is not None:
32+
r = r_key.load()
33+
if r is not None:
34+
res.append(r)
35+
36+
res.sort(key=lambda x: x.start_time, reverse=True)
37+
return res
38+
39+
def add_run(self, run_uuid: str):
40+
self.run_uuids.add(run_uuid)
41+
self.save()
42+
43+
def remove_run(self, run_uuid: str):
44+
self.run_uuids.remove(run_uuid)
45+
self.save()
46+

app/server/labml_app/db/project.py

+84-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from labml_db import Model, Key, Index
44

5-
from . import run
5+
from . import run, folder
66
from . import session
77
from ..logger import logger
88

@@ -14,6 +14,7 @@ class Project(Model['Project']):
1414
runs: Dict[str, Key['run.Run']]
1515
sessions: Dict[str, Key['session.Session']]
1616
is_run_added: bool
17+
folders: Dict[str, Key['folder.Folder']]
1718

1819
@classmethod
1920
def defaults(cls):
@@ -23,6 +24,7 @@ def defaults(cls):
2324
runs={},
2425
sessions={},
2526
is_run_added=False,
27+
folders={},
2628
)
2729

2830
def is_project_run(self, run_uuid: str) -> bool:
@@ -31,10 +33,15 @@ def is_project_run(self, run_uuid: str) -> bool:
3133
def is_project_session(self, session_uuid: str) -> bool:
3234
return session_uuid in self.sessions
3335

34-
def get_runs(self) -> List['run.Run']:
36+
def get_runs(self, folder_name: str = "Default") -> List['run.Run']:
3537
res = []
3638
likely_deleted = []
37-
for run_uuid, run_key in self.runs.items():
39+
run_uuids = self.runs.keys()
40+
if folder_name in self.folders:
41+
f = self.folders[folder_name].load()
42+
if f:
43+
run_uuids = f.run_uuids
44+
for run_uuid in run_uuids:
3845
try:
3946
r = run.get(run_uuid)
4047
if r:
@@ -71,6 +78,7 @@ def delete_runs(self, run_uuids: List[str], project_owner: str) -> None:
7178
r = run.get(run_uuid)
7279
if r and r.owner == project_owner:
7380
try:
81+
self.delete_from_folder(r)
7482
run.delete(run_uuid)
7583
except TypeError:
7684
logger.error(f'error while deleting the run {run_uuid}')
@@ -95,9 +103,24 @@ def add_run(self, run_uuid: str) -> None:
95103

96104
if r:
97105
self.runs[run_uuid] = r.key
106+
self.add_to_folder(folder.DefaultFolders.DEFAULT.value, r)
98107

99108
self.save()
100109

110+
def add_run_with_model(self, r: run.Run) -> None:
111+
self.runs[r.run_uuid] = r.key
112+
self.is_run_added = True
113+
114+
self.add_to_folder(folder.DefaultFolders.DEFAULT.value, r)
115+
116+
self.save()
117+
118+
def get_run(self, run_uuid: str) -> Optional['run.Run']:
119+
if run_uuid in self.runs:
120+
return self.runs[run_uuid].load()
121+
else:
122+
return None
123+
101124
def add_session(self, session_uuid: str) -> None:
102125
s = session.get(session_uuid)
103126

@@ -106,6 +129,62 @@ def add_session(self, session_uuid: str) -> None:
106129

107130
self.save()
108131

132+
def add_to_folder(self, folder_name: str, r: run.Run) -> None:
133+
if folder_name not in self.folders:
134+
f = folder.Folder(name=folder_name)
135+
self.folders[folder_name] = f.key
136+
else:
137+
f = self.folders[folder_name].load()
138+
139+
f.add_run(r.run_uuid)
140+
f.save()
141+
142+
r.parent_folder = f.name
143+
r.save()
144+
145+
def remove_from_folder(self, folder_name: str, r: run.Run) -> None:
146+
if folder_name not in self.folders:
147+
f = folder.Folder(name=folder_name)
148+
self.folders[folder_name] = f.key
149+
else:
150+
f = self.folders[folder_name].load()
151+
152+
f.remove_run(r.run_uuid)
153+
f.save()
154+
155+
r.parent_folder = ''
156+
r.save()
157+
158+
def archive_runs(self, run_uuids: List[str]) -> None:
159+
for run_uuid in run_uuids:
160+
if run_uuid in self.runs:
161+
r = self.runs[run_uuid].load()
162+
if r:
163+
self.remove_from_folder(folder.DefaultFolders.DEFAULT.value, r)
164+
self.add_to_folder(folder.DefaultFolders.ARCHIVE.value, r)
165+
166+
self.save()
167+
168+
def un_archive_runs(self, run_uuids: List[str]) -> None:
169+
for run_uuid in run_uuids:
170+
if run_uuid in self.runs:
171+
r = self.runs[run_uuid].load()
172+
if r:
173+
self.remove_from_folder(folder.DefaultFolders.ARCHIVE.value, r)
174+
self.add_to_folder(folder.DefaultFolders.DEFAULT.value, r)
175+
176+
self.save()
177+
178+
def delete_from_folder(self, r: run.Run) -> None:
179+
folder_name = r.parent_folder
180+
if folder_name in self.folders:
181+
return
182+
parent_folder = self.folders[folder_name].load()
183+
if parent_folder is None:
184+
return
185+
parent_folder.remove_run(r.run_uuid)
186+
parent_folder.save()
187+
109188

110189
class ProjectIndex(Index['Project']):
111190
pass
@@ -123,8 +202,8 @@ def get_project(labml_token: str) -> Union[None, Project]:
123202
def get_run(run_uuid: str, labml_token: str = '') -> Optional['run.Run']:
124203
p = get_project(labml_token)
125204

126-
if run_uuid in p.runs:
127-
return p.runs[run_uuid].load()
205+
if p.is_project_run(run_uuid):
206+
return p.get_run(run_uuid)
128207
else:
129208
return None
130209

app/server/labml_app/db/run.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from labml_db import Model, Key, Index, load_keys
77

88
from .. import auth
9-
from . import user
9+
from . import user, folder
1010
from .. import utils
1111
from . import project
1212
from . import computer
@@ -64,6 +64,7 @@ class Run(Model['Run']):
6464
selected_configs: List['str']
6565
favourite_configs: List['str']
6666
main_rank: int
67+
parent_folder: str
6768

6869
wildcard_indicators: Dict[str, Dict[str, Union[str, bool]]]
6970
indicators: Dict[str, Dict[str, Union[str, bool]]]
@@ -110,6 +111,7 @@ def defaults(cls):
110111
process_key=None,
111112
session_id='',
112113
main_rank=0,
114+
parent_folder='',
113115
)
114116

115117
@property
@@ -303,7 +305,8 @@ def get_data(self, request: Request, is_dist_run: bool = False) -> Dict[str, Uni
303305
'favourite_configs': self.favourite_configs,
304306
'selected_configs': self.selected_configs,
305307
'process_id': self.process_id,
306-
'session_id': self.session_id
308+
'session_id': self.session_id,
309+
'folder': self.parent_folder
307310
}
308311

309312
def get_summary(self) -> Dict[str, str]:
@@ -354,8 +357,8 @@ class RunIndex(Index['Run']):
354357
def get_or_create(request: Request, run_uuid: str, rank: int, world_size: int, main_rank: int, labml_token: str = '') -> 'Run':
355358
p = project.get_project(labml_token)
356359

357-
if run_uuid in p.runs:
358-
return p.runs[run_uuid].load()
360+
if p.is_project_run(run_uuid):
361+
return p.get_run(run_uuid)
359362

360363
run = get(run_uuid)
361364
if run is not None:
@@ -383,8 +386,7 @@ def get_or_create(request: Request, run_uuid: str, rank: int, world_size: int, m
383386
)
384387

385388
if run.rank == 0: # TODO
386-
p.runs[run.run_uuid] = run.key
387-
p.is_run_added = True
389+
p.add_run_with_model(run)
388390

389391
run.save()
390392
p.save()
@@ -414,10 +416,9 @@ def delete(run_uuid: str) -> None:
414416
analyses.AnalysisManager.delete_run(run_uuid)
415417

416418

417-
def get_runs(labml_token: str) -> List['Run']:
418-
res = []
419+
def get_runs(labml_token: str, folder_name: str) -> List['Run']:
419420
p = project.get_project(labml_token)
420-
load_keys(list(p.runs.values()))
421+
res = p.get_runs(folder_name)
421422

422423
return res
423424

app/server/labml_app/handlers.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .logger import logger
1111
from . import settings
1212
from . import auth
13-
from .db import run
13+
from .db import run, folder
1414
from .db import computer
1515
from .db import session
1616
from .db import user
@@ -214,12 +214,11 @@ async def claim_run(request: Request, run_uuid: str, token: Optional[str] = None
214214

215215
default_project = u.default_project
216216

217-
if r.run_uuid not in default_project.runs:
217+
if not default_project.is_project_run(run_uuid):
218218
# float_project = project.get_project(labml_token=settings.FLOAT_PROJECT_TOKEN)
219219

220220
# if r.run_uuid in float_project.runs:
221-
default_project.runs[r.run_uuid] = r.key
222-
default_project.is_run_added = True
221+
default_project.add_run_with_model(r)
223222
default_project.save()
224223
r.is_claimed = True
225224
r.owner = u.email
@@ -362,15 +361,16 @@ async def get_session_status(request: Request, session_uuid: str) -> JSONRespons
362361

363362
@auth.login_required
364363
@auth.check_labml_token_permission
365-
async def get_runs(request: Request, labml_token: str, token: Optional[str] = None) -> EndPointRes:
364+
async def get_runs(request: Request, labml_token: str, token: Optional[str] = None,
365+
folder_name: str = folder.DefaultFolders.DEFAULT.value) -> EndPointRes:
366366
u = user.get_by_session_token(token)
367367

368368
if labml_token:
369-
runs_list = run.get_runs(labml_token)
369+
runs_list = run.get_runs(labml_token, folder_name)
370370
else:
371371
default_project = u.default_project
372372
labml_token = default_project.labml_token
373-
runs_list = default_project.get_runs()
373+
runs_list = default_project.get_runs(folder_name)
374374

375375
# run_uuids = [r.run_uuid for r in runs_list if r.world_size == 0]
376376
#
@@ -494,6 +494,39 @@ async def add_run(request: Request, run_uuid: str, token: Optional[str] = None)
494494
return {'is_successful': True}
495495

496496

497+
@auth.login_required
498+
async def archive_runs(request: Request, token: Optional[str] = None) -> EndPointRes:
499+
json = await request.json()
500+
run_uuids = json['run_uuids']
501+
502+
u = user.get_by_session_token(token)
503+
504+
try:
505+
u.default_project.archive_runs(run_uuids)
506+
except KeyError:
507+
return {'is_successful': False, 'error': "Failed to archive. Probably due to inconsistencies with the server."
508+
"Please refresh the page and try again."}
509+
510+
return {'is_successful': True}
511+
512+
513+
@auth.login_required
514+
async def un_archive_runs(request: Request, token: Optional[str] = None) -> EndPointRes:
515+
json = await request.json()
516+
run_uuids = json['run_uuids']
517+
518+
u = user.get_by_session_token(token)
519+
520+
try:
521+
u.default_project.un_archive_runs(run_uuids)
522+
except KeyError:
523+
return {'is_successful': False, 'error': "Failed to un-archive. Probably due to inconsistencies with the "
524+
"server."
525+
"Please refresh the page and try again."}
526+
527+
return {'is_successful': True}
528+
529+
497530
@auth.login_required
498531
async def add_session(request: Request, session_uuid: str, token: Optional[str] = None) -> EndPointRes:
499532
u = user.get_by_session_token(token)
@@ -553,6 +586,9 @@ def add_handlers(app: FastAPI):
553586
_add_server(app, 'POST', update_run, '{labml_token}/track')
554587
_add_server(app, 'POST', update_session, '{labml_token}/computer')
555588

589+
_add_ui(app, 'POST', archive_runs, 'runs/archive')
590+
_add_ui(app, 'POST', un_archive_runs, 'runs/unarchive')
591+
556592
_add_ui(app, 'GET', get_runs, 'runs/{labml_token}')
557593
_add_ui(app, 'PUT', delete_runs, 'runs')
558594
_add_ui(app, 'GET', get_sessions, 'sessions/{labml_token}')

app/ui/src/analyses/experiments/configs/view.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class RunConfigsView extends ScreenView {
128128
}
129129

130130
await CACHE.getRun(this.uuid).updateRunData(data)
131-
await CACHE.getRunsList().localUpdateRun(this.run)
131+
await CACHE.getRunsList(this.run.folder).localUpdateRun(this.run)
132132
} catch (e) {
133133
this.userMessage.networkError(e, "Failed to save configurations")
134134
this.save.disabled = false

0 commit comments

Comments
 (0)