|
| 1 | +""" |
| 2 | +app.py: serve the model as a REST API |
| 3 | +""" |
| 4 | +from pathlib import Path |
| 5 | +from typing import Dict, List, Optional |
| 6 | + |
| 7 | +import hydra |
| 8 | +import pandas as pd |
| 9 | +import pyrootutils |
| 10 | +import torch |
| 11 | +import uvicorn |
| 12 | +from fastapi import FastAPI |
| 13 | +from lightning import LightningModule |
| 14 | +from omegaconf import DictConfig |
| 15 | +from pydantic import BaseModel |
| 16 | + |
| 17 | +from src import utils |
| 18 | +from src.inference.single_inference import single_inference |
| 19 | + |
| 20 | +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| 21 | +# ------------------------------------------------------------------------------------ # |
| 22 | +# the setup_root above is equivalent to: |
| 23 | +# - adding project root dir to PYTHONPATH |
| 24 | +# (so you don't need to force user to install project as a package) |
| 25 | +# (necessary before importing any local modules e.g. `from src import utils`) |
| 26 | +# - setting up PROJECT_ROOT environment variable |
| 27 | +# (which is used as a base for paths in "configs/paths/default.yaml") |
| 28 | +# (this way all filepaths are the same no matter where you run the code) |
| 29 | +# - loading environment variables from ".env" in root dir |
| 30 | +# |
| 31 | +# you can remove it if you: |
| 32 | +# 1. either install project as a package or move entry files to project root dir |
| 33 | +# 2. set `root_dir` to "." in "configs/paths/default.yaml" |
| 34 | +# |
| 35 | +# more info: https://github.com/ashleve/pyrootutils |
| 36 | +# ------------------------------------------------------------------------------------ # |
| 37 | + |
| 38 | + |
| 39 | +log = utils.get_pylogger(__name__) |
| 40 | +app = FastAPI() |
| 41 | + |
| 42 | +# global variables |
| 43 | +model: Optional[LightningModule] = None |
| 44 | +dt_s: float = 0.0 |
| 45 | +model_phases: List[str] = [] |
| 46 | +window_length_in_npts: int = 0 |
| 47 | +hop_length_in_npts: int = 0 |
| 48 | +sensitive_distances_in_seconds: float = 0.0 |
| 49 | + |
| 50 | + |
| 51 | +@hydra.main(version_base="1.3", config_path="../configs", config_name="app.yaml") |
| 52 | +def main(cfg: DictConfig) -> None: |
| 53 | + global model, dt_s, model_phases, window_length_in_npts, hop_length_in_npts, sensitive_distances_in_seconds |
| 54 | + # apply extra utilities |
| 55 | + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) |
| 56 | + utils.extras(cfg) |
| 57 | + |
| 58 | + if not Path(cfg.get("ckpt_path", "")).is_file(): |
| 59 | + raise FileNotFoundError(f"Checkpoint file not found: {cfg.ckpt_path}") |
| 60 | + |
| 61 | + # load model from checkpoint |
| 62 | + log.info(f"Instantiating model <{cfg.model._target_}>") |
| 63 | + Model = hydra.utils.get_class(cfg.model._target_) |
| 64 | + |
| 65 | + # save_hyperparameters excludes net and sgram_generator in core_module.py, so we need to pass them in |
| 66 | + model = Model.load_from_checkpoint( |
| 67 | + cfg.ckpt_path, map_location=torch.device(cfg.get("device", "cpu"))) |
| 68 | + |
| 69 | + # load other parameters |
| 70 | + dt_s = cfg.model.get("dt_s", 0.025) |
| 71 | + model_phases = cfg.model.get("phases", ["P", "S", "PS"]) |
| 72 | + window_length_in_npts = cfg.model.get("window_length_in_npts", 4800) |
| 73 | + hop_length_in_npts = cfg.app.get("hop_length_in_npts", 2400) |
| 74 | + sensitive_distances_in_seconds = cfg.model.get( |
| 75 | + "extract_peaks_sensitive_distances_in_seconds", 5.0) |
| 76 | + |
| 77 | + uvicorn.run(app, host=cfg.app.host, port=cfg.app.port) |
| 78 | + |
| 79 | + |
| 80 | +class PredictionRequest(BaseModel): |
| 81 | + """ |
| 82 | + Request body for prediction endpoint. |
| 83 | +
|
| 84 | + Args: |
| 85 | + id (List[str]): List of IDs. |
| 86 | + timestamp (List[str]): List of timestamps. |
| 87 | + vec (List[List[List[float]]]): List of 3 X NPTS points. |
| 88 | + """ |
| 89 | + id: List[str] |
| 90 | + timestamp: List[str] |
| 91 | + vec: List[List[List[float]]] |
| 92 | + extract_phases: List[str] |
| 93 | + extract_phase_sensitivity: List[float] |
| 94 | + |
| 95 | + |
| 96 | +class PredictResponse(BaseModel): |
| 97 | + """ |
| 98 | + Response body for prediction endpoint. |
| 99 | +
|
| 100 | + Args: |
| 101 | + id (List[str]): List of IDs. |
| 102 | + possibility (List[Dict[str, List[float]]]): List of possibility, the keys of the dict are phases, and the values are the possibility of the phases. |
| 103 | + arrivals (List[Dict[str, List[float]]]): List of arrivals, the keys of the dict are phases, and the values are the arrival index of the phases. |
| 104 | + amps (List[Dict[str, List[float]]]): List of amps, the keys of the dict are phases, and the values are the amps of the phases possibility. |
| 105 | + arrival_times (List[Dict[str, List[float]]]): List of arrival times, the keys of the dict are phases, and the values are the arrival times of the phases. |
| 106 | + """ |
| 107 | + id: List[str] |
| 108 | + possibility: List[Dict[str, List[float]]] |
| 109 | + arrivals: List[Dict[str, List[float]]] |
| 110 | + amps: List[Dict[str, List[float]]] |
| 111 | + arrival_times: List[Dict[str, List[str]]] |
| 112 | + |
| 113 | + |
| 114 | +@app.post("/predict", response_model=PredictResponse) |
| 115 | +async def predict(request: PredictionRequest) -> List[Dict[str, List[float]]]: |
| 116 | + """ |
| 117 | + Prediction endpoint. |
| 118 | +
|
| 119 | + Args: |
| 120 | + request (PredictionRequest): Request body. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + List[Dict[str, List[float]]]: List of predictions, with keys "arrivals", "amps", "arrival_times", and "possibility". |
| 124 | + """ |
| 125 | + number_of_traces = len(request.vec) |
| 126 | + res = [] |
| 127 | + for itrace in range(number_of_traces): |
| 128 | + # run inference |
| 129 | + pred = single_inference( |
| 130 | + model=model, |
| 131 | + data=torch.tensor(request.vec[itrace], dtype=torch.float32), |
| 132 | + extract_phases=request.extract_phases, |
| 133 | + extract_phase_sensitivity=request.extract_phase_sensitivity, |
| 134 | + model_phases=model_phases, |
| 135 | + window_length_in_npts=window_length_in_npts, |
| 136 | + hop_length_in_npts=hop_length_in_npts, |
| 137 | + dt_s=dt_s, |
| 138 | + sensitive_distances_in_seconds=sensitive_distances_in_seconds |
| 139 | + ) |
| 140 | + |
| 141 | + # the keys of the dict are arrivals, amps, and possibility |
| 142 | + # create a new arrival_times list from the arrivals dict and timestamp |
| 143 | + # example request.timestamp[itrace]: 2020-01-01 00:00:00.000 |
| 144 | + start_time = pd.Timestamp(request.timestamp[itrace]) |
| 145 | + pred["arrival_times"] = {} |
| 146 | + for phase in pred["arrivals"]: |
| 147 | + pred["arrival_times"][phase] = [ |
| 148 | + start_time + pd.Timedelta(seconds=arrival*dt_s) for arrival in pred["arrivals"][phase]] |
| 149 | + pred["arrival_times"][phase] = [arrival.strftime( |
| 150 | + "%Y-%m-%d %H:%M:%S.%f") for arrival in pred["arrival_times"][phase]] |
| 151 | + |
| 152 | + # append to res |
| 153 | + res.append(pred) |
| 154 | + |
| 155 | + # convert to PredictResponse |
| 156 | + res = PredictResponse( |
| 157 | + id=request.id, |
| 158 | + possibility=[pred["possibility"] for pred in res], |
| 159 | + arrivals=[pred["arrivals"] for pred in res], |
| 160 | + amps=[pred["amps"] for pred in res], |
| 161 | + arrival_times=[pred["arrival_times"] for pred in res] |
| 162 | + ) |
| 163 | + |
| 164 | + return res |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == "__main__": |
| 168 | + main() |
0 commit comments