forked from rust-lang/triagebot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb.rs
More file actions
373 lines (336 loc) · 10.9 KB
/
db.rs
File metadata and controls
373 lines (336 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
use crate::{db::jobs::*, handlers::Context, jobs::jobs};
use anyhow::Context as _;
use chrono::Utc;
use native_tls::{Certificate, TlsConnector};
use postgres_native_tls::MakeTlsConnector;
use std::sync::{Arc, LazyLock, Mutex};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_postgres::Client as DbClient;
pub mod issue_data;
pub mod jobs;
pub mod notifications;
pub mod review_prefs;
pub mod rustc_commits;
pub mod users;
const CERT_URL: &str = "https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem";
static CERTIFICATE_PEMS: LazyLock<Vec<u8>> = LazyLock::new(|| {
let client = reqwest::blocking::Client::new();
let resp = client.get(CERT_URL).send().expect("failed to get RDS cert");
resp.bytes().expect("failed to get RDS cert body").to_vec()
});
pub struct ClientPool {
connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
permits: Arc<Semaphore>,
db_url: String,
}
pub struct PooledClient {
client: Option<tokio_postgres::Client>,
#[allow(unused)] // only used for drop impl
permit: OwnedSemaphorePermit,
pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
}
impl Drop for PooledClient {
fn drop(&mut self) {
let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
clients.push(self.client.take().unwrap());
}
}
impl std::ops::Deref for PooledClient {
type Target = tokio_postgres::Client;
fn deref(&self) -> &Self::Target {
self.client.as_ref().unwrap()
}
}
impl std::ops::DerefMut for PooledClient {
fn deref_mut(&mut self) -> &mut Self::Target {
self.client.as_mut().unwrap()
}
}
impl ClientPool {
pub fn new(db_url: String) -> ClientPool {
ClientPool {
connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
permits: Arc::new(Semaphore::new(16)),
db_url,
}
}
pub async fn get(&self) -> PooledClient {
let permit = self.permits.clone().acquire_owned().await.unwrap();
{
let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
// Pop connections until we hit a non-closed connection (or there are no
// "possibly open" connections left).
while let Some(c) = slots.pop() {
if !c.is_closed() {
return PooledClient {
client: Some(c),
permit,
pool: self.connections.clone(),
};
}
}
}
PooledClient {
client: Some(make_client(&self.db_url).await.unwrap()),
permit,
pool: self.connections.clone(),
}
}
}
pub async fn make_client(db_url: &str) -> anyhow::Result<tokio_postgres::Client> {
if db_url.contains("rds.amazonaws.com") {
let mut builder = TlsConnector::builder();
for cert in make_certificates() {
builder.add_root_certificate(cert);
}
let connector = builder.build().context("built TlsConnector")?;
let connector = MakeTlsConnector::new(connector);
let (db_client, connection) = match tokio_postgres::connect(db_url, connector).await {
Ok(v) => v,
Err(e) => {
anyhow::bail!("failed to connect to DB: {}", e);
}
};
tokio::task::spawn(async move {
if let Err(e) = connection.await {
eprintln!("database connection error: {e}");
}
});
Ok(db_client)
} else {
eprintln!("Warning: Non-TLS connection to non-RDS DB");
let (db_client, connection) =
match tokio_postgres::connect(db_url, tokio_postgres::NoTls).await {
Ok(v) => v,
Err(e) => {
anyhow::bail!("failed to connect to DB: {}", e);
}
};
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("database connection error: {e}");
}
});
Ok(db_client)
}
}
fn make_certificates() -> Vec<Certificate> {
use x509_cert::der::EncodePem;
use x509_cert::der::pem::LineEnding;
let certs = x509_cert::Certificate::load_pem_chain(&CERTIFICATE_PEMS[..]).unwrap();
certs
.into_iter()
.map(|cert| Certificate::from_pem(cert.to_pem(LineEnding::LF).unwrap().as_bytes()).unwrap())
.collect()
}
// Makes sure we successfully parse the RDS certificates and load them into native-tls compatible
// format.
#[test]
fn cert() {
make_certificates();
}
pub async fn run_migrations(client: &mut DbClient) -> anyhow::Result<()> {
client
.execute(
"CREATE TABLE IF NOT EXISTS database_versions (
zero INTEGER PRIMARY KEY,
migration_counter INTEGER
);",
&[],
)
.await
.context("creating database versioning table")?;
client
.execute(
"INSERT INTO database_versions (zero, migration_counter)
VALUES (0, 0)
ON CONFLICT DO NOTHING",
&[],
)
.await
.context("inserting initial database_versions")?;
let migration_idx: i32 = client
.query_one("SELECT migration_counter FROM database_versions", &[])
.await
.context("getting migration counter")?
.get(0);
let migration_idx = migration_idx as usize;
for (idx, migration) in MIGRATIONS.iter().enumerate() {
if idx >= migration_idx {
let tx = client
.transaction()
.await
.context("Cannot create migration transaction")?;
tx.execute(*migration, &[])
.await
.with_context(|| format!("executing {idx}th migration"))?;
tx.execute(
"UPDATE database_versions SET migration_counter = $1",
&[&(idx as i32 + 1)],
)
.await
.with_context(|| format!("updating migration counter to {idx}"))?;
tx.commit()
.await
.context("Cannot commit migration transaction")?;
}
}
Ok(())
}
pub async fn schedule_jobs(db: &DbClient, jobs: Vec<JobSchedule>) -> anyhow::Result<()> {
for job in jobs {
let mut upcoming = job.schedule.upcoming(Utc).take(1);
if let Some(scheduled_at) = upcoming.next() {
schedule_job(db, job.name, job.metadata, scheduled_at).await?;
}
}
Ok(())
}
pub async fn schedule_job(
db: &DbClient,
job_name: &str,
job_metadata: serde_json::Value,
when: chrono::DateTime<Utc>,
) -> anyhow::Result<()> {
let all_jobs = jobs();
if !all_jobs.iter().any(|j| j.name() == job_name) {
anyhow::bail!("Job {} does not exist in the current job list.", job_name);
}
if get_job_by_name_and_scheduled_at(db, job_name, &when)
.await
.is_err()
{
// means there's no job already in the db with that name and scheduled_at
insert_job(db, job_name, &when, &job_metadata).await?;
}
Ok(())
}
pub async fn run_scheduled_jobs(ctx: &Context) -> anyhow::Result<()> {
let db = &ctx.db.get().await;
let jobs = get_jobs_to_execute(db).await?;
tracing::trace!("jobs to execute: {jobs:#?}");
for job in &jobs {
update_job_executed_at(db, &job.id).await?;
match handle_job(ctx, &job.name, &job.metadata).await {
Ok(()) => {
tracing::trace!("job successfully executed (id={})", job.id);
delete_job(db, &job.id).await?;
}
Err(e) => {
tracing::error!("job failed on execution (id={:?}, error={e:?})", job.id);
update_job_error_message(db, &job.id, &e.to_string()).await?;
}
}
}
Ok(())
}
// Try to handle a specific job
async fn handle_job(ctx: &Context, name: &str, metadata: &serde_json::Value) -> anyhow::Result<()> {
for job in jobs() {
if job.name() == name {
return job.run(ctx, metadata).await;
}
}
tracing::trace!("handle_job fell into default case: (name={name:?}, metadata={metadata:?})");
Ok(())
}
// Important notes when adding migrations:
// - Each DB change is an element in this array and must be a single SQL instruction
// - The total # of items in this array must be equal to the value of `database_versions.migration_counter`
static MIGRATIONS: &[&str] = &[
"
CREATE TABLE notifications (
notification_id BIGSERIAL PRIMARY KEY,
user_id BIGINT,
origin_url TEXT NOT NULL,
origin_html TEXT,
time TIMESTAMP WITH TIME ZONE
);
",
"
CREATE TABLE users (
user_id BIGINT PRIMARY KEY,
username TEXT NOT NULL
);
",
"ALTER TABLE notifications ADD COLUMN short_description TEXT;",
"ALTER TABLE notifications ADD COLUMN team_name TEXT;",
"ALTER TABLE notifications ADD COLUMN idx INTEGER;",
"ALTER TABLE notifications ADD COLUMN metadata TEXT;",
"
CREATE TABLE rustc_commits (
sha TEXT PRIMARY KEY,
parent_sha TEXT NOT NULL,
time TIMESTAMP WITH TIME ZONE
);
",
"ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;",
"
CREATE TABLE issue_data (
repo TEXT,
issue_number INTEGER,
key TEXT,
data JSONB,
PRIMARY KEY (repo, issue_number, key)
);
",
"
CREATE TABLE jobs (
id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
name TEXT NOT NULL,
scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL,
metadata JSONB,
executed_at TIMESTAMP WITH TIME ZONE,
error_message TEXT
);
",
"
CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index
ON jobs (
name, scheduled_at
);
",
"
CREATE table review_prefs (
id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
user_id BIGINT REFERENCES users(user_id),
assigned_prs INT[] NOT NULL DEFAULT array[]::INT[]
);",
"
CREATE EXTENSION IF NOT EXISTS intarray;",
"
CREATE UNIQUE INDEX IF NOT EXISTS review_prefs_user_id ON review_prefs(user_id);
",
"
ALTER TABLE review_prefs ADD COLUMN IF NOT EXISTS max_assigned_prs INTEGER DEFAULT NULL;
",
"
ALTER TABLE review_prefs ADD COLUMN IF NOT EXISTS rotation_mode TEXT NOT NULL DEFAULT 'on-rotation';
",
r#"
CREATE TABLE IF NOT EXISTS team_review_prefs (
id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
user_id BIGINT REFERENCES users(user_id),
team TEXT NOT NULL,
rotation_mode TEXT NOT NULL,
UNIQUE(user_id, team)
);
"#,
r#"
CREATE TABLE IF NOT EXISTS repo_review_prefs (
id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
user_id BIGINT REFERENCES users(user_id),
repo TEXT NOT NULL,
max_assigned_prs INTEGER DEFAULT NULL,
UNIQUE(user_id, repo)
);
"#,
// Backfill existing repository preferences - the global ones were treated as being for
// rust-lang/rust
r#"
INSERT INTO repo_review_prefs(user_id, repo, max_assigned_prs)
SELECT user_id, 'rust-lang/rust', max_assigned_prs
FROM review_prefs
WHERE max_assigned_prs IS NOT NULL
"#,
];