Skip to content

Commit abf1b72

Browse files
authored
Merge pull request #221 from labmlai/stdout
Stdout
2 parents ec2a012 + 4a89b60 commit abf1b72

File tree

24 files changed

+796
-138
lines changed

24 files changed

+796
-138
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any
2+
3+
from starlette.responses import JSONResponse
4+
5+
import labml_app
6+
from labml_db import Model, Index
7+
from labml_db.serializer.pickle import PickleSerializer
8+
from labml_db.serializer.yaml import YamlSerializer
9+
from fastapi import Request
10+
11+
from labml_app.analyses.analysis import Analysis
12+
from labml_app.analyses.logs import Logs
13+
14+
15+
class StdErr(Logs):
16+
pass
17+
18+
19+
@Analysis.db_model(PickleSerializer, 'stderr')
20+
class StdErrModel(Model['StdErrModel'], StdErr):
21+
pass
22+
23+
24+
@Analysis.db_index(YamlSerializer, 'stderr_index.yaml')
25+
class StdErrIndex(Index['StdErr']):
26+
pass
27+
28+
29+
@Analysis.route('POST', 'logs/stderr/{run_uuid}')
30+
async def get_std_err(request: Request, run_uuid: str) -> Any:
31+
"""
32+
body data: {
33+
page: int
34+
}
35+
36+
page = -2 means get all logs.
37+
page = -1 means get last page.
38+
page = n means get nth page.
39+
"""
40+
run_uuid = labml_app.db.run.get_main_rank(run_uuid)
41+
if run_uuid is None:
42+
return JSONResponse(status_code=404, content={'message': 'Run not found'})
43+
44+
json = await request.json()
45+
page = json.get('page', -1)
46+
47+
key = StdErrIndex.get(run_uuid)
48+
std_out: StdErrModel
49+
50+
if key is None:
51+
std_out = StdErrModel()
52+
std_out.save()
53+
StdErrIndex.set(run_uuid, std_out.key)
54+
else:
55+
std_out = key.load()
56+
57+
return std_out.get_data(page_no=page)
58+
59+
60+
def update_stderr(run_uuid: str, content: str):
61+
key = StdErrIndex.get(run_uuid)
62+
std_err: StdErrModel
63+
64+
if key is None:
65+
std_err = StdErrModel()
66+
std_err.save()
67+
StdErrIndex.set(run_uuid, std_err.key)
68+
else:
69+
std_err = key.load()
70+
71+
std_err.update_logs(content)
72+
std_err.save()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any
2+
3+
from starlette.responses import JSONResponse
4+
5+
import labml_app
6+
from labml_db import Model, Index
7+
from labml_db.serializer.pickle import PickleSerializer
8+
from labml_db.serializer.yaml import YamlSerializer
9+
from fastapi import Request
10+
11+
from labml_app.analyses.analysis import Analysis
12+
from labml_app.analyses.logs import Logs
13+
14+
15+
class StdLogger(Logs):
16+
pass
17+
18+
19+
@Analysis.db_model(PickleSerializer, 'std_logger')
20+
class StdLoggerModel(Model['StdLoggerModel'], StdLogger):
21+
pass
22+
23+
24+
@Analysis.db_index(YamlSerializer, 'std_logger_index.yaml')
25+
class StdLoggerIndex(Index['StdLogger']):
26+
pass
27+
28+
29+
@Analysis.route('POST', 'logs/std_logger/{run_uuid}')
30+
async def get_std_logger(request: Request, run_uuid: str) -> Any:
31+
"""
32+
body data: {
33+
page: int
34+
}
35+
36+
page = -2 means get all logs.
37+
page = -1 means get last page.
38+
page = n means get nth page.
39+
"""
40+
run_uuid = labml_app.db.run.get_main_rank(run_uuid)
41+
if run_uuid is None:
42+
return JSONResponse(status_code=404, content={'message': 'Run not found'})
43+
44+
json = await request.json()
45+
page = json.get('page', -1)
46+
47+
key = StdLoggerIndex.get(run_uuid)
48+
std_out: StdLoggerModel
49+
50+
if key is None:
51+
std_out = StdLoggerModel()
52+
std_out.save()
53+
StdLoggerIndex.set(run_uuid, std_out.key)
54+
else:
55+
std_out = key.load()
56+
57+
return std_out.get_data(page_no=page)
58+
59+
60+
def update_std_logger(run_uuid: str, content: str):
61+
key = StdLoggerIndex.get(run_uuid)
62+
std_logger: StdLoggerModel
63+
64+
if key is None:
65+
std_logger = StdLoggerModel()
66+
std_logger.save()
67+
StdLoggerIndex.set(run_uuid, std_logger.key)
68+
else:
69+
std_logger = key.load()
70+
71+
std_logger.update_logs(content)
72+
std_logger.save()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Any
2+
3+
from labml_db import Model, Index
4+
from labml_db.serializer.pickle import PickleSerializer
5+
from labml_db.serializer.yaml import YamlSerializer
6+
from fastapi import Request
7+
from starlette.responses import JSONResponse
8+
9+
import labml_app.db.run
10+
from labml_app.analyses.analysis import Analysis
11+
from labml_app.analyses.logs import Logs
12+
13+
14+
class StdOut(Logs):
15+
pass
16+
17+
18+
@Analysis.db_model(PickleSerializer, 'stdout')
19+
class StdOutModel(Model['StdOutModel'], StdOut):
20+
pass
21+
22+
23+
@Analysis.db_index(YamlSerializer, 'stdout_index.yaml')
24+
class StdOutIndex(Index['StdOut']):
25+
pass
26+
27+
28+
@Analysis.route('POST', 'logs/stdout/{run_uuid}')
29+
async def get_stdout(request: Request, run_uuid: str) -> Any:
30+
"""
31+
body data: {
32+
page: int
33+
}
34+
35+
page = -2 means get all logs.
36+
page = -1 means get last page.
37+
page = n means get nth page.
38+
"""
39+
# get the run
40+
41+
run_uuid = labml_app.db.run.get_main_rank(run_uuid)
42+
if run_uuid is None:
43+
return JSONResponse(status_code=404, content={'message': 'Run not found'})
44+
45+
json = await request.json()
46+
page = json.get('page', -1)
47+
48+
key = StdOutIndex.get(run_uuid)
49+
std_out: StdOutModel
50+
51+
if key is None:
52+
std_out = StdOutModel()
53+
std_out.save()
54+
StdOutIndex.set(run_uuid, std_out.key)
55+
else:
56+
std_out = key.load()
57+
58+
return std_out.get_data(page_no=page)
59+
60+
61+
def update_stdout(run_uuid: str, content: str):
62+
key = StdOutIndex.get(run_uuid)
63+
std_out: StdOutModel
64+
65+
if key is None:
66+
std_out = StdOutModel()
67+
std_out.save()
68+
StdOutIndex.set(run_uuid, std_out.key)
69+
else:
70+
std_out = key.load()
71+
72+
std_out.update_logs(content)
73+
std_out.save()

app/server/labml_app/analyses/logs.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from labml_app.settings import LOG_CHAR_LIMIT
2+
3+
from typing import Any, Dict, List
4+
5+
from labml_db import Key, Model
6+
from labml_db.serializer.pickle import PickleSerializer
7+
8+
from labml_app.analyses.analysis import Analysis
9+
10+
11+
class LogPage:
12+
logs: str
13+
logs_unmerged: str
14+
15+
@classmethod
16+
def defaults(cls):
17+
return dict(logs='', logs_unmerged='')
18+
19+
def update_logs(self, new_logs):
20+
unmerged = self.logs_unmerged + new_logs
21+
processed = ''
22+
if len(new_logs) > 1:
23+
processed, unmerged = self._format_output(unmerged)
24+
25+
self.logs_unmerged = unmerged
26+
self.logs += processed
27+
28+
@staticmethod
29+
def _format_output(output: str) -> (str, str):
30+
res = []
31+
temp = ''
32+
for i, c in enumerate(output):
33+
if c == '\n':
34+
temp += '\n'
35+
res.append(temp)
36+
temp = ''
37+
elif c == '\r' and len(output) > i + 1 and output[i + 1] == '\n':
38+
pass
39+
elif c == '\r':
40+
temp = ''
41+
else:
42+
temp += c
43+
44+
return ''.join(res), temp
45+
46+
def get_data(self) -> Dict[str, Any]:
47+
return {
48+
'logs': self.logs + self.logs_unmerged,
49+
}
50+
51+
def is_full(self):
52+
return len(self.logs) > LOG_CHAR_LIMIT
53+
54+
55+
@Analysis.db_model(PickleSerializer, 'log_page')
56+
class LogPageModel(Model['LogPageModel'], LogPage):
57+
pass
58+
59+
60+
class Logs:
61+
log_pages: List[Key['LogPageModel']]
62+
63+
@classmethod
64+
def defaults(cls):
65+
return dict(
66+
log_pages=[]
67+
)
68+
69+
def get_data(self, page_no: int = -1):
70+
page_dict: Dict[str, str] = {}
71+
72+
if page_no == -2:
73+
pages: List['LogPage'] = [page.load() for page in self.log_pages]
74+
for i, p in enumerate(pages):
75+
page_dict[str(i)] = p.logs + p.logs_unmerged
76+
elif len(self.log_pages) > page_no >= 0:
77+
page = self.log_pages[page_no].load()
78+
page_dict[str(page_no)] = page.logs + page.logs_unmerged
79+
80+
if len(self.log_pages) > 0: # always include the last page
81+
page = self.log_pages[-1].load()
82+
page_dict[str(len(self.log_pages) - 1)] = page.logs + page.logs_unmerged
83+
84+
return {
85+
'pages': page_dict,
86+
'page_length': len(self.log_pages)
87+
}
88+
89+
def update_logs(self, content: str):
90+
if len(self.log_pages) == 0:
91+
page = LogPageModel()
92+
page.save()
93+
self.log_pages.append(page.key)
94+
else:
95+
page = self.log_pages[-1].load()
96+
97+
if page.is_full():
98+
unmerged_logs = page.logs_unmerged
99+
page.logs_unmerged = ''
100+
page.save()
101+
content = unmerged_logs + content
102+
103+
page = LogPageModel()
104+
page.update_logs(content)
105+
page.save()
106+
self.log_pages.append(page.key)
107+
else:
108+
page.update_logs(content)
109+
page.save()

app/server/labml_app/analyses_settings.sample.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from .analyses.computers.disk import DiskAnalysis
1111
from .analyses.computers.process import ProcessAnalysis
1212

13+
from .analyses.experiments.stdout import StdOutModel
14+
from .analyses.experiments.stderr import StdErrModel
15+
from .analyses.experiments.stdlogger import StdLoggerModel
16+
1317
experiment_analyses = [MetricsAnalysis]
1418

1519
computer_analyses = [CPUAnalysis,

0 commit comments

Comments
 (0)