|
20 | 20 |
|
21 | 21 | def upgrade() -> None: |
22 | 22 | """Upgrade schema.""" |
23 | | - pass |
| 23 | + conn = op.get_bind() |
| 24 | + inspector = sa.inspect(conn) |
| 25 | + columns = [col['name'] for col in inspector.get_columns('users')] |
| 26 | + |
| 27 | + if 'id' not in columns: |
| 28 | + # 1. Create a temporary table with the correct current schema |
| 29 | + op.create_table('users_fix', |
| 30 | + sa.Column('id', sa.Integer(), primary_key=True), |
| 31 | + sa.Column('name', sa.String(), nullable=True), |
| 32 | + sa.Column('phone_number', sa.String(), unique=True, nullable=True), |
| 33 | + sa.Column('email', sa.String(), unique=True, nullable=True), |
| 34 | + sa.Column('business_id', sa.Integer(), sa.ForeignKey('businesses.id'), nullable=False), |
| 35 | + sa.Column('role', sa.String(), nullable=False), |
| 36 | + sa.Column('created_at', sa.DateTime(), nullable=False), |
| 37 | + sa.Column('preferred_channel', sa.String(), nullable=False, server_default='WHATSAPP'), |
| 38 | + sa.Column('preferences', sa.JSON(), nullable=False), |
| 39 | + sa.Column('timezone', sa.String(), nullable=False, server_default='UTC'), |
| 40 | + sa.Column('default_start_location_lat', sa.Float(), nullable=True), |
| 41 | + sa.Column('default_start_location_lng', sa.Float(), nullable=True), |
| 42 | + sa.Column('google_calendar_credentials', sa.JSON(), nullable=True), |
| 43 | + sa.Column('google_calendar_sync_enabled', sa.Boolean(), nullable=False, server_default='0'), |
| 44 | + sa.Column('clerk_id', sa.String(), unique=True, nullable=True), |
| 45 | + sa.Column('current_latitude', sa.Float(), nullable=True), |
| 46 | + sa.Column('current_longitude', sa.Float(), nullable=True), |
| 47 | + sa.Column('location_updated_at', sa.DateTime(), nullable=True), |
| 48 | + sa.Column('current_shift_start', sa.DateTime(), nullable=True), |
| 49 | + sa.Column('geocoding_count', sa.Integer(), nullable=False, server_default='0'), |
| 50 | + ) |
| 51 | + |
| 52 | + # 2. Copy data from old users table to users_fix |
| 53 | + # We map columns that we know exist in the old schema |
| 54 | + old_cols = [c for c in columns if c in [ |
| 55 | + 'name', 'phone_number', 'email', 'business_id', 'role', 'created_at', |
| 56 | + 'preferred_channel', 'preferences', 'timezone', 'default_start_location_lat', |
| 57 | + 'default_start_location_lng', 'google_calendar_credentials', |
| 58 | + 'google_calendar_sync_enabled', 'clerk_id', 'current_latitude', |
| 59 | + 'current_longitude', 'location_updated_at', 'current_shift_start', 'geocoding_count' |
| 60 | + ]] |
| 61 | + col_list = ", ".join(old_cols) |
| 62 | + op.execute(f"INSERT INTO users_fix ({col_list}) SELECT {col_list} FROM users") |
| 63 | + |
| 64 | + # 3. Swap tables |
| 65 | + op.drop_table('users') |
| 66 | + op.rename_table('users_fix', 'users') |
| 67 | + |
| 68 | + # 4. Re-create indexes |
| 69 | + op.create_index(op.f('ix_users_business_id'), 'users', ['business_id'], unique=False) |
| 70 | + |
| 71 | + # 5. Fix messages table if it has user_id but it's null/incorrect |
| 72 | + msg_cols = [col['name'] for col in inspector.get_columns('messages')] |
| 73 | + if 'user_id' in msg_cols: |
| 74 | + op.execute(""" |
| 75 | + UPDATE messages |
| 76 | + SET user_id = (SELECT u.id FROM users u WHERE u.phone_number = messages.from_number) |
| 77 | + WHERE role = 'USER' AND user_id IS NULL |
| 78 | + """) |
| 79 | + op.execute(""" |
| 80 | + UPDATE messages |
| 81 | + SET user_id = (SELECT u.id FROM users u WHERE u.phone_number = messages.to_number) |
| 82 | + WHERE role = 'ASSISTANT' AND user_id IS NULL |
| 83 | + """) |
| 84 | + |
| 85 | + # 6. Fix conversation_states table if it uses phone_number instead of user_id |
| 86 | + cs_cols = [col['name'] for col in inspector.get_columns('conversation_states')] |
| 87 | + if 'user_id' not in cs_cols and 'phone_number' in cs_cols: |
| 88 | + op.create_table('cs_fix', |
| 89 | + sa.Column('user_id', sa.Integer(), sa.ForeignKey('users.id'), primary_key=True), |
| 90 | + sa.Column('state', sa.String(), nullable=False), |
| 91 | + sa.Column('draft_data', sa.JSON(), nullable=True), |
| 92 | + sa.Column('last_action_metadata', sa.JSON(), nullable=True), |
| 93 | + sa.Column('last_updated', sa.DateTime(), nullable=False), |
| 94 | + sa.Column('pending_action_timestamp', sa.DateTime(), nullable=True), |
| 95 | + sa.Column('pending_action_payload', sa.JSON(), nullable=True), |
| 96 | + sa.Column('active_channel', sa.String(), nullable=False, server_default='WHATSAPP'), |
| 97 | + ) |
| 98 | + op.execute(""" |
| 99 | + INSERT INTO cs_fix (user_id, state, draft_data, last_action_metadata, last_updated) |
| 100 | + SELECT u.id, cs.state, cs.draft_data, cs.last_action_metadata, cs.last_updated |
| 101 | + FROM conversation_states cs |
| 102 | + JOIN users u ON cs.phone_number = u.phone_number |
| 103 | + """) |
| 104 | + op.drop_table('conversation_states') |
| 105 | + op.rename_table('cs_fix', 'conversation_states') |
| 106 | + |
| 107 | + # Also check if expenses needs fixing due to previous broken states |
| 108 | + if 'expenses' in inspector.get_table_names(): |
| 109 | + pass |
24 | 110 |
|
25 | 111 |
|
26 | 112 | def downgrade() -> None: |
|
0 commit comments