diff --git a/cmd/seed/main.go b/cmd/seed/main.go index 6764e4c2..f71036a1 100644 --- a/cmd/seed/main.go +++ b/cmd/seed/main.go @@ -7,6 +7,7 @@ import ( "runtime/debug" "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" @@ -14,6 +15,7 @@ import ( "github.com/iota-uz/iota-sdk/pkg/application" "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/configuration" + "github.com/iota-uz/iota-sdk/pkg/constants" "github.com/iota-uz/iota-sdk/pkg/eventbus" "github.com/jackc/pgx/v5/pgxpool" @@ -67,14 +69,28 @@ func main() { panicWithStack(err) } seeder.Register( + coreseed.CreateDefaultTenant, coreseed.CreateCurrencies, coreseed.CreatePermissions, coreseed.UserSeedFunc(usr), ) - if err := seeder.Seed(composables.WithTx(ctx, tx), app); err != nil { + + // Add default tenant to context + defaultTenant := &composables.Tenant{ + ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), + Name: "Default", + Domain: "default.localhost", + } + ctxWithTenant := context.WithValue( + composables.WithTx(ctx, tx), + constants.TenantKey, + defaultTenant, + ) + + if err := seeder.Seed(ctxWithTenant, app); err != nil { panicWithStack(err) } - if err := tx.Commit(ctx); err != nil { + if err := tx.Commit(ctxWithTenant); err != nil { panicWithStack(err) } } diff --git a/go.mod b/go.mod index 91b138fc..c9da1592 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( cloud.google.com/go/auth v0.11.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect + github.com/PuerkitoBio/goquery v1.10.1 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect diff --git a/go.sum b/go.sum index ebfbc65e..bd846ae4 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/Oudwins/tailwind-merge-go v0.2.1 h1:jxRaEqGtwwwF48UuFIQ8g8XT7YSualNuG github.com/Oudwins/tailwind-merge-go v0.2.1/go.mod h1:kkZodgOPvZQ8f7SIrlWkG/w1g9JTbtnptnePIh3V72U= github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.31 h1:SIkzqC6Nv+znY4NGbWlJceWdns8QVmf9cwAYXd7Cg8k= github.com/PaulSonOfLars/gotgbot/v2 v2.0.0-rc.31/go.mod h1:kL1v4iIjlalwm3gCYGvF4NLa3hs+aKEfRkNJvj4aoDU= -github.com/PuerkitoBio/goquery v1.9.3 h1:mpJr/ikUA9/GNJB/DBZcGeFDXUtosHRyRrwh7KGdTG0= -github.com/PuerkitoBio/goquery v1.9.3/go.mod h1:1ndLHPdTz+DyQPICCWYlYQMPl0oXZj0G6D4LCYA6u4U= +github.com/PuerkitoBio/goquery v1.10.1 h1:Y8JGYUkXWTGRB6Ars3+j3kN0xg1YqqlwvdTV8WTFQcU= +github.com/PuerkitoBio/goquery v1.10.1/go.mod h1:IYiHrOMps66ag56LEH7QYDDupKXyo5A8qrjIx3ZtujY= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/a-h/templ v0.3.857 h1:6EqcJuGZW4OL+2iZ3MD+NnIcG7nGkaQeF2Zq5kf9ZGg= github.com/a-h/templ v0.3.857/go.mod h1:qhrhAkRFubE7khxLZHsBFHfX+gWwVNKbzKeF9GlPV4M= @@ -34,8 +34,8 @@ github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtC github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= -github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= -github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= +github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= diff --git a/migrations/changes-1740741698.sql b/migrations/changes-1740741698.sql index f0ef329f..d0b7a6eb 100644 --- a/migrations/changes-1740741698.sql +++ b/migrations/changes-1740741698.sql @@ -4,13 +4,14 @@ CREATE TABLE uploads ( id SERIAL8 PRIMARY KEY, name VARCHAR(255) NOT NULL, - hash VARCHAR(255) NOT NULL UNIQUE, + hash VARCHAR(255) NOT NULL, path VARCHAR(1024) DEFAULT '' NOT NULL, size INT8 DEFAULT 0 NOT NULL, mimetype VARCHAR(255) NOT NULL, type VARCHAR(255) NOT NULL, created_at TIMESTAMPTZ DEFAULT now(), - updated_at TIMESTAMPTZ DEFAULT now() + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uploads_hash_key UNIQUE (hash) ); -- Change CREATE_TABLE: clients @@ -61,20 +62,22 @@ CREATE TABLE positions ( -- Change CREATE_TABLE: permissions CREATE TABLE permissions ( id UUID DEFAULT gen_random_uuid() NOT NULL PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, + name VARCHAR(255) NOT NULL, resource VARCHAR(255) NOT NULL, action VARCHAR(255) NOT NULL, modifier VARCHAR(255) NOT NULL, - description TEXT + description TEXT, + CONSTRAINT permissions_name_key UNIQUE (name) ); -- Change CREATE_TABLE: roles CREATE TABLE roles ( id SERIAL8 PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, + name VARCHAR(255) NOT NULL, description TEXT, created_at TIMESTAMPTZ DEFAULT now(), - updated_at TIMESTAMPTZ DEFAULT now() + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT roles_name_key UNIQUE (name) ); -- Change CREATE_TABLE: warehouse_orders @@ -122,11 +125,12 @@ CREATE TABLE inventory ( CREATE TABLE warehouse_positions ( id SERIAL8 PRIMARY KEY, title VARCHAR(255) NOT NULL, - barcode VARCHAR(255) NOT NULL UNIQUE, + barcode VARCHAR(255) NOT NULL, description TEXT, unit_id INT8 REFERENCES warehouse_units (id) ON DELETE SET NULL, created_at TIMESTAMPTZ DEFAULT now(), - updated_at TIMESTAMPTZ DEFAULT now() + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT warehouse_positions_barcode_key UNIQUE (barcode) ); -- Change CREATE_TABLE: employees @@ -135,7 +139,7 @@ CREATE TABLE employees ( first_name VARCHAR(255) NOT NULL, last_name VARCHAR(255) NOT NULL, middle_name VARCHAR(255), - email VARCHAR(255) NOT NULL UNIQUE, + email VARCHAR(255) NOT NULL, phone VARCHAR(255), salary DECIMAL(9,2) NOT NULL, salary_currency_id VARCHAR(3) REFERENCES currencies (code) ON DELETE SET NULL, @@ -143,7 +147,8 @@ CREATE TABLE employees ( coefficient FLOAT8 NOT NULL, avatar_id INT8 REFERENCES uploads (id) ON DELETE SET NULL, created_at TIMESTAMPTZ DEFAULT now(), - updated_at TIMESTAMPTZ DEFAULT now() + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT employees_email_key UNIQUE (email) ); -- Change CREATE_TABLE: warehouse_position_images @@ -159,7 +164,7 @@ CREATE TABLE users ( first_name VARCHAR(255) NOT NULL, last_name VARCHAR(255) NOT NULL, middle_name VARCHAR(255), - email VARCHAR(255) NOT NULL UNIQUE, + email VARCHAR(255) NOT NULL, password VARCHAR(255), ui_language VARCHAR(3) NOT NULL, avatar_id INT8 REFERENCES uploads (id) ON DELETE SET NULL, @@ -167,7 +172,8 @@ CREATE TABLE users ( last_ip VARCHAR(255) NULL, last_action TIMESTAMPTZ NULL, created_at TIMESTAMPTZ DEFAULT now() NOT NULL, - updated_at TIMESTAMPTZ DEFAULT now() NOT NULL + updated_at TIMESTAMPTZ DEFAULT now() NOT NULL, + CONSTRAINT users_email_key UNIQUE (email) ); -- Change CREATE_TABLE: role_permissions @@ -229,10 +235,11 @@ CREATE TABLE employee_meta ( CREATE TABLE warehouse_products ( id SERIAL8 PRIMARY KEY, position_id INT8 NOT NULL REFERENCES warehouse_positions (id) ON DELETE CASCADE, - rfid VARCHAR(255) NULL UNIQUE, + rfid VARCHAR(255) NULL, status VARCHAR(255) NOT NULL, created_at TIMESTAMPTZ DEFAULT now(), - updated_at TIMESTAMPTZ DEFAULT now() + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT warehouse_products_rfid UNIQUE (rfid) ); -- Change CREATE_TABLE: authentication_logs @@ -292,7 +299,7 @@ CREATE TABLE tabs ( href VARCHAR(255) NOT NULL, user_id INT8 NOT NULL REFERENCES users (id) ON DELETE CASCADE, "position" INT8 DEFAULT 0 NOT NULL, - UNIQUE (href, user_id) + CONSTRAINT tabs_href_user_id_key UNIQUE (href, user_id) ); -- Change CREATE_TABLE: counterparty_contacts diff --git a/migrations/changes-1740935243.sql b/migrations/changes-1740935243.sql index a6f102cd..3dae9cbf 100644 --- a/migrations/changes-1740935243.sql +++ b/migrations/changes-1740935243.sql @@ -35,7 +35,10 @@ ALTER TABLE clients -- Change ADD_COLUMN: phone ALTER TABLE users - ADD COLUMN phone VARCHAR(255) UNIQUE; + ADD COLUMN phone VARCHAR(255); + +ALTER TABLE users + ADD CONSTRAINT users_phone_key UNIQUE (phone); -- Change CREATE_TABLE: client_contacts CREATE TABLE client_contacts ( diff --git a/migrations/changes-1741630199.sql b/migrations/changes-1741630199.sql index f3ed9ccf..97da02e2 100644 --- a/migrations/changes-1741630199.sql +++ b/migrations/changes-1741630199.sql @@ -2,10 +2,11 @@ -- Change CREATE_TABLE: user_groups CREATE TABLE user_groups ( id uuid DEFAULT gen_random_uuid () PRIMARY KEY, - name varchar(255) NOT NULL UNIQUE, + name varchar(255) NOT NULL, description text, created_at timestamp DEFAULT now(), - updated_at timestamp DEFAULT now() + updated_at timestamp DEFAULT now(), + CONSTRAINT user_groups_name_key UNIQUE (name) ); -- Change CREATE_TABLE: group_roles diff --git a/migrations/changes-1744627696.sql b/migrations/changes-1744627696.sql new file mode 100644 index 00000000..aa3c0ac7 --- /dev/null +++ b/migrations/changes-1744627696.sql @@ -0,0 +1,511 @@ +-- +migrate Up + +-- Change CREATE_TABLE: tenants +CREATE TABLE tenants ( + id UUID DEFAULT gen_random_uuid() PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + domain VARCHAR(255), + is_active BOOL DEFAULT true NOT NULL, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() +); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE prompts ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE permissions DROP CONSTRAINT IF EXISTS permissions_name_key; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE permissions ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE permissions ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE money_accounts ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE money_accounts ADD UNIQUE (tenant_id, account_number); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE message_templates ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE warehouse_orders ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE counterparty ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE counterparty ADD UNIQUE (tenant_id, tin); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE positions ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE positions ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE transactions ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE passports DROP CONSTRAINT IF EXISTS passports_passport_number_series_key; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE passports ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE passports + ADD CONSTRAINT passports_tenant_passport_number_series_key UNIQUE (tenant_id, passport_number, series); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE user_groups ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE user_groups DROP CONSTRAINT IF EXISTS user_groups_name_key; + +ALTER TABLE user_groups ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE uploads ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE uploads DROP CONSTRAINT IF EXISTS uploads_hash_key; + +ALTER TABLE uploads ADD UNIQUE (tenant_id, hash); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE inventory ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE inventory ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE users ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_email_key; + +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_phone_key; + +ALTER TABLE users ADD UNIQUE (tenant_id, email); + +ALTER TABLE users ADD UNIQUE (tenant_id, phone); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE warehouse_units ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE warehouse_units ADD UNIQUE (tenant_id, title); + +ALTER TABLE warehouse_units ADD UNIQUE (tenant_id, short_title); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE authentication_logs ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE expense_categories ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE expense_categories ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE roles ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE roles DROP CONSTRAINT IF EXISTS roles_name_key; + +ALTER TABLE roles ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE sessions ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE clients ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE clients ADD UNIQUE (tenant_id, phone_number); + +ALTER TABLE clients ADD UNIQUE (tenant_id, email); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE warehouse_positions ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE warehouse_positions DROP CONSTRAINT IF EXISTS warehouse_positions_barcode_key; + +ALTER TABLE warehouse_positions ADD UNIQUE (tenant_id, barcode); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE tabs ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE tabs DROP CONSTRAINT IF EXISTS tabs_href_user_id_key; + +ALTER TABLE tabs ADD UNIQUE (tenant_id, href, user_id); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE inventory_checks ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE companies ADD COLUMN tenant_id UUID REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE companies ADD UNIQUE (tenant_id, name); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE chats ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE dialogues ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE employees ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE employees DROP CONSTRAINT IF EXISTS employees_email_key; + +ALTER TABLE employees ADD UNIQUE (tenant_id, email); + +ALTER TABLE employees ADD UNIQUE (tenant_id, phone); + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE inventory_check_results ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE action_logs ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +-- Change ADD_COLUMN: tenant_id +ALTER TABLE warehouse_products ADD COLUMN tenant_id UUID NOT NULL REFERENCES tenants (id) ON DELETE CASCADE; + +ALTER TABLE warehouse_products DROP CONSTRAINT IF EXISTS warehouse_products_rfid; + +ALTER TABLE warehouse_products ADD UNIQUE (tenant_id, rfid); + +-- Change CREATE_INDEX: user_groups_tenant_id_idx +CREATE INDEX user_groups_tenant_id_idx ON user_groups (tenant_id); + +-- Change CREATE_INDEX: counterparty_tenant_id_idx +CREATE INDEX counterparty_tenant_id_idx ON counterparty (tenant_id); + +-- Change CREATE_INDEX: action_logs_tenant_id_idx +CREATE INDEX action_logs_tenant_id_idx ON action_logs (tenant_id); + +-- Change CREATE_INDEX: employees_tenant_id_idx +CREATE INDEX employees_tenant_id_idx ON employees (tenant_id); + +-- Change CREATE_INDEX: warehouse_units_tenant_id_idx +CREATE INDEX warehouse_units_tenant_id_idx ON warehouse_units (tenant_id); + +-- Change CREATE_INDEX: warehouse_positions_tenant_id_idx +CREATE INDEX warehouse_positions_tenant_id_idx ON warehouse_positions (tenant_id); + +-- Change CREATE_INDEX: users_tenant_id_idx +CREATE INDEX users_tenant_id_idx ON users (tenant_id); + +-- Change CREATE_INDEX: inventory_checks_tenant_id_idx +CREATE INDEX inventory_checks_tenant_id_idx ON inventory_checks (tenant_id); + +-- Change CREATE_INDEX: employees_first_name_idx +CREATE INDEX employees_first_name_idx ON employees (first_name); + +-- Change CREATE_INDEX: inventory_tenant_id_idx +CREATE INDEX inventory_tenant_id_idx ON inventory (tenant_id); + +-- Change CREATE_INDEX: authentication_logs_tenant_id_idx +CREATE INDEX authentication_logs_tenant_id_idx ON authentication_logs (tenant_id); + +-- Change CREATE_INDEX: dialogues_tenant_id_idx +CREATE INDEX dialogues_tenant_id_idx ON dialogues (tenant_id); + +-- Change CREATE_INDEX: employees_email_idx +CREATE INDEX employees_email_idx ON employees (email); + +-- Change CREATE_INDEX: inventory_check_results_tenant_id_idx +CREATE INDEX inventory_check_results_tenant_id_idx ON inventory_check_results (tenant_id); + +-- Change CREATE_INDEX: sessions_tenant_id_idx +CREATE INDEX sessions_tenant_id_idx ON sessions (tenant_id); + +-- Change CREATE_INDEX: prompts_tenant_id_idx +CREATE INDEX prompts_tenant_id_idx ON prompts (tenant_id); + +-- Change CREATE_INDEX: warehouse_products_tenant_id_idx +CREATE INDEX warehouse_products_tenant_id_idx ON warehouse_products (tenant_id); + +-- Change CREATE_INDEX: transactions_tenant_id_idx +CREATE INDEX transactions_tenant_id_idx ON transactions (tenant_id); + +-- Change CREATE_INDEX: idx_message_templates_tenant_id +CREATE INDEX idx_message_templates_tenant_id ON message_templates (tenant_id); + +-- Change CREATE_INDEX: positions_tenant_id_idx +CREATE INDEX positions_tenant_id_idx ON positions (tenant_id); + +-- Change CREATE_INDEX: warehouse_orders_tenant_id_idx +CREATE INDEX warehouse_orders_tenant_id_idx ON warehouse_orders (tenant_id); + +-- Change CREATE_INDEX: permissions_tenant_id_idx +CREATE INDEX permissions_tenant_id_idx ON permissions (tenant_id); + +-- Change CREATE_INDEX: idx_chats_tenant_id +CREATE INDEX idx_chats_tenant_id ON chats (tenant_id); + +-- Change CREATE_INDEX: expense_categories_tenant_id_idx +CREATE INDEX expense_categories_tenant_id_idx ON expense_categories (tenant_id); + +-- Change CREATE_INDEX: employees_phone_idx +CREATE INDEX employees_phone_idx ON employees (phone); + +-- Change CREATE_INDEX: roles_tenant_id_idx +CREATE INDEX roles_tenant_id_idx ON roles (tenant_id); + +-- Change CREATE_INDEX: uploads_tenant_id_idx +CREATE INDEX uploads_tenant_id_idx ON uploads (tenant_id); + +-- Change CREATE_INDEX: tabs_tenant_id_idx +CREATE INDEX tabs_tenant_id_idx ON tabs (tenant_id); + +-- Change CREATE_INDEX: money_accounts_tenant_id_idx +CREATE INDEX money_accounts_tenant_id_idx ON money_accounts (tenant_id); + +-- Change CREATE_INDEX: idx_clients_tenant_id +CREATE INDEX idx_clients_tenant_id ON clients (tenant_id); + +-- Change CREATE_INDEX: employees_last_name_idx +CREATE INDEX employees_last_name_idx ON employees (last_name); + + +-- +migrate Down + +-- Undo CREATE_INDEX: employees_last_name_idx +DROP INDEX employees_last_name_idx; + +-- Undo CREATE_INDEX: idx_clients_tenant_id +DROP INDEX idx_clients_tenant_id; + +-- Undo CREATE_INDEX: money_accounts_tenant_id_idx +DROP INDEX money_accounts_tenant_id_idx; + +-- Undo CREATE_INDEX: tabs_tenant_id_idx +DROP INDEX tabs_tenant_id_idx; + +-- Undo CREATE_INDEX: uploads_tenant_id_idx +DROP INDEX uploads_tenant_id_idx; + +-- Undo CREATE_INDEX: roles_tenant_id_idx +DROP INDEX roles_tenant_id_idx; + +-- Undo CREATE_INDEX: employees_phone_idx +DROP INDEX employees_phone_idx; + +-- Undo CREATE_INDEX: expense_categories_tenant_id_idx +DROP INDEX expense_categories_tenant_id_idx; + +-- Undo CREATE_INDEX: idx_chats_tenant_id +DROP INDEX idx_chats_tenant_id; + +-- Undo CREATE_INDEX: permissions_tenant_id_idx +DROP INDEX permissions_tenant_id_idx; + +-- Undo CREATE_INDEX: warehouse_orders_tenant_id_idx +DROP INDEX warehouse_orders_tenant_id_idx; + +-- Undo CREATE_INDEX: positions_tenant_id_idx +DROP INDEX positions_tenant_id_idx; + +-- Undo CREATE_INDEX: idx_message_templates_tenant_id +DROP INDEX idx_message_templates_tenant_id; + +-- Undo CREATE_INDEX: transactions_tenant_id_idx +DROP INDEX transactions_tenant_id_idx; + +-- Undo CREATE_INDEX: warehouse_products_tenant_id_idx +DROP INDEX warehouse_products_tenant_id_idx; + +-- Undo CREATE_INDEX: prompts_tenant_id_idx +DROP INDEX prompts_tenant_id_idx; + +-- Undo CREATE_INDEX: sessions_tenant_id_idx +DROP INDEX sessions_tenant_id_idx; + +-- Undo CREATE_INDEX: inventory_check_results_tenant_id_idx +DROP INDEX inventory_check_results_tenant_id_idx; + +-- Undo CREATE_INDEX: employees_email_idx +DROP INDEX employees_email_idx; + +-- Undo CREATE_INDEX: dialogues_tenant_id_idx +DROP INDEX dialogues_tenant_id_idx; + +-- Undo CREATE_INDEX: authentication_logs_tenant_id_idx +DROP INDEX authentication_logs_tenant_id_idx; + +-- Undo CREATE_INDEX: inventory_tenant_id_idx +DROP INDEX inventory_tenant_id_idx; + +-- Undo CREATE_INDEX: employees_first_name_idx +DROP INDEX employees_first_name_idx; + +-- Undo CREATE_INDEX: inventory_checks_tenant_id_idx +DROP INDEX inventory_checks_tenant_id_idx; + +-- Undo CREATE_INDEX: users_tenant_id_idx +DROP INDEX users_tenant_id_idx; + +-- Undo CREATE_INDEX: warehouse_positions_tenant_id_idx +DROP INDEX warehouse_positions_tenant_id_idx; + +-- Undo CREATE_INDEX: warehouse_units_tenant_id_idx +DROP INDEX warehouse_units_tenant_id_idx; + +-- Undo CREATE_INDEX: employees_tenant_id_idx +DROP INDEX employees_tenant_id_idx; + +-- Undo CREATE_INDEX: action_logs_tenant_id_idx +DROP INDEX action_logs_tenant_id_idx; + +-- Undo CREATE_INDEX: counterparty_tenant_id_idx +DROP INDEX counterparty_tenant_id_idx; + +-- Undo CREATE_INDEX: user_groups_tenant_id_idx +DROP INDEX user_groups_tenant_id_idx; + +ALTER TABLE warehouse_products DROP CONSTRAINT IF EXISTS warehouse_products_tenant_id_rfid_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE warehouse_products DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE warehouse_products ADD CONSTRAINT warehouse_products_rfid UNIQUE (rfid); + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE action_logs DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE inventory_check_results DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE employees DROP CONSTRAINT IF EXISTS employees_tenant_id_phone_key; + +ALTER TABLE employees DROP CONSTRAINT IF EXISTS employees_tenant_id_email_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE employees DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE employees ADD CONSTRAINT employees_email_key UNIQUE (email); + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE dialogues DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE chats DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE companies DROP CONSTRAINT IF EXISTS companies_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE companies DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE inventory_checks DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE tabs DROP CONSTRAINT IF EXISTS tabs_tenant_id_href_user_id_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE tabs DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE tabs ADD CONSTRAINT tabs_href_user_id_key UNIQUE (href, user_id); + +ALTER TABLE warehouse_positions DROP CONSTRAINT IF EXISTS warehouse_positions_tenant_id_barcode_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE warehouse_positions DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE warehouse_positions ADD CONSTRAINT warehouse_positions_barcode_key UNIQUE (barcode); + +ALTER TABLE clients DROP CONSTRAINT IF EXISTS clients_tenant_id_email_key; + +ALTER TABLE clients DROP CONSTRAINT IF EXISTS clients_tenant_id_phone_number_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE clients DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE sessions DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE roles DROP CONSTRAINT IF EXISTS roles_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE roles DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE roles ADD CONSTRAINT roles_name_key UNIQUE (name); + +ALTER TABLE expense_categories DROP CONSTRAINT IF EXISTS expense_categories_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE expense_categories DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE authentication_logs DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE warehouse_units DROP CONSTRAINT IF EXISTS warehouse_units_tenant_id_short_title_key; + +ALTER TABLE warehouse_units DROP CONSTRAINT IF EXISTS warehouse_units_tenant_id_title_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE warehouse_units DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_tenant_id_phone_key; + +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_tenant_id_email_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE users DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE users ADD CONSTRAINT users_phone_key UNIQUE (phone); + +ALTER TABLE users ADD CONSTRAINT users_email_key UNIQUE (email); + +ALTER TABLE inventory DROP CONSTRAINT IF EXISTS inventory_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE inventory DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE uploads DROP CONSTRAINT IF EXISTS uploads_tenant_id_hash_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE uploads DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE uploads ADD CONSTRAINT uploads_hash_key UNIQUE (hash); + +ALTER TABLE user_groups DROP CONSTRAINT IF EXISTS user_groups_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE user_groups DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE user_groups ADD CONSTRAINT user_groups_name_key UNIQUE (name); + +ALTER TABLE passports DROP CONSTRAINT IF EXISTS passports_tenant_passport_number_series_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE passports DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE passports ADD CONSTRAINT passports_passport_number_series_key UNIQUE (passport_number, series); + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE transactions DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE positions DROP CONSTRAINT IF EXISTS positions_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE positions DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE counterparty DROP CONSTRAINT IF EXISTS counterparty_tenant_id_tin_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE counterparty DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE warehouse_orders DROP COLUMN IF EXISTS tenant_id; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE message_templates DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE money_accounts DROP CONSTRAINT IF EXISTS money_accounts_tenant_id_account_number_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE money_accounts DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE permissions DROP CONSTRAINT IF EXISTS permissions_tenant_id_name_key; + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE permissions DROP COLUMN IF EXISTS tenant_id; + +ALTER TABLE permissions ADD CONSTRAINT permissions_name_key UNIQUE (name); + +-- Undo ADD_COLUMN: tenant_id +ALTER TABLE prompts DROP COLUMN IF EXISTS tenant_id; + +-- Undo CREATE_TABLE: tenants +DROP TABLE IF EXISTS tenants CASCADE; + diff --git a/modules/bichat/domain/entities/dialogue/dialogue.go b/modules/bichat/domain/entities/dialogue/dialogue.go index 7381d789..27a6e9ea 100644 --- a/modules/bichat/domain/entities/dialogue/dialogue.go +++ b/modules/bichat/domain/entities/dialogue/dialogue.go @@ -3,6 +3,7 @@ package dialogue import ( "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/bichat/domain/entities/llm" ) @@ -10,6 +11,7 @@ type Messages []llm.ChatCompletionMessage type Dialogue interface { ID() uint + TenantID() uuid.UUID UserID() uint Label() string Messages() Messages diff --git a/modules/bichat/domain/entities/dialogue/dialogue_impl.go b/modules/bichat/domain/entities/dialogue/dialogue_impl.go index 375944eb..d98de182 100644 --- a/modules/bichat/domain/entities/dialogue/dialogue_impl.go +++ b/modules/bichat/domain/entities/dialogue/dialogue_impl.go @@ -3,11 +3,13 @@ package dialogue import ( "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/bichat/domain/entities/llm" ) -func New(userID uint, label string) Dialogue { +func New(tenantID uuid.UUID, userID uint, label string) Dialogue { return &dialogue{ + tenantID: tenantID, userID: userID, label: label, messages: Messages{}, @@ -16,9 +18,10 @@ func New(userID uint, label string) Dialogue { } } -func NewWithID(id uint, userID uint, label string, messages Messages, createdAt, updatedAt time.Time) Dialogue { +func NewWithID(id uint, tenantID uuid.UUID, userID uint, label string, messages Messages, createdAt, updatedAt time.Time) Dialogue { return &dialogue{ id: id, + tenantID: tenantID, userID: userID, label: label, messages: messages, @@ -29,6 +32,7 @@ func NewWithID(id uint, userID uint, label string, messages Messages, createdAt, type dialogue struct { id uint + tenantID uuid.UUID userID uint label string messages Messages @@ -40,6 +44,10 @@ func (d *dialogue) ID() uint { return d.id } +func (d *dialogue) TenantID() uuid.UUID { + return d.tenantID +} + func (d *dialogue) UserID() uint { return d.userID } @@ -71,6 +79,7 @@ func (d *dialogue) UpdatedAt() time.Time { func (d *dialogue) AddMessages(messages ...llm.ChatCompletionMessage) Dialogue { return &dialogue{ id: d.id, + tenantID: d.tenantID, userID: d.userID, label: d.label, messages: append(d.messages, messages...), @@ -82,6 +91,7 @@ func (d *dialogue) AddMessages(messages ...llm.ChatCompletionMessage) Dialogue { func (d *dialogue) SetMessages(messages Messages) Dialogue { return &dialogue{ id: d.id, + tenantID: d.tenantID, userID: d.userID, label: d.label, messages: messages, @@ -95,6 +105,7 @@ func (d *dialogue) SetLastMessage(msg llm.ChatCompletionMessage) Dialogue { messages[len(messages)-1] = msg return &dialogue{ id: d.id, + tenantID: d.tenantID, userID: d.userID, label: d.label, messages: messages, diff --git a/modules/bichat/infrastructure/persistence/bichat_mappers.go b/modules/bichat/infrastructure/persistence/bichat_mappers.go index 3c4dfa04..541bf2a2 100644 --- a/modules/bichat/infrastructure/persistence/bichat_mappers.go +++ b/modules/bichat/infrastructure/persistence/bichat_mappers.go @@ -3,6 +3,7 @@ package persistence import ( "encoding/json" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/bichat/domain/entities/dialogue" "github.com/iota-uz/iota-sdk/modules/bichat/domain/entities/llm" "github.com/iota-uz/iota-sdk/modules/bichat/infrastructure/persistence/models" @@ -31,6 +32,7 @@ func toDBDialogue(entity dialogue.Dialogue) (*models.Dialogue, error) { } return &models.Dialogue{ ID: entity.ID(), + TenantID: entity.TenantID().String(), UserID: entity.UserID(), Label: entity.Label(), Messages: dbMessages, @@ -44,8 +46,13 @@ func toDomainDialogue(dbDialogue *models.Dialogue) (dialogue.Dialogue, error) { if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbDialogue.TenantID) + if err != nil { + return nil, err + } return dialogue.NewWithID( dbDialogue.ID, + tenantID, dbDialogue.UserID, dbDialogue.Label, messages, diff --git a/modules/bichat/infrastructure/persistence/dialogue_repository.go b/modules/bichat/infrastructure/persistence/dialogue_repository.go index 189e6f0e..5f092a2c 100644 --- a/modules/bichat/infrastructure/persistence/dialogue_repository.go +++ b/modules/bichat/infrastructure/persistence/dialogue_repository.go @@ -20,6 +20,7 @@ var ( const ( dialogueFindQuery = ` SELECT id, + tenant_id, user_id, label, messages, @@ -31,28 +32,33 @@ const ( dialogueInsertQuery = ` INSERT INTO dialogues ( + tenant_id, user_id, label, messages, created_at, updated_at - ) VALUES ($1, $2, $3, $4, $5) RETURNING id` + ) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id` dialogueUpdateQuery = ` UPDATE dialogues SET label = $1, messages = $2, updated_at = $3 - WHERE id = $4` + WHERE id = $4 AND tenant_id = $5` - dialogueDeleteQuery = `DELETE FROM dialogues WHERE id = $1` + dialogueDeleteQuery = `DELETE FROM dialogues WHERE id = $1 AND tenant_id = $2` ) type GormDialogueRepository struct{} func (g *GormDialogueRepository) GetByUserID(ctx context.Context, userID uint) ([]dialogue.Dialogue, error) { - //TODO implement me - panic("implement me") + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + return g.queryDialogues(ctx, dialogueFindQuery+" WHERE user_id = $1 AND tenant_id = $2", userID, tenant.ID) } func NewDialogueRepository() dialogue.Repository { @@ -60,11 +66,16 @@ func NewDialogueRepository() dialogue.Repository { } func (g *GormDialogueRepository) GetPaginated(ctx context.Context, params *dialogue.FindParams) ([]dialogue.Dialogue, error) { - var args []interface{} - where := []string{"1 = 1"} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + where := []string{"tenant_id = $1"} + args := []interface{}{tenant.ID} if params.Query != "" && params.Field != "" { - where = append(where, fmt.Sprintf("%s::VARCHAR ILIKE $%d", params.Field, len(where))) + where = append(where, fmt.Sprintf("%s::VARCHAR ILIKE $%d", params.Field, len(args)+1)) args = append(args, "%"+params.Query+"%") } @@ -81,19 +92,35 @@ func (g *GormDialogueRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, errors.Wrap(err, "failed to get tenant from context") + } + var count int64 - if err := tx.QueryRow(ctx, dialogueCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, dialogueCountQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormDialogueRepository) GetAll(ctx context.Context) ([]dialogue.Dialogue, error) { - return g.queryDialogues(ctx, dialogueFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + return g.queryDialogues(ctx, dialogueFindQuery+" WHERE tenant_id = $1", tenant.ID) } func (g *GormDialogueRepository) GetByID(ctx context.Context, id uint) (dialogue.Dialogue, error) { - dialogues, err := g.queryDialogues(ctx, repo.Join(dialogueFindQuery, "WHERE id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + dialogues, err := g.queryDialogues(ctx, repo.Join(dialogueFindQuery, "WHERE id = $1 AND tenant_id = $2"), id, tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get dialogue by id") } @@ -108,13 +135,22 @@ func (g *GormDialogueRepository) Create(ctx context.Context, d dialogue.Dialogue if err != nil { return nil, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + dbDialogue, err := toDBDialogue(d) if err != nil { return nil, err } + dbDialogue.TenantID = tenant.ID.String() + row := tx.QueryRow( ctx, dialogueInsertQuery, + dbDialogue.TenantID, dbDialogue.UserID, dbDialogue.Label, dbDialogue.Messages, @@ -131,10 +167,17 @@ func (g *GormDialogueRepository) Create(ctx context.Context, d dialogue.Dialogue } func (g *GormDialogueRepository) Update(ctx context.Context, d dialogue.Dialogue) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + dbDialogue, err := toDBDialogue(d) if err != nil { return err } + dbDialogue.TenantID = tenant.ID.String() + return g.execQuery( ctx, dialogueUpdateQuery, @@ -142,11 +185,17 @@ func (g *GormDialogueRepository) Update(ctx context.Context, d dialogue.Dialogue dbDialogue.Messages, dbDialogue.UpdatedAt, dbDialogue.ID, + dbDialogue.TenantID, ) } func (g *GormDialogueRepository) Delete(ctx context.Context, id uint) error { - return g.execQuery(ctx, dialogueDeleteQuery, id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + + return g.execQuery(ctx, dialogueDeleteQuery, id, tenant.ID) } func (g *GormDialogueRepository) queryDialogues(ctx context.Context, query string, args ...interface{}) ([]dialogue.Dialogue, error) { @@ -166,6 +215,7 @@ func (g *GormDialogueRepository) queryDialogues(ctx context.Context, query strin var d models.Dialogue if err := rows.Scan( &d.ID, + &d.TenantID, &d.UserID, &d.Label, &d.Messages, diff --git a/modules/bichat/infrastructure/persistence/models/models.go b/modules/bichat/infrastructure/persistence/models/models.go index 8c379877..a82e0aa1 100644 --- a/modules/bichat/infrastructure/persistence/models/models.go +++ b/modules/bichat/infrastructure/persistence/models/models.go @@ -14,6 +14,7 @@ type Prompt struct { type Dialogue struct { ID uint + TenantID string UserID uint Label string Messages string diff --git a/modules/bichat/infrastructure/persistence/schema/bichat-schema.sql b/modules/bichat/infrastructure/persistence/schema/bichat-schema.sql index f13b3e29..614803e5 100644 --- a/modules/bichat/infrastructure/persistence/schema/bichat-schema.sql +++ b/modules/bichat/infrastructure/persistence/schema/bichat-schema.sql @@ -1,5 +1,6 @@ CREATE TABLE prompts ( id varchar(30) PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, title varchar(255) NOT NULL, description text NOT NULL, prompt text NOT NULL, @@ -8,6 +9,7 @@ CREATE TABLE prompts ( CREATE TABLE dialogues ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, user_id int NOT NULL REFERENCES users (id) ON DELETE CASCADE, label varchar(255) NOT NULL, messages json NOT NULL, @@ -17,3 +19,7 @@ CREATE TABLE dialogues ( CREATE INDEX dialogues_user_id_idx ON dialogues (user_id); +CREATE INDEX dialogues_tenant_id_idx ON dialogues (tenant_id); + +CREATE INDEX prompts_tenant_id_idx ON prompts (tenant_id); + diff --git a/modules/bichat/services/dialogue_service.go b/modules/bichat/services/dialogue_service.go index 2f3ec017..55209691 100644 --- a/modules/bichat/services/dialogue_service.go +++ b/modules/bichat/services/dialogue_service.go @@ -249,7 +249,12 @@ func (s *DialogueService) StartDialogue(ctx context.Context, message string, mod if err != nil { return nil, err } + tenant, err := localComposables.UseTenant(ctx) + if err != nil { + return nil, err + } data := dialogue.New( + tenant.ID, u.ID(), "Новый чат", ).AddMessages( diff --git a/modules/core/domain/aggregates/group/group.go b/modules/core/domain/aggregates/group/group.go index 0c18039f..75fc4b4b 100644 --- a/modules/core/domain/aggregates/group/group.go +++ b/modules/core/domain/aggregates/group/group.go @@ -14,6 +14,7 @@ type Option func(g *group) type Group interface { ID() uuid.UUID + TenantID() uuid.UUID Name() string Description() string Users() []user.User @@ -28,6 +29,7 @@ type Group interface { SetRoles(roles []role.Role) Group SetName(name string) Group SetDescription(desc string) Group + SetTenantID(tenantID uuid.UUID) Group } // ---- Implementations ---- @@ -38,6 +40,12 @@ func WithID(id uuid.UUID) Option { } } +func WithTenantID(tenantID uuid.UUID) Option { + return func(g *group) { + g.tenantID = tenantID + } +} + func WithDescription(desc string) Option { return func(g *group) { g.description = desc @@ -71,6 +79,7 @@ func WithUpdatedAt(t time.Time) Option { func New(name string, opts ...Option) Group { g := &group{ id: uuid.New(), + tenantID: uuid.Nil, name: name, createdAt: time.Now(), updatedAt: time.Now(), @@ -84,6 +93,7 @@ func New(name string, opts ...Option) Group { type group struct { id uuid.UUID + tenantID uuid.UUID name string description string roles []role.Role @@ -96,6 +106,10 @@ func (g *group) ID() uuid.UUID { return g.id } +func (g *group) TenantID() uuid.UUID { + return g.tenantID +} + func (g *group) Name() string { return g.name } @@ -134,6 +148,13 @@ func (g *group) SetDescription(desc string) Group { return &r } +func (g *group) SetTenantID(tenantID uuid.UUID) Group { + r := *g + r.tenantID = tenantID + r.updatedAt = time.Now() + return &r +} + func (g *group) AssignRole(r role.Role) Group { res := *g res.roles = append(res.roles, r) diff --git a/modules/core/domain/aggregates/group/group_repository.go b/modules/core/domain/aggregates/group/group_repository.go index d3c73b5a..218d23c5 100644 --- a/modules/core/domain/aggregates/group/group_repository.go +++ b/modules/core/domain/aggregates/group/group_repository.go @@ -12,6 +12,7 @@ type Field = int const ( CreatedAt Field = iota UpdatedAt + TenantID ) type SortBy repo.SortBy[Field] diff --git a/modules/core/domain/aggregates/role/role.go b/modules/core/domain/aggregates/role/role.go index 14ac1915..a2cf0836 100644 --- a/modules/core/domain/aggregates/role/role.go +++ b/modules/core/domain/aggregates/role/role.go @@ -3,6 +3,7 @@ package role import ( "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/permission" ) @@ -34,6 +35,7 @@ func WithUpdatedAt(t time.Time) Option { type Role interface { ID() uint + TenantID() uuid.UUID Name() string Description() string Permissions() []*permission.Permission @@ -42,6 +44,7 @@ type Role interface { SetName(name string) Role SetDescription(description string) Role + SetTenantID(tenantID uuid.UUID) Role AddPermission(p *permission.Permission) Role SetPermissions(permissions []*permission.Permission) Role @@ -54,12 +57,19 @@ func WithDescription(description string) Option { } } +func WithTenantID(tenantID uuid.UUID) Option { + return func(r *role) { + r.tenantID = tenantID + } +} + func New( name string, opts ...Option, ) Role { r := &role{ id: 0, + tenantID: uuid.Nil, name: name, description: "", permissions: []*permission.Permission{}, @@ -74,6 +84,7 @@ func New( type role struct { id uint + tenantID uuid.UUID name string description string permissions []*permission.Permission @@ -85,6 +96,10 @@ func (r *role) ID() uint { return r.id } +func (r *role) TenantID() uuid.UUID { + return r.tenantID +} + func (r *role) Name() string { return r.name } @@ -119,6 +134,13 @@ func (r *role) SetDescription(description string) Role { return &result } +func (r *role) SetTenantID(tenantID uuid.UUID) Role { + result := *r + result.tenantID = tenantID + result.updatedAt = time.Now() + return &result +} + func (r *role) AddPermission(p *permission.Permission) Role { result := *r result.permissions = append(result.permissions, p) diff --git a/modules/core/domain/aggregates/role/role_repository.go b/modules/core/domain/aggregates/role/role_repository.go index 55441880..10a79f28 100644 --- a/modules/core/domain/aggregates/role/role_repository.go +++ b/modules/core/domain/aggregates/role/role_repository.go @@ -12,6 +12,7 @@ const ( Description CreatedAt PermissionID + TenantID ) type SortBy repo.SortBy[Field] diff --git a/modules/core/domain/aggregates/user/user.go b/modules/core/domain/aggregates/user/user.go index 41e78c22..3538b689 100644 --- a/modules/core/domain/aggregates/user/user.go +++ b/modules/core/domain/aggregates/user/user.go @@ -26,6 +26,12 @@ func WithID(id uint) Option { } } +func WithTenantID(id uuid.UUID) Option { + return func(u *user) { + u.tenantID = id + } +} + func WithMiddleName(middleName string) Option { return func(u *user) { u.middleName = middleName @@ -111,6 +117,7 @@ func WithPhone(p phone.Phone) Option { type User interface { ID() uint + TenantID() uuid.UUID FirstName() string LastName() string MiddleName() string @@ -160,6 +167,7 @@ func New( ) User { u := &user{ id: 0, + tenantID: uuid.Nil, firstName: firstName, lastName: lastName, middleName: "", @@ -186,6 +194,7 @@ func New( type user struct { id uint + tenantID uuid.UUID firstName string lastName string middleName string @@ -209,6 +218,10 @@ func (u *user) ID() uint { return u.id } +func (u *user) TenantID() uuid.UUID { + return u.tenantID +} + func (u *user) FirstName() string { return u.firstName } diff --git a/modules/core/domain/aggregates/user/user_repository.go b/modules/core/domain/aggregates/user/user_repository.go index f61028ab..8a535623 100644 --- a/modules/core/domain/aggregates/user/user_repository.go +++ b/modules/core/domain/aggregates/user/user_repository.go @@ -20,6 +20,7 @@ const ( LastLogin CreatedAt UpdatedAt + TenantID ) type SortBy repo.SortBy[Field] diff --git a/modules/core/domain/entities/authlog/authlog.go b/modules/core/domain/entities/authlog/authlog.go index 71e25582..626e794e 100644 --- a/modules/core/domain/entities/authlog/authlog.go +++ b/modules/core/domain/entities/authlog/authlog.go @@ -2,10 +2,13 @@ package authlog import ( "time" + + "github.com/google/uuid" ) type AuthenticationLog struct { ID uint + TenantID uuid.UUID UserID uint IP string UserAgent string diff --git a/modules/core/domain/entities/passport/passport.go b/modules/core/domain/entities/passport/passport.go index f9eaa7ef..05763458 100644 --- a/modules/core/domain/entities/passport/passport.go +++ b/modules/core/domain/entities/passport/passport.go @@ -9,6 +9,7 @@ import ( type Passport interface { ID() uuid.UUID + TenantID() uuid.UUID Series() string Number() string Identifier() string // Series + Number @@ -39,6 +40,12 @@ func WithID(id uuid.UUID) Option { } } +func WithTenantID(tenantID uuid.UUID) Option { + return func(p *passport) { + p.tenantID = tenantID + } +} + func WithFullName(firstName, lastName, middleName string) Option { return func(p *passport) { p.firstName = firstName @@ -135,6 +142,7 @@ func New(series, number string, opts ...Option) Passport { type passport struct { id uuid.UUID + tenantID uuid.UUID firstName string lastName string middleName string @@ -159,6 +167,10 @@ func (p *passport) ID() uuid.UUID { return p.id } +func (p *passport) TenantID() uuid.UUID { + return p.tenantID +} + func (p *passport) Series() string { return p.series } diff --git a/modules/core/domain/entities/permission/permission.go b/modules/core/domain/entities/permission/permission.go index 559acd13..d2a415a4 100644 --- a/modules/core/domain/entities/permission/permission.go +++ b/modules/core/domain/entities/permission/permission.go @@ -18,6 +18,7 @@ const ( type Permission struct { ID uuid.UUID + TenantID uuid.UUID Name string Resource Resource Action Action diff --git a/modules/core/domain/entities/session/session.go b/modules/core/domain/entities/session/session.go index 1881e281..97e599e0 100644 --- a/modules/core/domain/entities/session/session.go +++ b/modules/core/domain/entities/session/session.go @@ -1,13 +1,16 @@ package session import ( - "github.com/iota-uz/iota-sdk/pkg/configuration" "time" + + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/pkg/configuration" ) type Session struct { Token string `gorm:"primaryKey"` UserID uint + TenantID uuid.UUID IP string UserAgent string ExpiresAt time.Time @@ -17,6 +20,7 @@ type Session struct { type CreateDTO struct { Token string UserID uint + TenantID uuid.UUID IP string UserAgent string } @@ -25,6 +29,7 @@ func (d *CreateDTO) ToEntity() *Session { return &Session{ Token: d.Token, UserID: d.UserID, + TenantID: d.TenantID, IP: d.IP, UserAgent: d.UserAgent, ExpiresAt: time.Now().Add(configuration.Use().SessionDuration), diff --git a/modules/core/domain/entities/tab/tab.go b/modules/core/domain/entities/tab/tab.go index 47f2406f..836446d3 100644 --- a/modules/core/domain/entities/tab/tab.go +++ b/modules/core/domain/entities/tab/tab.go @@ -1,8 +1,13 @@ package tab +import ( + "github.com/google/uuid" +) + type Tab struct { ID uint Href string UserID uint Position uint + TenantID uuid.UUID } diff --git a/modules/core/domain/entities/tenant/tenant.go b/modules/core/domain/entities/tenant/tenant.go new file mode 100644 index 00000000..ed2dc714 --- /dev/null +++ b/modules/core/domain/entities/tenant/tenant.go @@ -0,0 +1,86 @@ +package tenant + +import ( + "time" + + "github.com/google/uuid" +) + +type Tenant struct { + id uuid.UUID + name string + domain string + isActive bool + createdAt time.Time + updatedAt time.Time +} + +type Option func(*Tenant) + +func WithID(id uuid.UUID) Option { + return func(t *Tenant) { + t.id = id + } +} + +func WithDomain(domain string) Option { + return func(t *Tenant) { + t.domain = domain + } +} + +func WithIsActive(isActive bool) Option { + return func(t *Tenant) { + t.isActive = isActive + } +} + +func WithCreatedAt(createdAt time.Time) Option { + return func(t *Tenant) { + t.createdAt = createdAt + } +} + +func WithUpdatedAt(updatedAt time.Time) Option { + return func(t *Tenant) { + t.updatedAt = updatedAt + } +} + +func New(name string, opts ...Option) *Tenant { + t := &Tenant{ + id: uuid.New(), + name: name, + isActive: true, + createdAt: time.Now(), + updatedAt: time.Now(), + } + for _, opt := range opts { + opt(t) + } + return t +} + +func (t *Tenant) ID() uuid.UUID { + return t.id +} + +func (t *Tenant) Name() string { + return t.name +} + +func (t *Tenant) Domain() string { + return t.domain +} + +func (t *Tenant) IsActive() bool { + return t.isActive +} + +func (t *Tenant) CreatedAt() time.Time { + return t.createdAt +} + +func (t *Tenant) UpdatedAt() time.Time { + return t.updatedAt +} diff --git a/modules/core/domain/entities/tenant/tenant_repository.go b/modules/core/domain/entities/tenant/tenant_repository.go new file mode 100644 index 00000000..c7580925 --- /dev/null +++ b/modules/core/domain/entities/tenant/tenant_repository.go @@ -0,0 +1,16 @@ +package tenant + +import ( + "context" + + "github.com/google/uuid" +) + +type Repository interface { + GetByID(ctx context.Context, id uuid.UUID) (*Tenant, error) + GetByDomain(ctx context.Context, domain string) (*Tenant, error) + Create(ctx context.Context, tenant *Tenant) (*Tenant, error) + Update(ctx context.Context, tenant *Tenant) (*Tenant, error) + Delete(ctx context.Context, id uuid.UUID) error + List(ctx context.Context) ([]*Tenant, error) +} diff --git a/modules/core/domain/entities/upload/upload.go b/modules/core/domain/entities/upload/upload.go index b9262dd3..29825c9b 100644 --- a/modules/core/domain/entities/upload/upload.go +++ b/modules/core/domain/entities/upload/upload.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gabriel-vasile/mimetype" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/presentation/assets" "github.com/iota-uz/iota-sdk/pkg/configuration" @@ -38,6 +39,7 @@ type Size interface { type Upload interface { ID() uint + TenantID() uuid.UUID Type() UploadType Hash() string Path() string @@ -66,6 +68,7 @@ func New( } return &upload{ id: 0, + tenantID: uuid.Nil, hash: hash, path: path, name: name, @@ -79,6 +82,7 @@ func New( func NewWithID( id uint, + tenantID uuid.UUID, hash, path, name string, size int, mimetype *mimetype.MIME, @@ -87,6 +91,7 @@ func NewWithID( ) Upload { return &upload{ id: id, + tenantID: tenantID, hash: hash, path: path, name: name, @@ -100,6 +105,7 @@ func NewWithID( type upload struct { id uint + tenantID uuid.UUID hash string path string name string @@ -114,6 +120,10 @@ func (u *upload) ID() uint { return u.id } +func (u *upload) TenantID() uuid.UUID { + return u.tenantID +} + func (u *upload) Type() UploadType { return u._type } diff --git a/modules/core/infrastructure/persistence/authlog_repository.go b/modules/core/infrastructure/persistence/authlog_repository.go index 8e4c57d4..473cab8c 100644 --- a/modules/core/infrastructure/persistence/authlog_repository.go +++ b/modules/core/infrastructure/persistence/authlog_repository.go @@ -4,9 +4,10 @@ import ( "context" "errors" "fmt" - "github.com/iota-uz/iota-sdk/pkg/repo" "strings" + "github.com/iota-uz/iota-sdk/pkg/repo" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/authlog" "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/composables" @@ -38,8 +39,14 @@ func (g *GormAuthLogRepository) GetPaginated( where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, params.UserID) } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + where, args = append(where, fmt.Sprintf("tenant_id = $%d", len(args)+1)), append(args, tenant.ID) + rows, err := pool.Query(ctx, ` - SELECT id, user_id, ip, user_agent, created_at + SELECT id, user_id, ip, user_agent, created_at, tenant_id FROM authentication_logs WHERE `+strings.Join(where, " AND ")+` ORDER BY id DESC @@ -59,6 +66,7 @@ func (g *GormAuthLogRepository) GetPaginated( &log.IP, &log.UserAgent, &log.CreatedAt, + &log.TenantID, ); err != nil { return nil, err } @@ -75,10 +83,16 @@ func (g *GormAuthLogRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, err + } + var count int64 if err := pool.QueryRow(ctx, ` - SELECT COUNT(*) as count FROM authentication_logs - `).Scan(&count); err != nil { + SELECT COUNT(*) as count FROM authentication_logs WHERE tenant_id = $1 + `, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil @@ -108,10 +122,18 @@ func (g *GormAuthLogRepository) Create(ctx context.Context, data *authlog.Authen if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + dbRow := toDBAuthenticationLog(data) + dbRow.TenantID = tenant.ID.String() + if err := tx.QueryRow(ctx, ` - INSERT INTO authentication_logs (user_id, ip, user_agent) VALUES ($1, $2, $3) - `, dbRow.UserID, dbRow.IP, dbRow.UserAgent).Scan(&data.ID); err != nil { + INSERT INTO authentication_logs (user_id, ip, user_agent, tenant_id) VALUES ($1, $2, $3, $4) + `, dbRow.UserID, dbRow.IP, dbRow.UserAgent, dbRow.TenantID).Scan(&data.ID); err != nil { return err } return nil diff --git a/modules/core/infrastructure/persistence/core_mappers.go b/modules/core/infrastructure/persistence/core_mappers.go index fd4d4873..aff308be 100644 --- a/modules/core/infrastructure/persistence/core_mappers.go +++ b/modules/core/infrastructure/persistence/core_mappers.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gabriel-vasile/mimetype" + "github.com/go-faster/errors" "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/group" @@ -30,7 +31,11 @@ import ( func ToDomainUser(dbUser *models.User, dbUpload *models.Upload, roles []role.Role, groupIDs []uuid.UUID, permissions []*permission.Permission) (user.User, error) { var avatar upload.Upload if dbUpload != nil { - avatar = ToDomainUpload(dbUpload) + var err error + avatar, err = ToDomainUpload(dbUpload) + if err != nil { + return nil, err + } } email, err := internet.NewEmail(dbUser.Email) @@ -40,6 +45,7 @@ func ToDomainUser(dbUser *models.User, dbUpload *models.Upload, roles []role.Rol options := []user.Option{ user.WithID(dbUser.ID), + user.WithTenantID(uuid.MustParse(dbUser.TenantID)), user.WithMiddleName(dbUser.MiddleName.String), user.WithPassword(dbUser.Password.String), user.WithRoles(roles), @@ -91,6 +97,7 @@ func toDBUser(entity user.User) (*models.User, []*models.Role) { return &models.User{ ID: entity.ID(), + TenantID: entity.TenantID().String(), FirstName: entity.FirstName(), LastName: entity.LastName(), MiddleName: mapping.ValueToSQLNullString(entity.MiddleName()), @@ -112,16 +119,23 @@ func toDomainRole(dbRole *models.Role, permissions []*models.Permission) (role.R for i, p := range permissions { dP, err := toDomainPermission(p) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to cast to domain permission") } domainPermissions[i] = dP } + + tenantID, err := uuid.Parse(dbRole.TenantID) + if err != nil { + return nil, errors.Wrap(err, "failed to parse uuid") + } + options := []role.Option{ role.WithID(dbRole.ID), role.WithDescription(dbRole.Description.String), role.WithPermissions(domainPermissions), role.WithCreatedAt(dbRole.CreatedAt), role.WithUpdatedAt(dbRole.UpdatedAt), + role.WithTenantID(tenantID), } return role.New(dbRole.Name, options...), nil } @@ -133,6 +147,7 @@ func toDBRole(entity role.Role) (*models.Role, []*models.Permission) { } return &models.Role{ ID: entity.ID(), + TenantID: entity.TenantID().String(), Name: entity.Name(), Description: mapping.ValueToSQLNullString(entity.Description()), CreatedAt: entity.CreatedAt(), @@ -143,6 +158,7 @@ func toDBRole(entity role.Role) (*models.Role, []*models.Permission) { func toDBPermission(entity *permission.Permission) *models.Permission { return &models.Permission{ ID: entity.ID.String(), + TenantID: entity.TenantID.String(), Name: entity.Name, Resource: string(entity.Resource), Action: string(entity.Action), @@ -155,8 +171,15 @@ func toDomainPermission(dbPermission *models.Permission) (*permission.Permission if err != nil { return nil, err } + + tenantID, err := uuid.Parse(dbPermission.TenantID) + if err != nil { + return nil, errors.Wrap(err, "failed to parse uuid") + } + return &permission.Permission{ ID: id, + TenantID: tenantID, Name: dbPermission.Name, Resource: permission.Resource(dbPermission.Resource), Action: permission.Action(dbPermission.Action), @@ -181,6 +204,7 @@ func ToDomainTin(s sql.NullString, c country.Country) (tax.Tin, error) { func ToDBUpload(upload upload.Upload) *models.Upload { return &models.Upload{ ID: upload.ID(), + TenantID: upload.TenantID().String(), Path: upload.Path(), Hash: upload.Hash(), Name: upload.Name(), @@ -192,13 +216,20 @@ func ToDBUpload(upload upload.Upload) *models.Upload { } } -func ToDomainUpload(dbUpload *models.Upload) upload.Upload { +func ToDomainUpload(dbUpload *models.Upload) (upload.Upload, error) { var mime *mimetype.MIME if dbUpload.Mimetype != "" { mime = mimetype.Lookup(dbUpload.Mimetype) } + + tenantID, err := uuid.Parse(dbUpload.TenantID) + if err != nil { + return nil, err + } + return upload.NewWithID( dbUpload.ID, + tenantID, dbUpload.Hash, dbUpload.Path, dbUpload.Name, @@ -207,7 +238,7 @@ func ToDomainUpload(dbUpload *models.Upload) upload.Upload { upload.UploadType(dbUpload.Type), dbUpload.CreatedAt, dbUpload.UpdatedAt, - ) + ), nil } func ToDBCurrency(entity *currency.Currency) *models.Currency { @@ -242,21 +273,29 @@ func ToDBTab(tab *tab.Tab) *models.Tab { Href: tab.Href, Position: tab.Position, UserID: tab.UserID, + TenantID: tab.TenantID.String(), } } func ToDomainTab(dbTab *models.Tab) (*tab.Tab, error) { + tenantID, err := uuid.Parse(dbTab.TenantID) + if err != nil { + return nil, err + } + return &tab.Tab{ ID: dbTab.ID, Href: dbTab.Href, Position: dbTab.Position, UserID: dbTab.UserID, + TenantID: tenantID, }, nil } -func toDBSession(session *session.Session) *models.Session { +func ToDBSession(session *session.Session) *models.Session { return &models.Session{ UserID: session.UserID, + TenantID: session.TenantID.String(), Token: session.Token, IP: session.IP, UserAgent: session.UserAgent, @@ -265,9 +304,15 @@ func toDBSession(session *session.Session) *models.Session { } } -func toDomainSession(dbSession *models.Session) *session.Session { +func ToDomainSession(dbSession *models.Session) *session.Session { + tenantID, err := uuid.Parse(dbSession.TenantID) + if err != nil { + tenantID = uuid.Nil + } + return &session.Session{ UserID: dbSession.UserID, + TenantID: tenantID, Token: dbSession.Token, IP: dbSession.IP, UserAgent: dbSession.UserAgent, @@ -279,6 +324,7 @@ func toDomainSession(dbSession *models.Session) *session.Session { func toDBAuthenticationLog(log *authlog.AuthenticationLog) *models.AuthenticationLog { return &models.AuthenticationLog{ ID: log.ID, + TenantID: log.TenantID.String(), UserID: log.UserID, IP: log.IP, UserAgent: log.UserAgent, @@ -287,8 +333,14 @@ func toDBAuthenticationLog(log *authlog.AuthenticationLog) *models.Authenticatio } func toDomainAuthenticationLog(dbLog *models.AuthenticationLog) *authlog.AuthenticationLog { + tenantID, err := uuid.Parse(dbLog.TenantID) + if err != nil { + tenantID = uuid.Nil + } + return &authlog.AuthenticationLog{ ID: dbLog.ID, + TenantID: tenantID, UserID: dbLog.UserID, IP: dbLog.IP, UserAgent: dbLog.UserAgent, @@ -302,8 +354,15 @@ func ToDomainPassport(dbPassport *models.Passport) (passport.Passport, error) { if err != nil { return nil, err } + + tenantID, err := uuid.Parse(dbPassport.TenantID) + if err != nil { + return nil, err + } + opts := []passport.Option{ passport.WithID(id), + passport.WithTenantID(tenantID), } if dbPassport.FirstName.Valid || dbPassport.LastName.Valid || dbPassport.MiddleName.Valid { @@ -412,6 +471,7 @@ func ToDBPassport(passportEntity passport.Passport) (*models.Passport, error) { return &models.Passport{ ID: passportEntity.ID().String(), + TenantID: passportEntity.TenantID().String(), FirstName: mapping.ValueToSQLNullString(passportEntity.FirstName()), LastName: mapping.ValueToSQLNullString(passportEntity.LastName()), MiddleName: mapping.ValueToSQLNullString(passportEntity.MiddleName()), @@ -441,8 +501,14 @@ func ToDomainGroup(dbGroup *models.Group, users []user.User, roles []role.Role) return nil, err } + tenantID, err := uuid.Parse(dbGroup.TenantID) + if err != nil { + return nil, err + } + opts := []group.Option{ group.WithID(groupID), + group.WithTenantID(tenantID), group.WithDescription(dbGroup.Description.String), group.WithUsers(users), group.WithRoles(roles), @@ -456,6 +522,7 @@ func ToDomainGroup(dbGroup *models.Group, users []user.User, roles []role.Role) func ToDBGroup(g group.Group) *models.Group { return &models.Group{ ID: g.ID().String(), + TenantID: g.TenantID().String(), Name: g.Name(), Description: mapping.ValueToSQLNullString(g.Description()), CreatedAt: g.CreatedAt(), diff --git a/modules/core/infrastructure/persistence/group_repository.go b/modules/core/infrastructure/persistence/group_repository.go index af6ced14..35b21df7 100644 --- a/modules/core/infrastructure/persistence/group_repository.go +++ b/modules/core/infrastructure/persistence/group_repository.go @@ -25,12 +25,13 @@ const ( g.name, g.description, g.created_at, - g.updated_at + g.updated_at, + g.tenant_id FROM user_groups g` groupCountQuery = `SELECT COUNT(DISTINCT g.id) FROM user_groups g` - groupDeleteQuery = `DELETE FROM user_groups WHERE id = $1` + groupDeleteQuery = `DELETE FROM user_groups WHERE id = $1 AND tenant_id = $2` groupUserDeleteQuery = `DELETE FROM group_users WHERE group_id = $1` groupRoleDeleteQuery = `DELETE FROM group_roles WHERE group_id = $1` groupUserInsertQuery = `INSERT INTO group_users (group_id, user_id) VALUES` @@ -50,6 +51,7 @@ func NewGroupRepository(userRepo user.Repository, roleRepo role.Repository) grou fieldMap: map[group.Field]string{ group.CreatedAt: "g.created_at", group.UpdatedAt: "g.updated_at", + group.TenantID: "g.tenant_id", }, } } @@ -135,7 +137,12 @@ func (g *PgGroupRepository) Count(ctx context.Context, params *group.FindParams) } func (g *PgGroupRepository) GetByID(ctx context.Context, id uuid.UUID) (group.Group, error) { - groups, err := g.queryGroups(ctx, groupFindQuery+" WHERE g.id = $1", id.String()) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + groups, err := g.queryGroups(ctx, groupFindQuery+" WHERE g.id = $1 AND g.tenant_id = $2", id.String(), tenant.ID) if err != nil { return nil, errors.Wrap(err, fmt.Sprintf("failed to query group with id: %s", id.String())) } @@ -151,8 +158,14 @@ func (g *PgGroupRepository) Exists(ctx context.Context, id uuid.UUID) (bool, err return false, errors.Wrap(err, "failed to get transaction") } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return false, errors.Wrap(err, "failed to get tenant from context") + } + var exists bool - err = tx.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM user_groups WHERE id = $1)", id.String()).Scan(&exists) + err = tx.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM user_groups WHERE id = $1 AND tenant_id = $2)", + id.String(), tenant.ID).Scan(&exists) if err != nil { return false, errors.Wrap(err, "failed to check if group exists") } @@ -177,6 +190,11 @@ func (g *PgGroupRepository) create(ctx context.Context, entity group.Group) (gro return nil, errors.Wrap(err, "failed to get transaction") } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + // Generate a new UUID if not provided var groupID uuid.UUID if entity.ID() == uuid.Nil { @@ -187,6 +205,7 @@ func (g *PgGroupRepository) create(ctx context.Context, entity group.Group) (gro dbGroup := ToDBGroup(entity) dbGroup.ID = groupID.String() + dbGroup.TenantID = tenant.ID.String() fields := []string{ "id", @@ -194,6 +213,7 @@ func (g *PgGroupRepository) create(ctx context.Context, entity group.Group) (gro "description", "created_at", "updated_at", + "tenant_id", } values := []interface{}{ @@ -202,6 +222,7 @@ func (g *PgGroupRepository) create(ctx context.Context, entity group.Group) (gro dbGroup.Description, dbGroup.CreatedAt, dbGroup.UpdatedAt, + dbGroup.TenantID, } _, err = tx.Exec(ctx, repo.Insert("user_groups", fields), values...) @@ -272,6 +293,11 @@ func (g *PgGroupRepository) update(ctx context.Context, entity group.Group) (gro func (g *PgGroupRepository) Delete(ctx context.Context, id uuid.UUID) error { uuidStr := id.String() + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + if err := g.execQuery(ctx, groupUserDeleteQuery, uuidStr); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to delete users for group ID: %s", uuidStr)) } @@ -280,7 +306,7 @@ func (g *PgGroupRepository) Delete(ctx context.Context, id uuid.UUID) error { return errors.Wrap(err, fmt.Sprintf("failed to delete roles for group ID: %s", uuidStr)) } - if err := g.execQuery(ctx, groupDeleteQuery, uuidStr); err != nil { + if err := g.execQuery(ctx, groupDeleteQuery, uuidStr, tenant.ID); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to delete group with ID: %s", uuidStr)) } @@ -309,6 +335,7 @@ func (g *PgGroupRepository) queryGroups(ctx context.Context, query string, args &dbGroup.Description, &dbGroup.CreatedAt, &dbGroup.UpdatedAt, + &dbGroup.TenantID, ); err != nil { return nil, errors.Wrap(err, "failed to scan group row") } diff --git a/modules/core/infrastructure/persistence/models/models.go b/modules/core/infrastructure/persistence/models/models.go index 9b163950..ae48f1e1 100644 --- a/modules/core/infrastructure/persistence/models/models.go +++ b/modules/core/infrastructure/persistence/models/models.go @@ -5,8 +5,18 @@ import ( "time" ) +type Tenant struct { + ID string + Name string + Domain sql.NullString + IsActive bool + CreatedAt time.Time + UpdatedAt time.Time +} + type Upload struct { ID uint + TenantID string Hash string Path string Name string @@ -27,6 +37,7 @@ type Currency struct { type Company struct { ID uint + TenantID string Name string About string Address string @@ -39,6 +50,7 @@ type Company struct { type Permission struct { ID string + TenantID string Name string Resource string Action string @@ -53,6 +65,7 @@ type RolePermission struct { type Role struct { ID uint + TenantID string Name string Description sql.NullString CreatedAt time.Time @@ -61,6 +74,7 @@ type Role struct { type User struct { ID uint + TenantID string // UUID stored as string FirstName string LastName string MiddleName sql.NullString @@ -95,6 +109,7 @@ type UploadedImage struct { type Session struct { Token string + TenantID string // UUID stored as string UserID uint ExpiresAt time.Time IP string @@ -104,6 +119,7 @@ type Session struct { type AuthenticationLog struct { ID uint + TenantID string // UUID stored as string UserID uint IP string UserAgent string @@ -112,6 +128,7 @@ type AuthenticationLog struct { type Tab struct { ID uint + TenantID string Href string Position uint UserID uint @@ -119,6 +136,7 @@ type Tab struct { type Passport struct { ID string + TenantID string FirstName sql.NullString LastName sql.NullString MiddleName sql.NullString @@ -143,6 +161,7 @@ type Passport struct { type Group struct { ID string + TenantID string Name string Description sql.NullString CreatedAt time.Time diff --git a/modules/core/infrastructure/persistence/passport_repository.go b/modules/core/infrastructure/persistence/passport_repository.go index 7c3506b8..d3f79159 100644 --- a/modules/core/infrastructure/persistence/passport_repository.go +++ b/modules/core/infrastructure/persistence/passport_repository.go @@ -18,7 +18,7 @@ var ( const ( selectPassportQuery = ` - SELECT + SELECT id, first_name, last_name, @@ -39,7 +39,8 @@ const ( signature_image, remarks, created_at, - updated_at + updated_at, + tenant_id FROM passports ` insertPassportQuery = ` @@ -62,19 +63,20 @@ const ( machine_readable_zone, biometric_data, signature_image, - remarks - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) - ON CONFLICT (passport_number, series) DO UPDATE SET + remarks, + tenant_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20) + ON CONFLICT (tenant_id, passport_number, series) DO UPDATE SET first_name = COALESCE(NULLIF(EXCLUDED.first_name, ''), passports.first_name), last_name = COALESCE(NULLIF(EXCLUDED.last_name, ''), passports.last_name), middle_name = COALESCE(NULLIF(EXCLUDED.middle_name, ''), passports.middle_name) RETURNING id ` updatePassportQuery = ` - UPDATE passports + UPDATE passports SET first_name = $1, - last_name = $2, - middle_name = $3, + last_name = $2, + middle_name = $3, gender = $4, birth_date = $5, birth_place = $6, @@ -140,6 +142,7 @@ func (r *PassportRepository) queryPassports(ctx context.Context, query string, a &p.Remarks, &p.CreatedAt, &p.UpdatedAt, + &p.TenantID, ); err != nil { return nil, err } @@ -169,8 +172,14 @@ func (r *PassportRepository) exists(ctx context.Context, id string) (bool, error return false, err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return false, err + } + var exists bool - err = pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM passports WHERE id = $1)", id).Scan(&exists) + err = pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM passports WHERE id = $1 AND tenant_id = $2)", + id, tenant.ID.String()).Scan(&exists) if err != nil { return false, err } @@ -198,11 +207,18 @@ func (r *PassportRepository) Create(ctx context.Context, data passport.Passport) return nil, err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + dbRow, err := ToDBPassport(data) if err != nil { return nil, fmt.Errorf("failed to convert passport to db model: %w", err) } + dbRow.TenantID = tenant.ID.String() + var id string err = pool.QueryRow( ctx, @@ -226,6 +242,7 @@ func (r *PassportRepository) Create(ctx context.Context, data passport.Passport) dbRow.BiometricData, dbRow.SignatureImage, dbRow.Remarks, + dbRow.TenantID, ).Scan(&id) if err != nil { return nil, fmt.Errorf("failed to create passport: %w", err) @@ -235,7 +252,13 @@ func (r *PassportRepository) Create(ctx context.Context, data passport.Passport) } func (r *PassportRepository) GetByID(ctx context.Context, id uuid.UUID) (passport.Passport, error) { - passports, err := r.queryPassports(ctx, selectPassportQuery+" WHERE id = $1", id.String()) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + passports, err := r.queryPassports(ctx, selectPassportQuery+" WHERE id = $1 AND tenant_id = $2", + id.String(), tenant.ID.String()) if err != nil { return nil, err } @@ -246,7 +269,13 @@ func (r *PassportRepository) GetByID(ctx context.Context, id uuid.UUID) (passpor } func (r *PassportRepository) GetByPassportNumber(ctx context.Context, series, number string) (passport.Passport, error) { - passports, err := r.queryPassports(ctx, selectPassportQuery+" WHERE series = $1 AND passport_number = $2", series, number) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + passports, err := r.queryPassports(ctx, selectPassportQuery+" WHERE series = $1 AND passport_number = $2 AND tenant_id = $3", + series, number, tenant.ID.String()) if err != nil { return nil, err } @@ -262,6 +291,11 @@ func (r *PassportRepository) Update(ctx context.Context, id uuid.UUID, data pass return nil, err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + dbRow, err := ToDBPassport(data) if err != nil { return nil, fmt.Errorf("failed to convert passport to db model: %w", err) @@ -269,7 +303,7 @@ func (r *PassportRepository) Update(ctx context.Context, id uuid.UUID, data pass _, err = pool.Exec( ctx, - updatePassportQuery, + updatePassportQuery+" AND tenant_id = $21", dbRow.FirstName, dbRow.LastName, dbRow.MiddleName, @@ -290,6 +324,7 @@ func (r *PassportRepository) Update(ctx context.Context, id uuid.UUID, data pass dbRow.Remarks, time.Now(), id.String(), + tenant.ID.String(), ) if err != nil { return nil, fmt.Errorf("failed to update passport: %w", err) @@ -304,7 +339,12 @@ func (r *PassportRepository) Delete(ctx context.Context, id uuid.UUID) error { return err } - _, err = pool.Exec(ctx, deletePassportQuery, id.String()) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + + _, err = pool.Exec(ctx, deletePassportQuery+" AND tenant_id = $2", id.String(), tenant.ID.String()) if err != nil { return fmt.Errorf("failed to delete passport: %w", err) } diff --git a/modules/core/infrastructure/persistence/permission_repository.go b/modules/core/infrastructure/persistence/permission_repository.go index 53893d22..9cd28dd7 100644 --- a/modules/core/infrastructure/persistence/permission_repository.go +++ b/modules/core/infrastructure/persistence/permission_repository.go @@ -18,12 +18,12 @@ var ( ) const ( - permissionsSelectQuery = `SELECT id, name, resource, action, modifier, description FROM permissions` + permissionsSelectQuery = `SELECT id, name, resource, action, modifier, description, tenant_id FROM permissions` permissionsCountQuery = `SELECT COUNT(*) FROM permissions` permissionsInsertQuery = ` - INSERT INTO permissions (id, name, resource, action, modifier, description) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (name) DO UPDATE SET resource = permissions.resource + INSERT INTO permissions (id, name, resource, action, modifier, description, tenant_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (tenant_id, name) DO UPDATE SET resource = permissions.resource RETURNING id` permissionsUpdateQuery = ` UPDATE permissions @@ -41,6 +41,11 @@ func NewPermissionRepository() permission.Repository { func (g *GormPermissionRepository) GetPaginated( ctx context.Context, params *permission.FindParams, ) ([]*permission.Permission, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + sortFields := []string{} for _, f := range params.SortBy.Fields { switch f { @@ -56,16 +61,21 @@ func (g *GormPermissionRepository) GetPaginated( return nil, fmt.Errorf("unknown sort field: %v", f) } } - joins, args := []string{}, []interface{}{} + + joins, args := []string{}, []interface{}{tenant.ID} + where := []string{"permissions.tenant_id = $1"} + if params.RoleID != 0 { joins = append(joins, fmt.Sprintf("INNER JOIN role_permissions rp ON rp.permission_id = permissions.id and rp.role_id = $%d", len(args)+1)) args = append(args, params.RoleID) } + return g.queryPermissions( ctx, repo.Join( permissionsSelectQuery, repo.Join(joins...), + repo.JoinWhere(where...), repo.OrderBy(sortFields, params.SortBy.Ascending), repo.FormatLimitOffset(params.Limit, params.Offset), ), @@ -78,10 +88,17 @@ func (g *GormPermissionRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + var count int64 if err := pool.QueryRow( ctx, - permissionsCountQuery, + permissionsCountQuery+" WHERE tenant_id = $1", + tenant.ID, ).Scan(&count); err != nil { return 0, err } @@ -89,11 +106,21 @@ func (g *GormPermissionRepository) Count(ctx context.Context) (int64, error) { } func (g *GormPermissionRepository) GetAll(ctx context.Context) ([]*permission.Permission, error) { - return g.queryPermissions(ctx, permissionsSelectQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.queryPermissions(ctx, permissionsSelectQuery+" WHERE tenant_id = $1", tenant.ID) } func (g *GormPermissionRepository) GetByID(ctx context.Context, id string) (*permission.Permission, error) { - permissions, err := g.queryPermissions(ctx, permissionsSelectQuery+" WHERE id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + permissions, err := g.queryPermissions(ctx, permissionsSelectQuery+" WHERE id = $1 AND tenant_id = $2", id, tenant.ID) if err != nil { return nil, err } @@ -108,7 +135,15 @@ func (g *GormPermissionRepository) Save(ctx context.Context, data *permission.Pe if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + dbPerm := toDBPermission(data) + dbPerm.TenantID = tenant.ID.String() + if err := tx.QueryRow( ctx, permissionsInsertQuery, @@ -118,6 +153,7 @@ func (g *GormPermissionRepository) Save(ctx context.Context, data *permission.Pe dbPerm.Action, dbPerm.Modifier, dbPerm.Description, + dbPerm.TenantID, ).Scan(&data.ID); err != nil { return err } @@ -128,7 +164,13 @@ func (g *GormPermissionRepository) Delete(ctx context.Context, id string) error if err := uuid.Validate(id); err != nil { return err } - return g.execQuery(ctx, permissionsDeleteQuery, id) + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.execQuery(ctx, permissionsDeleteQuery+" AND tenant_id = $2", id, tenant.ID) } func (g *GormPermissionRepository) queryPermissions( @@ -159,6 +201,7 @@ func (g *GormPermissionRepository) queryPermissions( &p.Action, &p.Modifier, &p.Description, + &p.TenantID, ); err != nil { return nil, err } diff --git a/modules/core/infrastructure/persistence/role_repository.go b/modules/core/infrastructure/persistence/role_repository.go index 32e7b1ae..953d4bdc 100644 --- a/modules/core/infrastructure/persistence/role_repository.go +++ b/modules/core/infrastructure/persistence/role_repository.go @@ -22,27 +22,29 @@ const ( r.name, r.description, r.created_at, - r.updated_at + r.updated_at, + r.tenant_id FROM roles r` rolePermissionsQuery = ` SELECT p.id, + p.tenant_id, p.name, p.resource, p.action, p.modifier, p.description, rp.role_id - FROM permissions p LEFT JOIN role_permissions rp ON rp.permission_id = p.id WHERE rp.role_id = ANY($1)` - roleCountQuery = `SELECT COUNT(DISTINCT roles.id) FROM roles` - roleInsertQuery = `INSERT INTO roles (name, description) VALUES ($1, $2) RETURNING id` - roleUpdateQuery = `UPDATE roles SET name = $1, description = $2, updated_at = $3 WHERE id = $4` + FROM permissions p LEFT JOIN role_permissions rp ON rp.permission_id = p.id WHERE rp.role_id = ANY($1) AND p.tenant_id = $2` + roleCountQuery = `SELECT COUNT(DISTINCT roles.id) FROM roles WHERE tenant_id = $1` + roleInsertQuery = `INSERT INTO roles (name, description, tenant_id) VALUES ($1, $2, $3) RETURNING id` + roleUpdateQuery = `UPDATE roles SET name = $1, description = $2, updated_at = $3 WHERE id = $4 AND tenant_id = $5` roleDeletePermissionsQuery = `DELETE FROM role_permissions WHERE role_id = $1` roleInsertPermissionQuery = ` INSERT INTO role_permissions (role_id, permission_id) VALUES ($1, $2) ON CONFLICT (role_id, permission_id) DO NOTHING` - roleDeleteQuery = `DELETE FROM roles WHERE id = $1` + roleDeleteQuery = `DELETE FROM roles WHERE id = $1 AND tenant_id = $2` ) type GormRoleRepository struct { @@ -56,6 +58,7 @@ func NewRoleRepository() role.Repository { role.Description: "r.description", role.CreatedAt: "r.created_at", role.PermissionID: "rp.permission_id", + role.TenantID: "r.tenant_id", }, } } @@ -166,14 +169,25 @@ func (g *GormRoleRepository) Count(ctx context.Context, params *role.FindParams) } func (g *GormRoleRepository) GetAll(ctx context.Context) ([]role.Role, error) { - return g.queryRoles(ctx, roleFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + query := roleFindQuery + " WHERE r.tenant_id = $1" + return g.queryRoles(ctx, query, tenant.ID) } func (g *GormRoleRepository) GetByID(ctx context.Context, id uint) (role.Role, error) { - query := roleFindQuery + " WHERE r.id = $1" - roles, err := g.queryRoles(ctx, query, id) + tenant, err := composables.UseTenant(ctx) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + query := roleFindQuery + " WHERE r.id = $1 AND r.tenant_id = $2" + roles, err := g.queryRoles(ctx, query, id, tenant.ID.String()) + if err != nil { + return nil, errors.Wrap(err, "failed to query roles") } if len(roles) == 0 { return nil, ErrRoleNotFound @@ -184,18 +198,26 @@ func (g *GormRoleRepository) GetByID(ctx context.Context, id uint) (role.Role, e func (g *GormRoleRepository) Create(ctx context.Context, data role.Role) (role.Role, error) { tx, err := composables.UseTx(ctx) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get tx from ctx") + } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") } entity, permissions := toDBRole(data) + entity.TenantID = tenant.ID.String() + var id uint if err := tx.QueryRow( ctx, roleInsertQuery, entity.Name, entity.Description, + entity.TenantID, ).Scan(&id); err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to insert role") } for _, permission := range permissions { @@ -212,6 +234,13 @@ func (g *GormRoleRepository) Create(ctx context.Context, data role.Role) (role.R func (g *GormRoleRepository) Update(ctx context.Context, data role.Role) (role.Role, error) { dbRole, dbPermissions := toDBRole(data) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + dbRole.TenantID = tenant.ID.String() + if err := g.execQuery( ctx, roleUpdateQuery, @@ -219,6 +248,7 @@ func (g *GormRoleRepository) Update(ctx context.Context, data role.Role) (role.R dbRole.Description, dbRole.UpdatedAt, dbRole.ID, + dbRole.TenantID, ); err != nil { return nil, err } @@ -239,10 +269,15 @@ func (g *GormRoleRepository) Update(ctx context.Context, data role.Role) (role.R } func (g *GormRoleRepository) Delete(ctx context.Context, id uint) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + if err := g.execQuery(ctx, roleDeletePermissionsQuery, id); err != nil { return err } - return g.execQuery(ctx, roleDeleteQuery, id) + return g.execQuery(ctx, roleDeleteQuery, id, tenant.ID) } func (g *GormRoleRepository) queryPermissions(ctx context.Context, roleIDs []uint) (map[uint][]*models.Permission, error) { @@ -251,7 +286,12 @@ func (g *GormRoleRepository) queryPermissions(ctx context.Context, roleIDs []uin return nil, err } - rows, err := tx.Query(ctx, rolePermissionsQuery, roleIDs) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + rows, err := tx.Query(ctx, rolePermissionsQuery, roleIDs, tenant.ID.String()) if err != nil { return nil, err } @@ -263,6 +303,7 @@ func (g *GormRoleRepository) queryPermissions(ctx context.Context, roleIDs []uin var p models.Permission if err := rows.Scan( &p.ID, + &p.TenantID, &p.Name, &p.Resource, &p.Action, @@ -303,6 +344,7 @@ func (g *GormRoleRepository) queryRoles(ctx context.Context, query string, args &r.Description, &r.CreatedAt, &r.UpdatedAt, + &r.TenantID, ); err != nil { return nil, err } @@ -325,7 +367,7 @@ func (g *GormRoleRepository) queryRoles(ctx context.Context, query string, args for _, dbRole := range dbRoles { entity, err := toDomainRole(dbRole, permissionsMap[dbRole.ID]) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to cast to domain role") } roles = append(roles, entity) } diff --git a/modules/core/infrastructure/persistence/schema/core-schema.sql b/modules/core/infrastructure/persistence/schema/core-schema.sql index 4cd0f9ce..e046356e 100644 --- a/modules/core/infrastructure/persistence/schema/core-schema.sql +++ b/modules/core/infrastructure/persistence/schema/core-schema.sql @@ -1,17 +1,29 @@ +CREATE TABLE tenants ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid (), + name varchar(255) NOT NULL UNIQUE, + domain varchar(255), + is_active boolean NOT NULL DEFAULT TRUE, + created_at timestamp with time zone DEFAULT now(), + updated_at timestamp with time zone DEFAULT now() +); + CREATE TABLE uploads ( id serial PRIMARY KEY, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, -- original file name - hash VARCHAR(255) NOT NULL UNIQUE, -- md5 hash of the file + hash VARCHAR(255) NOT NULL, -- md5 hash of the file path varchar(1024) NOT NULL DEFAULT '', -- relative path to the file size int NOT NULL DEFAULT 0, -- in bytes mimetype varchar(255) NOT NULL, -- image/jpeg, application/pdf, etc. type VARCHAR(255) NOT NULL, -- image, document, etc. created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, hash) ); CREATE TABLE passports ( id uuid PRIMARY KEY DEFAULT gen_random_uuid (), + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, first_name varchar(255), last_name varchar(255), middle_name varchar(255), @@ -32,18 +44,20 @@ CREATE TABLE passports ( remarks text, -- Additional notes (e.g., travel restrictions, visa endorsements). created_at timestamp with time zone DEFAULT now(), updated_at timestamp with time zone DEFAULT now(), - CONSTRAINT passports_passport_number_series_key UNIQUE (passport_number, series) + CONSTRAINT passports_tenant_passport_number_series_key UNIQUE (tenant_id, passport_number, series) ); CREATE TABLE companies ( id serial PRIMARY KEY, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, about text, address varchar(255), phone varchar(255), logo_id int REFERENCES uploads (id) ON DELETE SET NULL, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE currencies ( @@ -56,27 +70,32 @@ CREATE TABLE currencies ( CREATE TABLE roles ( id serial PRIMARY KEY, - name varchar(255) NOT NULL UNIQUE, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, + name varchar(255) NOT NULL, description text, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE users ( id serial PRIMARY KEY, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, first_name varchar(255) NOT NULL, last_name varchar(255) NOT NULL, middle_name varchar(255), - email varchar(255) NOT NULL UNIQUE, + email varchar(255) NOT NULL, password VARCHAR(255), ui_language varchar(3) NOT NULL, - phone varchar(255) UNIQUE, + phone varchar(255), avatar_id int REFERENCES uploads (id) ON DELETE SET NULL, last_login timestamp NULL, last_ip varchar(255) NULL, last_action timestamp with time zone NULL, created_at timestamp with time zone NOT NULL DEFAULT now(), - updated_at timestamp with time zone NOT NULL DEFAULT now() + updated_at timestamp with time zone NOT NULL DEFAULT now(), + UNIQUE (tenant_id, email), + UNIQUE (tenant_id, phone) ); CREATE TABLE user_roles ( @@ -88,10 +107,12 @@ CREATE TABLE user_roles ( CREATE TABLE user_groups ( id uuid PRIMARY KEY DEFAULT gen_random_uuid (), - name varchar(255) UNIQUE NOT NULL, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, + name varchar(255) NOT NULL, description text, created_at timestamp DEFAULT now(), - updated_at timestamp DEFAULT now() + updated_at timestamp DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE group_users ( @@ -121,11 +142,13 @@ CREATE TABLE uploaded_images ( CREATE TABLE permissions ( id uuid PRIMARY KEY DEFAULT gen_random_uuid () NOT NULL, - name varchar(255) NOT NULL UNIQUE, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, + name varchar(255) NOT NULL, resource varchar(255) NOT NULL, -- roles, users, etc. action varchar(255) NOT NULL, -- create, read, update, delete modifier varchar(255) NOT NULL, -- all / own - description text + description text, + UNIQUE (tenant_id, name) ); CREATE TABLE role_permissions ( @@ -142,6 +165,7 @@ CREATE TABLE user_permissions ( CREATE TABLE sessions ( token varchar(255) NOT NULL PRIMARY KEY, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, user_id integer NOT NULL REFERENCES users (id) ON DELETE CASCADE, expires_at timestamp with time zone NOT NULL, ip varchar(255) NOT NULL, @@ -151,16 +175,21 @@ CREATE TABLE sessions ( CREATE TABLE tabs ( id serial PRIMARY KEY, + tenant_id uuid REFERENCES tenants (id) ON DELETE CASCADE, href varchar(255) NOT NULL, user_id int NOT NULL REFERENCES users (id) ON DELETE CASCADE, position int NOT NULL DEFAULT 0, - UNIQUE (href, user_id) + UNIQUE (tenant_id, href, user_id) ); +CREATE INDEX users_tenant_id_idx ON users (tenant_id); + CREATE INDEX users_first_name_idx ON users (first_name); CREATE INDEX users_last_name_idx ON users (last_name); +CREATE INDEX sessions_tenant_id_idx ON sessions (tenant_id); + CREATE INDEX sessions_user_id_idx ON sessions (user_id); CREATE INDEX sessions_expires_at_idx ON sessions (expires_at); @@ -171,3 +200,13 @@ CREATE INDEX role_permissions_permission_id_idx ON role_permissions (permission_ CREATE INDEX uploaded_images_upload_id_idx ON uploaded_images (upload_id); +CREATE INDEX uploads_tenant_id_idx ON uploads (tenant_id); + +CREATE INDEX roles_tenant_id_idx ON roles (tenant_id); + +CREATE INDEX user_groups_tenant_id_idx ON user_groups (tenant_id); + +CREATE INDEX permissions_tenant_id_idx ON permissions (tenant_id); + +CREATE INDEX tabs_tenant_id_idx ON tabs (tenant_id); + diff --git a/modules/core/infrastructure/persistence/session_repository.go b/modules/core/infrastructure/persistence/session_repository.go index 81d78993..9179d44a 100644 --- a/modules/core/infrastructure/persistence/session_repository.go +++ b/modules/core/infrastructure/persistence/session_repository.go @@ -4,11 +4,13 @@ import ( "context" "fmt" - "github.com/go-faster/errors" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/session" "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/repo" + "github.com/iota-uz/psql-parser/util/uuid" + + "github.com/go-faster/errors" ) var ( @@ -16,7 +18,7 @@ var ( ) const ( - sessionFindQuery = `SELECT token, user_id, expires_at, ip, user_agent, created_at FROM sessions` + sessionFindQuery = `SELECT token, user_id, expires_at, ip, user_agent, created_at, tenant_id FROM sessions` sessionCountQuery = `SELECT COUNT(*) as count FROM sessions` sessionInsertQuery = ` INSERT INTO sessions ( @@ -25,16 +27,17 @@ const ( expires_at, ip, user_agent, - created_at + created_at, + tenant_id ) - VALUES ($1, $2, $3, $4, $5, $6)` + VALUES ($1, $2, $3, $4, $5, $6, $7)` sessionUpdateQuery = ` UPDATE sessions SET expires_at = $1, ip = $2, user_agent = $3 - WHERE token = $4` - sessionDeleteQuery = `DELETE FROM sessions WHERE token = $1` + WHERE token = $4 AND tenant_id = $5` + sessionDeleteQuery = `DELETE FROM sessions WHERE token = $1 AND tenant_id = $2` ) type GormSessionRepository struct{} @@ -64,6 +67,13 @@ func (g *GormSessionRepository) GetPaginated(ctx context.Context, params *sessio args = append(args, params.Token) } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + where = append(where, fmt.Sprintf("tenant_id = $%d", len(args)+1)) + args = append(args, tenant.ID) + return g.querySessions( ctx, repo.Join( @@ -81,19 +91,53 @@ func (g *GormSessionRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, err + } + var count int64 - if err := tx.QueryRow(ctx, sessionCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, sessionCountQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormSessionRepository) GetAll(ctx context.Context) ([]*session.Session, error) { - return g.querySessions(ctx, sessionFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + return g.querySessions(ctx, sessionFindQuery+" WHERE tenant_id = $1", tenant.ID) } func (g *GormSessionRepository) GetByToken(ctx context.Context, token string) (*session.Session, error) { - sessions, err := g.querySessions(ctx, repo.Join(sessionFindQuery, "WHERE token = $1"), token) + // First try with tenant from context + tenant, err := composables.UseTenant(ctx) + + // If tenant is not in context (like during login), get the session regardless of tenant + if err != nil { + // Ensure we have a transaction context + _, err := composables.UseTx(ctx) + if err != nil { + return nil, err + } + + // Query without tenant filter during login + sessions, err := g.querySessions(ctx, repo.Join(sessionFindQuery, "WHERE token = $1"), token) + if err != nil { + return nil, errors.Wrap(err, "failed to get session by token") + } + if len(sessions) == 0 { + return nil, ErrSessionNotFound + } + return sessions[0], nil + } + + // Normal flow with tenant from context + sessions, err := g.querySessions(ctx, repo.Join(sessionFindQuery, "WHERE token = $1 AND tenant_id = $2"), token, tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get session by token") } @@ -104,7 +148,15 @@ func (g *GormSessionRepository) GetByToken(ctx context.Context, token string) (* } func (g *GormSessionRepository) Create(ctx context.Context, data *session.Session) error { - dbSession := toDBSession(data) + dbSession := ToDBSession(data) + + // First try to get tenant from context + tenant, err := composables.UseTenant(ctx) + if err == nil { + dbSession.TenantID = tenant.ID.String() + } + // If tenant is not in context but session has TenantID set (from session.CreateDTO), use that + return g.execQuery( ctx, sessionInsertQuery, @@ -114,11 +166,26 @@ func (g *GormSessionRepository) Create(ctx context.Context, data *session.Sessio dbSession.IP, dbSession.UserAgent, dbSession.CreatedAt, + dbSession.TenantID, ) } func (g *GormSessionRepository) Update(ctx context.Context, data *session.Session) error { - dbSession := toDBSession(data) + dbSession := ToDBSession(data) + + // First try to get tenant from context + tenant, err := composables.UseTenant(ctx) + if err == nil { + dbSession.TenantID = tenant.ID.String() + } else if dbSession.TenantID == uuid.Nil.String() { + // If tenant is not in context and session has no TenantID, get the current session's tenant ID + existingSession, err := g.GetByToken(ctx, dbSession.Token) + if err != nil { + return err + } + dbSession.TenantID = existingSession.TenantID.String() + } + return g.execQuery( ctx, sessionUpdateQuery, @@ -126,11 +193,17 @@ func (g *GormSessionRepository) Update(ctx context.Context, data *session.Sessio dbSession.IP, dbSession.UserAgent, dbSession.Token, + dbSession.TenantID, ) } func (g *GormSessionRepository) Delete(ctx context.Context, token string) error { - return g.execQuery(ctx, sessionDeleteQuery, token) + // First get the session to know its tenant ID + session, err := g.GetByToken(ctx, token) + if err != nil { + return err + } + return g.execQuery(ctx, sessionDeleteQuery, token, session.TenantID) } func (g *GormSessionRepository) querySessions(ctx context.Context, query string, args ...interface{}) ([]*session.Session, error) { @@ -155,10 +228,11 @@ func (g *GormSessionRepository) querySessions(ctx context.Context, query string, &sessionRow.IP, &sessionRow.UserAgent, &sessionRow.CreatedAt, + &sessionRow.TenantID, ); err != nil { return nil, err } - sessions = append(sessions, toDomainSession(&sessionRow)) + sessions = append(sessions, ToDomainSession(&sessionRow)) } if err := rows.Err(); err != nil { diff --git a/modules/core/infrastructure/persistence/setup_test.go b/modules/core/infrastructure/persistence/setup_test.go index 426d7ed9..7f0e620e 100644 --- a/modules/core/infrastructure/persistence/setup_test.go +++ b/modules/core/infrastructure/persistence/setup_test.go @@ -2,13 +2,14 @@ package persistence_test import ( "context" + "os" + "testing" + "github.com/iota-uz/iota-sdk/modules" "github.com/iota-uz/iota-sdk/pkg/application" "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/testutils" "github.com/jackc/pgx/v5/pgxpool" - "os" - "testing" ) func TestMain(m *testing.M) { @@ -33,23 +34,45 @@ func setupTest(t *testing.T) *testFixtures { pool := testutils.NewPool(testutils.DbOpts(t.Name())) ctx := context.Background() + + // Setup application and run migrations first (outside the transaction) + app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) + if err != nil { + t.Fatal(err) + } + + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant outside transaction + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + + // Now start the transaction tx, err := pool.Begin(ctx) if err != nil { t.Fatal(err) } t.Cleanup(func() { - if err := tx.Commit(ctx); err != nil { - t.Fatal(err) + // Rollback instead of commit to ensure clean state + // This is safer as it ensures tests don't affect each other + if err := tx.Rollback(ctx); err != nil { + // Only fatal if it's not already committed + if err.Error() != "sql: transaction has already been committed or rolled back" { + t.Fatal(err) + } } pool.Close() }) + // Add transaction and tenant to context ctx = composables.WithTx(ctx, tx) - app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) - if err != nil { - t.Fatal(err) - } + ctx = composables.WithTenant(ctx, tenant) return &testFixtures{ ctx: ctx, diff --git a/modules/core/infrastructure/persistence/tab_repository.go b/modules/core/infrastructure/persistence/tab_repository.go index 29b8723c..468891e9 100644 --- a/modules/core/infrastructure/persistence/tab_repository.go +++ b/modules/core/infrastructure/persistence/tab_repository.go @@ -16,9 +16,9 @@ var ( ) const ( - selectTabsQuery = `SELECT id, href, user_id, position FROM tabs` + selectTabsQuery = `SELECT id, href, user_id, position, tenant_id FROM tabs` countTabsQuery = `SELECT COUNT(*) as count FROM tabs` - insertTabsQuery = `INSERT INTO tabs (href, user_id, position) VALUES ($1, $2, $3) RETURNING id` + insertTabsQuery = `INSERT INTO tabs (href, user_id, position, tenant_id) VALUES ($1, $2, $3, $4) RETURNING id` updateTabsQuery = `UPDATE tabs SET href = $1, position = $2 WHERE id = $3` deleteTabsQuery = `DELETE FROM tabs WHERE id = $1` deleteUserTabsQuery = `DELETE FROM tabs WHERE user_id = $1` @@ -50,6 +50,7 @@ func (g *tabRepository) queryTabs(ctx context.Context, query string, args ...int &tab.Href, &tab.UserID, &tab.Position, + &tab.TenantID, ); err != nil { return nil, err } @@ -73,15 +74,26 @@ func (g *tabRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + var count int64 - if err := pool.QueryRow(ctx, countTabsQuery).Scan(&count); err != nil { + if err := pool.QueryRow(ctx, countTabsQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *tabRepository) GetAll(ctx context.Context, params *tab.FindParams) ([]*tab.Tab, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"tenant_id = $1"}, []interface{}{tenant.ID} if params.UserID != 0 { where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, params.UserID) } @@ -100,7 +112,12 @@ func (g *tabRepository) GetUserTabs(ctx context.Context, userID uint) ([]*tab.Ta } func (g *tabRepository) GetByID(ctx context.Context, id uint) (*tab.Tab, error) { - tabs, err := g.queryTabs(ctx, repo.Join(selectTabsQuery, "WHERE id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + tabs, err := g.queryTabs(ctx, repo.Join(selectTabsQuery, "WHERE id = $1 AND tenant_id = $2"), id, tenant.ID) if err != nil { return nil, err } @@ -115,13 +132,23 @@ func (g *tabRepository) Create(ctx context.Context, data *tab.Tab) error { if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tab := ToDBTab(data) + tab.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + if err := tx.QueryRow( ctx, insertTabsQuery, tab.Href, tab.UserID, tab.Position, + tab.TenantID, ).Scan(&data.ID); err != nil { return err } @@ -133,14 +160,24 @@ func (g *tabRepository) CreateMany(ctx context.Context, tabs []*tab.Tab) error { if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + for _, data := range tabs { tab := ToDBTab(data) + tab.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + if err := tx.QueryRow( ctx, insertTabsQuery, tab.Href, tab.UserID, tab.Position, + tab.TenantID, ).Scan(&data.ID); err != nil { return err } @@ -149,11 +186,17 @@ func (g *tabRepository) CreateMany(ctx context.Context, tabs []*tab.Tab) error { } func (g *tabRepository) CreateOrUpdate(ctx context.Context, data *tab.Tab) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + matches, err := g.queryTabs( ctx, - selectTabsQuery+" WHERE user_id = $1 AND href = $2", + selectTabsQuery+" WHERE user_id = $1 AND href = $2 AND tenant_id = $3", data.UserID, data.Href, + tenant.ID, ) if err != nil { return err @@ -173,13 +216,20 @@ func (g *tabRepository) Update(ctx context.Context, data *tab.Tab) error { if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tab := ToDBTab(data) if _, err := tx.Exec( ctx, - updateTabsQuery, + updateTabsQuery+" AND tenant_id = $4", tab.Href, tab.Position, tab.ID, + tenant.ID, ); err != nil { return err } @@ -191,7 +241,13 @@ func (g *tabRepository) Delete(ctx context.Context, id uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, deleteTabsQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + if _, err := tx.Exec(ctx, deleteTabsQuery+" AND tenant_id = $2", id, tenant.ID); err != nil { return err } return nil @@ -202,7 +258,13 @@ func (g *tabRepository) DeleteUserTabs(ctx context.Context, userID uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, deleteUserTabsQuery, userID); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + if _, err := tx.Exec(ctx, deleteUserTabsQuery+" AND tenant_id = $2", userID, tenant.ID); err != nil { return err } return nil diff --git a/modules/core/infrastructure/persistence/tenant_repository.go b/modules/core/infrastructure/persistence/tenant_repository.go new file mode 100644 index 00000000..4fdd85e9 --- /dev/null +++ b/modules/core/infrastructure/persistence/tenant_repository.go @@ -0,0 +1,187 @@ +package persistence + +import ( + "context" + "fmt" + + "github.com/go-faster/errors" + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/tenant" + "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" + "github.com/iota-uz/iota-sdk/pkg/composables" +) + +var ( + ErrTenantNotFound = fmt.Errorf("tenant not found") +) + +const ( + tenantFindQuery = `SELECT id, name, domain, is_active, created_at, updated_at FROM tenants` +) + +type TenantRepository struct{} + +func NewTenantRepository() tenant.Repository { + return &TenantRepository{} +} + +func (r *TenantRepository) GetByID(ctx context.Context, id uuid.UUID) (*tenant.Tenant, error) { + query := tenantFindQuery + " WHERE id = $1" + tenants, err := r.queryTenants(ctx, query, id.String()) + if err != nil { + return nil, err + } + + if len(tenants) == 0 { + return nil, ErrTenantNotFound + } + + return tenants[0], nil +} + +func (r *TenantRepository) GetByDomain(ctx context.Context, domain string) (*tenant.Tenant, error) { + query := tenantFindQuery + " WHERE domain = $1" + tenants, err := r.queryTenants(ctx, query, domain) + if err != nil { + return nil, err + } + + if len(tenants) == 0 { + return nil, ErrTenantNotFound + } + + return tenants[0], nil +} + +func (r *TenantRepository) Create(ctx context.Context, t *tenant.Tenant) (*tenant.Tenant, error) { + query := ` + INSERT INTO tenants (id, name, domain, is_active, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id + ` + tx, err := composables.UseTx(ctx) + if err != nil { + return nil, err + } + + var idStr string + if err := tx.QueryRow( + ctx, + query, + t.ID().String(), + t.Name(), + t.Domain(), + t.IsActive(), + t.CreatedAt(), + t.UpdatedAt(), + ).Scan(&idStr); err != nil { + return nil, err + } + + id, err := uuid.Parse(idStr) + if err != nil { + return nil, err + } + + return r.GetByID(ctx, id) +} + +func (r *TenantRepository) Update(ctx context.Context, t *tenant.Tenant) (*tenant.Tenant, error) { + query := ` + UPDATE tenants + SET name = $1, domain = $2, is_active = $3, updated_at = $4 + WHERE id = $5 + RETURNING id + ` + tx, err := composables.UseTx(ctx) + if err != nil { + return nil, err + } + + var idStr string + if err := tx.QueryRow( + ctx, + query, + t.Name(), + t.Domain(), + t.IsActive(), + t.UpdatedAt(), + t.ID().String(), + ).Scan(&idStr); err != nil { + return nil, err + } + + id, err := uuid.Parse(idStr) + if err != nil { + return nil, err + } + + return r.GetByID(ctx, id) +} + +func (r *TenantRepository) Delete(ctx context.Context, id uuid.UUID) error { + query := `DELETE FROM tenants WHERE id = $1` + tx, err := composables.UseTx(ctx) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, query, id.String()) + return err +} + +func (r *TenantRepository) List(ctx context.Context) ([]*tenant.Tenant, error) { + return r.queryTenants(ctx, tenantFindQuery) +} + +func (r *TenantRepository) queryTenants(ctx context.Context, query string, args ...interface{}) ([]*tenant.Tenant, error) { + tx, err := composables.UseTx(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get transaction") + } + + rows, err := tx.Query(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(err, "failed to execute query") + } + defer rows.Close() + + var tenants []*tenant.Tenant + for rows.Next() { + var t models.Tenant + if err := rows.Scan( + &t.ID, + &t.Name, + &t.Domain, + &t.IsActive, + &t.CreatedAt, + &t.UpdatedAt, + ); err != nil { + return nil, errors.Wrap(err, "failed to scan tenant row") + } + tenants = append(tenants, toDomainTenant(&t)) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(err, "row iteration error") + } + + return tenants, nil +} + +func toDomainTenant(t *models.Tenant) *tenant.Tenant { + id, err := uuid.Parse(t.ID) + if err != nil { + // Log error or handle it appropriately + id = uuid.Nil + } + + return tenant.New( + t.Name, + tenant.WithID(id), + tenant.WithDomain(t.Domain.String), + tenant.WithIsActive(t.IsActive), + tenant.WithCreatedAt(t.CreatedAt), + tenant.WithUpdatedAt(t.UpdatedAt), + ) +} diff --git a/modules/core/infrastructure/persistence/upload_repository.go b/modules/core/infrastructure/persistence/upload_repository.go index 80ea3cf2..12deddd6 100644 --- a/modules/core/infrastructure/persistence/upload_repository.go +++ b/modules/core/infrastructure/persistence/upload_repository.go @@ -16,12 +16,12 @@ var ( ) const ( - selectUploadQuery = `SELECT id, hash, path, name, size, type, mimetype, created_at, updated_at FROM uploads` + selectUploadQuery = `SELECT id, hash, path, name, size, type, mimetype, created_at, updated_at, tenant_id FROM uploads` countUploadsQuery = `SELECT COUNT(*) FROM uploads` - insertUploadQuery = `INSERT INTO uploads (hash, path, name, size, type, mimetype, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + insertUploadQuery = `INSERT INTO uploads (hash, path, name, size, type, mimetype, created_at, updated_at, tenant_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id` updatedUploadQuery = `UPDATE uploads @@ -32,9 +32,9 @@ const ( type = $5, mimetype = $6, updated_at = $7 - WHERE id = $8` + WHERE id = $8 AND tenant_id = $9` - deleteUploadQuery = `DELETE FROM uploads WHERE id = $1` + deleteUploadQuery = `DELETE FROM uploads WHERE id = $1 AND tenant_id = $2` ) type GormUploadRepository struct{} @@ -70,10 +70,15 @@ func (g *GormUploadRepository) queryUploads( &dbUpload.Mimetype, &dbUpload.CreatedAt, &dbUpload.UpdatedAt, + &dbUpload.TenantID, ); err != nil { return nil, err } - uploads = append(uploads, ToDomainUpload(&dbUpload)) + domainUpload, err := ToDomainUpload(&dbUpload) + if err != nil { + return nil, err + } + uploads = append(uploads, domainUpload) } if err := rows.Err(); err != nil { return nil, err @@ -115,6 +120,12 @@ func (g *GormUploadRepository) GetPaginated( where, args = append(where, fmt.Sprintf("mimetype = $%d", len(args)+1)), append(args, params.Mimetype.String()) } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + where, args = append(where, fmt.Sprintf("tenant_id = $%d", len(args)+1)), append(args, tenant.ID.String()) + return g.queryUploads( ctx, repo.Join( @@ -132,8 +143,14 @@ func (g *GormUploadRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, err + } + var count int64 - if err := pool.QueryRow(ctx, countUploadsQuery).Scan(&count); err != nil { + if err := pool.QueryRow(ctx, countUploadsQuery+" WHERE tenant_id = $1", tenant.ID.String()).Scan(&count); err != nil { return 0, err } return count, nil @@ -176,7 +193,15 @@ func (g *GormUploadRepository) Create(ctx context.Context, data upload.Upload) ( if err != nil { return nil, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + dbUpload := ToDBUpload(data) + dbUpload.TenantID = tenant.ID.String() + if err := tx.QueryRow( ctx, insertUploadQuery, @@ -188,6 +213,7 @@ func (g *GormUploadRepository) Create(ctx context.Context, data upload.Upload) ( dbUpload.Mimetype, dbUpload.CreatedAt, dbUpload.UpdatedAt, + dbUpload.TenantID, ).Scan(&dbUpload.ID); err != nil { return nil, err } @@ -199,7 +225,15 @@ func (g *GormUploadRepository) Update(ctx context.Context, data upload.Upload) e if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + dbUpload := ToDBUpload(data) + dbUpload.TenantID = tenant.ID.String() + if _, err := tx.Exec( ctx, updatedUploadQuery, @@ -211,6 +245,7 @@ func (g *GormUploadRepository) Update(ctx context.Context, data upload.Upload) e dbUpload.Mimetype, dbUpload.UpdatedAt, dbUpload.ID, + dbUpload.TenantID, ); err != nil { return err } @@ -222,7 +257,13 @@ func (g *GormUploadRepository) Delete(ctx context.Context, id uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, deleteUploadQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + + if _, err := tx.Exec(ctx, deleteUploadQuery, id, tenant.ID.String()); err != nil { return err } return nil diff --git a/modules/core/infrastructure/persistence/upload_repository_test.go b/modules/core/infrastructure/persistence/upload_repository_test.go index 5cd24245..ee1144fb 100644 --- a/modules/core/infrastructure/persistence/upload_repository_test.go +++ b/modules/core/infrastructure/persistence/upload_repository_test.go @@ -131,6 +131,7 @@ func TestGormUploadRepository_CRUD(t *testing.T) { updatedMime := mimetype.Lookup("image/png") updatedUpload := upload.NewWithID( createdUpload.ID(), + createdUpload.TenantID(), "updated-hash", "uploads/updated.png", "updated.png", diff --git a/modules/core/infrastructure/persistence/user_repository.go b/modules/core/infrastructure/persistence/user_repository.go index 8310eceb..1879cb64 100644 --- a/modules/core/infrastructure/persistence/user_repository.go +++ b/modules/core/infrastructure/persistence/user_repository.go @@ -23,6 +23,7 @@ const ( userFindQuery = ` SELECT u.id, + u.tenant_id, u.first_name, u.last_name, u.middle_name, @@ -40,11 +41,11 @@ const ( userCountQuery = `SELECT COUNT(u.id) FROM users u` - userUpdateLastLoginQuery = `UPDATE users SET last_login = NOW() WHERE id = $1` + userUpdateLastLoginQuery = `UPDATE users SET last_login = NOW() WHERE id = $1 AND tenant_id = $2` - userUpdateLastActionQuery = `UPDATE users SET last_action = NOW() WHERE id = $1` + userUpdateLastActionQuery = `UPDATE users SET last_action = NOW() WHERE id = $1 AND tenant_id = $2` - userDeleteQuery = `DELETE FROM users WHERE id = $1` + userDeleteQuery = `DELETE FROM users WHERE id = $1 AND tenant_id = $2` userRoleDeleteQuery = `DELETE FROM user_roles WHERE user_id = $1` userRoleInsertQuery = `INSERT INTO user_roles (user_id, role_id) VALUES` @@ -55,16 +56,17 @@ const ( userPermissionInsertQuery = `INSERT INTO user_permissions (user_id, permission_id) VALUES` userRolePermissionsQuery = ` - SELECT p.id, p.name, p.resource, p.action, p.modifier, p.description + SELECT p.id, p.tenant_id, p.name, p.resource, p.action, p.modifier, p.description FROM role_permissions rp LEFT JOIN permissions p ON rp.permission_id = p.id WHERE role_id = $1` userPermissionsQuery = ` - SELECT p.id, p.name, p.resource, p.action, p.modifier, p.description + SELECT p.id, p.tenant_id, p.name, p.resource, p.action, p.modifier, p.description FROM user_permissions up LEFT JOIN permissions p ON up.permission_id = p.id WHERE up.user_id = $1` userRolesQuery = ` SELECT r.id, + r.tenant_id, r.name, r.description, r.created_at, @@ -100,6 +102,7 @@ func NewUserRepository(uploadRepo upload.Repository) user.Repository { user.LastLogin: "u.last_login", user.CreatedAt: "u.created_at", user.UpdatedAt: "u.updated_at", + user.TenantID: "u.tenant_id", }, } } @@ -180,6 +183,11 @@ func (g *PgUserRepository) GetPaginated(ctx context.Context, params *user.FindPa } func (g *PgUserRepository) Count(ctx context.Context, params *user.FindParams) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, errors.Wrap(err, "failed to get tenant from context") + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, errors.Wrap(err, "failed to get transaction") @@ -190,6 +198,9 @@ func (g *PgUserRepository) Count(ctx context.Context, params *user.FindParams) ( return 0, err } + where = append(where, fmt.Sprintf("u.tenant_id = $%d", len(args)+1)) + args = append(args, tenant.ID) + baseQuery := userCountQuery for _, f := range params.Filters { @@ -220,7 +231,12 @@ func (g *PgUserRepository) Count(ctx context.Context, params *user.FindParams) ( } func (g *PgUserRepository) GetAll(ctx context.Context) ([]user.User, error) { - users, err := g.queryUsers(ctx, userFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + users, err := g.queryUsers(ctx, userFindQuery+" WHERE u.tenant_id = $1", tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get all users") } @@ -228,29 +244,65 @@ func (g *PgUserRepository) GetAll(ctx context.Context) ([]user.User, error) { } func (g *PgUserRepository) GetByID(ctx context.Context, id uint) (user.User, error) { - users, err := g.queryUsers(ctx, userFindQuery+" WHERE u.id = $1", id) - if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with id: %d", id)) + // First check if we have a tenant in context + tenant, err := composables.UseTenant(ctx) + + var users []user.User + if err == nil { + // If we have a tenant, use it to filter + users, err = g.queryUsers(ctx, userFindQuery+" WHERE u.id = $1 AND u.tenant_id = $2", id, tenant.ID) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with id: %d", id)) + } + } else { + // If no tenant in context, get user by ID without tenant filter + // This is less secure but needed for some operations + users, err = g.queryUsers(ctx, userFindQuery+" WHERE u.id = $1", id) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with id: %d", id)) + } } + if len(users) == 0 { return nil, errors.Wrap(ErrUserNotFound, fmt.Sprintf("id: %d", id)) } + return users[0], nil } func (g *PgUserRepository) GetByEmail(ctx context.Context, email string) (user.User, error) { - users, err := g.queryUsers(ctx, userFindQuery+" WHERE u.email = $1", email) - if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with email: %s", email)) + // First check if we have a tenant in context + tenant, err := composables.UseTenant(ctx) + + var users []user.User + if err == nil { + // If we have a tenant, use it to filter + users, err = g.queryUsers(ctx, userFindQuery+" WHERE u.email = $1 AND u.tenant_id = $2", email, tenant.ID) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with email: %s", email)) + } + } else { + // If no tenant in context (like during login), get user by email across all tenants + users, err = g.queryUsers(ctx, userFindQuery+" WHERE u.email = $1", email) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with email: %s", email)) + } } + if len(users) == 0 { return nil, errors.Wrap(ErrUserNotFound, fmt.Sprintf("email: %s", email)) } + return users[0], nil } func (g *PgUserRepository) GetByPhone(ctx context.Context, phone string) (user.User, error) { - users, err := g.queryUsers(ctx, userFindQuery+" WHERE u.phone = $1", phone) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + users, err := g.queryUsers(ctx, userFindQuery+" WHERE u.phone = $1 AND u.tenant_id = $2", phone, tenant.ID) if err != nil { return nil, errors.Wrap(err, fmt.Sprintf("failed to query user with phone: %s", phone)) } @@ -261,14 +313,43 @@ func (g *PgUserRepository) GetByPhone(ctx context.Context, phone string) (user.U } func (g *PgUserRepository) Create(ctx context.Context, data user.User) (user.User, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + tx, err := composables.UseTx(ctx) if err != nil { return nil, errors.Wrap(err, "failed to get transaction") } - dbUser, _ := toDBUser(data) + // Create a copy of the user with the tenant ID from context + updatedData := data + if data.TenantID() == uuid.Nil { + updatedData = user.New( + data.FirstName(), + data.LastName(), + data.Email(), + data.UILanguage(), + user.WithID(data.ID()), + user.WithTenantID(tenant.ID), + user.WithMiddleName(data.MiddleName()), + user.WithPassword(data.Password()), + user.WithRoles(data.Roles()), + user.WithGroupIDs(data.GroupIDs()), + user.WithPermissions(data.Permissions()), + user.WithCreatedAt(data.CreatedAt()), + user.WithUpdatedAt(data.UpdatedAt()), + ) + if data.Phone() != nil { + updatedData = updatedData.SetPhone(data.Phone()) + } + } + + dbUser, _ := toDBUser(updatedData) fields := []string{ + "tenant_id", "first_name", "last_name", "middle_name", @@ -282,6 +363,7 @@ func (g *PgUserRepository) Create(ctx context.Context, data user.User) (user.Use } values := []interface{}{ + dbUser.TenantID, dbUser.FirstName, dbUser.LastName, dbUser.MiddleName, @@ -322,14 +404,22 @@ func (g *PgUserRepository) Create(ctx context.Context, data user.User) (user.Use } func (g *PgUserRepository) Update(ctx context.Context, data user.User) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + tx, err := composables.UseTx(ctx) if err != nil { return errors.Wrap(err, "failed to get transaction") } - dbUser, _ := toDBUser(data) + if dbUser.TenantID == uuid.Nil.String() { + dbUser.TenantID = tenant.ID.String() + } fields := []string{ + "tenant_id", "first_name", "last_name", "middle_name", @@ -341,6 +431,7 @@ func (g *PgUserRepository) Update(ctx context.Context, data user.User) error { } values := []interface{}{ + dbUser.TenantID, dbUser.FirstName, dbUser.LastName, dbUser.MiddleName, @@ -387,20 +478,71 @@ func (g *PgUserRepository) Update(ctx context.Context, data user.User) error { } func (g *PgUserRepository) UpdateLastLogin(ctx context.Context, id uint) error { - if err := g.execQuery(ctx, userUpdateLastLoginQuery, id); err != nil { + // First check if we have a tenant in context + tenant, err := composables.UseTenant(ctx) + + // If tenant exists in context, use it + if err == nil { + if err := g.execQuery(ctx, userUpdateLastLoginQuery, id, tenant.ID); err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to update last login for user ID: %d", id)) + } + return nil + } + + // If no tenant in context, get the user's tenant from DB and use that + tx, txErr := composables.UseTx(ctx) + if txErr != nil { + return errors.Wrap(txErr, "failed to get transaction") + } + + var tenantID string + err = tx.QueryRow(ctx, "SELECT tenant_id FROM users WHERE id = $1", id).Scan(&tenantID) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to get tenant ID for user ID: %d", id)) + } + + if err := g.execQuery(ctx, userUpdateLastLoginQuery, id, tenantID); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to update last login for user ID: %d", id)) } return nil } func (g *PgUserRepository) UpdateLastAction(ctx context.Context, id uint) error { - if err := g.execQuery(ctx, userUpdateLastActionQuery, id); err != nil { + // First check if we have a tenant in context + tenant, err := composables.UseTenant(ctx) + + // If tenant exists in context, use it + if err == nil { + if err := g.execQuery(ctx, userUpdateLastActionQuery, id, tenant.ID); err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to update last action for user ID: %d", id)) + } + return nil + } + + // If no tenant in context, get the user's tenant from DB and use that + tx, txErr := composables.UseTx(ctx) + if txErr != nil { + return errors.Wrap(txErr, "failed to get transaction") + } + + var tenantID string + err = tx.QueryRow(ctx, "SELECT tenant_id FROM users WHERE id = $1", id).Scan(&tenantID) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to get tenant ID for user ID: %d", id)) + } + + if err := g.execQuery(ctx, userUpdateLastActionQuery, id, tenantID); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to update last action for user ID: %d", id)) } return nil } func (g *PgUserRepository) Delete(ctx context.Context, id uint) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + if err := g.execQuery(ctx, userRoleDeleteQuery, id); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to delete roles for user ID: %d", id)) } @@ -410,7 +552,7 @@ func (g *PgUserRepository) Delete(ctx context.Context, id uint) error { if err := g.execQuery(ctx, userPermissionDeleteQuery, id); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to delete permissions for user ID: %d", id)) } - if err := g.execQuery(ctx, userDeleteQuery, id); err != nil { + if err := g.execQuery(ctx, userDeleteQuery, id, tenant.ID); err != nil { return errors.Wrap(err, fmt.Sprintf("failed to delete user with ID: %d", id)) } return nil @@ -434,6 +576,7 @@ func (g *PgUserRepository) queryUsers(ctx context.Context, query string, args .. if err := rows.Scan( &u.ID, + &u.TenantID, &u.FirstName, &u.LastName, &u.MiddleName, @@ -514,6 +657,7 @@ func (g *PgUserRepository) rolePermissions(ctx context.Context, roleID uint) ([] var p models.Permission if err := rows.Scan( &p.ID, + &p.TenantID, &p.Name, &p.Resource, &p.Action, @@ -549,6 +693,7 @@ func (g *PgUserRepository) userRoles(ctx context.Context, userID uint) ([]role.R var r models.Role if err := rows.Scan( &r.ID, + &r.TenantID, &r.Name, &r.Description, &r.CreatedAt, @@ -630,6 +775,7 @@ func (g *PgUserRepository) userPermissions(ctx context.Context, userID uint) ([] var p models.Permission if err := rows.Scan( &p.ID, + &p.TenantID, &p.Name, &p.Resource, &p.Action, diff --git a/modules/core/module.go b/modules/core/module.go index 5e6d8497..782419e9 100644 --- a/modules/core/module.go +++ b/modules/core/module.go @@ -48,8 +48,13 @@ func (m *Module) Register(app application.Application) error { // Create repositories userRepo := persistence.NewUserRepository(uploadRepo) roleRepo := persistence.NewRoleRepository() + tenantRepo := persistence.NewTenantRepository() permRepo := persistence.NewPermissionRepository() + // Create services + tabService := services.NewTabService(persistence.NewTabRepository()) + tenantService := services.NewTenantService(tenantRepo) + app.RegisterServices( services.NewUploadService(uploadRepo, fsStorage, app.EventPublisher()), services.NewUserService(userRepo, app.EventPublisher()), @@ -59,6 +64,8 @@ func (m *Module) Register(app application.Application) error { services.NewAuthService(app), services.NewCurrencyService(persistence.NewCurrencyRepository(), app.EventPublisher()), services.NewRoleService(roleRepo, app.EventPublisher()), + tabService, + tenantService, services.NewPermissionService(permRepo, app.EventPublisher()), services.NewTabService(persistence.NewTabRepository()), services.NewTabService(persistence.NewTabRepository()), diff --git a/modules/core/presentation/controllers/group_controller.go b/modules/core/presentation/controllers/group_controller.go index 9f7dd2e3..5528e05a 100644 --- a/modules/core/presentation/controllers/group_controller.go +++ b/modules/core/presentation/controllers/group_controller.go @@ -194,11 +194,24 @@ func (c *GroupsController) Groups( params := composables.UsePaginated(r) search := r.URL.Query().Get("name") + tenant, err := composables.UseTenant(r.Context()) + if err != nil { + logger.Errorf("Error retrieving tenant from request context: %v", err) + http.Error(w, "Error retrieving tenant", http.StatusBadRequest) + return + } + findParams := &group.FindParams{ Limit: params.Limit, Offset: params.Offset, SortBy: group.SortBy{Fields: []group.Field{}}, Search: search, + Filters: []group.Filter{ + { + Column: group.TenantID, + Filter: repo.Eq(tenant.ID.String()), + }, + }, } if v := r.URL.Query().Get("CreatedAt.To"); v != "" { diff --git a/modules/core/presentation/controllers/login_controller.go b/modules/core/presentation/controllers/login_controller.go index 2f1b0302..552506ff 100644 --- a/modules/core/presentation/controllers/login_controller.go +++ b/modules/core/presentation/controllers/login_controller.go @@ -166,7 +166,9 @@ func (c *LoginController) Post(w http.ResponseWriter, r *http.Request) { if errors.Is(err, composables.ErrInvalidPassword) { shared.SetFlash(w, "error", []byte(composables.MustT(r.Context(), "Login.Errors.PasswordInvalid"))) } else { - shared.SetFlash(w, "error", []byte(composables.MustT(r.Context(), "Errors.Internal"))) + errMsg := fmt.Sprintf("Login error: %v", err) + configuration.Use().Logger().Error(errMsg) + shared.SetFlash(w, "error", []byte(errMsg)) } http.Redirect(w, r, fmt.Sprintf("/login?email=%s&next=%s", dto.Email, r.URL.Query().Get("next")), http.StatusFound) return diff --git a/modules/core/presentation/controllers/roles_controller.go b/modules/core/presentation/controllers/roles_controller.go index 2b322dbb..3d4cd1f5 100644 --- a/modules/core/presentation/controllers/roles_controller.go +++ b/modules/core/presentation/controllers/roles_controller.go @@ -18,6 +18,7 @@ import ( "github.com/iota-uz/iota-sdk/pkg/mapping" "github.com/iota-uz/iota-sdk/pkg/middleware" "github.com/iota-uz/iota-sdk/pkg/rbac" + "github.com/iota-uz/iota-sdk/pkg/repo" "github.com/iota-uz/iota-sdk/pkg/shared" "github.com/sirupsen/logrus" @@ -108,11 +109,23 @@ func (c *RolesController) List( params := composables.UsePaginated(r) search := r.URL.Query().Get("name") + tenant, err := composables.UseTenant(r.Context()) + if err != nil { + logger.Errorf("Error retrieving tenant from request context: %v", err) + http.Error(w, "Error retrieving tenant", http.StatusBadRequest) + return + } + // Create find params with search findParams := &role.FindParams{ - Limit: params.Limit, - Offset: params.Offset, - Filters: []role.Filter{}, + Limit: params.Limit, + Offset: params.Offset, + Filters: []role.Filter{ + { + Column: role.TenantID, + Filter: repo.Eq(tenant.ID.String()), + }, + }, } // Apply search filter if provided diff --git a/modules/core/presentation/controllers/user_controller.go b/modules/core/presentation/controllers/user_controller.go index a75581a4..4b6348cf 100644 --- a/modules/core/presentation/controllers/user_controller.go +++ b/modules/core/presentation/controllers/user_controller.go @@ -245,6 +245,13 @@ func (c *UsersController) Users( params := composables.UsePaginated(r) groupIDs := r.URL.Query()["groupID"] + tenant, err := composables.UseTenant(r.Context()) + if err != nil { + logger.Errorf("Error retrieving tenant from request: %v", err) + http.Error(w, "Error retrieving tenant", http.StatusBadRequest) + return + } + // Create find params findParams := &user.FindParams{ Limit: params.Limit, @@ -253,6 +260,12 @@ func (c *UsersController) Users( user.CreatedAt, }}, Search: r.URL.Query().Get("Search"), + Filters: []user.Filter{ + { + Column: user.TenantID, + Filter: repo.Eq(tenant.ID.String()), + }, + }, } if len(groupIDs) > 0 { diff --git a/modules/core/seed/seed_tenant.go b/modules/core/seed/seed_tenant.go new file mode 100644 index 00000000..f3bc2963 --- /dev/null +++ b/modules/core/seed/seed_tenant.go @@ -0,0 +1,41 @@ +package seed + +import ( + "context" + + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/tenant" + "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence" + "github.com/iota-uz/iota-sdk/pkg/application" + "github.com/iota-uz/iota-sdk/pkg/configuration" +) + +func CreateDefaultTenant(ctx context.Context, app application.Application) error { + conf := configuration.Use() + logger := conf.Logger() + tenantRepository := persistence.NewTenantRepository() + // Create a new tenant with a fixed UUID for the default tenant + defaultTenant := tenant.New( + "Default", + tenant.WithID(uuid.MustParse("00000000-0000-0000-0000-000000000001")), // Use a fixed UUID for default tenant + tenant.WithDomain("default.localhost"), + ) + existingTenants, err := tenantRepository.List(ctx) + if err != nil { + logger.Errorf("Failed to list tenants: %v", err) + return err + } + + if len(existingTenants) > 0 { + logger.Infof("Default tenant already exists") + return nil + } + + logger.Infof("Creating default tenant") + _, err = tenantRepository.Create(ctx, defaultTenant) + if err != nil { + logger.Errorf("Failed to create default tenant: %v", err) + return err + } + return nil +} diff --git a/modules/core/seed/seed_user.go b/modules/core/seed/seed_user.go index 8c3147e6..6b8e6dd0 100644 --- a/modules/core/seed/seed_user.go +++ b/modules/core/seed/seed_user.go @@ -4,12 +4,14 @@ import ( "context" "github.com/go-faster/errors" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/role" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/tab" "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence" "github.com/iota-uz/iota-sdk/pkg/application" + "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/configuration" "github.com/iota-uz/iota-sdk/pkg/repo" "github.com/iota-uz/iota-sdk/pkg/types" @@ -33,12 +35,17 @@ func UserSeedFunc(usr user.User) application.SeedFunc { } func (s *userSeeder) CreateUser(ctx context.Context, app application.Application) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrapf(err, "failed to get tenant from context") + } + r, err := s.getOrCreateRole(ctx, app) if err != nil { return err } - usr, err := s.getOrCreateUser(ctx, r) + usr, err := s.getOrCreateUser(ctx, r, tenant.ID) if err != nil { return err } @@ -65,14 +72,20 @@ func (s *userSeeder) getOrCreateRole(ctx context.Context, app application.Applic return matches[0], nil } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrapf(err, "failed to get tenant from context") + } + newRole := role.New(adminRoleName, role.WithDescription(adminRoleDesc), - role.WithPermissions(app.RBAC().Permissions())) + role.WithPermissions(app.RBAC().Permissions()), + role.WithTenantID(tenant.ID)) logger.Infof("Creating role %s", adminRoleName) return roleRepository.Create(ctx, newRole) } -func (s *userSeeder) getOrCreateUser(ctx context.Context, r role.Role) (user.User, error) { +func (s *userSeeder) getOrCreateUser(ctx context.Context, r role.Role, tenantID uuid.UUID) (user.User, error) { uploadRepository := persistence.NewUploadRepository() userRepository := persistence.NewUserRepository(uploadRepository) foundUser, err := userRepository.GetByEmail(ctx, s.user.Email().Value()) @@ -86,8 +99,19 @@ func (s *userSeeder) getOrCreateUser(ctx context.Context, r role.Role) (user.Use return foundUser, nil } + newUser := user.New( + s.user.FirstName(), + s.user.LastName(), + s.user.Email(), + s.user.UILanguage(), + user.WithTenantID(tenantID), + user.WithPassword(s.user.Password()), + user.WithMiddleName(s.user.MiddleName()), + user.WithPhone(s.user.Phone()), + ) + logger.Infof("Creating user %s", s.user.Email().Value()) - return userRepository.Create(ctx, s.user.AddRole(r)) + return userRepository.Create(ctx, newUser.AddRole(r)) } func (s *userSeeder) createUserTabs( @@ -97,7 +121,13 @@ func (s *userSeeder) createUserTabs( ) error { tabsRepository := persistence.NewTabRepository() localizer := i18n.NewLocalizer(app.Bundle(), string(s.user.UILanguage())) - tabs := buildTabsFromNavItems(app.NavItems(localizer), usr.ID()) + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrapf(err, "failed to get tenant from context") + } + + tabs := buildTabsFromNavItems(app.NavItems(localizer), usr.ID(), tenant.ID) for _, t := range tabs { if err := tabsRepository.CreateOrUpdate(ctx, t); err != nil { @@ -107,7 +137,7 @@ func (s *userSeeder) createUserTabs( return nil } -func buildTabsFromNavItems(navItems []types.NavigationItem, userID uint) []*tab.Tab { +func buildTabsFromNavItems(navItems []types.NavigationItem, userID uint, tenantID uuid.UUID) []*tab.Tab { tabs := make([]*tab.Tab, 0, len(navItems)*4) var position uint = 1 @@ -117,6 +147,7 @@ func buildTabsFromNavItems(navItems []types.NavigationItem, userID uint) []*tab. tabs = append(tabs, &tab.Tab{ ID: position, UserID: userID, + TenantID: tenantID, Position: position, Href: item.Href, }) diff --git a/modules/core/services/auth_service.go b/modules/core/services/auth_service.go index 6eaf5de0..c90851fd 100644 --- a/modules/core/services/auth_service.go +++ b/modules/core/services/auth_service.go @@ -120,27 +120,58 @@ func (s *AuthService) newSessionToken() (string, error) { } func (s *AuthService) authenticate(ctx context.Context, u user.User) (*session.Session, error) { - ip, _ := composables.UseIP(ctx) - userAgent, _ := composables.UseUserAgent(ctx) + logger := configuration.Use().Logger() + logger.Infof("Creating session for user ID: %d, tenant ID: %d", u.ID(), u.TenantID()) + + // Get IP and user agent + ip, ok := composables.UseIP(ctx) + if !ok { + logger.Warnf("Could not get IP, using default") + ip = "0.0.0.0" + } + + userAgent, ok := composables.UseUserAgent(ctx) + if !ok { + logger.Warnf("Could not get User-Agent, using default") + userAgent = "Unknown" + } + + // Generate session token token, err := s.newSessionToken() if err != nil { + logger.Errorf("Failed to generate session token: %v", err) return nil, err } + + // Create session DTO sess := &session.CreateDTO{ Token: token, UserID: u.ID(), IP: ip, UserAgent: userAgent, + TenantID: u.TenantID(), // Ensure tenant ID is set in the session } + + // Update user last login if err := s.usersService.UpdateLastLogin(ctx, u.ID()); err != nil { + logger.Errorf("Failed to update last login: %v", err) return nil, err } + + // Update user last action if err := s.usersService.UpdateLastAction(ctx, u.ID()); err != nil { + logger.Errorf("Failed to update last action: %v", err) return nil, err } + + // Create the session + logger.Infof("Creating session in DB for user ID: %d, token: %s (partial)", u.ID(), token[:5]) if err := s.sessionService.Create(ctx, sess); err != nil { + logger.Errorf("Failed to create session in DB: %v", err) return nil, err } + + logger.Infof("Session created successfully") return sess.ToEntity(), nil } @@ -178,17 +209,28 @@ func (s *AuthService) CookieAuthenticateWithUserID(ctx context.Context, id uint, } func (s *AuthService) Authenticate(ctx context.Context, email, password string) (user.User, *session.Session, error) { + logger := configuration.Use().Logger() + logger.Infof("Authentication attempt for email: %s", email) + u, err := s.usersService.GetByEmail(ctx, email) if err != nil { + logger.Errorf("Failed to get user by email: %v", err) return nil, nil, err } + if !u.CheckPassword(password) { + logger.Errorf("Invalid password for user: %s", email) return nil, nil, composables.ErrInvalidPassword } + + logger.Infof("User authenticated, creating session for user ID: %d", u.ID()) sess, err := s.authenticate(ctx, u) if err != nil { + logger.Errorf("Failed to create session: %v", err) return nil, nil, err } + + logger.Infof("Session created successfully with token: %s (partial)", sess.Token[:5]) return u, sess, nil } diff --git a/modules/core/services/setup_test.go b/modules/core/services/setup_test.go index fc9a4a26..3496bfd3 100644 --- a/modules/core/services/setup_test.go +++ b/modules/core/services/setup_test.go @@ -55,6 +55,18 @@ func setupTest(t *testing.T) *testFixtures { t.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, diff --git a/modules/core/services/tenant_service.go b/modules/core/services/tenant_service.go new file mode 100644 index 00000000..bc6acdfc --- /dev/null +++ b/modules/core/services/tenant_service.go @@ -0,0 +1,43 @@ +package services + +import ( + "context" + + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/tenant" +) + +type TenantService struct { + repo tenant.Repository +} + +func NewTenantService(repo tenant.Repository) *TenantService { + return &TenantService{ + repo: repo, + } +} + +func (s *TenantService) GetByID(ctx context.Context, id uuid.UUID) (*tenant.Tenant, error) { + return s.repo.GetByID(ctx, id) +} + +func (s *TenantService) GetByDomain(ctx context.Context, domain string) (*tenant.Tenant, error) { + return s.repo.GetByDomain(ctx, domain) +} + +func (s *TenantService) Create(ctx context.Context, name, domain string) (*tenant.Tenant, error) { + t := tenant.New(name, tenant.WithDomain(domain)) + return s.repo.Create(ctx, t) +} + +func (s *TenantService) Update(ctx context.Context, t *tenant.Tenant) (*tenant.Tenant, error) { + return s.repo.Update(ctx, t) +} + +func (s *TenantService) Delete(ctx context.Context, id uuid.UUID) error { + return s.repo.Delete(ctx, id) +} + +func (s *TenantService) List(ctx context.Context) ([]*tenant.Tenant, error) { + return s.repo.List(ctx) +} diff --git a/modules/crm/infrastructure/persistence/chat_repository.go b/modules/crm/infrastructure/persistence/chat_repository.go index a829c90c..775dd666 100644 --- a/modules/crm/infrastructure/persistence/chat_repository.go +++ b/modules/crm/infrastructure/persistence/chat_repository.go @@ -20,8 +20,9 @@ var ( const ( selectChatQuery = ` - SELECT + SELECT c.id, + c.tenant_id, c.created_at, c.last_message_at, c.client_id @@ -32,21 +33,22 @@ const ( insertChatQuery = ` INSERT INTO chats ( + tenant_id, client_id, created_at - ) VALUES ($1, $2) RETURNING id + ) VALUES ($1, $2, $3) RETURNING id ` updateChatQuery = `UPDATE chats SET client_id = $1, created_at = $2, last_message_at = $3 - WHERE id = $4` + WHERE id = $4 AND tenant_id = $5` - deleteChatQuery = `DELETE FROM chats WHERE id = $1` + deleteChatQuery = `DELETE FROM chats WHERE id = $1 AND tenant_id = $2` selectMessagesQuery = ` - SELECT + SELECT m.id, m.chat_id, m.message, @@ -63,7 +65,7 @@ const ( selectMessageClientSender = `SELECT id, first_name, last_name FROM clients WHERE id = $1` selectMessageAttachmentsQuery = ` - SELECT + SELECT u.id AS upload_id, u.hash, u.path, @@ -88,12 +90,12 @@ const ( ) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id` updateMessageQuery = ` - UPDATE messages SET + UPDATE messages SET chat_id = $1, message = $2, sender_user_id = $3, sender_client_id = $4, - is_read = $5, + is_read = $5, read_at = $6 WHERE id = $7 ` @@ -125,6 +127,7 @@ func (g *ChatRepository) queryChats(ctx context.Context, query string, args ...i var c models.Chat if err := rows.Scan( &c.ID, + &c.TenantID, &c.CreatedAt, &c.LastMessageAt, &c.ClientID, @@ -264,6 +267,11 @@ func (g *ChatRepository) queryMessages(ctx context.Context, query string, args . func (g *ChatRepository) GetPaginated( ctx context.Context, params *chat.FindParams, ) ([]chat.Chat, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + sortFields := []string{} for _, f := range params.SortBy.Fields { switch f { @@ -276,14 +284,16 @@ func (g *ChatRepository) GetPaginated( } } - where, args := []string{"1 = 1"}, []interface{}{} + where, args := []string{"c.tenant_id = $1"}, []interface{}{tenant.ID} + if params.Search != "" { where = append( where, - "cl.first_name ILIKE $1 OR cl.last_name ILIKE $1 OR cl.middle_name ILIKE $1 OR cl.phone_number ILIKE $1", + fmt.Sprintf("cl.first_name ILIKE $%d OR cl.last_name ILIKE $%d OR cl.middle_name ILIKE $%d OR cl.phone_number ILIKE $%d", len(args)+1, len(args)+1, len(args)+1, len(args)+1), ) args = append(args, "%"+params.Search+"%") } + return g.queryChats( ctx, repo.Join( @@ -301,15 +311,26 @@ func (g *ChatRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, errors.Wrap(err, "failed to get transaction") } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, errors.Wrap(err, "failed to get tenant from context") + } + var count int64 - if err := pool.QueryRow(ctx, countChatQuery).Scan(&count); err != nil { + if err := pool.QueryRow(ctx, countChatQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, errors.Wrap(err, "failed to count chats") } return count, nil } func (g *ChatRepository) GetAll(ctx context.Context) ([]chat.Chat, error) { - chats, err := g.queryChats(ctx, selectChatQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + chats, err := g.queryChats(ctx, selectChatQuery+" WHERE c.tenant_id = $1", tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get all chats") } @@ -317,7 +338,12 @@ func (g *ChatRepository) GetAll(ctx context.Context) ([]chat.Chat, error) { } func (g *ChatRepository) GetByID(ctx context.Context, id uint) (chat.Chat, error) { - chats, err := g.queryChats(ctx, selectChatQuery+" WHERE c.id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + chats, err := g.queryChats(ctx, selectChatQuery+" WHERE c.id = $1 AND c.tenant_id = $2", id, tenant.ID) if err != nil { return nil, errors.Wrapf(err, "failed to get chat with id %d", id) } @@ -328,7 +354,12 @@ func (g *ChatRepository) GetByID(ctx context.Context, id uint) (chat.Chat, error } func (g *ChatRepository) GetByClientID(ctx context.Context, clientID uint) (chat.Chat, error) { - chats, err := g.queryChats(ctx, selectChatQuery+" WHERE c.client_id = $1", clientID) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + chats, err := g.queryChats(ctx, selectChatQuery+" WHERE c.client_id = $1 AND c.tenant_id = $2", clientID, tenant.ID) if err != nil { return nil, errors.Wrapf(err, "failed to get chat for client %d", clientID) } @@ -473,7 +504,13 @@ func (g *ChatRepository) Delete(ctx context.Context, id uint) error { if err != nil { return errors.Wrap(err, "failed to get transaction") } - if _, err := tx.Exec(ctx, deleteChatQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + + if _, err := tx.Exec(ctx, deleteChatQuery, id, tenant.ID); err != nil { return errors.Wrapf(err, "failed to delete chat with id %d", id) } return nil diff --git a/modules/crm/infrastructure/persistence/client_repository.go b/modules/crm/infrastructure/persistence/client_repository.go index 8bb0025c..a685d8e0 100644 --- a/modules/crm/infrastructure/persistence/client_repository.go +++ b/modules/crm/infrastructure/persistence/client_repository.go @@ -20,7 +20,7 @@ var ( const ( selectClientQuery = ` - SELECT + SELECT c.id, c.first_name, c.last_name, @@ -161,9 +161,15 @@ func (g *ClientRepository) exists(ctx context.Context, id uint) (bool, error) { if err != nil { return false, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return false, fmt.Errorf("failed to get tenant from context: %w", err) + } + var exists bool - q := "SELECT EXISTS(SELECT 1 FROM clients WHERE id = $1)" - if err := pool.QueryRow(ctx, q, id).Scan(&exists); err != nil { + q := "SELECT EXISTS(SELECT 1 FROM clients WHERE id = $1 AND tenant_id = $2)" + if err := pool.QueryRow(ctx, q, id, tenant.ID).Scan(&exists); err != nil { return false, err } return exists, nil @@ -172,7 +178,14 @@ func (g *ClientRepository) exists(ctx context.Context, id uint) (bool, error) { func (g *ClientRepository) GetPaginated( ctx context.Context, params *client.FindParams, ) ([]client.Client, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + // Start with tenant filter + where, args := []string{"c.tenant_id = $1"}, []interface{}{tenant.ID} + if params.CreatedAt.To != "" && params.CreatedAt.From != "" { where, args = append(where, fmt.Sprintf("c.created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)), append(args, params.CreatedAt.From, params.CreatedAt.To) } @@ -220,19 +233,35 @@ func (g *ClientRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + var count int64 - if err := pool.QueryRow(ctx, countClientQuery).Scan(&count); err != nil { + if err := pool.QueryRow(ctx, countClientQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *ClientRepository) GetAll(ctx context.Context) ([]client.Client, error) { - return g.queryClients(ctx, selectClientQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.queryClients(ctx, selectClientQuery+" WHERE c.tenant_id = $1", tenant.ID) } func (g *ClientRepository) GetByID(ctx context.Context, id uint) (client.Client, error) { - clients, err := g.queryClients(ctx, selectClientQuery+" WHERE c.id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + clients, err := g.queryClients(ctx, selectClientQuery+" WHERE c.id = $1 AND c.tenant_id = $2", id, tenant.ID) if err != nil { return nil, err } @@ -243,7 +272,12 @@ func (g *ClientRepository) GetByID(ctx context.Context, id uint) (client.Client, } func (g *ClientRepository) GetByPhone(ctx context.Context, phoneNumber string) (client.Client, error) { - clients, err := g.queryClients(ctx, selectClientQuery+" WHERE c.phone_number = $1", phoneNumber) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + clients, err := g.queryClients(ctx, selectClientQuery+" WHERE c.phone_number = $1 AND c.tenant_id = $2", phoneNumber, tenant.ID) if err != nil { return nil, err } @@ -259,7 +293,13 @@ func (g *ClientRepository) Create(ctx context.Context, data client.Client) (clie return nil, err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + dbRow := ToDBClient(data) + dbRow.TenantID = tenant.ID.String() if data.Passport() != nil { p, err := g.passportRepo.Save(ctx, data.Passport()) @@ -273,6 +313,7 @@ func (g *ClientRepository) Create(ctx context.Context, data client.Client) (clie } fields := []string{ + "tenant_id", "first_name", "last_name", "middle_name", @@ -288,6 +329,7 @@ func (g *ClientRepository) Create(ctx context.Context, data client.Client) (clie } values := []interface{}{ + dbRow.TenantID, dbRow.FirstName, dbRow.LastName, dbRow.MiddleName, @@ -322,7 +364,14 @@ func (g *ClientRepository) Update(ctx context.Context, data client.Client) (clie return nil, err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + dbRow := ToDBClient(data) + dbRow.TenantID = tenant.ID.String() + if data.Passport() != nil { p, err := g.passportRepo.Save(ctx, data.Passport()) if err != nil { @@ -369,11 +418,11 @@ func (g *ClientRepository) Update(ctx context.Context, data client.Client) (clie } } - values = append(values, data.ID()) + values = append(values, data.ID(), tenant.ID) if _, err := tx.Exec( ctx, - repo.Update("clients", fields, fmt.Sprintf("id = $%d", len(values))), + repo.Update("clients", fields, fmt.Sprintf("id = $%d AND tenant_id = $%d", len(values)-1, len(values))), values..., ); err != nil { return nil, err @@ -399,8 +448,13 @@ func (g *ClientRepository) Delete(ctx context.Context, id uint) error { return err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + var passportID sql.NullString - err = tx.QueryRow(ctx, "SELECT passport_id FROM clients WHERE id = $1", id).Scan(&passportID) + err = tx.QueryRow(ctx, "SELECT passport_id FROM clients WHERE id = $1 AND tenant_id = $2", id, tenant.ID).Scan(&passportID) if err != nil && err != sql.ErrNoRows { return err } @@ -416,7 +470,7 @@ func (g *ClientRepository) Delete(ctx context.Context, id uint) error { } // Delete the client record - if _, err := tx.Exec(ctx, deleteClientQuery, id); err != nil { + if _, err := tx.Exec(ctx, deleteClientQuery+" AND tenant_id = $2", id, tenant.ID); err != nil { return err } diff --git a/modules/crm/infrastructure/persistence/client_repository_test.go b/modules/crm/infrastructure/persistence/client_repository_test.go index 00c7a957..61587a11 100644 --- a/modules/crm/infrastructure/persistence/client_repository_test.go +++ b/modules/crm/infrastructure/persistence/client_repository_test.go @@ -37,13 +37,24 @@ func createTestPassport() passport.Passport { func createTestClient(t *testing.T, withPassport bool) client.Client { t.Helper() - p, err := phone.NewFromE164("12345678901") + // Use a different phone number to avoid duplication + phoneNumber := "12345678901" + if withPassport { + phoneNumber = "98765432109" + } + p, err := phone.NewFromE164(phoneNumber) if err != nil { t.Fatal(err) } birthDate := time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC) - email, err := internet.NewEmail("john.doe@example.com") + + // Use a different email to avoid duplication + emailStr := "john.doe@example.com" + if withPassport { + emailStr = "john.smith@example.com" + } + email, err := internet.NewEmail(emailStr) if err != nil { t.Fatal(err) } @@ -513,8 +524,8 @@ func TestClientRepository_Update(t *testing.T) { corepersistence.NewPassportRepository(), ) - // Create a client without passport - p, err := phone.NewFromE164("12345678901") + // Create a client without passport - use a unique phone number + p, err := phone.NewFromE164("55555555555") if err != nil { t.Fatal(err) } @@ -584,12 +595,18 @@ func TestClientRepository_Update(t *testing.T) { // Create another client specifically for testing passport updates t.Run("Update with passport", func(t *testing.T) { + // Create a new client without passport with a unique phone number + newPhone, err := phone.NewFromE164("77777777777") + if err != nil { + t.Fatal(err) + } + // Create a new client without passport noPassportClient, err := client.New( "Alice", "Wonder", "", - client.WithPhone(p), + client.WithPhone(newPhone), client.WithEmail(email), client.WithPin(pin), client.WithGender(general.Female), diff --git a/modules/crm/infrastructure/persistence/crm_mappers.go b/modules/crm/infrastructure/persistence/crm_mappers.go index 3667b5fc..30ffd0c1 100644 --- a/modules/crm/infrastructure/persistence/crm_mappers.go +++ b/modules/crm/infrastructure/persistence/crm_mappers.go @@ -157,7 +157,11 @@ func ToDomainMessage( ) (chat.Message, error) { uploads := make([]upload.Upload, 0, len(dbUploads)) for _, u := range dbUploads { - uploads = append(uploads, corepersistence.ToDomainUpload(u)) + domainUpload, err := corepersistence.ToDomainUpload(u) + if err != nil { + return nil, err + } + uploads = append(uploads, domainUpload) } return chat.NewMessageWithID( dbRow.ID, diff --git a/modules/crm/infrastructure/persistence/models/models.go b/modules/crm/infrastructure/persistence/models/models.go index f79fde49..83b78e07 100644 --- a/modules/crm/infrastructure/persistence/models/models.go +++ b/modules/crm/infrastructure/persistence/models/models.go @@ -7,6 +7,7 @@ import ( type Client struct { ID uint + TenantID string FirstName string LastName sql.NullString MiddleName sql.NullString @@ -23,6 +24,7 @@ type Client struct { type Chat struct { ID uint + TenantID uint ClientID uint LastMessageAt sql.NullTime CreatedAt time.Time @@ -41,6 +43,7 @@ type Message struct { type MessageTemplate struct { ID uint + TenantID uint Template string CreatedAt time.Time } diff --git a/modules/crm/infrastructure/persistence/schema/crm-schema.sql b/modules/crm/infrastructure/persistence/schema/crm-schema.sql index f84d2cfd..94728564 100644 --- a/modules/crm/infrastructure/persistence/schema/crm-schema.sql +++ b/modules/crm/infrastructure/persistence/schema/crm-schema.sql @@ -1,5 +1,6 @@ CREATE TABLE clients ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, first_name varchar(255) NOT NULL, last_name varchar(255), middle_name varchar(255), @@ -11,7 +12,9 @@ CREATE TABLE clients ( passport_id uuid REFERENCES passports (id) ON DELETE SET NULL ON UPDATE CASCADE, pin varchar(128), -- Personal Identification Number created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, phone_number), + UNIQUE (tenant_id, email) ); CREATE TABLE client_contacts ( @@ -25,6 +28,7 @@ CREATE TABLE client_contacts ( CREATE TABLE chats ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, created_at timestamp(3) DEFAULT now() NOT NULL, client_id int NOT NULL REFERENCES clients (id) ON DELETE RESTRICT ON UPDATE CASCADE, last_message_at timestamp(3) DEFAULT now() @@ -50,6 +54,7 @@ CREATE TABLE message_media ( CREATE TABLE message_templates ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, template TEXT NOT NULL, created_at timestamp with time zone DEFAULT now() ); @@ -70,5 +75,11 @@ CREATE INDEX idx_clients_phone_number ON clients (phone_number); CREATE INDEX idx_clients_email ON clients (email); +CREATE INDEX idx_clients_tenant_id ON clients (tenant_id); + CREATE INDEX idx_client_contacts_client_id ON client_contacts (client_id); +CREATE INDEX idx_chats_tenant_id ON chats (tenant_id); + +CREATE INDEX idx_message_templates_tenant_id ON message_templates (tenant_id); + diff --git a/modules/crm/infrastructure/persistence/setup_test.go b/modules/crm/infrastructure/persistence/setup_test.go index 2dd6e44e..95f7ab67 100644 --- a/modules/crm/infrastructure/persistence/setup_test.go +++ b/modules/crm/infrastructure/persistence/setup_test.go @@ -47,11 +47,25 @@ func setupTest(t *testing.T) *testFixtures { }) ctx = composables.WithTx(ctx, tx) + + // Setup application and run migrations app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) if err != nil { t.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, diff --git a/modules/finance/domain/aggregates/expense_category/expense_category.go b/modules/finance/domain/aggregates/expense_category/expense_category.go index b9ed3f70..4aea1a1c 100644 --- a/modules/finance/domain/aggregates/expense_category/expense_category.go +++ b/modules/finance/domain/aggregates/expense_category/expense_category.go @@ -3,6 +3,7 @@ package category import ( "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" ) @@ -51,9 +52,16 @@ func WithUpdatedAt(updatedAt time.Time) Option { } } +func WithTenantID(tenantID uuid.UUID) Option { + return func(e *expenseCategory) { + e.tenantID = tenantID + } +} + // Interface type ExpenseCategory interface { ID() uint + TenantID() uuid.UUID Name() string Description() string Amount() float64 @@ -73,6 +81,7 @@ func New( ) ExpenseCategory { e := &expenseCategory{ id: 0, + tenantID: uuid.Nil, name: name, description: "", // description is optional amount: amount, @@ -88,6 +97,7 @@ func New( type expenseCategory struct { id uint + tenantID uuid.UUID name string description string amount float64 @@ -100,6 +110,10 @@ func (e *expenseCategory) ID() uint { return e.id } +func (e *expenseCategory) TenantID() uuid.UUID { + return e.tenantID +} + func (e *expenseCategory) Name() string { return e.name } @@ -118,6 +132,7 @@ func (e *expenseCategory) UpdateAmount(a float64) ExpenseCategory { a, e.currency, WithID(e.id), + WithTenantID(e.tenantID), WithDescription(e.description), WithCreatedAt(e.createdAt), WithUpdatedAt(time.Now()), diff --git a/modules/finance/domain/aggregates/money_account/account.go b/modules/finance/domain/aggregates/money_account/account.go index 126baf9a..b5dc8696 100644 --- a/modules/finance/domain/aggregates/money_account/account.go +++ b/modules/finance/domain/aggregates/money_account/account.go @@ -1,9 +1,11 @@ package moneyaccount import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" "github.com/iota-uz/iota-sdk/modules/finance/domain/entities/transaction" - "time" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" @@ -12,6 +14,7 @@ import ( type Account struct { ID uint + TenantID uuid.UUID Name string AccountNumber string Description string @@ -63,6 +66,7 @@ func (p *UpdateDTO) ToEntity(id uint) (*Account, error) { } return &Account{ ID: id, + TenantID: uuid.Nil, Name: p.Name, AccountNumber: p.AccountNumber, Balance: p.Balance, diff --git a/modules/finance/domain/aggregates/money_account/account_dto.go b/modules/finance/domain/aggregates/money_account/account_dto.go index a8a59107..a674742b 100644 --- a/modules/finance/domain/aggregates/money_account/account_dto.go +++ b/modules/finance/domain/aggregates/money_account/account_dto.go @@ -5,6 +5,7 @@ import ( ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" "github.com/iota-uz/iota-sdk/pkg/constants" ) @@ -44,6 +45,7 @@ func (p *CreateDTO) ToEntity() (*Account, error) { } return &Account{ ID: 0, + TenantID: uuid.Nil, Name: p.Name, AccountNumber: p.AccountNumber, Balance: p.Balance, diff --git a/modules/finance/domain/aggregates/payment/payment.go b/modules/finance/domain/aggregates/payment/payment.go index 91e5495a..60731de1 100644 --- a/modules/finance/domain/aggregates/payment/payment.go +++ b/modules/finance/domain/aggregates/payment/payment.go @@ -1,15 +1,20 @@ package payment import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" moneyaccount "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/money_account" - "time" ) type Payment interface { ID() uint SetID(id uint) + TenantID() uuid.UUID + SetTenantID(id uuid.UUID) + Amount() float64 SetAmount(amount float64) diff --git a/modules/finance/domain/aggregates/payment/payment_dto.go b/modules/finance/domain/aggregates/payment/payment_dto.go index c8b29308..29e50403 100644 --- a/modules/finance/domain/aggregates/payment/payment_dto.go +++ b/modules/finance/domain/aggregates/payment/payment_dto.go @@ -1,13 +1,15 @@ package payment import ( + "time" + ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" moneyaccount "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/money_account" "github.com/iota-uz/iota-sdk/pkg/shared" - "time" ) var validate = validator.New(validator.WithRequiredStructEnabled()) @@ -90,6 +92,7 @@ func (p *UpdateDTO) ToEntity(id uint) Payment { return NewWithID( id, + uuid.Nil, // TenantID will be set in repository p.Amount, 0, p.CounterpartyID, diff --git a/modules/finance/domain/aggregates/payment/payment_impl.go b/modules/finance/domain/aggregates/payment/payment_impl.go index a92e6753..9dbe4ad8 100644 --- a/modules/finance/domain/aggregates/payment/payment_impl.go +++ b/modules/finance/domain/aggregates/payment/payment_impl.go @@ -1,13 +1,16 @@ package payment import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" moneyaccount "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/money_account" - "time" ) func NewWithID( id uint, + tenantID uuid.UUID, amount float64, transactionID, counterpartyID uint, comment string, @@ -17,6 +20,7 @@ func NewWithID( ) Payment { return &payment{ id: id, + tenantID: tenantID, amount: amount, account: account, transactionID: transactionID, @@ -40,6 +44,7 @@ func New( ) Payment { return NewWithID( 0, + uuid.Nil, amount, transactionID, counterpartyID, @@ -55,6 +60,7 @@ func New( type payment struct { id uint + tenantID uuid.UUID amount float64 transactionID uint counterpartyID uint @@ -139,3 +145,11 @@ func (p *payment) CreatedAt() time.Time { func (p *payment) UpdatedAt() time.Time { return p.updatedAt } + +func (p *payment) TenantID() uuid.UUID { + return p.tenantID +} + +func (p *payment) SetTenantID(id uuid.UUID) { + p.tenantID = id +} diff --git a/modules/finance/domain/entities/transaction/transaction.go b/modules/finance/domain/entities/transaction/transaction.go index bfe43689..4fb58c55 100644 --- a/modules/finance/domain/entities/transaction/transaction.go +++ b/modules/finance/domain/entities/transaction/transaction.go @@ -2,10 +2,13 @@ package transaction import ( "time" + + "github.com/google/uuid" ) type Transaction struct { ID uint + TenantID uuid.UUID Amount float64 OriginAccountID *uint DestinationAccountID *uint @@ -34,6 +37,7 @@ func NewDeposit( } return &Transaction{ ID: 0, + TenantID: uuid.Nil, Amount: amount, OriginAccountID: origAccID, DestinationAccountID: destAccID, @@ -63,6 +67,7 @@ func NewWithdrawal( } return &Transaction{ ID: 0, + TenantID: uuid.Nil, Amount: amount, OriginAccountID: origAccID, DestinationAccountID: destAccID, diff --git a/modules/finance/infrastructure/persistence/counterparty_repository.go b/modules/finance/infrastructure/persistence/counterparty_repository.go index b4640369..e3c96485 100644 --- a/modules/finance/infrastructure/persistence/counterparty_repository.go +++ b/modules/finance/infrastructure/persistence/counterparty_repository.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/iota-uz/iota-sdk/modules/finance/domain/entities/counterparty" "github.com/iota-uz/iota-sdk/modules/finance/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/composables" @@ -35,9 +36,10 @@ const ( legal_type, legal_address, created_at, - updated_at + updated_at, + tenant_id ) - VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id` + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` updateCounterpartyQuery = ` UPDATE counterparty SET name = $1, tin = $2, type = $3, legal_type = $4, legal_address = $5, updated_at = $6 @@ -106,6 +108,12 @@ func (g *GormCounterpartyRepository) Create(ctx context.Context, data counterpar if err != nil { return nil, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + args := []interface{}{ entity.Name, entity.Tin, @@ -114,6 +122,7 @@ func (g *GormCounterpartyRepository) Create(ctx context.Context, data counterpar entity.LegalAddress, entity.CreatedAt, entity.UpdatedAt, + tenant.ID, } row := tx.QueryRow(ctx, insertCounterpartyQuery, args...) var id uint diff --git a/modules/finance/infrastructure/persistence/expense_category_repository.go b/modules/finance/infrastructure/persistence/expense_category_repository.go index 185c87d2..175f2458 100644 --- a/modules/finance/infrastructure/persistence/expense_category_repository.go +++ b/modules/finance/infrastructure/persistence/expense_category_repository.go @@ -18,8 +18,9 @@ var ( const ( selectExpenseCategoryQuery = ` - SELECT + SELECT ec.id, + ec.tenant_id, ec.name, ec.description, ec.amount_currency_id, @@ -36,14 +37,15 @@ const ( countExpenseCategoryQuery = `SELECT COUNT(*) as count FROM expense_categories ec` insertExpenseCategoryQuery = ` INSERT INTO expense_categories ( - name, - description, - amount, + tenant_id, + name, + description, + amount, amount_currency_id ) - VALUES ($1, $2, $3, $4) RETURNING id` - updateExpenseCategoryQuery = `UPDATE expense_categories SET name = $1, description = $2, amount = $3, amount_currency_id = $4 WHERE id = $5` - deleteExpenseCategoryQuery = `DELETE FROM expense_categories WHERE id = $1` + VALUES ($1, $2, $3, $4, $5) RETURNING id` + updateExpenseCategoryQuery = `UPDATE expense_categories SET name = $1, description = $2, amount = $3, amount_currency_id = $4 WHERE id = $5 AND tenant_id = $6` + deleteExpenseCategoryQuery = `DELETE FROM expense_categories WHERE id = $1 AND tenant_id = $2` ) type GormExpenseCategoryRepository struct { @@ -64,9 +66,14 @@ func NewExpenseCategoryRepository() category.Repository { } } -func (g *GormExpenseCategoryRepository) buildCategoryFilters(params *category.FindParams) ([]string, []interface{}, error) { - where := []string{"1 = 1"} - args := []interface{}{} +func (g *GormExpenseCategoryRepository) buildCategoryFilters(ctx context.Context, params *category.FindParams) ([]string, []interface{}, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where := []string{"ec.tenant_id = $1"} + args := []interface{}{tenant.ID.String()} for _, filter := range params.Filters { column, ok := g.fieldMap[filter.Column] @@ -111,6 +118,7 @@ func (g *GormExpenseCategoryRepository) queryCategories(ctx context.Context, que var c coremodels.Currency if err := rows.Scan( &ec.ID, + &ec.TenantID, &ec.Name, &ec.Description, &ec.AmountCurrencyID, @@ -141,7 +149,7 @@ func (g *GormExpenseCategoryRepository) queryCategories(ctx context.Context, que func (g *GormExpenseCategoryRepository) GetPaginated( ctx context.Context, params *category.FindParams, ) ([]category.ExpenseCategory, error) { - where, args, err := g.buildCategoryFilters(params) + where, args, err := g.buildCategoryFilters(ctx, params) if err != nil { return nil, fmt.Errorf("failed to build filters: %w", err) } @@ -171,7 +179,7 @@ func (g *GormExpenseCategoryRepository) Count(ctx context.Context, params *categ return 0, fmt.Errorf("failed to get transaction: %w", err) } - where, args, err := g.buildCategoryFilters(params) + where, args, err := g.buildCategoryFilters(ctx, params) if err != nil { return 0, fmt.Errorf("failed to build filters: %w", err) } @@ -190,17 +198,31 @@ func (g *GormExpenseCategoryRepository) Count(ctx context.Context, params *categ } func (g *GormExpenseCategoryRepository) GetAll(ctx context.Context) ([]category.ExpenseCategory, error) { - query := selectExpenseCategoryQuery - return g.queryCategories(ctx, query) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + query := repo.Join( + selectExpenseCategoryQuery, + repo.JoinWhere("ec.tenant_id = $1"), + ) + + return g.queryCategories(ctx, query, tenant.ID.String()) } func (g *GormExpenseCategoryRepository) GetByID(ctx context.Context, id uint) (category.ExpenseCategory, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + query := repo.Join( selectExpenseCategoryQuery, - repo.JoinWhere("ec.id = $1"), + repo.JoinWhere("ec.id = $1 AND ec.tenant_id = $2"), ) - categories, err := g.queryCategories(ctx, query, id) + categories, err := g.queryCategories(ctx, query, id, tenant.ID.String()) if err != nil { return nil, fmt.Errorf("failed to get expense category with ID: %d: %w", id, err) } @@ -215,11 +237,20 @@ func (g *GormExpenseCategoryRepository) Create(ctx context.Context, data categor if err != nil { return nil, fmt.Errorf("failed to get transaction: %w", err) } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + dbRow := toDBExpenseCategory(data) + dbRow.TenantID = tenant.ID.String() + var id uint if err := tx.QueryRow( ctx, insertExpenseCategoryQuery, + dbRow.TenantID, dbRow.Name, dbRow.Description, dbRow.Amount, @@ -235,7 +266,15 @@ func (g *GormExpenseCategoryRepository) Update(ctx context.Context, data categor if err != nil { return nil, fmt.Errorf("failed to get transaction: %w", err) } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + dbRow := toDBExpenseCategory(data) + dbRow.TenantID = tenant.ID.String() + if _, err := tx.Exec( ctx, updateExpenseCategoryQuery, @@ -244,6 +283,7 @@ func (g *GormExpenseCategoryRepository) Update(ctx context.Context, data categor dbRow.Amount, dbRow.AmountCurrencyID, data.ID(), + dbRow.TenantID, ); err != nil { return nil, fmt.Errorf("failed to update expense category with ID: %d: %w", data.ID(), err) } @@ -255,7 +295,13 @@ func (g *GormExpenseCategoryRepository) Delete(ctx context.Context, id uint) err if err != nil { return fmt.Errorf("failed to get transaction: %w", err) } - if _, err := tx.Exec(ctx, deleteExpenseCategoryQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + if _, err := tx.Exec(ctx, deleteExpenseCategoryQuery, id, tenant.ID.String()); err != nil { return fmt.Errorf("failed to delete expense category with ID: %d: %w", id, err) } return nil diff --git a/modules/finance/infrastructure/persistence/expense_category_repository_test.go b/modules/finance/infrastructure/persistence/expense_category_repository_test.go index 57b4dcd8..d6a4068f 100644 --- a/modules/finance/infrastructure/persistence/expense_category_repository_test.go +++ b/modules/finance/infrastructure/persistence/expense_category_repository_test.go @@ -1,11 +1,12 @@ package persistence_test import ( + "testing" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" corepersistence "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence" - "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/expense_category" + category "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/expense_category" "github.com/iota-uz/iota-sdk/modules/finance/infrastructure/persistence" - "testing" ) func TestGormExpenseCategoryRepository_CRUD(t *testing.T) { diff --git a/modules/finance/infrastructure/persistence/finance_mappers.go b/modules/finance/infrastructure/persistence/finance_mappers.go index 85eaf69c..a1bbacdb 100644 --- a/modules/finance/infrastructure/persistence/finance_mappers.go +++ b/modules/finance/infrastructure/persistence/finance_mappers.go @@ -1,6 +1,7 @@ package persistence import ( + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/country" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" @@ -20,6 +21,7 @@ import ( func toDBTransaction(entity *transaction.Transaction) *models.Transaction { return &models.Transaction{ ID: entity.ID, + TenantID: entity.TenantID.String(), Amount: entity.Amount, Comment: entity.Comment, AccountingPeriod: entity.AccountingPeriod, @@ -36,9 +38,14 @@ func toDomainTransaction(dbTransaction *models.Transaction) (*transaction.Transa if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbTransaction.TenantID) + if err != nil { + return nil, err + } return &transaction.Transaction{ ID: dbTransaction.ID, + TenantID: tenantID, Amount: dbTransaction.Amount, TransactionType: _type, Comment: dbTransaction.Comment, @@ -53,6 +60,7 @@ func toDomainTransaction(dbTransaction *models.Transaction) (*transaction.Transa func toDBPayment(entity payment.Payment) (*models.Payment, *models.Transaction) { dbTransaction := &models.Transaction{ ID: entity.TransactionID(), + TenantID: entity.TenantID().String(), Amount: entity.Amount(), Comment: entity.Comment(), AccountingPeriod: entity.AccountingPeriod(), @@ -82,9 +90,14 @@ func toDomainPayment(dbPayment *models.Payment, dbTransaction *models.Transactio if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbTransaction.TenantID) + if err != nil { + return nil, err + } return payment.NewWithID( dbPayment.ID, + tenantID, t.Amount, t.ID, dbPayment.CounterpartyID, @@ -106,6 +119,7 @@ func toDomainPayment(dbPayment *models.Payment, dbTransaction *models.Transactio func toDBExpenseCategory(entity category.ExpenseCategory) *models.ExpenseCategory { return &models.ExpenseCategory{ ID: entity.ID(), + TenantID: entity.TenantID().String(), Name: entity.Name(), Description: mapping.ValueToSQLNullString(entity.Description()), Amount: entity.Amount(), @@ -120,8 +134,15 @@ func toDomainExpenseCategory(dbCategory *models.ExpenseCategory, dbCurrency *cor if err != nil { return nil, err } + + tenantID, err := uuid.Parse(dbCategory.TenantID) + if err != nil { + return nil, err + } + opts := []category.Option{ category.WithID(dbCategory.ID), + category.WithTenantID(tenantID), category.WithCreatedAt(dbCategory.CreatedAt), category.WithUpdatedAt(dbCategory.UpdatedAt), } @@ -143,8 +164,14 @@ func toDomainMoneyAccount(dbAccount *models.MoneyAccount) (*moneyaccount.Account if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbAccount.TenantID) + if err != nil { + return nil, err + } + return &moneyaccount.Account{ ID: dbAccount.ID, + TenantID: tenantID, Name: dbAccount.Name, AccountNumber: dbAccount.AccountNumber, Balance: dbAccount.Balance, @@ -158,6 +185,7 @@ func toDomainMoneyAccount(dbAccount *models.MoneyAccount) (*moneyaccount.Account func toDBMoneyAccount(entity *moneyaccount.Account) *models.MoneyAccount { return &models.MoneyAccount{ ID: entity.ID, + TenantID: entity.TenantID.String(), Name: entity.Name, AccountNumber: entity.AccountNumber, Balance: entity.Balance, @@ -170,12 +198,18 @@ func toDBMoneyAccount(entity *moneyaccount.Account) *models.MoneyAccount { } func toDomainExpense(dbExpense *models.Expense, dbTransaction *models.Transaction) (expense.Expense, error) { + tenantID, err := uuid.Parse(dbTransaction.TenantID) + if err != nil { + return nil, err + } + account := moneyaccount.Account{ID: *dbTransaction.OriginAccountID} //nolint:exhaustruct expenseCategory := category.New( "", // name - will be populated when actual category is fetched 0.0, // amount - will be populated when actual category is fetched nil, // currency - will be populated when actual category is fetched category.WithID(dbExpense.CategoryID), + category.WithTenantID(tenantID), category.WithCreatedAt(dbExpense.CreatedAt), category.WithUpdatedAt(dbExpense.UpdatedAt), ) diff --git a/modules/finance/infrastructure/persistence/models/models.go b/modules/finance/infrastructure/persistence/models/models.go index 738dd689..09882e6b 100644 --- a/modules/finance/infrastructure/persistence/models/models.go +++ b/modules/finance/infrastructure/persistence/models/models.go @@ -2,12 +2,14 @@ package models import ( "database/sql" - coremodels "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" "time" + + coremodels "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" ) type ExpenseCategory struct { ID uint + TenantID string Name string Description sql.NullString Amount float64 @@ -18,6 +20,7 @@ type ExpenseCategory struct { type MoneyAccount struct { ID uint + TenantID string Name string AccountNumber string Description string @@ -30,6 +33,7 @@ type MoneyAccount struct { type Transaction struct { ID uint + TenantID string Amount float64 OriginAccountID *uint DestinationAccountID *uint @@ -58,6 +62,7 @@ type Payment struct { type Counterparty struct { ID uint + TenantID uint Tin string Name string Type string diff --git a/modules/finance/infrastructure/persistence/money_account_repository.go b/modules/finance/infrastructure/persistence/money_account_repository.go index 98e3e641..bd4ab96c 100644 --- a/modules/finance/infrastructure/persistence/money_account_repository.go +++ b/modules/finance/infrastructure/persistence/money_account_repository.go @@ -3,12 +3,13 @@ package persistence import ( "context" "fmt" + "github.com/go-faster/errors" "github.com/iota-uz/iota-sdk/modules/finance/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/repo" coremodels "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" - "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/money_account" + moneyaccount "github.com/iota-uz/iota-sdk/modules/finance/domain/aggregates/money_account" "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/mapping" ) @@ -20,6 +21,7 @@ var ( const ( findQuery = ` SELECT ma.id, + ma.tenant_id, ma.name, ma.account_number, ma.description, @@ -33,13 +35,14 @@ const ( c.created_at, c.updated_at FROM money_accounts ma LEFT JOIN currencies c ON c.code = ma.balance_currency_id` - countQuery = `SELECT COUNT(*) as count FROM money_accounts` + countQuery = `SELECT COUNT(*) as count FROM money_accounts WHERE tenant_id = $1` recalculateBalanceQuery = ` UPDATE money_accounts SET balance = (SELECT sum(t.amount) FROM transactions t WHERE origin_account_id = $1 OR destination_account_id = $2) - WHERE id = $3` + WHERE id = $3 AND tenant_id = $4` insertQuery = ` INSERT INTO money_accounts ( + tenant_id, name, account_number, description, @@ -48,13 +51,13 @@ const ( created_at, updated_at ) - VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id` + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` updateQuery = ` UPDATE money_accounts SET name = $1, account_number = $2, description = $3, balance = $4, balance_currency_id = $5, updated_at = $6 - WHERE id = $7` - deleteRelatedQuery = `DELETE FROM transactions WHERE origin_account_id = $1 OR destination_account_id = $1;` - deleteQuery = `DELETE FROM money_accounts WHERE id = $1;` + WHERE id = $7 AND tenant_id = $8` + deleteRelatedQuery = `DELETE FROM transactions WHERE origin_account_id = $1 OR destination_account_id = $1 AND tenant_id = $2;` + deleteQuery = `DELETE FROM money_accounts WHERE id = $1 AND tenant_id = $2;` ) type GormMoneyAccountRepository struct{} @@ -64,14 +67,20 @@ func NewMoneyAccountRepository() moneyaccount.Repository { } func (g *GormMoneyAccountRepository) GetPaginated(ctx context.Context, params *moneyaccount.FindParams) ([]*moneyaccount.Account, error) { - var args []interface{} - where := []string{"1 = 1"} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where := []string{"ma.tenant_id = $1"} + args := []interface{}{tenant.ID} + if params.CreatedAt.To != "" && params.CreatedAt.From != "" { - where = append(where, fmt.Sprintf("wo.created_at BETWEEN $%d and $%d", len(where), len(where)+1)) + where = append(where, fmt.Sprintf("ma.created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)) args = append(args, params.CreatedAt.From, params.CreatedAt.To) } if params.Query != "" && params.Field != "" { - where = append(where, fmt.Sprintf("wo.%s::VARCHAR ILIKE $%d", params.Field, len(where))) + where = append(where, fmt.Sprintf("ma.%s::VARCHAR ILIKE $%d", params.Field, len(args)+1)) args = append(args, "%"+params.Query+"%") } q := repo.Join( @@ -83,23 +92,39 @@ func (g *GormMoneyAccountRepository) GetPaginated(ctx context.Context, params *m } func (g *GormMoneyAccountRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 - if err := tx.QueryRow(ctx, countQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, countQuery, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormMoneyAccountRepository) GetAll(ctx context.Context) ([]*moneyaccount.Account, error) { - return g.queryAccounts(ctx, findQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + query := repo.Join(findQuery, "WHERE ma.tenant_id = $1") + return g.queryAccounts(ctx, query, tenant.ID) } func (g *GormMoneyAccountRepository) GetByID(ctx context.Context, id uint) (*moneyaccount.Account, error) { - accounts, err := g.queryAccounts(ctx, repo.Join(findQuery, "WHERE id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + accounts, err := g.queryAccounts(ctx, repo.Join(findQuery, "WHERE ma.id = $1 AND ma.tenant_id = $2"), id, tenant.ID) if err != nil { return nil, err } @@ -110,7 +135,12 @@ func (g *GormMoneyAccountRepository) GetByID(ctx context.Context, id uint) (*mon } func (g *GormMoneyAccountRepository) RecalculateBalance(ctx context.Context, id uint) error { - err := g.execQuery(ctx, recalculateBalanceQuery, id, id, id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + err = g.execQuery(ctx, recalculateBalanceQuery, id, id, id, tenant.ID) if err != nil { return errors.Wrap(err, "failed to recalculate balance") } @@ -118,12 +148,19 @@ func (g *GormMoneyAccountRepository) RecalculateBalance(ctx context.Context, id } func (g *GormMoneyAccountRepository) Create(ctx context.Context, data *moneyaccount.Account) (*moneyaccount.Account, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + data.TenantID = tenant.ID entity := toDBMoneyAccount(data) tx, err := composables.UseTx(ctx) if err != nil { return nil, err } args := []interface{}{ + entity.TenantID, entity.Name, entity.AccountNumber, entity.Description, @@ -141,6 +178,12 @@ func (g *GormMoneyAccountRepository) Create(ctx context.Context, data *moneyacco } func (g *GormMoneyAccountRepository) Update(ctx context.Context, data *moneyaccount.Account) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + data.TenantID = tenant.ID dbAccount := toDBMoneyAccount(data) args := []interface{}{ dbAccount.Name, @@ -150,15 +193,21 @@ func (g *GormMoneyAccountRepository) Update(ctx context.Context, data *moneyacco dbAccount.BalanceCurrencyID, dbAccount.UpdatedAt, dbAccount.ID, + dbAccount.TenantID, } return g.execQuery(ctx, updateQuery, args...) } func (g *GormMoneyAccountRepository) Delete(ctx context.Context, id uint) error { - if err := g.execQuery(ctx, deleteRelatedQuery, id); err != nil { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + if err := g.execQuery(ctx, deleteRelatedQuery, id, tenant.ID); err != nil { return err } - return g.execQuery(ctx, deleteQuery, id) + return g.execQuery(ctx, deleteQuery, id, tenant.ID) } func (g *GormMoneyAccountRepository) queryAccounts(ctx context.Context, query string, args ...interface{}) ([]*moneyaccount.Account, error) { @@ -178,6 +227,7 @@ func (g *GormMoneyAccountRepository) queryAccounts(ctx context.Context, query st } if err := rows.Scan( &r.ID, + &r.TenantID, &r.Name, &r.AccountNumber, &r.Description, diff --git a/modules/finance/infrastructure/persistence/payment_repository.go b/modules/finance/infrastructure/persistence/payment_repository.go index e955ef2a..dd6a62e0 100644 --- a/modules/finance/infrastructure/persistence/payment_repository.go +++ b/modules/finance/infrastructure/persistence/payment_repository.go @@ -3,6 +3,7 @@ package persistence import ( "context" "fmt" + "github.com/go-faster/errors" "github.com/iota-uz/iota-sdk/modules/finance/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/repo" @@ -22,6 +23,7 @@ const ( p.created_at, p.updated_at, t.id, + t.tenant_id, t.amount, t.destination_account_id, t.origin_account_id, @@ -31,7 +33,7 @@ const ( t.comment, t.created_at FROM payments p LEFT JOIN transactions t ON t.id = p.transaction_id` - paymentCountQuery = `SELECT COUNT(*) as count FROM payments` + paymentCountQuery = `SELECT COUNT(*) as count FROM payments p LEFT JOIN transactions t ON t.id = p.transaction_id WHERE t.tenant_id = $1` paymentInsertQuery = ` INSERT INTO payments ( counterparty_id, @@ -40,8 +42,8 @@ const ( updated_at ) VALUES ($1, $2, $3, $4) RETURNING id` - paymentUpdateQuery = `UPDATE payments SET counterparty_id = $1, updated_at = $2 WHERE id = $5` - paymentDeleteRelatedQuery = `DELETE FROM transactions WHERE id = $1` + paymentUpdateQuery = `UPDATE payments SET counterparty_id = $1, updated_at = $2 WHERE id = $3` + paymentDeleteRelatedQuery = `DELETE FROM transactions WHERE id = $1 AND tenant_id = $2` paymentDeleteQuery = `DELETE FROM payments WHERE id = $1` ) @@ -52,14 +54,20 @@ func NewPaymentRepository() payment.Repository { } func (g *GormPaymentRepository) GetPaginated(ctx context.Context, params *payment.FindParams) ([]payment.Payment, error) { - var args []interface{} - where := []string{"1 = 1"} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where := []string{"t.tenant_id = $1"} + args := []interface{}{tenant.ID} + if params.CreatedAt.To != "" && params.CreatedAt.From != "" { - where = append(where, fmt.Sprintf("p.created_at BETWEEN $%d and $%d", len(where), len(where)+1)) + where = append(where, fmt.Sprintf("p.created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)) args = append(args, params.CreatedAt.From, params.CreatedAt.To) } if params.Query != "" && params.Field != "" { - where = append(where, fmt.Sprintf("p.%s::VARCHAR ILIKE $%d", params.Field, len(where))) + where = append(where, fmt.Sprintf("p.%s::VARCHAR ILIKE $%d", params.Field, len(args)+1)) args = append(args, "%"+params.Query+"%") } q := repo.Join( @@ -71,23 +79,39 @@ func (g *GormPaymentRepository) GetPaginated(ctx context.Context, params *paymen } func (g *GormPaymentRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 - if err := tx.QueryRow(ctx, paymentCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, paymentCountQuery, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormPaymentRepository) GetAll(ctx context.Context) ([]payment.Payment, error) { - return g.queryPayments(ctx, paymentFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + query := repo.Join(paymentFindQuery, "WHERE t.tenant_id = $1") + return g.queryPayments(ctx, query, tenant.ID) } func (g *GormPaymentRepository) GetByID(ctx context.Context, id uint) (payment.Payment, error) { - payments, err := g.queryPayments(ctx, repo.Join(paymentFindQuery, "WHERE p.id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + payments, err := g.queryPayments(ctx, repo.Join(paymentFindQuery, "WHERE p.id = $1 AND t.tenant_id = $2"), id, tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get payment by id") } @@ -98,6 +122,14 @@ func (g *GormPaymentRepository) GetByID(ctx context.Context, id uint) (payment.P } func (g *GormPaymentRepository) Create(ctx context.Context, data payment.Payment) (payment.Payment, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + // Set tenant ID on the domain entity + data.SetTenantID(tenant.ID) + dbPayment, dbTransaction := toDBPayment(data) tx, err := composables.UseTx(ctx) if err != nil { @@ -106,11 +138,12 @@ func (g *GormPaymentRepository) Create(ctx context.Context, data payment.Payment if err := tx.QueryRow( ctx, transactionInsertQuery, + dbTransaction.TenantID, dbTransaction.Amount, dbTransaction.OriginAccountID, dbTransaction.DestinationAccountID, - dbTransaction.AccountingPeriod, dbTransaction.TransactionDate, + dbTransaction.AccountingPeriod, dbTransaction.TransactionType, dbTransaction.Comment, ).Scan(&dbPayment.TransactionID); err != nil { @@ -132,6 +165,14 @@ func (g *GormPaymentRepository) Create(ctx context.Context, data payment.Payment } func (g *GormPaymentRepository) Update(ctx context.Context, data payment.Payment) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + // Set tenant ID on the domain entity + data.SetTenantID(tenant.ID) + dbPayment, dbTransaction := toDBPayment(data) if err := g.execQuery( ctx, @@ -153,10 +194,16 @@ func (g *GormPaymentRepository) Update(ctx context.Context, data payment.Payment dbTransaction.TransactionType, dbTransaction.Comment, dbTransaction.ID, + dbTransaction.TenantID, ) } func (g *GormPaymentRepository) Delete(ctx context.Context, id uint) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + entity, err := g.GetByID(ctx, id) if err != nil { return err @@ -164,7 +211,7 @@ func (g *GormPaymentRepository) Delete(ctx context.Context, id uint) error { if err := g.execQuery(ctx, paymentDeleteQuery, id); err != nil { return err } - return g.execQuery(ctx, paymentDeleteRelatedQuery, entity.TransactionID()) + return g.execQuery(ctx, paymentDeleteRelatedQuery, entity.TransactionID(), tenant.ID) } func (g *GormPaymentRepository) queryPayments(ctx context.Context, query string, args ...interface{}) ([]payment.Payment, error) { @@ -187,6 +234,7 @@ func (g *GormPaymentRepository) queryPayments(ctx context.Context, query string, &paymentRow.CreatedAt, &paymentRow.UpdatedAt, &transactionRow.ID, + &transactionRow.TenantID, &transactionRow.Amount, &transactionRow.DestinationAccountID, &transactionRow.OriginAccountID, diff --git a/modules/finance/infrastructure/persistence/schema/finance-schema.sql b/modules/finance/infrastructure/persistence/schema/finance-schema.sql index a5d76354..b5d63323 100644 --- a/modules/finance/infrastructure/persistence/schema/finance-schema.sql +++ b/modules/finance/infrastructure/persistence/schema/finance-schema.sql @@ -1,12 +1,14 @@ CREATE TABLE counterparty ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, tin varchar(20), name varchar(255) NOT NULL, type VARCHAR(255) NOT NULL, -- customer, supplier, individual legal_type varchar(255) NOT NULL, -- LLC, JSC, etc. legal_address varchar(255), created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, tin) ); CREATE TABLE counterparty_contacts ( @@ -23,38 +25,45 @@ CREATE TABLE counterparty_contacts ( CREATE TABLE inventory ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, description text, currency_id varchar(3) REFERENCES currencies (code) ON DELETE SET NULL, price numeric(9, 2) NOT NULL, quantity int NOT NULL, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE expense_categories ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, description text, amount numeric(9, 2) NOT NULL, amount_currency_id varchar(3) NOT NULL REFERENCES currencies (code) ON DELETE RESTRICT, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE money_accounts ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, account_number varchar(255) NOT NULL, description text, balance numeric(9, 2) NOT NULL, balance_currency_id varchar(3) NOT NULL REFERENCES currencies (code) ON DELETE CASCADE, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, account_number) ); CREATE TABLE transactions ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, amount numeric(9, 2) NOT NULL, origin_account_id int REFERENCES money_accounts (id) ON DELETE RESTRICT, destination_account_id int REFERENCES money_accounts (id) ON DELETE RESTRICT, @@ -89,15 +98,25 @@ CREATE INDEX payments_counterparty_id_idx ON payments (counterparty_id); CREATE INDEX payments_transaction_id_idx ON payments (transaction_id); +CREATE INDEX transactions_tenant_id_idx ON transactions (tenant_id); + CREATE INDEX transactions_destination_account_id_idx ON transactions (destination_account_id); CREATE INDEX transactions_origin_account_id_idx ON transactions (origin_account_id); +CREATE INDEX counterparty_tenant_id_idx ON counterparty (tenant_id); + CREATE INDEX counterparty_contacts_counterparty_id_idx ON counterparty_contacts (counterparty_id); CREATE INDEX counterparty_tin_idx ON counterparty (tin); +CREATE INDEX inventory_tenant_id_idx ON inventory (tenant_id); + CREATE INDEX inventory_currency_id_idx ON inventory (currency_id); +CREATE INDEX expense_categories_tenant_id_idx ON expense_categories (tenant_id); + +CREATE INDEX money_accounts_tenant_id_idx ON money_accounts (tenant_id); + CREATE INDEX money_accounts_balance_currency_id_idx ON money_accounts (balance_currency_id); diff --git a/modules/finance/infrastructure/persistence/setup_test.go b/modules/finance/infrastructure/persistence/setup_test.go index 426d7ed9..95f7ab67 100644 --- a/modules/finance/infrastructure/persistence/setup_test.go +++ b/modules/finance/infrastructure/persistence/setup_test.go @@ -2,13 +2,14 @@ package persistence_test import ( "context" + "os" + "testing" + "github.com/iota-uz/iota-sdk/modules" "github.com/iota-uz/iota-sdk/pkg/application" "github.com/iota-uz/iota-sdk/pkg/composables" "github.com/iota-uz/iota-sdk/pkg/testutils" "github.com/jackc/pgx/v5/pgxpool" - "os" - "testing" ) func TestMain(m *testing.M) { @@ -46,11 +47,25 @@ func setupTest(t *testing.T) *testFixtures { }) ctx = composables.WithTx(ctx, tx) + + // Setup application and run migrations app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) if err != nil { t.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, diff --git a/modules/finance/infrastructure/persistence/transaction_repository.go b/modules/finance/infrastructure/persistence/transaction_repository.go index 1e45b7fa..b8178e5f 100644 --- a/modules/finance/infrastructure/persistence/transaction_repository.go +++ b/modules/finance/infrastructure/persistence/transaction_repository.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/iota-uz/iota-sdk/modules/finance/domain/entities/transaction" "github.com/iota-uz/iota-sdk/modules/finance/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/composables" @@ -18,6 +19,7 @@ var ( const ( transactionFindQuery = ` SELECT id, + tenant_id, amount, origin_account_id, destination_account_id, @@ -27,9 +29,10 @@ const ( comment, created_at FROM transactions` - transactionCountQuery = `SELECT COUNT(*) as count FROM transactions` + transactionCountQuery = `SELECT COUNT(*) as count FROM transactions WHERE tenant_id = $1` transactionInsertQuery = ` INSERT INTO transactions ( + tenant_id, amount, origin_account_id, destination_account_id, @@ -38,7 +41,7 @@ const ( transaction_type, comment ) - VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id` + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` transactionUpdateQuery = ` UPDATE transactions SET amount = $1, @@ -48,8 +51,8 @@ const ( accounting_period = $5, transaction_type = $6, comment = $7 - WHERE id = $8` - transactionDeleteQuery = `DELETE FROM transactions WHERE id = $1` + WHERE id = $8 AND tenant_id = $9` + transactionDeleteQuery = `DELETE FROM transactions WHERE id = $1 AND tenant_id = $2` ) type GormTransactionRepository struct{} @@ -59,10 +62,16 @@ func NewTransactionRepository() transaction.Repository { } func (g *GormTransactionRepository) GetPaginated(ctx context.Context, params *transaction.FindParams) ([]*transaction.Transaction, error) { - where := []string{"1 = 1"} - var args []interface{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where := []string{"tenant_id = $1"} + args := []interface{}{tenant.ID} + if params.CreatedAt.To != "" && params.CreatedAt.From != "" { - where = append(where, fmt.Sprintf("created_at BETWEEN $%d and $%d", len(where), len(where)+1)) + where = append(where, fmt.Sprintf("created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)) args = append(args, params.CreatedAt.From, params.CreatedAt.To) } q := repo.Join( @@ -75,23 +84,39 @@ func (g *GormTransactionRepository) GetPaginated(ctx context.Context, params *tr } func (g *GormTransactionRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 - if err := tx.QueryRow(ctx, transactionCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, transactionCountQuery, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormTransactionRepository) GetAll(ctx context.Context) ([]*transaction.Transaction, error) { - return g.queryTransactions(ctx, transactionFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + query := repo.Join(transactionFindQuery, "WHERE tenant_id = $1") + return g.queryTransactions(ctx, query, tenant.ID) } func (g *GormTransactionRepository) GetByID(ctx context.Context, id uint) (*transaction.Transaction, error) { - transactions, err := g.queryTransactions(ctx, repo.Join(transactionFindQuery, "WHERE id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + transactions, err := g.queryTransactions(ctx, repo.Join(transactionFindQuery, "WHERE id = $1 AND tenant_id = $2"), id, tenant.ID) if err != nil { return nil, err } @@ -102,12 +127,19 @@ func (g *GormTransactionRepository) GetByID(ctx context.Context, id uint) (*tran } func (g *GormTransactionRepository) Create(ctx context.Context, data *transaction.Transaction) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + data.TenantID = tenant.ID entity := toDBTransaction(data) tx, err := composables.UseTx(ctx) if err != nil { return err } args := []interface{}{ + entity.TenantID, entity.Amount, entity.OriginAccountID, entity.DestinationAccountID, @@ -120,6 +152,12 @@ func (g *GormTransactionRepository) Create(ctx context.Context, data *transactio } func (g *GormTransactionRepository) Update(ctx context.Context, data *transaction.Transaction) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + data.TenantID = tenant.ID dbTransaction := toDBTransaction(data) args := []interface{}{ dbTransaction.Amount, @@ -130,12 +168,18 @@ func (g *GormTransactionRepository) Update(ctx context.Context, data *transactio dbTransaction.TransactionType, dbTransaction.Comment, dbTransaction.ID, + dbTransaction.TenantID, } return g.execQuery(ctx, transactionUpdateQuery, args...) } func (g *GormTransactionRepository) Delete(ctx context.Context, id uint) error { - return g.execQuery(ctx, transactionDeleteQuery, id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.execQuery(ctx, transactionDeleteQuery, id, tenant.ID) } func (g *GormTransactionRepository) queryTransactions(ctx context.Context, query string, args ...interface{}) ([]*transaction.Transaction, error) { @@ -154,6 +198,7 @@ func (g *GormTransactionRepository) queryTransactions(ctx context.Context, query r := &models.Transaction{} if err := rows.Scan( &r.ID, + &r.TenantID, &r.Amount, &r.OriginAccountID, &r.DestinationAccountID, diff --git a/modules/finance/services/payment_service_test.go b/modules/finance/services/payment_service_test.go index e5a4a3a9..a3c4860e 100644 --- a/modules/finance/services/payment_service_test.go +++ b/modules/finance/services/payment_service_test.go @@ -2,10 +2,11 @@ package services_test import ( "context" - "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/country" "testing" "time" + "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/country" + "github.com/iota-uz/iota-sdk/modules" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/permission" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/tax" @@ -65,6 +66,18 @@ func setupTest(t *testing.T, permissions ...*permission.Permission) *testFixture publisher := eventbus.NewEventPublisher(logging.ConsoleLogger(logrus.WarnLevel)) app := setupApplication(t, pool, publisher) + // Run migrations to ensure all tables are created (including tenants table) + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context (after migrations have created the table) + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, @@ -115,6 +128,7 @@ func setupTestData(ctx context.Context, t *testing.T, f *testFixtures) { t.Fatal(err) } + // Create the counterparty - the repository itself will set the tenant ID _, err = counterpartyRepo.Create(ctx, counterparty.New( tin, "Test", diff --git a/modules/hrm/domain/aggregates/employee/employee.go b/modules/hrm/domain/aggregates/employee/employee.go index ba63544c..1ed120d4 100644 --- a/modules/hrm/domain/aggregates/employee/employee.go +++ b/modules/hrm/domain/aggregates/employee/employee.go @@ -1,9 +1,11 @@ package employee import ( - "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" "time" + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/passport" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/money" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/tax" @@ -16,6 +18,7 @@ type Language interface { type Employee interface { ID() uint + TenantID() uuid.UUID FirstName() string LastName() string MiddleName() string @@ -45,6 +48,7 @@ type Employee interface { func NewWithID( id uint, + tenantID uuid.UUID, firstName, lastName, middleName, phone string, email internet.Email, salary money.Amount, @@ -59,6 +63,7 @@ func NewWithID( ) Employee { return &employee{ id: id, + tenantID: tenantID, firstName: firstName, lastName: lastName, middleName: middleName, @@ -91,6 +96,7 @@ func New( ) (Employee, error) { return &employee{ id: 0, + tenantID: uuid.Nil, // Will be set in repository firstName: firstName, lastName: lastName, middleName: middleName, @@ -111,6 +117,7 @@ func New( type employee struct { id uint + tenantID uuid.UUID firstName string lastName string middleName string @@ -134,6 +141,10 @@ func (e *employee) ID() uint { return e.id } +func (e *employee) TenantID() uuid.UUID { + return e.tenantID +} + func (e *employee) FirstName() string { return e.firstName } diff --git a/modules/hrm/domain/aggregates/employee/employee_update_dto.go b/modules/hrm/domain/aggregates/employee/employee_update_dto.go index 8d2a17f2..05a038e3 100644 --- a/modules/hrm/domain/aggregates/employee/employee_update_dto.go +++ b/modules/hrm/domain/aggregates/employee/employee_update_dto.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/money" @@ -119,6 +120,7 @@ func (d *UpdateDTO) ToEntity(id uint) (Employee, error) { } return NewWithID( id, + uuid.Nil, d.FirstName, d.LastName, d.MiddleName, diff --git a/modules/hrm/domain/entities/position/position.go b/modules/hrm/domain/entities/position/position.go index 1eaf8f39..ffea0ad6 100644 --- a/modules/hrm/domain/entities/position/position.go +++ b/modules/hrm/domain/entities/position/position.go @@ -6,6 +6,7 @@ import ( type Position struct { ID uint + TenantID string Name string Description string CreatedAt time.Time diff --git a/modules/hrm/infrastructure/persistence/employee_repository.go b/modules/hrm/infrastructure/persistence/employee_repository.go index 32027836..b9c18387 100644 --- a/modules/hrm/infrastructure/persistence/employee_repository.go +++ b/modules/hrm/infrastructure/persistence/employee_repository.go @@ -19,6 +19,7 @@ var ( const ( employeeFindQuery = ` SELECT e.id, + e.tenant_id, e.first_name, e.last_name, e.middle_name, @@ -40,12 +41,12 @@ const ( em.resignation_date FROM employees e LEFT JOIN employee_meta em ON e.id = em.employee_id` - employeeCountQuery = `SELECT COUNT(*) as count FROM employees` + employeeCountQuery = `SELECT COUNT(*) as count FROM employees WHERE tenant_id = $1` employeeInsertQuery = ` INSERT INTO employees ( - first_name, last_name, middle_name, email, phone, salary, salary_currency_id, + tenant_id, first_name, last_name, middle_name, email, phone, salary, salary_currency_id, hourly_rate, coefficient, avatar_id, created_at, updated_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id` + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING id` employeeMetaInsertQuery = ` INSERT INTO employee_meta ( employee_id, @@ -64,7 +65,7 @@ const ( SET first_name = $1, last_name = $2, middle_name = $3, email = $4, phone = $5, salary = $6, salary_currency_id = $7, hourly_rate = $8, coefficient = $9, avatar_id = $10, updated_at = $11 - WHERE id = $12` + WHERE id = $12 AND tenant_id = $13` employeeUpdateMetaQuery = ` UPDATE employee_meta @@ -72,7 +73,7 @@ const ( birth_date = $5, hire_date = $6, resignation_date = $7 WHERE employee_id = $8` - employeeDeleteQuery = `DELETE FROM employees WHERE id = $1` + employeeDeleteQuery = `DELETE FROM employees WHERE id = $1 AND tenant_id = $2` employeeMetaDeleteQuery = `DELETE FROM employee_meta WHERE employee_id = $1` ) @@ -83,11 +84,17 @@ func NewEmployeeRepository() employee.Repository { } func (g *GormEmployeeRepository) GetPaginated(ctx context.Context, params *employee.FindParams) ([]employee.Employee, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + var args []interface{} - where := []string{"1 = 1"} + where := []string{"e.tenant_id = $1"} + args = append(args, tenant.ID) if params.Query != "" && params.Field != "" { - where = append(where, fmt.Sprintf("e.%s::VARCHAR ILIKE $%d", params.Field, len(where))) + where = append(where, fmt.Sprintf("e.%s::VARCHAR ILIKE $%d", params.Field, len(args)+1)) args = append(args, "%"+params.Query+"%") } @@ -100,23 +107,36 @@ func (g *GormEmployeeRepository) GetPaginated(ctx context.Context, params *emplo } func (g *GormEmployeeRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 - if err := tx.QueryRow(ctx, employeeCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, employeeCountQuery, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormEmployeeRepository) GetAll(ctx context.Context) ([]employee.Employee, error) { - return g.queryEmployees(ctx, employeeFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + return g.queryEmployees(ctx, repo.Join(employeeFindQuery, "WHERE e.tenant_id = $1"), tenant.ID) } func (g *GormEmployeeRepository) GetByID(ctx context.Context, id uint) (employee.Employee, error) { - employees, err := g.queryEmployees(ctx, repo.Join(employeeFindQuery, "WHERE e.id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + employees, err := g.queryEmployees(ctx, repo.Join(employeeFindQuery, "WHERE e.id = $1 AND e.tenant_id = $2"), id, tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get employee by id") } @@ -127,7 +147,14 @@ func (g *GormEmployeeRepository) GetByID(ctx context.Context, id uint) (employee } func (g *GormEmployeeRepository) Create(ctx context.Context, data employee.Employee) (employee.Employee, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + dbEmployee, dbMeta := toDBEmployee(data) + dbEmployee.TenantID = tenant.ID.String() // Set tenant ID from context + tx, err := composables.UseTx(ctx) if err != nil { return nil, err @@ -135,6 +162,7 @@ func (g *GormEmployeeRepository) Create(ctx context.Context, data employee.Emplo row := tx.QueryRow( ctx, employeeInsertQuery, + dbEmployee.TenantID, dbEmployee.FirstName, dbEmployee.LastName, dbEmployee.MiddleName, @@ -169,6 +197,11 @@ func (g *GormEmployeeRepository) Create(ctx context.Context, data employee.Emplo } func (g *GormEmployeeRepository) Update(ctx context.Context, data employee.Employee) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + dbEmployee, dbMeta := toDBEmployee(data) if err := g.execQuery( ctx, @@ -185,6 +218,7 @@ func (g *GormEmployeeRepository) Update(ctx context.Context, data employee.Emplo dbEmployee.AvatarID, dbEmployee.UpdatedAt, dbEmployee.ID, + tenant.ID, ); err != nil { return err } @@ -203,10 +237,15 @@ func (g *GormEmployeeRepository) Update(ctx context.Context, data employee.Emplo } func (g *GormEmployeeRepository) Delete(ctx context.Context, id uint) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + if err := g.execQuery(ctx, employeeMetaDeleteQuery, id); err != nil { return err } - return g.execQuery(ctx, employeeDeleteQuery, id) + return g.execQuery(ctx, employeeDeleteQuery, id, tenant.ID) } func (g *GormEmployeeRepository) queryEmployees(ctx context.Context, query string, args ...interface{}) ([]employee.Employee, error) { @@ -225,6 +264,7 @@ func (g *GormEmployeeRepository) queryEmployees(ctx context.Context, query strin var metaRow models.EmployeeMeta if err := rows.Scan( &employeeRow.ID, + &employeeRow.TenantID, &employeeRow.FirstName, &employeeRow.LastName, &employeeRow.MiddleName, diff --git a/modules/hrm/infrastructure/persistence/hrm_mappers.go b/modules/hrm/infrastructure/persistence/hrm_mappers.go index fa2c5739..def0fcec 100644 --- a/modules/hrm/infrastructure/persistence/hrm_mappers.go +++ b/modules/hrm/infrastructure/persistence/hrm_mappers.go @@ -1,6 +1,7 @@ package persistence import ( + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/currency" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/country" "github.com/iota-uz/iota-sdk/modules/core/domain/value_objects/internet" @@ -33,8 +34,13 @@ func toDomainEmployee(dbEmployee *models.Employee, dbMeta *models.EmployeeMeta) if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbEmployee.TenantID) + if err != nil { + return nil, err + } return employee.NewWithID( dbEmployee.ID, + tenantID, dbEmployee.FirstName, dbEmployee.LastName, dbEmployee.MiddleName.String, @@ -57,6 +63,7 @@ func toDBEmployee(entity employee.Employee) (*models.Employee, *models.EmployeeM salary := entity.Salary() dbEmployee := &models.Employee{ ID: entity.ID(), + TenantID: entity.TenantID().String(), FirstName: entity.FirstName(), LastName: entity.LastName(), MiddleName: mapping.ValueToSQLNullString(entity.MiddleName()), @@ -84,6 +91,7 @@ func toDBEmployee(entity employee.Employee) (*models.Employee, *models.EmployeeM func toDomainPosition(dbPosition *models.Position) (*position.Position, error) { return &position.Position{ ID: dbPosition.ID, + TenantID: dbPosition.TenantID, Name: dbPosition.Name, Description: dbPosition.Description.String, CreatedAt: dbPosition.CreatedAt, @@ -94,6 +102,7 @@ func toDomainPosition(dbPosition *models.Position) (*position.Position, error) { func toDBPosition(position *position.Position) *models.Position { return &models.Position{ ID: position.ID, + TenantID: position.TenantID, Name: position.Name, Description: mapping.ValueToSQLNullString(position.Description), CreatedAt: position.CreatedAt, diff --git a/modules/hrm/infrastructure/persistence/models/models.go b/modules/hrm/infrastructure/persistence/models/models.go index efeb145a..14203c71 100644 --- a/modules/hrm/infrastructure/persistence/models/models.go +++ b/modules/hrm/infrastructure/persistence/models/models.go @@ -7,6 +7,7 @@ import ( type Position struct { ID uint + TenantID string Name string Description sql.NullString CreatedAt time.Time @@ -15,6 +16,7 @@ type Position struct { type Employee struct { ID uint + TenantID string FirstName string LastName string MiddleName sql.NullString diff --git a/modules/hrm/infrastructure/persistence/position_repository.go b/modules/hrm/infrastructure/persistence/position_repository.go index ba232140..d85bf0b9 100644 --- a/modules/hrm/infrastructure/persistence/position_repository.go +++ b/modules/hrm/infrastructure/persistence/position_repository.go @@ -25,17 +25,22 @@ func NewPositionRepository() position.Repository { func (g *GormPositionRepository) GetPaginated( ctx context.Context, params *position.FindParams, ) ([]*position.Position, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + pool, err := composables.UseTx(ctx) if err != nil { return nil, err } - where, args := []string{"1 = 1"}, []interface{}{} + where, args := []string{"tenant_id = $1"}, []interface{}{tenant.ID} if params.ID != 0 { where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, params.ID) } rows, err := pool.Query(ctx, ` - SELECT id, name, description, created_at, updated_at FROM positions + SELECT id, tenant_id, name, description, created_at, updated_at FROM positions WHERE `+strings.Join(where, " AND ")+` `+repo.FormatLimitOffset(params.Limit, params.Offset)+` `, args...) @@ -48,6 +53,7 @@ func (g *GormPositionRepository) GetPaginated( var p models.Position if err := rows.Scan( &p.ID, + &p.TenantID, &p.Name, &p.Description, &p.CreatedAt, @@ -70,14 +76,19 @@ func (g *GormPositionRepository) GetPaginated( } func (g *GormPositionRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + pool, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 if err := pool.QueryRow(ctx, ` - SELECT COUNT(*) as count FROM positions - `).Scan(&count); err != nil { + SELECT COUNT(*) as count FROM positions WHERE tenant_id = $1 + `, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil @@ -104,43 +115,63 @@ func (g *GormPositionRepository) GetByID(ctx context.Context, id int64) (*positi } func (g *GormPositionRepository) Create(ctx context.Context, data *position.Position) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } + dbRow := toDBPosition(data) + dbRow.TenantID = tenant.ID.String() + if err := tx.QueryRow(ctx, ` - INSERT INTO positions (name, description) VALUES ($1, $2) - `, dbRow.Name, dbRow.Description).Scan(&data.ID); err != nil { + INSERT INTO positions (tenant_id, name, description) VALUES ($1, $2, $3) RETURNING id + `, dbRow.TenantID, dbRow.Name, dbRow.Description).Scan(&data.ID); err != nil { return err } + data.TenantID = tenant.ID.String() return nil } func (g *GormPositionRepository) Update(ctx context.Context, data *position.Position) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } dbRow := toDBPosition(data) + if _, err := tx.Exec(ctx, ` - UPDATE positions + UPDATE positions SET name = $1, description = $2 - WHERE id = $3 - `, dbRow.Name, dbRow.Description, dbRow.ID); err != nil { + WHERE id = $3 AND tenant_id = $4 + `, dbRow.Name, dbRow.Description, dbRow.ID, tenant.ID); err != nil { return err } return nil } func (g *GormPositionRepository) Delete(ctx context.Context, id int64) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } if _, err := tx.Exec(ctx, ` - DELETE FROM positions WHERE id = $1 - `, id); err != nil { + DELETE FROM positions WHERE id = $1 AND tenant_id = $2 + `, id, tenant.ID); err != nil { return err } return nil diff --git a/modules/hrm/infrastructure/persistence/schema/hrm-schema.sql b/modules/hrm/infrastructure/persistence/schema/hrm-schema.sql index cc79b982..dc4cb6f0 100644 --- a/modules/hrm/infrastructure/persistence/schema/hrm-schema.sql +++ b/modules/hrm/infrastructure/persistence/schema/hrm-schema.sql @@ -1,17 +1,20 @@ CREATE TABLE positions ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, name varchar(255) NOT NULL, description text, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, name) ); CREATE TABLE employees ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, first_name varchar(255) NOT NULL, last_name varchar(255) NOT NULL, middle_name varchar(255), - email varchar(255) NOT NULL UNIQUE, + email varchar(255) NOT NULL, phone varchar(255), salary numeric(9, 2) NOT NULL, salary_currency_id varchar(3) REFERENCES currencies (code) ON DELETE SET NULL, @@ -19,7 +22,9 @@ CREATE TABLE employees ( coefficient float NOT NULL, avatar_id int REFERENCES uploads (id) ON DELETE SET NULL, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, email), + UNIQUE (tenant_id, phone) ); CREATE TABLE employee_positions ( @@ -49,3 +54,15 @@ CREATE TABLE employee_contacts ( updated_at timestamp with time zone DEFAULT now() ); +CREATE INDEX positions_tenant_id_idx ON positions (tenant_id); + +CREATE INDEX employees_tenant_id_idx ON employees (tenant_id); + +CREATE INDEX employees_first_name_idx ON employees (first_name); + +CREATE INDEX employees_last_name_idx ON employees (last_name); + +CREATE INDEX employees_email_idx ON employees (email); + +CREATE INDEX employees_phone_idx ON employees (phone); + diff --git a/modules/logging/infrastructure/persistence/schema/logging-schema.sql b/modules/logging/infrastructure/persistence/schema/logging-schema.sql index a7e24916..f58107c8 100644 --- a/modules/logging/infrastructure/persistence/schema/logging-schema.sql +++ b/modules/logging/infrastructure/persistence/schema/logging-schema.sql @@ -1,5 +1,6 @@ CREATE TABLE authentication_logs ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, user_id integer NOT NULL CONSTRAINT fk_user_id REFERENCES users (id) ON DELETE CASCADE, ip varchar(255) NOT NULL, user_agent varchar(255) NOT NULL, @@ -8,6 +9,7 @@ CREATE TABLE authentication_logs ( CREATE TABLE action_logs ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, method varchar(255) NOT NULL, path varchar(255) NOT NULL, user_id int REFERENCES users (id) ON DELETE SET NULL, @@ -18,8 +20,12 @@ CREATE TABLE action_logs ( created_at timestamp with time zone DEFAULT now() ); +CREATE INDEX action_logs_tenant_id_idx ON action_logs (tenant_id); + CREATE INDEX action_log_user_id_idx ON action_logs (user_id); +CREATE INDEX authentication_logs_tenant_id_idx ON authentication_logs (tenant_id); + CREATE INDEX authentication_logs_user_id_idx ON authentication_logs (user_id); CREATE INDEX authentication_logs_created_at_idx ON authentication_logs (created_at); diff --git a/modules/warehouse/domain/aggregates/order/order.go b/modules/warehouse/domain/aggregates/order/order.go index 7f604486..49ebd395 100644 --- a/modules/warehouse/domain/aggregates/order/order.go +++ b/modules/warehouse/domain/aggregates/order/order.go @@ -1,19 +1,23 @@ package order import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/product" - "time" ) type Order interface { ID() uint + TenantID() uuid.UUID Type() Type Status() Status Items() []Item CreatedAt() time.Time SetID(id uint) + SetTenantID(id uuid.UUID) AddItem(position *position.Position, products ...*product.Product) error Complete() error diff --git a/modules/warehouse/domain/aggregates/order/order_impl.go b/modules/warehouse/domain/aggregates/order/order_impl.go index 342ec352..032ed3fe 100644 --- a/modules/warehouse/domain/aggregates/order/order_impl.go +++ b/modules/warehouse/domain/aggregates/order/order_impl.go @@ -1,15 +1,18 @@ package order import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/product" - "time" ) func New(orderType Type, status Status) Order { return &orderImpl{ _type: orderType, status: status, + tenantID: uuid.Nil, // Will be set in repository items: make([]Item, 0), createdAt: time.Now(), } @@ -18,6 +21,7 @@ func New(orderType Type, status Status) Order { func NewWithID(id uint, orderType Type, status Status, createdAt time.Time) Order { return &orderImpl{ id: id, + tenantID: uuid.Nil, // Will be set in repository _type: orderType, status: status, items: make([]Item, 0), @@ -27,6 +31,7 @@ func NewWithID(id uint, orderType Type, status Status, createdAt time.Time) Orde type orderImpl struct { id uint + tenantID uuid.UUID _type Type status Status items []Item @@ -37,10 +42,18 @@ func (o *orderImpl) SetID(id uint) { o.id = id } +func (o *orderImpl) SetTenantID(id uuid.UUID) { + o.tenantID = id +} + func (o *orderImpl) ID() uint { return o.id } +func (o *orderImpl) TenantID() uuid.UUID { + return o.tenantID +} + func (o *orderImpl) Type() Type { return o._type } diff --git a/modules/warehouse/domain/aggregates/position/position.go b/modules/warehouse/domain/aggregates/position/position.go index 0b4b7a00..d3646978 100644 --- a/modules/warehouse/domain/aggregates/position/position.go +++ b/modules/warehouse/domain/aggregates/position/position.go @@ -1,13 +1,16 @@ package position import ( + "time" + + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/upload" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/entities/unit" - "time" ) type Position struct { ID uint + TenantID uuid.UUID Title string Barcode string UnitID uint diff --git a/modules/warehouse/domain/aggregates/product/product.go b/modules/warehouse/domain/aggregates/product/product.go index b43c0e1d..b1526638 100644 --- a/modules/warehouse/domain/aggregates/product/product.go +++ b/modules/warehouse/domain/aggregates/product/product.go @@ -1,8 +1,10 @@ package product import ( - "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" "time" + + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" ) func New(rfid string, positionID uint, status Status, position *position.Position) *Product { @@ -18,6 +20,7 @@ func New(rfid string, positionID uint, status Status, position *position.Positio type Product struct { ID uint + TenantID uuid.UUID PositionID uint Rfid string Status Status diff --git a/modules/warehouse/domain/entities/inventory/inventory.go b/modules/warehouse/domain/entities/inventory/inventory.go index 9294df80..a88f8eca 100644 --- a/modules/warehouse/domain/entities/inventory/inventory.go +++ b/modules/warehouse/domain/entities/inventory/inventory.go @@ -1,14 +1,17 @@ package inventory import ( - "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "time" + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" + "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" ) type Check struct { ID uint + TenantID uuid.UUID Status Status Name string Results []*CheckResult @@ -23,6 +26,7 @@ type Check struct { func (c *Check) AddResult(positionID uint, expected, actual int) { c.Results = append(c.Results, &CheckResult{ PositionID: positionID, + TenantID: c.TenantID, ExpectedQuantity: expected, ActualQuantity: actual, Difference: expected - actual, @@ -39,6 +43,7 @@ type Position struct { type CheckResult struct { ID uint + TenantID uuid.UUID PositionID uint Position *position.Position ExpectedQuantity int diff --git a/modules/warehouse/domain/entities/unit/unit.go b/modules/warehouse/domain/entities/unit/unit.go index bfd38ee5..c73980dd 100644 --- a/modules/warehouse/domain/entities/unit/unit.go +++ b/modules/warehouse/domain/entities/unit/unit.go @@ -2,10 +2,13 @@ package unit import ( "time" + + "github.com/google/uuid" ) type Unit struct { ID uint + TenantID uuid.UUID Title string ShortTitle string CreatedAt time.Time diff --git a/modules/warehouse/infrastructure/persistence/inventory_repository.go b/modules/warehouse/infrastructure/persistence/inventory_repository.go index 5c38df04..9cdb50ac 100644 --- a/modules/warehouse/infrastructure/persistence/inventory_repository.go +++ b/modules/warehouse/infrastructure/persistence/inventory_repository.go @@ -5,9 +5,10 @@ import ( "database/sql" "errors" "fmt" - "github.com/iota-uz/iota-sdk/pkg/repo" "strings" + "github.com/iota-uz/iota-sdk/pkg/repo" + "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/entities/inventory" @@ -40,7 +41,13 @@ func (g *GormInventoryRepository) GetPaginated( if err != nil { return nil, err } - where, args := []string{"1 = 1"}, []interface{}{} + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"ic.tenant_id = $1"}, []interface{}{tenant.ID} if params.ID != 0 { where, args = append(where, fmt.Sprintf("ic.id = $%d", len(args)+1)), append(args, params.ID) } @@ -58,7 +65,7 @@ func (g *GormInventoryRepository) GetPaginated( } rows, err := pool.Query(ctx, ` - SELECT ic.id, status, name, ic.created_at, ic.finished_at, ic.created_by_id, ic.finished_by_id + SELECT ic.id, ic.tenant_id, status, name, ic.created_at, ic.finished_at, ic.created_by_id, ic.finished_by_id FROM inventory_checks ic WHERE `+strings.Join(where, " AND ")+` ORDER BY id DESC @@ -76,6 +83,7 @@ func (g *GormInventoryRepository) GetPaginated( var finishedByID sql.NullInt32 if err := rows.Scan( &check.ID, + &check.TenantID, &check.Status, &check.Name, &check.CreatedAt, @@ -128,13 +136,21 @@ func (g *GormInventoryRepository) Positions(ctx context.Context) ([]*inventory.P if err != nil { return nil, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + var entities []*models.InventoryPosition sql := ` SELECT warehouse_positions.id, warehouse_positions.title, COUNT(warehouse_products.id) quantity, array_agg(warehouse_products.rfid) rfid_tags - FROM warehouse_positions JOIN warehouse_products ON warehouse_positions.id = warehouse_products.position_id + FROM warehouse_positions + JOIN warehouse_products ON warehouse_positions.id = warehouse_products.position_id + WHERE warehouse_positions.tenant_id = $1 GROUP BY warehouse_positions.id; ` - rows, err := tx.Query(ctx, sql) + rows, err := tx.Query(ctx, sql, tenant.ID) if err != nil { return nil, err } @@ -160,10 +176,16 @@ func (g *GormInventoryRepository) Count(ctx context.Context) (uint, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + var count uint if err := pool.QueryRow(ctx, ` - SELECT COUNT(*) as count FROM inventory_checks - `).Scan(&count); err != nil { + SELECT COUNT(*) as count FROM inventory_checks WHERE tenant_id = $1 + `, tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil @@ -214,22 +236,31 @@ func (g *GormInventoryRepository) Create(ctx context.Context, data *inventory.Ch if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + // Set tenant ID in domain entity + data.TenantID = tenant.ID + dbRow, err := mappers.ToDBInventoryCheck(data) if err != nil { return err } if err := tx.QueryRow(ctx, ` - INSERT INTO inventory_checks (status, name, created_by_id) - VALUES ($1, $2, $3) RETURNING id - `, dbRow.Status, dbRow.Name, dbRow.CreatedByID).Scan(&data.ID); err != nil { + INSERT INTO inventory_checks (tenant_id, status, name, created_by_id) + VALUES ($1, $2, $3, $4) RETURNING id + `, dbRow.TenantID, dbRow.Status, dbRow.Name, dbRow.CreatedByID).Scan(&data.ID); err != nil { return err } if results := dbRow.Results; results != nil { for _, result := range results { if _, err := tx.Exec(ctx, ` - INSERT INTO inventory_check_results (inventory_check_id, position_id, expected_quantity, actual_quantity, difference) VALUES ($1, $2, $3, $4, $5) - `, data.ID, result.PositionID, result.ExpectedQuantity, result.ActualQuantity, result.Difference); err != nil { + INSERT INTO inventory_check_results (tenant_id, inventory_check_id, position_id, expected_quantity, actual_quantity, difference) VALUES ($1, $2, $3, $4, $5, $6) + `, tenant.ID, data.ID, result.PositionID, result.ExpectedQuantity, result.ActualQuantity, result.Difference); err != nil { return err } } @@ -242,14 +273,23 @@ func (g *GormInventoryRepository) Update(ctx context.Context, data *inventory.Ch if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + // Set tenant ID in domain entity + data.TenantID = tenant.ID + dbRow, err := mappers.ToDBInventoryCheck(data) if err != nil { return err } if _, err := tx.Exec(ctx, ` UPDATE inventory_checks ic SET name = COALESCE(NULLIF($1, ''), ic.name) - WHERE ic.id = $2 - `, dbRow.Name, dbRow.ID); err != nil { + WHERE ic.id = $2 AND ic.tenant_id = $3 + `, dbRow.Name, dbRow.ID, dbRow.TenantID); err != nil { return err } return nil @@ -260,7 +300,13 @@ func (g *GormInventoryRepository) Delete(ctx context.Context, id uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, `DELETE FROM inventory_checks WHERE id = $1`, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + if _, err := tx.Exec(ctx, `DELETE FROM inventory_checks WHERE id = $1 AND tenant_id = $2`, id, tenant.ID); err != nil { return err } return nil @@ -280,7 +326,13 @@ func (g *GormInventoryRepository) getCheckResults( if err != nil { return nil, err } - where, args := []string{"1 = 1"}, []interface{}{} + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"icr.tenant_id = $1"}, []interface{}{tenant.ID} if params.id != 0 { where, args = append(where, fmt.Sprintf("ic.id = $%d", len(args)+1)), append(args, params.id) } @@ -294,7 +346,7 @@ func (g *GormInventoryRepository) getCheckResults( } rows, err := pool.Query(ctx, ` - SELECT id, inventory_check_id, position_id, expected_quantity, actual_quantity, difference, created_at + SELECT id, tenant_id, inventory_check_id, position_id, expected_quantity, actual_quantity, difference, created_at FROM inventory_check_results icr WHERE `+strings.Join(where, " AND ")+` ORDER BY id DESC`, args...) @@ -307,6 +359,7 @@ func (g *GormInventoryRepository) getCheckResults( var result models.InventoryCheckResult if err := rows.Scan( &result.ID, + &result.TenantID, &result.InventoryCheckID, &result.PositionID, &result.ExpectedQuantity, diff --git a/modules/warehouse/infrastructure/persistence/mappers/order_mappers.go b/modules/warehouse/infrastructure/persistence/mappers/order_mappers.go index ca4527bd..b79bf58c 100644 --- a/modules/warehouse/infrastructure/persistence/mappers/order_mappers.go +++ b/modules/warehouse/infrastructure/persistence/mappers/order_mappers.go @@ -1,6 +1,7 @@ package mappers import ( + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/order" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/models" ) @@ -19,6 +20,7 @@ func ToDBOrder(entity order.Order) (*models.WarehouseOrder, []*models.WarehouseP dbOrder := &models.WarehouseOrder{ ID: entity.ID(), + TenantID: entity.TenantID().String(), Status: string(entity.Status()), Type: string(entity.Type()), CreatedAt: entity.CreatedAt(), @@ -35,5 +37,11 @@ func ToDomainOrder(dbOrder *models.WarehouseOrder) (order.Order, error) { if err != nil { return nil, err } - return order.NewWithID(dbOrder.ID, orderType, status, dbOrder.CreatedAt), nil + tenantID, err := uuid.Parse(dbOrder.TenantID) + if err != nil { + return nil, err + } + orderEntity := order.NewWithID(dbOrder.ID, orderType, status, dbOrder.CreatedAt) + orderEntity.SetTenantID(tenantID) + return orderEntity, nil } diff --git a/modules/warehouse/infrastructure/persistence/mappers/warehouse_mappers.go b/modules/warehouse/infrastructure/persistence/mappers/warehouse_mappers.go index 93d2f516..7f2f1c72 100644 --- a/modules/warehouse/infrastructure/persistence/mappers/warehouse_mappers.go +++ b/modules/warehouse/infrastructure/persistence/mappers/warehouse_mappers.go @@ -15,6 +15,7 @@ import ( func ToDBUnit(unit *unit.Unit) *models.WarehouseUnit { return &models.WarehouseUnit{ ID: unit.ID, + TenantID: unit.TenantID.String(), Title: unit.Title, ShortTitle: unit.ShortTitle, CreatedAt: unit.CreatedAt, @@ -22,19 +23,25 @@ func ToDBUnit(unit *unit.Unit) *models.WarehouseUnit { } } -func ToDomainUnit(dbUnit *models.WarehouseUnit) *unit.Unit { +func ToDomainUnit(dbUnit *models.WarehouseUnit) (*unit.Unit, error) { + tenantID, err := uuid.Parse(dbUnit.TenantID) + if err != nil { + return nil, err + } return &unit.Unit{ ID: dbUnit.ID, + TenantID: tenantID, Title: dbUnit.Title, ShortTitle: dbUnit.ShortTitle, CreatedAt: dbUnit.CreatedAt, UpdatedAt: dbUnit.UpdatedAt, - } + }, nil } func ToDBProduct(entity *product.Product) (*models.WarehouseProduct, error) { return &models.WarehouseProduct{ ID: entity.ID, + TenantID: entity.TenantID.String(), PositionID: entity.PositionID, Rfid: mapping.ValueToSQLNullString(entity.Rfid), Status: string(entity.Status), @@ -56,8 +63,13 @@ func ToDomainProduct( if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbProduct.TenantID) + if err != nil { + return nil, err + } return &product.Product{ ID: dbProduct.ID, + TenantID: tenantID, PositionID: dbProduct.PositionID, Rfid: dbProduct.Rfid.String, Position: pos, @@ -69,16 +81,29 @@ func ToDomainProduct( func ToDomainPosition(dbPosition *models.WarehousePosition, dbUnit *models.WarehouseUnit) (*position.Position, error) { // TODO: decouple - images := make([]upload.Upload, len(dbPosition.Images)) - for i, img := range dbPosition.Images { - images[i] = persistence.ToDomainUpload(&img) + images := make([]upload.Upload, 0, len(dbPosition.Images)) + for _, img := range dbPosition.Images { + domainUpload, err := persistence.ToDomainUpload(&img) + if err != nil { + return nil, err + } + images = append(images, domainUpload) + } + unit, err := ToDomainUnit(dbUnit) + if err != nil { + return nil, err + } + tenantID, err := uuid.Parse(dbPosition.TenantID) + if err != nil { + return nil, err } return &position.Position{ ID: dbPosition.ID, + TenantID: tenantID, Title: dbPosition.Title, Barcode: dbPosition.Barcode, UnitID: uint(dbPosition.UnitID.Int32), - Unit: ToDomainUnit(dbUnit), + Unit: unit, Images: images, CreatedAt: dbPosition.CreatedAt, UpdatedAt: dbPosition.UpdatedAt, @@ -97,6 +122,7 @@ func ToDBPosition(entity *position.Position) (*models.WarehousePosition, []*mode } dbPosition := &models.WarehousePosition{ ID: entity.ID, + TenantID: entity.TenantID.String(), Title: entity.Title, Barcode: entity.Barcode, UnitID: mapping.ValueToSQLNullInt32(int32(entity.UnitID)), @@ -124,8 +150,13 @@ func ToDomainInventoryCheck(dbInventoryCheck *models.InventoryCheck) (*inventory if err != nil { return nil, err } + tenantID, err := uuid.Parse(dbInventoryCheck.TenantID) + if err != nil { + return nil, err + } check := &inventory.Check{ ID: dbInventoryCheck.ID, + TenantID: tenantID, Status: status, Name: dbInventoryCheck.Name, Results: results, @@ -152,6 +183,7 @@ func ToDomainInventoryCheck(dbInventoryCheck *models.InventoryCheck) (*inventory func ToDBInventoryCheckResult(result *inventory.CheckResult) (*models.InventoryCheckResult, error) { return &models.InventoryCheckResult{ ID: result.ID, + TenantID: result.TenantID.String(), PositionID: result.PositionID, ExpectedQuantity: result.ExpectedQuantity, ActualQuantity: result.ActualQuantity, @@ -165,8 +197,13 @@ func ToDomainInventoryCheckResult(result *models.InventoryCheckResult) (*invento // if err != nil { // return nil, err // } + tenantID, err := uuid.Parse(result.TenantID) + if err != nil { + return nil, err + } return &inventory.CheckResult{ ID: result.ID, + TenantID: tenantID, PositionID: result.PositionID, // Position: pos, ExpectedQuantity: result.ExpectedQuantity, @@ -183,6 +220,7 @@ func ToDBInventoryCheck(check *inventory.Check) (*models.InventoryCheck, error) } return &models.InventoryCheck{ ID: check.ID, + TenantID: check.TenantID.String(), Status: string(check.Status), Name: check.Name, Results: results, diff --git a/modules/warehouse/infrastructure/persistence/models/models.go b/modules/warehouse/infrastructure/persistence/models/models.go index 0c5a9d5f..13e3bcae 100644 --- a/modules/warehouse/infrastructure/persistence/models/models.go +++ b/modules/warehouse/infrastructure/persistence/models/models.go @@ -2,13 +2,16 @@ package models import ( "database/sql" + "time" + coremodels "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence/models" + "github.com/lib/pq" - "time" ) type WarehouseUnit struct { ID uint + TenantID string Title string ShortTitle string CreatedAt time.Time @@ -17,6 +20,7 @@ type WarehouseUnit struct { type InventoryCheck struct { ID uint + TenantID string Status string Name string Results []*InventoryCheckResult `gorm:"foreignKey:InventoryCheckID"` @@ -37,6 +41,7 @@ type InventoryPosition struct { type InventoryCheckResult struct { ID uint + TenantID string InventoryCheckID uint PositionID uint Position *WarehousePosition @@ -48,6 +53,7 @@ type InventoryCheckResult struct { type WarehouseOrder struct { ID uint + TenantID string Type string Status string CreatedAt time.Time @@ -60,6 +66,7 @@ type WarehouseOrderItem struct { type WarehousePosition struct { ID uint + TenantID string Title string Barcode string UnitID sql.NullInt32 @@ -70,6 +77,7 @@ type WarehousePosition struct { type WarehouseProduct struct { ID uint + TenantID string PositionID uint Rfid sql.NullString Status string diff --git a/modules/warehouse/infrastructure/persistence/order_repository.go b/modules/warehouse/infrastructure/persistence/order_repository.go index 12d0a60e..61bb9d11 100644 --- a/modules/warehouse/infrastructure/persistence/order_repository.go +++ b/modules/warehouse/infrastructure/persistence/order_repository.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/order" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/product" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/mappers" @@ -18,53 +19,56 @@ var ( const ( orderFindQuery = ` - SELECT id, type, status, created_at + SELECT id, tenant_id, type, status, created_at FROM warehouse_orders wo` orderCountQuery = ` - SELECT COUNT(*) as count + SELECT COUNT(*) as count FROM warehouse_orders` orderInsertQuery = ` - INSERT INTO warehouse_orders (type, status, created_at) - VALUES ($1, $2, $3) + INSERT INTO warehouse_orders (tenant_id, type, status, created_at) + VALUES ($1, $2, $3, $4) RETURNING id` orderItemInsertQuery = ` - INSERT INTO warehouse_order_items (warehouse_order_id, warehouse_product_id) - VALUES ($1, $2) + INSERT INTO warehouse_order_items (warehouse_order_id, warehouse_product_id) + VALUES ($1, $2) ON CONFLICT DO NOTHING` orderUpdateQuery = ` - UPDATE warehouse_orders wo - SET + UPDATE warehouse_orders wo + SET type = COALESCE(NULLIF($1, ''), wo.type), status = COALESCE(NULLIF($2, ''), wo.status) - WHERE wo.id = $3` + WHERE wo.id = $3 AND wo.tenant_id = $4` orderItemsDeleteQuery = ` - DELETE FROM warehouse_order_items + DELETE FROM warehouse_order_items WHERE warehouse_order_id = $1` orderDeleteQuery = ` - DELETE FROM warehouse_orders - WHERE id = $1` + DELETE FROM warehouse_orders + WHERE id = $1 AND tenant_id = $2` selectOrderProductsQuery = ` - SELECT - wp.id, + SELECT + wp.id, + wp.tenant_id, wp.position_id, wp.rfid, wp.status, - wp.created_at, + wp.created_at, wp.updated_at, p.id, + p.tenant_id, p.title, p.barcode, p.unit_id, p.created_at, p.updated_at, wu.id, + wu.tenant_id, wu.title, wu.short_title, wu.created_at, @@ -74,14 +78,14 @@ const ( LEFT JOIN warehouse_units wu ON wu.id = p.unit_id` insertOrderProductsQuery = ` - INSERT INTO warehouse_products (position_id, rfid, status, created_at) - VALUES ($1, $2, $3, $4) + INSERT INTO warehouse_products (tenant_id, position_id, rfid, status, created_at) + VALUES ($1, $2, $3, $4, $5) RETURNING id` updateOrderProductsQuery = ` - UPDATE warehouse_products + UPDATE warehouse_products SET position_id = $1, rfid = $2, status = $3 - WHERE id = $4` + WHERE id = $4 AND tenant_id = $5` ) type GormOrderRepository struct { @@ -95,7 +99,12 @@ func NewOrderRepository(productRepo product.Repository) order.Repository { } func (g *GormOrderRepository) GetPaginated(ctx context.Context, params *order.FindParams) ([]order.Order, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"wo.tenant_id = $1"}, []interface{}{tenant.ID} if params.CreatedAt.To != "" && params.CreatedAt.From != "" { where, args = append(where, fmt.Sprintf("wo.created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)), append(args, params.CreatedAt.From, params.CreatedAt.To) } @@ -119,23 +128,36 @@ func (g *GormOrderRepository) GetPaginated(ctx context.Context, params *order.Fi } func (g *GormOrderRepository) Count(ctx context.Context) (int64, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return 0, err } var count int64 - if err := tx.QueryRow(ctx, orderCountQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, orderCountQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormOrderRepository) GetAll(ctx context.Context) ([]order.Order, error) { - return g.queryOrders(ctx, orderFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + return g.queryOrders(ctx, orderFindQuery+" WHERE wo.tenant_id = $1", tenant.ID) } func (g *GormOrderRepository) GetByID(ctx context.Context, id uint) (order.Order, error) { - orders, err := g.queryOrders(ctx, orderFindQuery+" WHERE wo.id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + orders, err := g.queryOrders(ctx, orderFindQuery+" WHERE wo.id = $1 AND wo.tenant_id = $2", id, tenant.ID) if err != nil { return nil, err } @@ -146,18 +168,31 @@ func (g *GormOrderRepository) GetByID(ctx context.Context, id uint) (order.Order } func (g *GormOrderRepository) Create(ctx context.Context, data order.Order) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } + + // Set tenant ID in domain entity + data.SetTenantID(tenant.ID) + dbOrder, dbProducts, err := mappers.ToDBOrder(data) if err != nil { return err } + // Make sure tenant ID is set in DB model + dbOrder.TenantID = tenant.ID.String() + if err := tx.QueryRow( ctx, orderInsertQuery, + dbOrder.TenantID, dbOrder.Type, dbOrder.Status, dbOrder.CreatedAt, @@ -166,9 +201,13 @@ func (g *GormOrderRepository) Create(ctx context.Context, data order.Order) erro } for _, p := range dbProducts { + // Set tenant ID in product + p.TenantID = tenant.ID.String() + if err := tx.QueryRow( ctx, insertOrderProductsQuery, + p.TenantID, p.PositionID, p.Rfid, p.Status, @@ -193,21 +232,34 @@ func (g *GormOrderRepository) Create(ctx context.Context, data order.Order) erro } func (g *GormOrderRepository) Update(ctx context.Context, data order.Order) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } + + // Set tenant ID in domain entity + data.SetTenantID(tenant.ID) + dbOrder, dbProducts, err := mappers.ToDBOrder(data) if err != nil { return err } + // Make sure tenant ID is set in DB model + dbOrder.TenantID = tenant.ID.String() + if _, err := tx.Exec( ctx, orderUpdateQuery, dbOrder.Type, dbOrder.Status, dbOrder.ID, + dbOrder.TenantID, ); err != nil { return err } @@ -228,6 +280,9 @@ func (g *GormOrderRepository) Update(ctx context.Context, data order.Order) erro } for _, product := range dbProducts { + // Set tenant ID + product.TenantID = tenant.ID.String() + if _, err := tx.Exec( ctx, updateOrderProductsQuery, @@ -235,6 +290,7 @@ func (g *GormOrderRepository) Update(ctx context.Context, data order.Order) erro product.Rfid, product.Status, product.ID, + product.TenantID, ); err != nil { return err } @@ -244,11 +300,16 @@ func (g *GormOrderRepository) Update(ctx context.Context, data order.Order) erro } func (g *GormOrderRepository) Delete(ctx context.Context, id uint) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + tx, err := composables.UseTx(ctx) if err != nil { return err } - if _, err := tx.Exec(ctx, orderDeleteQuery, id); err != nil { + if _, err := tx.Exec(ctx, orderDeleteQuery, id, tenant.ID); err != nil { return err } return nil @@ -275,18 +336,21 @@ func (g *GormOrderRepository) queryProducts(ctx context.Context, query string, a if err := rows.Scan( &wp.ID, + &wp.TenantID, &wp.PositionID, &wp.Rfid, &wp.Status, &wp.CreatedAt, &wp.UpdatedAt, &pos.ID, + &pos.TenantID, &pos.Title, &pos.Barcode, &pos.UnitID, &pos.CreatedAt, &pos.UpdatedAt, &wu.ID, + &wu.TenantID, &wu.Title, &wu.ShortTitle, &wu.CreatedAt, @@ -326,6 +390,7 @@ func (g *GormOrderRepository) queryOrders(ctx context.Context, query string, arg var o models.WarehouseOrder if err := rows.Scan( &o.ID, + &o.TenantID, &o.Type, &o.Status, &o.CreatedAt, diff --git a/modules/warehouse/infrastructure/persistence/position_repository.go b/modules/warehouse/infrastructure/persistence/position_repository.go index 6f09f22b..dfe3c9cd 100644 --- a/modules/warehouse/infrastructure/persistence/position_repository.go +++ b/modules/warehouse/infrastructure/persistence/position_repository.go @@ -21,25 +21,27 @@ var ( const ( selectPositionQuery = ` - SELECT + SELECT wp.id, wp.title, wp.barcode, wp.unit_id, wp.created_at, wp.updated_at, + wp.tenant_id, wu.id, wu.title, wu.short_title, wu.created_at, - wu.updated_at + wu.updated_at, + wu.tenant_id FROM warehouse_positions wp JOIN warehouse_units wu ON wp.unit_id = wu.id` selectPositionIdQuery = `SELECT id FROM warehouse_positions` countPositionQuery = `SELECT COUNT(*) FROM warehouse_positions` - insertPositionQuery = `INSERT INTO warehouse_positions (title, barcode, unit_id, created_at) VALUES ($1, $2, $3, $4) RETURNING id` + insertPositionQuery = `INSERT INTO warehouse_positions (title, barcode, unit_id, created_at, tenant_id) VALUES ($1, $2, $3, $4, $5) RETURNING id` insertPositionImageQuery = `INSERT INTO warehouse_position_images (warehouse_position_id, upload_id) VALUES` - updatePositionQuery = `UPDATE warehouse_positions SET title = $1, barcode = $2, unit_id = $3 WHERE id = $4` - deletePositionQuery = `DELETE FROM warehouse_positions WHERE id = $1` + updatePositionQuery = `UPDATE warehouse_positions SET title = $1, barcode = $2, unit_id = $3 WHERE id = $4 AND tenant_id = $5` + deletePositionQuery = `DELETE FROM warehouse_positions WHERE id = $1 AND tenant_id = $2` deletePositionImagesQuery = `DELETE FROM warehouse_position_images WHERE warehouse_position_id = $1` ) @@ -74,11 +76,13 @@ func (g *GormPositionRepository) queryPositions(ctx context.Context, query strin &p.UnitID, &p.CreatedAt, &p.UpdatedAt, + &p.TenantID, &u.ID, &u.Title, &u.ShortTitle, &u.CreatedAt, &u.UpdatedAt, + &u.TenantID, ); err != nil { return nil, err } @@ -97,7 +101,12 @@ func (g *GormPositionRepository) queryPositions(ctx context.Context, query strin func (g *GormPositionRepository) GetPaginated( ctx context.Context, params *position.FindParams, ) ([]*position.Position, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + where, args := []string{"wp.tenant_id = $1"}, []interface{}{tenant.ID} if params.CreatedAt.To != "" && params.CreatedAt.From != "" { where, args = append(where, fmt.Sprintf("wp.created_at BETWEEN $%d and $%d", len(args)+1, len(args)+2)), append(args, params.CreatedAt.From, params.CreatedAt.To) @@ -135,15 +144,26 @@ func (g *GormPositionRepository) Count(ctx context.Context) (int64, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, errors.Wrap(err, "failed to get tenant from context") + } + var count int64 - if err := tx.QueryRow(ctx, countPositionQuery).Scan(&count); err != nil { + if err := tx.QueryRow(ctx, countPositionQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormPositionRepository) GetAll(ctx context.Context) ([]*position.Position, error) { - positions, err := g.queryPositions(ctx, selectPositionQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + positions, err := g.queryPositions(ctx, selectPositionQuery+" WHERE wp.tenant_id = $1", tenant.ID) if err != nil { return nil, errors.Wrap(err, "failed to get all positions") } @@ -155,7 +175,13 @@ func (g *GormPositionRepository) GetAllPositionIds(ctx context.Context) ([]uint, if err != nil { return make([]uint, 0), err } - rows, err := pool.Query(ctx, selectPositionIdQuery) + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + rows, err := pool.Query(ctx, selectPositionIdQuery+" WHERE tenant_id = $1", tenant.ID) if err != nil { return nil, err } @@ -173,7 +199,12 @@ func (g *GormPositionRepository) GetAllPositionIds(ctx context.Context) ([]uint, } func (g *GormPositionRepository) GetByID(ctx context.Context, id uint) (*position.Position, error) { - positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.id = $1"), id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.id = $1 AND wp.tenant_id = $2"), id, tenant.ID) if err != nil { return nil, err } @@ -184,7 +215,12 @@ func (g *GormPositionRepository) GetByID(ctx context.Context, id uint) (*positio } func (g *GormPositionRepository) GetByIDs(ctx context.Context, ids []uint) ([]*position.Position, error) { - positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.id = ANY($1)"), ids) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.id = ANY($1) AND wp.tenant_id = $2"), ids, tenant.ID) if err != nil { return nil, err } @@ -192,7 +228,12 @@ func (g *GormPositionRepository) GetByIDs(ctx context.Context, ids []uint) ([]*p } func (g *GormPositionRepository) GetByBarcode(ctx context.Context, barcode string) (*position.Position, error) { - positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.barcode = $1"), barcode) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get tenant from context") + } + + positions, err := g.queryPositions(ctx, repo.Join(selectPositionQuery, "WHERE wp.barcode = $1 AND wp.tenant_id = $2"), barcode, tenant.ID) if err != nil { return nil, err } @@ -224,7 +265,16 @@ func (g *GormPositionRepository) Create(ctx context.Context, data *position.Posi if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + positionRow, junctionRows := mappers.ToDBPosition(data) + positionRow.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + if err := tx.QueryRow( ctx, insertPositionQuery, @@ -232,6 +282,7 @@ func (g *GormPositionRepository) Create(ctx context.Context, data *position.Posi positionRow.Barcode, positionRow.UnitID, positionRow.CreatedAt, + positionRow.TenantID, ).Scan(&data.ID); err != nil { return err } @@ -254,7 +305,16 @@ func (g *GormPositionRepository) Update(ctx context.Context, data *position.Posi if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + positionRow, junctionRows := mappers.ToDBPosition(data) + positionRow.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + if _, err := tx.Exec( ctx, updatePositionQuery, @@ -262,6 +322,7 @@ func (g *GormPositionRepository) Update(ctx context.Context, data *position.Posi positionRow.Barcode, positionRow.UnitID, positionRow.ID, + positionRow.TenantID, ); err != nil { return err } @@ -287,7 +348,13 @@ func (g *GormPositionRepository) Delete(ctx context.Context, id uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, deletePositionQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return errors.Wrap(err, "failed to get tenant from context") + } + + if _, err := tx.Exec(ctx, deletePositionQuery, id, tenant.ID); err != nil { return err } return nil diff --git a/modules/warehouse/infrastructure/persistence/position_repository_test.go b/modules/warehouse/infrastructure/persistence/position_repository_test.go index 45a78ae3..df26caa7 100644 --- a/modules/warehouse/infrastructure/persistence/position_repository_test.go +++ b/modules/warehouse/infrastructure/persistence/position_repository_test.go @@ -1,14 +1,16 @@ package persistence_test import ( + "testing" + "time" + "github.com/gabriel-vasile/mimetype" "github.com/go-faster/errors" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules/core/domain/entities/upload" core "github.com/iota-uz/iota-sdk/modules/core/infrastructure/persistence" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence" "github.com/iota-uz/utils/random" - "testing" - "time" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/position" "github.com/iota-uz/iota-sdk/modules/warehouse/domain/entities/unit" @@ -40,6 +42,7 @@ func BenchmarkGormPositionRepository_Create(b *testing.B) { f.ctx, upload.NewWithID( 0, + uuid.Nil, // tenant_id will be set correctly in repository random.String(32, random.LowerCharSet), "image.png", "image.png", @@ -100,6 +103,7 @@ func TestGormPositionRepository_CRUD(t *testing.T) { f.ctx, upload.NewWithID( 1, + uuid.Nil, // tenant_id will be set correctly in repository "hash", "url", "image.png", diff --git a/modules/warehouse/infrastructure/persistence/product_repository.go b/modules/warehouse/infrastructure/persistence/product_repository.go index a813eb8f..7836e3ca 100644 --- a/modules/warehouse/infrastructure/persistence/product_repository.go +++ b/modules/warehouse/infrastructure/persistence/product_repository.go @@ -4,9 +4,10 @@ import ( "context" "errors" "fmt" - "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/mappers" "strings" + "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/mappers" + "github.com/iota-uz/iota-sdk/modules/warehouse/domain/aggregates/product" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/models" "github.com/iota-uz/iota-sdk/pkg/composables" @@ -18,20 +19,23 @@ var ( const ( productFindQuery = ` - SELECT - wp.id, + SELECT + wp.id, + wp.tenant_id, wp.position_id, wp.rfid, wp.status, - wp.created_at, + wp.created_at, wp.updated_at, p.id, + p.tenant_id, p.title, p.barcode, p.unit_id, p.created_at, p.updated_at, wu.id, + wu.tenant_id, wu.title, wu.short_title, wu.created_at, @@ -44,27 +48,27 @@ const ( SELECT COUNT(DISTINCT wp.id) FROM warehouse_products wp` productInsertQuery = ` - INSERT INTO warehouse_products (position_id, rfid, status, created_at) - VALUES ($1, $2, $3, $4) + INSERT INTO warehouse_products (tenant_id, position_id, rfid, status, created_at) + VALUES ($1, $2, $3, $4, $5) RETURNING id` productUpdateQuery = ` - UPDATE warehouse_products + UPDATE warehouse_products SET position_id = $1, rfid = $2, status = $3 - WHERE id = $4` + WHERE id = $4 AND tenant_id = $5` productUpdateStatusQuery = ` - UPDATE warehouse_products + UPDATE warehouse_products SET status = $1 - WHERE id = ANY($2)` + WHERE id = ANY($2) AND tenant_id = $3` productDeleteQuery = ` - DELETE FROM warehouse_products - WHERE id = $1` + DELETE FROM warehouse_products + WHERE id = $1 AND tenant_id = $2` productBulkDeleteQuery = ` - DELETE FROM warehouse_products - WHERE id = ANY($1)` + DELETE FROM warehouse_products + WHERE id = ANY($1) AND tenant_id = $2` ) type GormProductRepository struct { @@ -75,7 +79,12 @@ func NewProductRepository() product.Repository { } func (g *GormProductRepository) GetPaginated(ctx context.Context, params *product.FindParams) ([]*product.Product, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"wp.tenant_id = $1"}, []interface{}{tenant.ID} if params.OrderID != 0 { where = append(where, fmt.Sprintf( @@ -131,7 +140,12 @@ func (g *GormProductRepository) GetPaginated(ctx context.Context, params *produc } func (g *GormProductRepository) Count(ctx context.Context, opts *product.CountParams) (int64, error) { - where, args := []string{"1 = 1"}, []interface{}{} + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get tenant from context: %w", err) + } + + where, args := []string{"tenant_id = $1"}, []interface{}{tenant.ID} if opts.PositionID != 0 { where = append(where, fmt.Sprintf("position_id = $%d", len(args)+1)) @@ -166,11 +180,20 @@ func (g *GormProductRepository) FindByPositionID(ctx context.Context, opts *prod } func (g *GormProductRepository) GetAll(ctx context.Context) ([]*product.Product, error) { - return g.queryProducts(ctx, productFindQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + return g.queryProducts(ctx, productFindQuery+" WHERE wp.tenant_id = $1", tenant.ID) } func (g *GormProductRepository) GetByID(ctx context.Context, id uint) (*product.Product, error) { - products, err := g.queryProducts(ctx, productFindQuery+" WHERE wp.id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + products, err := g.queryProducts(ctx, productFindQuery+" WHERE wp.id = $1 AND wp.tenant_id = $2", id, tenant.ID) if err != nil { return nil, err } @@ -181,7 +204,12 @@ func (g *GormProductRepository) GetByID(ctx context.Context, id uint) (*product. } func (g *GormProductRepository) GetByRfid(ctx context.Context, rfid string) (*product.Product, error) { - products, err := g.queryProducts(ctx, productFindQuery+" WHERE wp.rfid = $1", rfid) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tenant from context: %w", err) + } + + products, err := g.queryProducts(ctx, productFindQuery+" WHERE wp.rfid = $1 AND wp.tenant_id = $2", rfid, tenant.ID) if err != nil { return nil, err } @@ -203,14 +231,23 @@ func (g *GormProductRepository) Create(ctx context.Context, data *product.Produc return err } + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + dbProduct, err := mappers.ToDBProduct(data) if err != nil { return err } + dbProduct.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + if err := tx.QueryRow( ctx, productInsertQuery, + dbProduct.TenantID, dbProduct.PositionID, dbProduct.Rfid, dbProduct.Status, @@ -242,31 +279,55 @@ func (g *GormProductRepository) CreateOrUpdate(ctx context.Context, data *produc } func (g *GormProductRepository) Update(ctx context.Context, data *product.Product) error { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + dbProduct, err := mappers.ToDBProduct(data) if err != nil { return err } + + dbProduct.TenantID = tenant.ID.String() + data.TenantID = tenant.ID + return g.execQuery( ctx, productUpdateQuery, dbProduct.PositionID, dbProduct.Rfid, dbProduct.Status, - dbProduct.UpdatedAt, dbProduct.ID, + dbProduct.TenantID, ) } func (g *GormProductRepository) UpdateStatus(ctx context.Context, ids []uint, status product.Status) error { - return g.execQuery(ctx, productUpdateStatusQuery, status, ids) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.execQuery(ctx, productUpdateStatusQuery, status, ids, tenant.ID) } func (g *GormProductRepository) Delete(ctx context.Context, id uint) error { - return g.execQuery(ctx, productDeleteQuery, id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.execQuery(ctx, productDeleteQuery, id, tenant.ID) } func (g *GormProductRepository) BulkDelete(ctx context.Context, ids []uint) error { - return g.execQuery(ctx, productBulkDeleteQuery, ids) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant from context: %w", err) + } + + return g.execQuery(ctx, productBulkDeleteQuery, ids, tenant.ID) } func (g *GormProductRepository) queryProducts(ctx context.Context, query string, args ...interface{}) ([]*product.Product, error) { @@ -290,18 +351,21 @@ func (g *GormProductRepository) queryProducts(ctx context.Context, query string, if err := rows.Scan( &wp.ID, + &wp.TenantID, &wp.PositionID, &wp.Rfid, &wp.Status, &wp.CreatedAt, &wp.UpdatedAt, &pos.ID, + &pos.TenantID, &pos.Title, &pos.Barcode, &pos.UnitID, &pos.CreatedAt, &pos.UpdatedAt, &wu.ID, + &wu.TenantID, &wu.Title, &wu.ShortTitle, &wu.CreatedAt, diff --git a/modules/warehouse/infrastructure/persistence/schema/warehouse-schema.sql b/modules/warehouse/infrastructure/persistence/schema/warehouse-schema.sql index eaead993..938c1de5 100644 --- a/modules/warehouse/infrastructure/persistence/schema/warehouse-schema.sql +++ b/modules/warehouse/infrastructure/persistence/schema/warehouse-schema.sql @@ -1,19 +1,24 @@ CREATE TABLE warehouse_units ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, title varchar(255) NOT NULL, -- Kilogram, Piece, etc. short_title varchar(255) NOT NULL, -- kg, pcs, etc. created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, title), + UNIQUE (tenant_id, short_title) ); CREATE TABLE warehouse_positions ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, title varchar(255) NOT NULL, - barcode varchar(255) NOT NULL UNIQUE, + barcode varchar(255) NOT NULL, description text, unit_id int REFERENCES warehouse_units (id) ON DELETE SET NULL, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, barcode) ); CREATE TABLE warehouse_position_images ( @@ -24,15 +29,18 @@ CREATE TABLE warehouse_position_images ( CREATE TABLE warehouse_products ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, position_id int NOT NULL REFERENCES warehouse_positions (id) ON DELETE CASCADE, - rfid varchar(255) NULL UNIQUE, + rfid varchar(255) NULL, status varchar(255) NOT NULL, created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() + updated_at timestamp with time zone DEFAULT now(), + UNIQUE (tenant_id, rfid) ); CREATE TABLE warehouse_orders ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, type VARCHAR(255) NOT NULL, status varchar(255) NOT NULL, created_at timestamp with time zone DEFAULT now() @@ -46,6 +54,7 @@ CREATE TABLE warehouse_order_items ( CREATE TABLE inventory_checks ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, status varchar(255) NOT NULL, name varchar(255) NOT NULL, type VARCHAR(255) NOT NULL, @@ -57,6 +66,7 @@ CREATE TABLE inventory_checks ( CREATE TABLE inventory_check_results ( id serial PRIMARY KEY, + tenant_id uuid NOT NULL REFERENCES tenants (id) ON DELETE CASCADE, inventory_check_id int NOT NULL REFERENCES inventory_checks (id) ON DELETE CASCADE, position_id int NOT NULL REFERENCES warehouse_positions (id) ON DELETE CASCADE, expected_quantity int NOT NULL, @@ -65,3 +75,15 @@ CREATE TABLE inventory_check_results ( created_at timestamp with time zone DEFAULT now() ); +CREATE INDEX warehouse_units_tenant_id_idx ON warehouse_units (tenant_id); + +CREATE INDEX warehouse_positions_tenant_id_idx ON warehouse_positions (tenant_id); + +CREATE INDEX warehouse_products_tenant_id_idx ON warehouse_products (tenant_id); + +CREATE INDEX warehouse_orders_tenant_id_idx ON warehouse_orders (tenant_id); + +CREATE INDEX inventory_checks_tenant_id_idx ON inventory_checks (tenant_id); + +CREATE INDEX inventory_check_results_tenant_id_idx ON inventory_check_results (tenant_id); + diff --git a/modules/warehouse/infrastructure/persistence/setup_test.go b/modules/warehouse/infrastructure/persistence/setup_test.go index 5ce7dd09..b8f191c9 100644 --- a/modules/warehouse/infrastructure/persistence/setup_test.go +++ b/modules/warehouse/infrastructure/persistence/setup_test.go @@ -2,10 +2,11 @@ package persistence_test import ( "context" - "github.com/iota-uz/utils/random" "os" "testing" + "github.com/iota-uz/utils/random" + "github.com/jackc/pgx/v5/pgxpool" "github.com/iota-uz/iota-sdk/modules" @@ -52,11 +53,24 @@ func setupBenchmark(b *testing.B) *testFixtures { ctx = composables.WithTx(ctx, tx) ctx = composables.WithSession(ctx, &session.Session{}) + // Setup application and run migrations app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) if err != nil { b.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + b.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + b.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, @@ -87,11 +101,24 @@ func setupTest(t *testing.T) *testFixtures { ctx = composables.WithTx(ctx, tx) ctx = composables.WithSession(ctx, &session.Session{}) + // Setup application and run migrations app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) if err != nil { t.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, diff --git a/modules/warehouse/infrastructure/persistence/unit_repository.go b/modules/warehouse/infrastructure/persistence/unit_repository.go index d783328a..b8b334c3 100644 --- a/modules/warehouse/infrastructure/persistence/unit_repository.go +++ b/modules/warehouse/infrastructure/persistence/unit_repository.go @@ -3,6 +3,7 @@ package persistence import ( "context" "errors" + "github.com/iota-uz/iota-sdk/modules/warehouse/domain/entities/unit" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/mappers" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence/models" @@ -15,11 +16,11 @@ var ( ) const ( - selectUnitsQuery = `SELECT id, title, short_title, created_at, updated_at FROM warehouse_units` + selectUnitsQuery = `SELECT id, title, short_title, created_at, updated_at, tenant_id FROM warehouse_units` countUnitsQuery = `SELECT COUNT(*) FROM warehouse_units` - insertUnitQuery = `INSERT INTO warehouse_units (title, short_title, created_at) VALUES ($1, $2, $3) RETURNING id` - updateUnitQuery = `UPDATE warehouse_units SET title = $1, short_title = $2, updated_at = $3 WHERE id = $4` - deleteUnitQuery = `DELETE FROM warehouse_units WHERE id = $1` + insertUnitQuery = `INSERT INTO warehouse_units (title, short_title, created_at, tenant_id) VALUES ($1, $2, $3, $4) RETURNING id` + updateUnitQuery = `UPDATE warehouse_units SET title = $1, short_title = $2, updated_at = $3 WHERE id = $4 AND tenant_id = $5` + deleteUnitQuery = `DELETE FROM warehouse_units WHERE id = $1 AND tenant_id = $2` ) type GormUnitRepository struct{} @@ -29,12 +30,19 @@ func NewUnitRepository() unit.Repository { } func (g *GormUnitRepository) GetPaginated(ctx context.Context, params *unit.FindParams) ([]*unit.Unit, error) { + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + return g.queryUnits( ctx, repo.Join( selectUnitsQuery, + "WHERE tenant_id = $1", repo.FormatLimitOffset(params.Limit, params.Offset), ), + tenant.ID, ) } @@ -43,15 +51,26 @@ func (g *GormUnitRepository) Count(ctx context.Context) (uint, error) { if err != nil { return 0, err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return 0, err + } + var count uint - if err := pool.QueryRow(ctx, countUnitsQuery).Scan(&count); err != nil { + if err := pool.QueryRow(ctx, countUnitsQuery+" WHERE tenant_id = $1", tenant.ID).Scan(&count); err != nil { return 0, err } return count, nil } func (g *GormUnitRepository) GetAll(ctx context.Context) ([]*unit.Unit, error) { - units, err := g.queryUnits(ctx, selectUnitsQuery) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + units, err := g.queryUnits(ctx, selectUnitsQuery+" WHERE tenant_id = $1", tenant.ID) if err != nil { return nil, err } @@ -60,7 +79,12 @@ func (g *GormUnitRepository) GetAll(ctx context.Context) ([]*unit.Unit, error) { } func (g *GormUnitRepository) GetByID(ctx context.Context, id uint) (*unit.Unit, error) { - units, err := g.queryUnits(ctx, selectUnitsQuery+" WHERE id = $1", id) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + units, err := g.queryUnits(ctx, selectUnitsQuery+" WHERE id = $1 AND tenant_id = $2", id, tenant.ID) if err != nil { return nil, err } @@ -73,7 +97,12 @@ func (g *GormUnitRepository) GetByID(ctx context.Context, id uint) (*unit.Unit, } func (g *GormUnitRepository) GetByTitleOrShortTitle(ctx context.Context, name string) (*unit.Unit, error) { - units, err := g.queryUnits(ctx, selectUnitsQuery+" WHERE title = $1 OR short_title = $1", name) + tenant, err := composables.UseTenant(ctx) + if err != nil { + return nil, err + } + + units, err := g.queryUnits(ctx, selectUnitsQuery+" WHERE (title = $1 OR short_title = $1) AND tenant_id = $2", name, tenant.ID) if err != nil { return nil, err } @@ -89,13 +118,22 @@ func (g *GormUnitRepository) Create(ctx context.Context, data *unit.Unit) error if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + dbRow := mappers.ToDBUnit(data) + dbRow.TenantID = tenant.ID.String() + if err := tx.QueryRow( ctx, insertUnitQuery, dbRow.Title, dbRow.ShortTitle, dbRow.CreatedAt, + dbRow.TenantID, ).Scan(&data.ID); err != nil { return err } @@ -124,7 +162,15 @@ func (g *GormUnitRepository) Update(ctx context.Context, data *unit.Unit) error if err != nil { return err } + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + dbRow := mappers.ToDBUnit(data) + dbRow.TenantID = tenant.ID.String() + if _, err := tx.Exec( ctx, updateUnitQuery, @@ -132,6 +178,7 @@ func (g *GormUnitRepository) Update(ctx context.Context, data *unit.Unit) error dbRow.ShortTitle, dbRow.UpdatedAt, dbRow.ID, + dbRow.TenantID, ); err != nil { return err } @@ -143,7 +190,13 @@ func (g *GormUnitRepository) Delete(ctx context.Context, id uint) error { if err != nil { return err } - if _, err := tx.Exec(ctx, deleteUnitQuery, id); err != nil { + + tenant, err := composables.UseTenant(ctx) + if err != nil { + return err + } + + if _, err := tx.Exec(ctx, deleteUnitQuery, id, tenant.ID); err != nil { return err } return nil @@ -169,11 +222,15 @@ func (g *GormUnitRepository) queryUnits(ctx context.Context, query string, args &u.ShortTitle, &u.CreatedAt, &u.UpdatedAt, + &u.TenantID, ); err != nil { return nil, err } - domainUnit := mappers.ToDomainUnit(&u) + domainUnit, err := mappers.ToDomainUnit(&u) + if err != nil { + return nil, err + } units = append(units, domainUnit) } diff --git a/modules/warehouse/services/orderservice/setup_test.go b/modules/warehouse/services/orderservice/setup_test.go index b933ad42..f427718b 100644 --- a/modules/warehouse/services/orderservice/setup_test.go +++ b/modules/warehouse/services/orderservice/setup_test.go @@ -60,11 +60,24 @@ func setupTest(t *testing.T) *testFixtures { ctx = composables.WithTx(ctx, tx) ctx = composables.WithSession(ctx, &session.Session{}) + // Setup application and run migrations app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) if err != nil { t.Fatal(err) } + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + return &testFixtures{ ctx: ctx, pool: pool, diff --git a/modules/warehouse/services/positionservice/position_service_test.go b/modules/warehouse/services/positionservice/position_service_test.go index 49c8316d..5fe5f2b4 100644 --- a/modules/warehouse/services/positionservice/position_service_test.go +++ b/modules/warehouse/services/positionservice/position_service_test.go @@ -1,115 +1,12 @@ package positionservice_test import ( - "context" - "log" - "os" "testing" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/xuri/excelize/v2" - - "github.com/iota-uz/iota-sdk/modules" - "github.com/iota-uz/iota-sdk/modules/core/domain/entities/session" "github.com/iota-uz/iota-sdk/modules/warehouse/infrastructure/persistence" - "github.com/iota-uz/iota-sdk/modules/warehouse/permissions" "github.com/iota-uz/iota-sdk/modules/warehouse/services/positionservice" - "github.com/iota-uz/iota-sdk/pkg/application" - "github.com/iota-uz/iota-sdk/pkg/composables" - "github.com/iota-uz/iota-sdk/pkg/testutils" -) - -var ( - TestFilePath = "test.xlsx" - Data = []map[string]interface{}{ - {"A1": "Наименование", "B1": "Код в справочнике", "C1": "Ед. изм.", "D1": "Количество"}, - {"A2": "Дрель Молоток N.C.V (900W)", "B2": "3241324132", "C2": "шт", "D2": 10}, - {"A3": "Дрель Молоток N.C.V (900W)", "B3": "9230891234", "C3": "шт", "D3": 10}, - {"A4": "Дрель Молоток N.C.V (900W)", "B4": "3242198021", "C4": "шт", "D4": 3}, - } - TotalProducts = 23 ) -func TestMain(m *testing.M) { - if err := os.Chdir("../../../../"); err != nil { - panic(err) - } - if err := createTestFile(TestFilePath); err != nil { - panic(err) - } - code := m.Run() - if err := os.Remove(TestFilePath); err != nil { - log.Println(err) - } - os.Exit(code) -} - -// testFixtures contains common test dependencies -type testFixtures struct { - ctx context.Context - pool *pgxpool.Pool - app application.Application -} - -// setupTest creates all necessary dependencies for tests -func setupTest(t *testing.T) *testFixtures { - t.Helper() - - testutils.CreateDB(t.Name()) - pool := testutils.NewPool(testutils.DbOpts(t.Name())) - - ctx := composables.WithUser(context.Background(), testutils.MockUser( - permissions.PositionCreate, - permissions.PositionRead, - permissions.ProductCreate, - permissions.ProductRead, - permissions.UnitCreate, - permissions.UnitRead, - )) - tx, err := pool.Begin(ctx) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := tx.Commit(ctx); err != nil { - t.Fatal(err) - } - pool.Close() - }) - - ctx = composables.WithTx(ctx, tx) - ctx = composables.WithSession(ctx, &session.Session{}) - - app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) - if err != nil { - t.Fatal(err) - } - - return &testFixtures{ - ctx: ctx, - pool: pool, - app: app, - } -} - -func createTestFile(path string) error { - f := excelize.NewFile() - defer func() { - if err := f.Close(); err != nil { - log.Println(err) - } - }() - for _, v := range Data { - for k, val := range v { - if err := f.SetCellValue("Sheet1", k, val); err != nil { - return err - } - } - } - return f.SaveAs(path) -} - func TestPositionService_LoadFromFilePath(t *testing.T) { t.Parallel() f := setupTest(t) @@ -128,31 +25,45 @@ func TestPositionService_LoadFromFilePath(t *testing.T) { if err != nil { t.Error(err) } - if len(positions) != len(Data)-1 { - t.Fatalf("expected %d, got %d", len(Data)-1, len(positions)) - } - if positions[0].Title != Data[1]["A2"] { - t.Errorf("expected %s, got %s", Data[1]["A2"], positions[0].Title) + if len(positions) != 3 { + t.Errorf("expected 3 position, got %d", len(positions)) } - if positions[0].Barcode != Data[1]["B2"] { - t.Errorf("expected %s, got %s", Data[1]["B2"], positions[0].Barcode) + found := false + for _, pos := range positions { + if pos.Title == "Дрель Молоток N.C.V (900W)" { + found = true + break + } + } + if !found { + t.Errorf("position with title 'Дрель Молоток N.C.V (900W)' not found") } units, err := unitRepo.GetAll(f.ctx) if err != nil { t.Error(err) } + if len(units) != 1 { - t.Errorf("expected %d, got %d", 1, len(units)) + t.Errorf("expected 1 unit, got %d", len(units)) + } + + if units[0].Title != "шт" { + t.Errorf("expected title %s, got %s", "шт", units[0].Title) + } + + if units[0].ShortTitle != "шт" { + t.Errorf("expected short title %s, got %s", "шт", units[0].ShortTitle) } products, err := productRepo.GetAll(f.ctx) if err != nil { t.Error(err) } + if len(products) != TotalProducts { - t.Errorf("expected %d, got %d", TotalProducts, len(products)) + t.Errorf("expected %d products, got %d", TotalProducts, len(products)) } } diff --git a/modules/warehouse/services/positionservice/setup_test.go b/modules/warehouse/services/positionservice/setup_test.go new file mode 100644 index 00000000..e37180dc --- /dev/null +++ b/modules/warehouse/services/positionservice/setup_test.go @@ -0,0 +1,127 @@ +package positionservice_test + +import ( + "context" + "log" + "os" + "testing" + + "github.com/iota-uz/iota-sdk/modules" + "github.com/iota-uz/iota-sdk/modules/core/domain/entities/session" + "github.com/iota-uz/iota-sdk/modules/warehouse/permissions" + "github.com/iota-uz/iota-sdk/pkg/application" + "github.com/iota-uz/iota-sdk/pkg/composables" + "github.com/iota-uz/iota-sdk/pkg/testutils" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/xuri/excelize/v2" +) + +var ( + TestFilePath = "test.xlsx" + Data = []map[string]interface{}{ + {"A1": "Наименование", "B1": "Код в справочнике", "C1": "Ед. изм.", "D1": "Количество"}, + {"A2": "Дрель Молоток N.C.V (900W)", "B2": "3241324132", "C2": "шт", "D2": 10}, + {"A3": "Дрель Молоток N.C.V (900W)", "B3": "9230891234", "C3": "шт", "D3": 10}, + {"A4": "Дрель Молоток N.C.V (900W)", "B4": "3242198021", "C4": "шт", "D4": 3}, + } + TotalProducts = 23 +) + +func TestMain(m *testing.M) { + if err := os.Chdir("../../../../"); err != nil { + panic(err) + } + + // Create the test file for position service tests + if err := createTestFile(TestFilePath); err != nil { + panic(err) + } + + code := m.Run() + + // Clean up the test file + if err := os.Remove(TestFilePath); err != nil { + log.Println("Failed to remove test file:", err) + } + + os.Exit(code) +} + +// testFixtures contains common test dependencies +type testFixtures struct { + ctx context.Context + pool *pgxpool.Pool + app application.Application +} + +// setupTest creates all necessary dependencies for tests +func setupTest(t *testing.T) *testFixtures { + t.Helper() + + testutils.CreateDB(t.Name()) + pool := testutils.NewPool(testutils.DbOpts(t.Name())) + + ctx := composables.WithUser(context.Background(), testutils.MockUser( + permissions.PositionCreate, + permissions.PositionRead, + permissions.ProductCreate, + permissions.ProductRead, + permissions.UnitCreate, + permissions.UnitRead, + )) + tx, err := pool.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tx.Commit(ctx); err != nil { + t.Fatal(err) + } + pool.Close() + }) + + ctx = composables.WithTx(ctx, tx) + ctx = composables.WithSession(ctx, &session.Session{}) + + // Setup application and run migrations + app, err := testutils.SetupApplication(pool, modules.BuiltInModules...) + if err != nil { + t.Fatal(err) + } + + // Run migrations first to create all tables including tenants + if err := app.Migrations().Run(); err != nil { + t.Fatal(err) + } + + // Create a test tenant and add it to the context + tenant, err := testutils.CreateTestTenant(ctx, pool) + if err != nil { + t.Fatal(err) + } + ctx = composables.WithTenant(ctx, tenant) + + return &testFixtures{ + ctx: ctx, + pool: pool, + app: app, + } +} + +func createTestFile(path string) error { + f := excelize.NewFile() + defer func() { + if err := f.Close(); err != nil { + log.Println(err) + } + }() + for _, v := range Data { + for k, val := range v { + if err := f.SetCellValue("Sheet1", k, val); err != nil { + return err + } + } + } + return f.SaveAs(path) +} diff --git a/pkg/composables/tenant.go b/pkg/composables/tenant.go new file mode 100644 index 00000000..d153ad05 --- /dev/null +++ b/pkg/composables/tenant.go @@ -0,0 +1,39 @@ +package composables + +import ( + "context" + "errors" + + "github.com/google/uuid" + "github.com/iota-uz/iota-sdk/pkg/constants" +) + +var ( + ErrNoTenantFound = errors.New("no tenant found in context") +) + +type Tenant struct { + ID uuid.UUID + Name string + Domain string +} + +func UseTenant(ctx context.Context) (*Tenant, error) { + t, ok := ctx.Value(constants.TenantKey).(*Tenant) + if !ok { + return nil, ErrNoTenantFound + } + return t, nil +} + +func MustUseTenant(ctx context.Context) *Tenant { + t, err := UseTenant(ctx) + if err != nil { + panic(err) + } + return t +} + +func WithTenant(ctx context.Context, tenant *Tenant) context.Context { + return context.WithValue(ctx, constants.TenantKey, tenant) +} diff --git a/pkg/constants/middleware.go b/pkg/constants/middleware.go index 2ec44ee9..537c1776 100644 --- a/pkg/constants/middleware.go +++ b/pkg/constants/middleware.go @@ -20,6 +20,7 @@ const ( RequestStart ContextKey = "requestStart" LocalizerKey ContextKey = "localizer" PageContext ContextKey = "pageContext" + TenantKey ContextKey = "tenant" ) var Validate = validator.New(validator.WithRequiredStructEnabled()) diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 45071a38..60c35d08 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "log" "net/http" + "github.com/google/uuid" "github.com/gorilla/mux" "github.com/iota-uz/iota-sdk/modules/core/services" @@ -51,6 +53,57 @@ func Authorize() mux.MiddlewareFunc { next.ServeHTTP(w, r) return } + + // Now that we have the session, let's ensure we have the tenant in the context + // First check if we already have a tenant + _, tenantErr := composables.UseTenant(ctx) + if tenantErr != nil { + // First try to get tenant from session (already loaded from DB) + if sess.TenantID != uuid.Nil { + // Get tenant info directly + tx, txErr := composables.UseTx(ctx) + if txErr == nil { + var name string + var domain string + err := tx.QueryRow(ctx, "SELECT name, domain FROM tenants WHERE id = $1 LIMIT 1", sess.TenantID.String()).Scan(&name, &domain) + if err == nil { + // Add tenant to context from session + t := &composables.Tenant{ + ID: sess.TenantID, + Name: name, + Domain: domain, + } + ctx = context.WithValue(ctx, constants.TenantKey, t) + } + } + } else { + // Fallback: use direct database query to get the tenant ID for the user + tx, txErr := composables.UseTx(ctx) + if txErr == nil { + var tenantIDStr string + err := tx.QueryRow(ctx, "SELECT tenant_id FROM users WHERE id = $1 LIMIT 1", sess.UserID).Scan(&tenantIDStr) + if err == nil && tenantIDStr != "" { + tenantID, uuidErr := uuid.Parse(tenantIDStr) + if uuidErr == nil { + // Now query for the tenant info + var name string + var domain string + err := tx.QueryRow(ctx, "SELECT name, domain FROM tenants WHERE id = $1 LIMIT 1", tenantIDStr).Scan(&name, &domain) + if err == nil { + // Add tenant to context + t := &composables.Tenant{ + ID: tenantID, + Name: name, + Domain: domain, + } + ctx = context.WithValue(ctx, constants.TenantKey, t) + } + } + } + } + } + } + params, ok := composables.UseParams(ctx) if !ok { panic("params not found. Add RequestParams middleware up the chain") @@ -82,7 +135,24 @@ func ProvideUser() mux.MiddlewareFunc { next.ServeHTTP(w, r) return } + // Set the user in context ctx = context.WithValue(ctx, constants.UserKey, u) + + // Check if we already have a tenant in context + _, tenantErr := composables.UseTenant(ctx) + if tenantErr != nil { + // If not, get it from the user's tenant ID + tenantService := app.Service(services.TenantService{}).(*services.TenantService) + t, err := tenantService.GetByID(ctx, u.TenantID()) + if err != nil { + log.Printf("Error retrieving tenant: %v", err) + // Don't add tenant to context if we couldn't get it + } else { + // Add tenant to context + ctx = context.WithValue(ctx, constants.TenantKey, t) + } + } + next.ServeHTTP(w, r.WithContext(ctx)) }, ) diff --git a/pkg/testutils/utils.go b/pkg/testutils/utils.go index bd6d9ebb..857f07f3 100644 --- a/pkg/testutils/utils.go +++ b/pkg/testutils/utils.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/iota-uz/iota-sdk/modules" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/role" "github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user" @@ -38,6 +39,7 @@ func MockUser(permissions ...*permission.Permission) user.User { role.WithPermissions(permissions), role.WithCreatedAt(time.Now()), role.WithUpdatedAt(time.Now()), + role.WithTenantID(uuid.Nil), // tenant_id will be set correctly in repository ) email, err := internet.NewEmail("test@example.com") @@ -88,6 +90,38 @@ func DefaultParams() *composables.Params { } } +// CreateTestTenant creates a test tenant for testing +func CreateTestTenant(ctx context.Context, pool *pgxpool.Pool) (*composables.Tenant, error) { + testTenant := &composables.Tenant{ + ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), + Name: "Test Tenant", + Domain: "test.com", + } + + // Try to insert the tenant - if it fails because table doesn't exist, we'll catch it + _, err := pool.Exec(ctx, "INSERT INTO tenants (id, name, domain, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO NOTHING", + testTenant.ID, + testTenant.Name, + testTenant.Domain, + time.Now(), + time.Now(), + ) + + // If there's no error, we're done + if err == nil { + return testTenant, nil + } + + // If the error is not about missing table, return it + if !strings.Contains(err.Error(), "relation") && !strings.Contains(err.Error(), "does not exist") { + return nil, err + } + + // We can simply return the tenant object for test context + // The actual tenant will be created when migrations run + return testTenant, nil +} + func CreateDB(name string) { c := configuration.Use() db, err := sql.Open("postgres", c.Database.ConnectionString())