Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions autokeras/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import keras
import numpy as np
import pandas as pd

import autokeras as ak

Expand Down Expand Up @@ -45,15 +46,53 @@
"embark_town": "categorical",
"alone": "categorical",
}
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"

TRAIN_CSV_PATH = keras.utils.get_file(
fname=os.path.basename(TRAIN_DATA_URL), origin=TRAIN_DATA_URL
)
TEST_CSV_PATH = keras.utils.get_file(
fname=os.path.basename(TEST_DATA_URL), origin=TEST_DATA_URL
)

# Download Titanic dataset from OpenML and split into train/test
TITANIC_DATA_URL = "https://www.openml.org/data/get_csv/16826755/phpMYEkMl"

# Download the dataset handling SSL issues
import urllib.request
import ssl
_cache_dir = os.path.expanduser(os.path.join("~", ".keras", "datasets"))
os.makedirs(_cache_dir, exist_ok=True)
_titanic_data_path = os.path.join(_cache_dir, "titanic.csv")

if not os.path.exists(_titanic_data_path):
# Create unverified SSL context to handle certificate issues
ssl_context = ssl._create_unverified_context()
with urllib.request.urlopen(TITANIC_DATA_URL, context=ssl_context) as response:
with open(_titanic_data_path, 'wb') as out_file:
out_file.write(response.read())

# Load and preprocess the dataset to match expected format
_df = pd.read_csv(_titanic_data_path)

# Rename columns to match expected format
_df = _df.rename(columns={
'pclass': 'class',
'sibsp': 'n_siblings_spouses',
'cabin': 'deck',
'embarked': 'embark_town'
})

# Create 'alone' column
_df['alone'] = (_df['n_siblings_spouses'] + _df['parch'] == 0).astype(str)

# Select only the columns we need in the expected order
_columns_to_keep = ['sex', 'age', 'n_siblings_spouses', 'parch', 'fare',
'class', 'deck', 'embark_town', 'alone', 'survived']
_df = _df[_columns_to_keep]

# Split into train and test
_train_size = int(len(_df) * 0.8)
_train_df = _df.iloc[:_train_size]
_test_df = _df.iloc[_train_size:]

# Save train and test splits
TRAIN_CSV_PATH = os.path.join(_cache_dir, "titanic_train.csv")
TEST_CSV_PATH = os.path.join(_cache_dir, "titanic_test.csv")
_train_df.to_csv(TRAIN_CSV_PATH, index=False)
_test_df.to_csv(TEST_CSV_PATH, index=False)


def generate_data(num_instances=100, shape=(32, 32, 3)):
Expand Down