Skip to content

Commit 7564e68

Browse files
committed
fix: lock when adding class
1 parent e9f8c22 commit 7564e68

File tree

4 files changed

+87
-31
lines changed

4 files changed

+87
-31
lines changed

backend/migrations/00062_add_rooms_table.sql renamed to backend/migrations/00063_add_rooms_table.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ CREATE TABLE public.rooms (
77
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
88
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
99
deleted_at TIMESTAMP WITH TIME ZONE,
10+
create_user_id INTEGER REFERENCES public.users(id) ON DELETE SET NULL,
11+
update_user_id INTEGER REFERENCES public.users(id) ON DELETE SET NULL,
1012
FOREIGN KEY (facility_id) REFERENCES public.facilities(id) ON UPDATE CASCADE ON DELETE CASCADE
1113
);
14+
CREATE INDEX idx_rooms_create_user_id ON public.rooms(create_user_id);
15+
CREATE INDEX idx_rooms_update_user_id ON public.rooms(update_user_id);
1216
CREATE INDEX idx_rooms_facility_id ON public.rooms(facility_id);
1317
CREATE INDEX idx_rooms_deleted_at ON public.rooms(deleted_at);
1418
CREATE UNIQUE INDEX idx_rooms_facility_name ON public.rooms(facility_id, name) WHERE deleted_at IS NULL;

backend/src/database/program_classes.go

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package database
22

33
import (
44
"UnlockEdv2/src/models"
5-
"context"
65
"fmt"
76
"strings"
87

@@ -41,27 +40,66 @@ func (db *DB) GetClassesForFacility(args *models.QueryContext) ([]models.Program
4140
return content, nil
4241
}
4342

44-
func (db *DB) CreateProgramClass(content *models.ProgramClass) (*models.ProgramClass, error) {
45-
err := Validate().Struct(content)
43+
func (db *DB) CreateProgramClass(content *models.ProgramClass, conflictReq *models.ConflictCheckRequest) (*models.ProgramClass, []models.RoomConflict, error) {
44+
if err := Validate().Struct(content); err != nil {
45+
return nil, nil, newCreateDBError(err, "create program classes validation error")
46+
}
47+
48+
if conflictReq == nil {
49+
if err := db.Create(&content).Error; err != nil {
50+
return nil, nil, newCreateDBError(err, "program classes")
51+
}
52+
return content, nil, nil
53+
}
54+
55+
tx := db.Begin()
56+
if tx.Error != nil {
57+
return nil, nil, NewDBError(tx.Error, "unable to start transaction")
58+
}
59+
60+
conflicts, err := LockRoomAndCheckConflicts(tx, conflictReq)
4661
if err != nil {
47-
return nil, newCreateDBError(err, "create program classes validation error")
62+
tx.Rollback()
63+
return nil, nil, err
4864
}
49-
if err := db.Create(&content).Error; err != nil {
50-
return nil, newCreateDBError(err, "program classes")
65+
if len(conflicts) > 0 {
66+
tx.Rollback()
67+
return nil, conflicts, nil
5168
}
52-
return content, nil
69+
70+
if err := tx.Create(&content).Error; err != nil {
71+
tx.Rollback()
72+
return nil, nil, newCreateDBError(err, "program classes")
73+
}
74+
75+
if err := tx.Commit().Error; err != nil {
76+
return nil, nil, NewDBError(err, "unable to commit transaction")
77+
}
78+
return content, nil, nil
5379
}
5480

55-
func (db *DB) UpdateProgramClass(content *models.ProgramClass, id int) (*models.ProgramClass, error) {
81+
func (db *DB) UpdateProgramClass(content *models.ProgramClass, id int, conflictReq *models.ConflictCheckRequest) (*models.ProgramClass, []models.RoomConflict, error) {
5682
var allChanges []models.ChangeLogEntry
5783
existing := &models.ProgramClass{}
5884
if err := db.Preload("Events").First(existing, "id = ?", id).Error; err != nil {
59-
return nil, newNotFoundDBError(err, "program classes")
85+
return nil, nil, newNotFoundDBError(err, "program classes")
6086
}
6187

6288
trans := db.Begin()
6389
if trans.Error != nil {
64-
return nil, NewDBError(trans.Error, "unable to start the database transaction")
90+
return nil, nil, NewDBError(trans.Error, "unable to start the database transaction")
91+
}
92+
93+
if conflictReq != nil {
94+
conflicts, err := LockRoomAndCheckConflicts(trans, conflictReq)
95+
if err != nil {
96+
trans.Rollback()
97+
return nil, nil, err
98+
}
99+
if len(conflicts) > 0 {
100+
trans.Rollback()
101+
return nil, conflicts, nil
102+
}
65103
}
66104

67105
ignoredFieldNames := []string{"create_user_id", "update_user_id", "enrollments", "facility", "facilities", "events", "facility_program", "program_id", "start_dt", "end_dt", "program", "enrolled"}
@@ -83,29 +121,29 @@ func (db *DB) UpdateProgramClass(content *models.ProgramClass, id int) (*models.
83121
models.UpdateStruct(existing, content)
84122
if err := trans.Session(&gorm.Session{FullSaveAssociations: false}).Updates(&existing).Error; err != nil {
85123
trans.Rollback()
86-
return nil, newUpdateDBError(err, "program classes")
124+
return nil, nil, newUpdateDBError(err, "program classes")
87125
}
88126

89127
if needsRoomUpdate {
90128
if err := trans.Model(&models.ProgramClassEvent{}).Where("id = ?", eventID).Update("room_id", newRoomID).Error; err != nil {
91129
trans.Rollback()
92-
return nil, newUpdateDBError(err, "program class event room")
130+
return nil, nil, newUpdateDBError(err, "program class event room")
93131
}
94132
existing.Events[0].RoomID = newRoomID
95133
}
96134

97135
if len(allChanges) > 0 {
98136
if err := trans.Create(&allChanges).Error; err != nil {
99137
trans.Rollback()
100-
return nil, newCreateDBError(err, "change_log_entries")
138+
return nil, nil, newCreateDBError(err, "change_log_entries")
101139
}
102140
}
103141

104142
if err := trans.Commit().Error; err != nil {
105-
return nil, NewDBError(err, "unable to commit the database transaction")
143+
return nil, nil, NewDBError(err, "unable to commit the database transaction")
106144
}
107145

108-
return existing, nil
146+
return existing, nil, nil
109147
}
110148

111149
func (db *DB) GetTotalEnrollmentsByClassID(id int) (int64, error) {

backend/src/database/rooms.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99

1010
"github.com/sirupsen/logrus"
1111
"github.com/teambition/rrule-go"
12+
"gorm.io/gorm"
13+
"gorm.io/gorm/clause"
1214
)
1315

1416
func (db *DB) GetRoomsForFacility(facilityID uint) ([]models.Room, error) {
@@ -37,9 +39,21 @@ func (db *DB) CreateRoom(room *models.Room) (*models.Room, error) {
3739
return room, nil
3840
}
3941

42+
func LockRoomAndCheckConflicts(tx *gorm.DB, req *models.ConflictCheckRequest) ([]models.RoomConflict, error) {
43+
var room models.Room
44+
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&room, req.RoomID).Error; err != nil {
45+
return nil, newNotFoundDBError(err, "room")
46+
}
47+
return checkRRuleConflictsInternal(&DB{tx}, req)
48+
}
49+
4050
const maxConflictsToReturn = 50
4151

4252
func (db *DB) CheckRRuleConflicts(req *models.ConflictCheckRequest) ([]models.RoomConflict, error) {
53+
return checkRRuleConflictsInternal(db, req)
54+
}
55+
56+
func checkRRuleConflictsInternal(db *DB, req *models.ConflictCheckRequest) ([]models.RoomConflict, error) {
4357
if req.RecurrenceRule == "" {
4458
return nil, NewDBError(errors.New("recurrence rule is required"), "invalid conflict check request")
4559
}

backend/src/handlers/classes_handler.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,27 @@ func (srv *Server) handleCreateClass(w http.ResponseWriter, r *http.Request, log
102102
claims := r.Context().Value(ClaimsKey).(*Claims)
103103
class.FacilityID = claims.FacilityID
104104
class.ProgramID = uint(id)
105+
106+
var conflictReq *models.ConflictCheckRequest
105107
if len(class.Events) > 0 && class.Events[0].RoomID != nil {
106108
if _, err := srv.Db.GetRoomByIDForFacility(*class.Events[0].RoomID, claims.FacilityID); err != nil {
107109
return newDatabaseServiceError(err)
108110
}
109-
conflicts, err := srv.Db.CheckRRuleConflicts(&models.ConflictCheckRequest{
111+
conflictReq = &models.ConflictCheckRequest{
110112
FacilityID: claims.FacilityID,
111113
RoomID: *class.Events[0].RoomID,
112114
RecurrenceRule: class.Events[0].RecurrenceRule,
113115
Duration: class.Events[0].Duration,
114-
})
115-
if err != nil {
116-
return newDatabaseServiceError(err)
117-
}
118-
if len(conflicts) > 0 {
119-
return writeConflictResponse(w, conflicts)
120116
}
121117
}
122-
newClass, err := srv.WithUserContext(r).CreateProgramClass(&class)
118+
119+
newClass, conflicts, err := srv.WithUserContext(r).CreateProgramClass(&class, conflictReq)
123120
if err != nil {
124121
return newDatabaseServiceError(err)
125122
}
123+
if len(conflicts) > 0 {
124+
return writeConflictResponse(w, conflicts)
125+
}
126126
log.add("program_id", id)
127127
log.add("class_id", newClass.ID)
128128
return writeJsonResponse(w, http.StatusCreated, newClass)
@@ -149,6 +149,8 @@ func (srv *Server) handleUpdateClass(w http.ResponseWriter, r *http.Request, log
149149
}
150150
claims := r.Context().Value(ClaimsKey).(*Claims)
151151
class.UpdateUserID = models.UintPtr(claims.UserID)
152+
153+
var conflictReq *models.ConflictCheckRequest
152154
if len(class.Events) > 0 && class.Events[0].RoomID != nil {
153155
if _, err := srv.Db.GetRoomByIDForFacility(*class.Events[0].RoomID, claims.FacilityID); err != nil {
154156
return newDatabaseServiceError(err)
@@ -162,25 +164,23 @@ func (srv *Server) handleUpdateClass(w http.ResponseWriter, r *http.Request, log
162164
existingRoomID = *existing.Events[0].RoomID
163165
}
164166
if *class.Events[0].RoomID != existingRoomID {
165-
conflicts, err := srv.Db.CheckRRuleConflicts(&models.ConflictCheckRequest{
167+
conflictReq = &models.ConflictCheckRequest{
166168
FacilityID: claims.FacilityID,
167169
RoomID: *class.Events[0].RoomID,
168170
RecurrenceRule: existing.Events[0].RecurrenceRule,
169171
Duration: existing.Events[0].Duration,
170172
ExcludeEventID: &existing.Events[0].ID,
171-
})
172-
if err != nil {
173-
return newDatabaseServiceError(err)
174-
}
175-
if len(conflicts) > 0 {
176-
return writeConflictResponse(w, conflicts)
177173
}
178174
}
179175
}
180-
updated, err := srv.WithUserContext(r).UpdateProgramClass(&class, id)
176+
177+
updated, conflicts, err := srv.WithUserContext(r).UpdateProgramClass(&class, id, conflictReq)
181178
if err != nil {
182179
return newDatabaseServiceError(err)
183180
}
181+
if len(conflicts) > 0 {
182+
return writeConflictResponse(w, conflicts)
183+
}
184184
return writeJsonResponse(w, http.StatusOK, updated)
185185
}
186186

0 commit comments

Comments
 (0)