|
1 | 1 | import logging |
2 | 2 | import os |
3 | | -import traceback |
| 3 | +from contextlib import asynccontextmanager |
4 | 4 | from datetime import datetime |
| 5 | +from typing import List |
5 | 6 |
|
6 | | -import psycopg |
| 7 | +import asyncpg |
7 | 8 | from dotenv import load_dotenv |
8 | | -from flask import Flask, request |
9 | | -from psycopg.rows import dict_row |
| 9 | +from fastapi import Depends, FastAPI, HTTPException, Query |
10 | 10 |
|
11 | 11 | load_dotenv() |
12 | 12 |
|
| 13 | +DATABASE_URL = "postgres://student:[email protected]:5432/chartsdb" |
13 | 14 |
|
14 | | -app = Flask(__name__) |
15 | | -app.config['DATABASE_URL'] = os.getenv('DATABASE_URL') |
| 15 | +# Менеджер контекста для подключения к БД |
| 16 | +@asynccontextmanager |
| 17 | +async def lifespan(app: FastAPI): |
| 18 | + # Создаем пул соединений при старте |
| 19 | + app.state.pool = await asyncpg.create_pool( |
| 20 | + DATABASE_URL, |
| 21 | + min_size=5, |
| 22 | + max_size=20 |
| 23 | + ) |
| 24 | + yield |
| 25 | + # Закрываем пул при выключении |
| 26 | + await app.state.pool.close() |
16 | 27 |
|
| 28 | +app = FastAPI(lifespan=lifespan) |
17 | 29 |
|
18 | | -def get_db(): |
19 | | - return psycopg.connect(app.config['DATABASE_URL']) |
| 30 | +# Зависимость для получения соединения с БД |
| 31 | +async def get_db(): |
| 32 | + async with app.state.pool.acquire() as conn: |
| 33 | + yield conn |
20 | 34 |
|
| 35 | +@app.get("/") |
| 36 | +async def index(): |
| 37 | + return {"status": "It Works"} |
21 | 38 |
|
22 | | -@app.route('/') |
23 | | -def index(): |
24 | | - return 'It Works' |
25 | | - |
26 | | - |
27 | | -@app.get('/visits') |
28 | | -def get_visits(): |
29 | | - query = request.args |
| 39 | +@app.get("/visits") |
| 40 | +async def get_visits( |
| 41 | + begin: datetime = Query(..., description="Start date in ISO format"), |
| 42 | + end: datetime = Query(..., description="End date in ISO format"), |
| 43 | + db = Depends(get_db) |
| 44 | +): |
30 | 45 | try: |
31 | | - begin_date = datetime.fromisoformat(query['begin']) |
32 | | - end_date = datetime.fromisoformat(query['end']) |
33 | | - except ValueError as e: |
34 | | - logging.error(traceback.format_exc(2, chain=False)) |
35 | | - return 'Invalid date. Please check your input', 400 |
36 | | - query = ''' |
37 | | - SELECT * |
38 | | - FROM visits |
39 | | - WHERE visits.datetime BETWEEN (%s) AND (%s);''' |
40 | | - conn = get_db() |
41 | | - with conn.cursor(row_factory=dict_row) as c: |
42 | | - c.execute(query, [begin_date, end_date]) |
43 | | - res = c.fetchall() |
44 | | - return res |
45 | | - |
| 46 | + query = """ |
| 47 | + SELECT * |
| 48 | + FROM visits |
| 49 | + WHERE visits.datetime BETWEEN $1 AND $2; |
| 50 | + """ |
| 51 | + records = await db.fetch(query, begin, end) |
| 52 | + return records |
| 53 | + except Exception as e: |
| 54 | + logging.error(f"Error fetching visits: {str(e)}") |
| 55 | + raise HTTPException( |
| 56 | + status_code=500, |
| 57 | + detail="Internal server error occurred while fetching visits" |
| 58 | + ) |
46 | 59 |
|
47 | | -@app.get('/registrations') |
48 | | -def get_registrations(): |
49 | | - query = request.args |
| 60 | +@app.get("/registrations") |
| 61 | +async def get_registrations( |
| 62 | + begin: datetime = Query(..., description="Start date in ISO format"), |
| 63 | + end: datetime = Query(..., description="End date in ISO format"), |
| 64 | + db = Depends(get_db) |
| 65 | +): |
50 | 66 | try: |
51 | | - begin_date = datetime.fromisoformat(query['begin']) |
52 | | - end_date = datetime.fromisoformat(query['end']) |
53 | | - except ValueError as e: |
54 | | - logging.error(traceback.format_exc(2, chain=False)) |
55 | | - return 'Invalid date. Please check your input', 400 |
56 | | - query = ''' |
57 | | - SELECT * |
58 | | - FROM registrations |
59 | | - WHERE registrations.datetime BETWEEN (%s) AND (%s);''' |
60 | | - conn = get_db() |
61 | | - with conn.cursor(row_factory=dict_row) as c: |
62 | | - c.execute(query, [begin_date, end_date]) |
63 | | - res = c.fetchall() |
64 | | - return res |
| 67 | + query = """ |
| 68 | + SELECT * |
| 69 | + FROM registrations |
| 70 | + WHERE registrations.datetime BETWEEN $1 AND $2; |
| 71 | + """ |
| 72 | + records = await db.fetch(query, begin, end) |
| 73 | + return records |
| 74 | + except Exception as e: |
| 75 | + logging.error(f"Error fetching registrations: {str(e)}") |
| 76 | + raise HTTPException( |
| 77 | + status_code=500, |
| 78 | + detail="Internal server error occurred while fetching registrations" |
| 79 | + ) |
0 commit comments