Skip to content

Commit 33c9801

Browse files
committed
fix: inference api with jwt token
1 parent b7892f9 commit 33c9801

3 files changed

Lines changed: 54 additions & 5 deletions

File tree

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,37 @@ Then interactively run training by :
7171

7272
(remember to set any env variables needed beforehand)
7373

74+
#### Test Inference
75+
76+
```
77+
runai workspace submit openpulse-inference \
78+
-i ghcr.io/sdsc-ordes/open-pulse-graph-classifier-inference:latest \
79+
--image-pull-policy Always \
80+
--gpu-devices-request 1 --preemptible
81+
```
82+
83+
Testing from inside the container:
84+
85+
1. entering into the container: `runai workspace bash openpulse-inference`
86+
2. going to the right repository: `cd ../app/`
87+
3. install curl for all testing the api : `apt-get -y update; apt-get -y install curl`
88+
4. Try the test endpoint:
89+
90+
```bash
91+
curl -X GET "http://localhost:8000/v1/test"
92+
```
93+
94+
the output should be: `{"data":{"type":"test","id":"1","attributes":{"message":"Test endpoint is working!"}}}`
95+
96+
5. Make a valid JWT token with the `utils/api_token.py` (if you have configured secret key etc.)
97+
6. Try the inference endpoint:
98+
99+
```bash
100+
curl -X GET "http://localhost:8000/v1/inference/epfl/YOUR_DB" \
101+
-H "accept: application/json" \
102+
-H "Authorization: Bearer YOUR_JWT_TOKEN"
103+
```
104+
74105
#### Inference
75106

76107
To do: this needs update to new runai cli

src/inference_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
2323
if payload.get("service") != "airflow":
2424
raise HTTPException(status_code=403, detail="Invalid service identity")
2525
return payload
26-
except JWTError:
26+
except JWTError as e:
27+
print("JWT decode error:", e)
2728
raise HTTPException(status_code=403, detail="Invalid token")
2829

2930

@@ -40,12 +41,13 @@ async def test():
4041

4142
@app.get("/v1/inference/epfl/{neo4j_database}")
4243
async def do_inference(neo4j_database: str, token_data: dict = Depends(verify_jwt)):
43-
output = inference(neo4j_database)
44-
# should the predictions be returned in the response or should it be saved to NEO4J or something?
44+
inference(neo4j_database)
4545
return {
4646
"data": {
4747
"type": "inference",
4848
"id": neo4j_database,
49-
"attributes": {"results": output},
49+
"attributes": {
50+
"results": "Predictions were uploaded to your Neo4J database under the node property 'predition'."
51+
},
5052
}
5153
}

src/utils/api_token.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from dotenv import load_dotenv
33
import os
44
import requests
5+
from datetime import datetime, timedelta
6+
from fastapi.security import HTTPAuthorizationCredentials
7+
8+
from inference_api import verify_jwt
59

610
load_dotenv()
711
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY")
@@ -17,8 +21,20 @@
1721
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE")
1822

1923
# FOR RUNAI INFERENCE API
20-
payload = {"sub": AIRFLOW_SUB, "service": AIRFLOW_SERVICE}
24+
print("All configurations for token in place:")
25+
print(JWT_SECRET_KEY, JWT_ALGORITHM)
26+
payload = {
27+
"sub": AIRFLOW_SUB,
28+
"service": AIRFLOW_SERVICE,
29+
"exp": datetime.utcnow() + timedelta(hours=1),
30+
"iat": datetime.utcnow(),
31+
}
2132
token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
33+
print("Validating token")
34+
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
35+
print(credentials)
36+
payload = verify_jwt(credentials)
37+
print(payload)
2238
print(f"Generated token: {token}")
2339

2440
# for AIRFLOW API

0 commit comments

Comments
 (0)