|
| 1 | +import subprocess |
| 2 | +import time |
| 3 | + |
| 4 | +from datafusion import SessionContext |
| 5 | +from datafusion_table_providers import mysql |
| 6 | + |
| 7 | +def run_docker_container(): |
| 8 | + """Run the Docker container with the MySQL image""" |
| 9 | + result = subprocess.run( |
| 10 | + ["docker", "run", "--name", "mysql", "-e", "MYSQL_ROOT_PASSWORD=password", "-e", "MYSQL_DATABASE=mysql_db", |
| 11 | + "-p", "3306:3306", "-d", "mysql:9.0"], |
| 12 | + stdout=subprocess.PIPE, |
| 13 | + stderr=subprocess.PIPE |
| 14 | + ) |
| 15 | + if result.returncode != 0: |
| 16 | + print(f"Failed to start MySQL container: {result.stderr.decode()}") |
| 17 | + |
| 18 | +def create_table_and_insert_data(): |
| 19 | + """Create a table and insert data into MySQL""" |
| 20 | + sql_commands = """ |
| 21 | + CREATE TABLE companies ( |
| 22 | + id INT PRIMARY KEY, |
| 23 | + name VARCHAR(100) |
| 24 | + ); |
| 25 | + |
| 26 | + INSERT INTO companies (id, name) VALUES (1, 'Acme Corporation'); |
| 27 | + """ |
| 28 | + |
| 29 | + # Execute the SQL commands inside the Docker container |
| 30 | + result = subprocess.run( |
| 31 | + ["docker", "exec", "-i", "mysql", "mysql", "-uroot", "-ppassword", "mysql_db"], |
| 32 | + input=sql_commands.encode(), # Pass SQL commands to stdin |
| 33 | + stdout=subprocess.PIPE, |
| 34 | + stderr=subprocess.PIPE |
| 35 | + ) |
| 36 | + |
| 37 | + # Check if the SQL execution was successful |
| 38 | + if result.returncode != 0: |
| 39 | + print(f"Error executing SQL commands: {result.stderr.decode()}") |
| 40 | + else: |
| 41 | + print(f"SQL commands executed successfully:\n{result.stdout.decode()}") |
| 42 | + |
| 43 | +def stop_and_remove_container(): |
| 44 | + """Stop and remove the MySQL container after use""" |
| 45 | + subprocess.run(["docker", "stop", "mysql"]) |
| 46 | + subprocess.run(["docker", "rm", "mysql"]) |
| 47 | + print("MySQL container stopped and removed.") |
| 48 | + |
| 49 | + |
| 50 | +class TestMySQLIntegration: |
| 51 | + @classmethod |
| 52 | + def setup_class(self): |
| 53 | + run_docker_container() |
| 54 | + time.sleep(30) |
| 55 | + create_table_and_insert_data() |
| 56 | + time.sleep(10) |
| 57 | + self.ctx = SessionContext() |
| 58 | + connection_param = { |
| 59 | + "connection_string": "mysql://root:password@localhost:3306/mysql_db", |
| 60 | + "sslmode": "disabled"} |
| 61 | + self.pool = mysql.MySQLTableFactory(connection_param) |
| 62 | + |
| 63 | + @classmethod |
| 64 | + def teardown_class(self): |
| 65 | + stop_and_remove_container() |
| 66 | + |
| 67 | + def test_get_tables(self): |
| 68 | + """Test retrieving tables from the database""" |
| 69 | + tables = self.pool.tables() |
| 70 | + assert isinstance(tables, list) |
| 71 | + assert len(tables) == 1 |
| 72 | + assert tables == ["companies"] |
| 73 | + |
| 74 | + def test_query_companies(self): |
| 75 | + """Test querying companies table with SQL""" |
| 76 | + table_name = "companies" |
| 77 | + self.ctx.register_table_provider(table_name, self.pool.get_table("companies")) |
| 78 | + query = "SELECT * FROM companies" |
| 79 | + df = self.ctx.sql(query).collect() |
| 80 | + assert df is not None |
| 81 | + name_column = df[0]['name'] |
| 82 | + assert str(name_column[0]) == "Acme Corporation" |
0 commit comments