Skip to content

Commit e918d4e

Browse files
committed
✨ feat(app): Add serving part
FInish the API design, make the inference work, publish the first dockable version
1 parent c86bc36 commit e918d4e

File tree

12 files changed

+683
-3
lines changed

12 files changed

+683
-3
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,7 @@ workplace/
140140

141141
# logs
142142
logs/
143+
144+
# models directory
145+
models/*
146+
!models/.gitkeep

configs/app.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# @package _global_
2+
3+
# specify here default configuration
4+
# order of defaults determines the order in which configs override each other
5+
defaults:
6+
- _self_
7+
- model: deeplabv3p.yaml
8+
- app: default.yaml
9+
- paths: default.yaml
10+
- hydra: default.yaml
11+
12+
# experiment configs allow for version control of specific hyperparameters
13+
# e.g. best hyperparameters for given model and datamodule
14+
- experiment: null
15+
16+
# optional local config for machine/user specific settings
17+
# it's optional since it doesn't need to exist and is excluded from version control
18+
- optional local: default.yaml
19+
20+
# debugging config (enable through command line, e.g. `python train.py debug=default)
21+
- debug: null
22+
23+
# task name, determines output directory path
24+
task_name: "app"
25+
26+
# the checkpoint path to do inference
27+
ckpt_path: null
28+
29+
# the device to do inference
30+
device: cpu

configs/app/default.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# the host address of the application
2+
host: 127.0.0.1
3+
4+
# the port to run the application on
5+
port: 8080
6+
7+
# the step length of the sliding window in number of samples
8+
hop_length_in_npts: 2400

configs/experiment/app_serve.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package _global_
2+
3+
ckpt_path: "${paths.root_dir}/models/model.ckpt"

models/.gitkeep

Whitespace-only changes.

notebook/test_inference_api.ipynb

Lines changed: 218 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ hydra-submitit-launcher = "^1.2.0"
1616
lightning = "~2.0"
1717
matplotlib = "^3.7.1"
1818
numpy = "^1.24.2"
19+
obspy = "^1.4.0"
1920
pandas = "^2.0.0"
2021
pyrootutils = "^1.0.4"
2122
python = ">=3.9,<3.12"
@@ -25,13 +26,18 @@ segmentation-models-pytorch = "^0.3.2"
2526
torch = "^2.0.0"
2627
torchaudio = "^2.0.1"
2728
wandb = "^0.14.2"
28-
obspy = "^1.4.0"
2929

3030
[tool.poetry.group.dev.dependencies]
3131
autopep8 = "^2.0.2"
3232
ipykernel = "^6.22.0"
33-
pylint = "^2.17.2"
3433
joblib = "^1.2.0"
34+
pylint = "^2.17.2"
35+
36+
[tool.poetry.group.api.dependencies]
37+
fastapi = "0.86"
38+
pydantic = "^1.10.7"
39+
uvicorn = "^0.21.1"
40+
httpx = "^0.24.0"
3541

3642
[build-system]
3743
build-backend = "poetry.core.masonry.api"

src/app.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
app.py: serve the model as a REST API
3+
"""
4+
from pathlib import Path
5+
from typing import Dict, List, Optional
6+
7+
import hydra
8+
import pandas as pd
9+
import pyrootutils
10+
import torch
11+
import uvicorn
12+
from fastapi import FastAPI
13+
from lightning import LightningModule
14+
from omegaconf import DictConfig
15+
from pydantic import BaseModel
16+
17+
from src import utils
18+
from src.inference.single_inference import single_inference
19+
20+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
21+
# ------------------------------------------------------------------------------------ #
22+
# the setup_root above is equivalent to:
23+
# - adding project root dir to PYTHONPATH
24+
# (so you don't need to force user to install project as a package)
25+
# (necessary before importing any local modules e.g. `from src import utils`)
26+
# - setting up PROJECT_ROOT environment variable
27+
# (which is used as a base for paths in "configs/paths/default.yaml")
28+
# (this way all filepaths are the same no matter where you run the code)
29+
# - loading environment variables from ".env" in root dir
30+
#
31+
# you can remove it if you:
32+
# 1. either install project as a package or move entry files to project root dir
33+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
34+
#
35+
# more info: https://github.com/ashleve/pyrootutils
36+
# ------------------------------------------------------------------------------------ #
37+
38+
39+
log = utils.get_pylogger(__name__)
40+
app = FastAPI()
41+
42+
# global variables
43+
model: Optional[LightningModule] = None
44+
dt_s: float = 0.0
45+
model_phases: List[str] = []
46+
window_length_in_npts: int = 0
47+
hop_length_in_npts: int = 0
48+
sensitive_distances_in_seconds: float = 0.0
49+
50+
51+
@hydra.main(version_base="1.3", config_path="../configs", config_name="app.yaml")
52+
def main(cfg: DictConfig) -> None:
53+
global model, dt_s, model_phases, window_length_in_npts, hop_length_in_npts, sensitive_distances_in_seconds
54+
# apply extra utilities
55+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
56+
utils.extras(cfg)
57+
58+
if not Path(cfg.get("ckpt_path", "")).is_file():
59+
raise FileNotFoundError(f"Checkpoint file not found: {cfg.ckpt_path}")
60+
61+
# load model from checkpoint
62+
log.info(f"Instantiating model <{cfg.model._target_}>")
63+
Model = hydra.utils.get_class(cfg.model._target_)
64+
65+
# save_hyperparameters excludes net and sgram_generator in core_module.py, so we need to pass them in
66+
model = Model.load_from_checkpoint(
67+
cfg.ckpt_path, map_location=torch.device(cfg.get("device", "cpu")))
68+
69+
# load other parameters
70+
dt_s = cfg.model.get("dt_s", 0.025)
71+
model_phases = cfg.model.get("phases", ["P", "S", "PS"])
72+
window_length_in_npts = cfg.model.get("window_length_in_npts", 4800)
73+
hop_length_in_npts = cfg.app.get("hop_length_in_npts", 2400)
74+
sensitive_distances_in_seconds = cfg.model.get(
75+
"extract_peaks_sensitive_distances_in_seconds", 5.0)
76+
77+
uvicorn.run(app, host=cfg.app.host, port=cfg.app.port)
78+
79+
80+
class PredictionRequest(BaseModel):
81+
"""
82+
Request body for prediction endpoint.
83+
84+
Args:
85+
id (List[str]): List of IDs.
86+
timestamp (List[str]): List of timestamps.
87+
vec (List[List[List[float]]]): List of 3 X NPTS points.
88+
"""
89+
id: List[str]
90+
timestamp: List[str]
91+
vec: List[List[List[float]]]
92+
extract_phases: List[str]
93+
extract_phase_sensitivity: List[float]
94+
95+
96+
class PredictResponse(BaseModel):
97+
"""
98+
Response body for prediction endpoint.
99+
100+
Args:
101+
id (List[str]): List of IDs.
102+
possibility (List[Dict[str, List[float]]]): List of possibility, the keys of the dict are phases, and the values are the possibility of the phases.
103+
arrivals (List[Dict[str, List[float]]]): List of arrivals, the keys of the dict are phases, and the values are the arrival index of the phases.
104+
amps (List[Dict[str, List[float]]]): List of amps, the keys of the dict are phases, and the values are the amps of the phases possibility.
105+
arrival_times (List[Dict[str, List[float]]]): List of arrival times, the keys of the dict are phases, and the values are the arrival times of the phases.
106+
"""
107+
id: List[str]
108+
possibility: List[Dict[str, List[float]]]
109+
arrivals: List[Dict[str, List[float]]]
110+
amps: List[Dict[str, List[float]]]
111+
arrival_times: List[Dict[str, List[str]]]
112+
113+
114+
@app.post("/predict", response_model=PredictResponse)
115+
async def predict(request: PredictionRequest) -> List[Dict[str, List[float]]]:
116+
"""
117+
Prediction endpoint.
118+
119+
Args:
120+
request (PredictionRequest): Request body.
121+
122+
Returns:
123+
List[Dict[str, List[float]]]: List of predictions, with keys "arrivals", "amps", "arrival_times", and "possibility".
124+
"""
125+
number_of_traces = len(request.vec)
126+
res = []
127+
for itrace in range(number_of_traces):
128+
# run inference
129+
pred = single_inference(
130+
model=model,
131+
data=torch.tensor(request.vec[itrace], dtype=torch.float32),
132+
extract_phases=request.extract_phases,
133+
extract_phase_sensitivity=request.extract_phase_sensitivity,
134+
model_phases=model_phases,
135+
window_length_in_npts=window_length_in_npts,
136+
hop_length_in_npts=hop_length_in_npts,
137+
dt_s=dt_s,
138+
sensitive_distances_in_seconds=sensitive_distances_in_seconds
139+
)
140+
141+
# the keys of the dict are arrivals, amps, and possibility
142+
# create a new arrival_times list from the arrivals dict and timestamp
143+
# example request.timestamp[itrace]: 2020-01-01 00:00:00.000
144+
start_time = pd.Timestamp(request.timestamp[itrace])
145+
pred["arrival_times"] = {}
146+
for phase in pred["arrivals"]:
147+
pred["arrival_times"][phase] = [
148+
start_time + pd.Timedelta(seconds=arrival*dt_s) for arrival in pred["arrivals"][phase]]
149+
pred["arrival_times"][phase] = [arrival.strftime(
150+
"%Y-%m-%d %H:%M:%S.%f") for arrival in pred["arrival_times"][phase]]
151+
152+
# append to res
153+
res.append(pred)
154+
155+
# convert to PredictResponse
156+
res = PredictResponse(
157+
id=request.id,
158+
possibility=[pred["possibility"] for pred in res],
159+
arrivals=[pred["arrivals"] for pred in res],
160+
amps=[pred["amps"] for pred in res],
161+
arrival_times=[pred["arrival_times"] for pred in res]
162+
)
163+
164+
return res
165+
166+
167+
if __name__ == "__main__":
168+
main()

src/inference/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)