diff --git a/dbmigrator/tests/data/md/20170810093842_create_a_table.py b/dbmigrator/tests/data/md/20170810093842_create_a_table.py new file mode 100644 index 0000000..2152a09 --- /dev/null +++ b/dbmigrator/tests/data/md/20170810093842_create_a_table.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + + +def up(cursor): + cursor.execute('CREATE TABLE a_table (name TEXT)') + + +def down(cursor): + cursor.execute('DROP TABLE a_table') diff --git a/dbmigrator/tests/test_cli.py b/dbmigrator/tests/test_cli.py index f947d31..1fb638b 100644 --- a/dbmigrator/tests/test_cli.py +++ b/dbmigrator/tests/test_cli.py @@ -144,6 +144,29 @@ def test_wide(self): \n""", stdout) self.assertEqual('', stderr) + def test_migrations_directory_and_context(self): + testing.install_test_packages() + + cmd = ['--db-connection-string', testing.db_connection_string] + md = os.path.join(testing.test_data_path, 'md') + self.target(cmd + ['init']) + with testing.captured_output() as (out, err): + self.target(cmd + [ + '--context', 'package-a', '--context', 'package-b', + '--migrations-directory', md, 'list']) + + stdout = out.getvalue() + + # Assert package-a migrations are in + self.assertIn('20160228202637 add_table', stdout) + self.assertIn('20160228212456 cool_stuff', stdout) + + # Assert package-b migrations are in + self.assertIn('20160228210326 initial_data', stdout) + + # Assert migrations directory migrations are in + self.assertIn('20170810093842 create_a_table', stdout) + class InitTestCase(BaseTestCase): def test_multiple_contexts(self): diff --git a/dbmigrator/utils.py b/dbmigrator/utils.py index caa4ee6..b4762ba 100644 --- a/dbmigrator/utils.py +++ b/dbmigrator/utils.py @@ -58,7 +58,7 @@ def get_settings_from_entry_points(settings, contexts): context, __package__).values() for entry_point in entry_points: setting_name = entry_point.name - if settings.get(setting_name): + if not isinstance(settings.get(setting_name, []), list): # don't overwrite settings given from the CLI continue @@ -76,7 +76,12 @@ def get_settings_from_entry_points(settings, contexts): context_settings[setting_name] = value for name, value in context_settings.items(): - settings[name] = value + if isinstance(settings.get(name), list): + if not isinstance(value, list): + value = [value] + settings[name] += value + else: + settings[name] = value def get_settings_from_config(filename, config_names, settings):