diff --git a/docs/tutorial/pytorch.qmd b/docs/tutorial/pytorch.qmd index 0cc5845..24a55a4 100644 --- a/docs/tutorial/pytorch.qmd +++ b/docs/tutorial/pytorch.qmd @@ -122,43 +122,24 @@ To get started, let's split this single dataset into two: a _training_ set and a Because the order of rows in an Ibis table is undefined, we need a unique key to split the data reproducibly. [It is permissible for airlines to use the same flight number for different routes, as long as the flights do not operate on the same day. This means that the combination of the flight number and the date of travel is always unique.](https://www.euclaim.com/blog/flight-numbers-explained#:~:text=Can%20flight%20numbers%20be%20reused,of%20travel%20is%20always%20unique.) ```{python} -flight_data_with_unique_key = flight_data.mutate( - unique_key=ibis.literal(",").join( - [flight_data.carrier, flight_data.flight.cast(str), flight_data.date.cast(str)] - ) -) -flight_data_with_unique_key -``` - -```{python} -flight_data_with_unique_key.group_by("unique_key").mutate( - count=flight_data_with_unique_key.count() -).filter(ibis._["count"] > 1) -``` - -```{python} -import random - -# Fix the random numbers by setting the seed -# This enables the analysis to be reproducible when random numbers are used -random.seed(222) - -# Put 3/4 of the data into the training set -random_key = str(random.getrandbits(256)) -data_split = flight_data_with_unique_key.mutate( - train=(flight_data_with_unique_key.unique_key + random_key).hash().abs() % 4 < 3 -) +import ibis_ml as ml # Create data frames for the two sets: -train_data = data_split[data_split.train].drop("unique_key", "train") -test_data = data_split[~data_split.train].drop("unique_key", "train") +train_data, test_data = ml.train_test_split( + flight_data, + unique_key=["carrier", "flight", "date"], + # Put 3/4 of the data into the training set + test_size=0.25, + num_buckets=4, + # Fix the random numbers by setting the seed + # This enables the analysis to be reproducible when random numbers are used + random_seed=222, +) ``` ## Create features ```{python} -import ibis_ml as ml - flights_rec = ml.Recipe( ml.ExpandDate("date", components=["dow", "month"]), ml.Drop("date"), diff --git a/docs/tutorial/scikit-learn.qmd b/docs/tutorial/scikit-learn.qmd index 060f4df..30f1958 100644 --- a/docs/tutorial/scikit-learn.qmd +++ b/docs/tutorial/scikit-learn.qmd @@ -121,43 +121,24 @@ To get started, let's split this single dataset into two: a _training_ set and a Because the order of rows in an Ibis table is undefined, we need a unique key to split the data reproducibly. [It is permissible for airlines to use the same flight number for different routes, as long as the flights do not operate on the same day. This means that the combination of the flight number and the date of travel is always unique.](https://www.euclaim.com/blog/flight-numbers-explained#:~:text=Can%20flight%20numbers%20be%20reused,of%20travel%20is%20always%20unique.) ```{python} -flight_data_with_unique_key = flight_data.mutate( - unique_key=ibis.literal(",").join( - [flight_data.carrier, flight_data.flight.cast(str), flight_data.date.cast(str)] - ) -) -flight_data_with_unique_key -``` - -```{python} -flight_data_with_unique_key.group_by("unique_key").mutate( - count=flight_data_with_unique_key.count() -).filter(ibis._["count"] > 1) -``` - -```{python} -import random - -# Fix the random numbers by setting the seed -# This enables the analysis to be reproducible when random numbers are used -random.seed(222) - -# Put 3/4 of the data into the training set -random_key = str(random.getrandbits(256)) -data_split = flight_data_with_unique_key.mutate( - train=(flight_data_with_unique_key.unique_key + random_key).hash().abs() % 4 < 3 -) +import ibis_ml as ml # Create data frames for the two sets: -train_data = data_split[data_split.train].drop("unique_key", "train") -test_data = data_split[~data_split.train].drop("unique_key", "train") +train_data, test_data = ml.train_test_split( + flight_data, + unique_key=["carrier", "flight", "date"], + # Put 3/4 of the data into the training set + test_size=0.25, + num_buckets=4, + # Fix the random numbers by setting the seed + # This enables the analysis to be reproducible when random numbers are used + random_seed=222, +) ``` ## Create features ```{python} -import ibis_ml as ml - flights_rec = ml.Recipe( ml.ExpandDate("date", components=["dow", "month"]), ml.Drop("date"), diff --git a/docs/tutorial/xgboost.qmd b/docs/tutorial/xgboost.qmd index aa2736b..4ec77ad 100644 --- a/docs/tutorial/xgboost.qmd +++ b/docs/tutorial/xgboost.qmd @@ -121,43 +121,24 @@ To get started, let's split this single dataset into two: a _training_ set and a Because the order of rows in an Ibis table is undefined, we need a unique key to split the data reproducibly. [It is permissible for airlines to use the same flight number for different routes, as long as the flights do not operate on the same day. This means that the combination of the flight number and the date of travel is always unique.](https://www.euclaim.com/blog/flight-numbers-explained#:~:text=Can%20flight%20numbers%20be%20reused,of%20travel%20is%20always%20unique.) ```{python} -flight_data_with_unique_key = flight_data.mutate( - unique_key=ibis.literal(",").join( - [flight_data.carrier, flight_data.flight.cast(str), flight_data.date.cast(str)] - ) -) -flight_data_with_unique_key -``` - -```{python} -flight_data_with_unique_key.group_by("unique_key").mutate( - count=flight_data_with_unique_key.count() -).filter(ibis._["count"] > 1) -``` - -```{python} -import random - -# Fix the random numbers by setting the seed -# This enables the analysis to be reproducible when random numbers are used -random.seed(222) - -# Put 3/4 of the data into the training set -random_key = str(random.getrandbits(256)) -data_split = flight_data_with_unique_key.mutate( - train=(flight_data_with_unique_key.unique_key + random_key).hash().abs() % 4 < 3 -) +import ibis_ml as ml # Create data frames for the two sets: -train_data = data_split[data_split.train].drop("unique_key", "train") -test_data = data_split[~data_split.train].drop("unique_key", "train") +train_data, test_data = ml.train_test_split( + flight_data, + unique_key=["carrier", "flight", "date"], + # Put 3/4 of the data into the training set + test_size=0.25, + num_buckets=4, + # Fix the random numbers by setting the seed + # This enables the analysis to be reproducible when random numbers are used + random_seed=222, +) ``` ## Create features ```{python} -import ibis_ml as ml - flights_rec = ml.Recipe( ml.ExpandDate("date", components=["dow", "month"]), ml.Drop("date"), diff --git a/examples/Preprocess your data with recipes.ipynb b/examples/Preprocess your data with recipes.ipynb index 7ac842f..87393d2 100644 --- a/examples/Preprocess your data with recipes.ipynb +++ b/examples/Preprocess your data with recipes.ipynb @@ -243,16 +243,16 @@ "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n", "│ timeint64stringstringint64int64stringdateint64timestamp(6) │\n", "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┤\n", - "│ 10:45:0067EWR ORD 120719UA 2013-02-1402013-02-14 15:00:00 │\n", - "│ 10:48:00373LGA FLL 1791076B6 2013-02-1402013-02-14 15:00:00 │\n", - "│ 10:48:00764EWR IAH 2071400UA 2013-02-1402013-02-14 15:00:00 │\n", - "│ 10:51:002044LGA MIA 1711096DL 2013-02-1402013-02-14 16:00:00 │\n", - "│ 10:51:002171LGA DCA 40214US 2013-02-1402013-02-14 16:00:00 │\n", - "│ 10:57:001275JFK SLC 2861990DL 2013-02-1402013-02-14 16:00:00 │\n", - "│ 10:57:00366LGA STL 135888WN 2013-02-1402013-02-14 16:00:00 │\n", - "│ 10:57:001550EWR SFO 3382565UA 2013-02-1402013-02-14 15:00:00 │\n", - "│ 10:58:004694EWR MKE 113725EV 2013-02-1402013-02-14 15:00:00 │\n", - "│ 10:58:001647LGA ATL 117762DL 2013-02-1402013-02-14 16:00:00 │\n", + "│ 05:57:00461LGA ATL 100762DL 2013-06-2602013-06-26 10:00:00 │\n", + "│ 05:58:004424EWR RDU 63416EV 2013-06-2602013-06-26 10:00:00 │\n", + "│ 05:58:006177EWR IAD 45212EV 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:00:00731LGA DTW 78502DL 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:01:00684EWR LAX 3162454UA 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:01:00301LGA ORD 164733AA 2013-06-2612013-06-26 10:00:00 │\n", + "│ 06:01:001837LGA MIA 1481096AA 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:01:001279LGA MEM 128963DL 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:02:001691JFK LAX 3092475UA 2013-06-2602013-06-26 10:00:00 │\n", + "│ 06:04:001447JFK CLT 75541US 2013-06-2602013-06-26 10:00:00 │\n", "│ │\n", "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┘\n", "\n" @@ -263,16 +263,16 @@ "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n", "│ \u001b[2mtime\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mdate\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mtimestamp(6)\u001b[0m │\n", "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┤\n", - "│ \u001b[35m10:45:00\u001b[0m │ \u001b[1;36m67\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mORD \u001b[0m │ \u001b[1;36m120\u001b[0m │ \u001b[1;36m719\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 15:00:00\u001b[0m │\n", - "│ \u001b[35m10:48:00\u001b[0m │ \u001b[1;36m373\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mFLL \u001b[0m │ \u001b[1;36m179\u001b[0m │ \u001b[1;36m1076\u001b[0m │ \u001b[32mB6 \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 15:00:00\u001b[0m │\n", - "│ \u001b[35m10:48:00\u001b[0m │ \u001b[1;36m764\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m207\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 15:00:00\u001b[0m │\n", - "│ \u001b[35m10:51:00\u001b[0m │ \u001b[1;36m2044\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mMIA \u001b[0m │ \u001b[1;36m171\u001b[0m │ \u001b[1;36m1096\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 16:00:00\u001b[0m │\n", - "│ \u001b[35m10:51:00\u001b[0m │ \u001b[1;36m2171\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mDCA \u001b[0m │ \u001b[1;36m40\u001b[0m │ \u001b[1;36m214\u001b[0m │ \u001b[32mUS \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 16:00:00\u001b[0m │\n", - "│ \u001b[35m10:57:00\u001b[0m │ \u001b[1;36m1275\u001b[0m │ \u001b[32mJFK \u001b[0m │ \u001b[32mSLC \u001b[0m │ \u001b[1;36m286\u001b[0m │ \u001b[1;36m1990\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 16:00:00\u001b[0m │\n", - "│ \u001b[35m10:57:00\u001b[0m │ \u001b[1;36m366\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mSTL \u001b[0m │ \u001b[1;36m135\u001b[0m │ \u001b[1;36m888\u001b[0m │ \u001b[32mWN \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 16:00:00\u001b[0m │\n", - "│ \u001b[35m10:57:00\u001b[0m │ \u001b[1;36m1550\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mSFO \u001b[0m │ \u001b[1;36m338\u001b[0m │ \u001b[1;36m2565\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 15:00:00\u001b[0m │\n", - "│ \u001b[35m10:58:00\u001b[0m │ \u001b[1;36m4694\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mMKE \u001b[0m │ \u001b[1;36m113\u001b[0m │ \u001b[1;36m725\u001b[0m │ \u001b[32mEV \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 15:00:00\u001b[0m │\n", - "│ \u001b[35m10:58:00\u001b[0m │ \u001b[1;36m1647\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mATL \u001b[0m │ \u001b[1;36m117\u001b[0m │ \u001b[1;36m762\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-02-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-14 16:00:00\u001b[0m │\n", + "│ \u001b[35m05:57:00\u001b[0m │ \u001b[1;36m461\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mATL \u001b[0m │ \u001b[1;36m100\u001b[0m │ \u001b[1;36m762\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m05:58:00\u001b[0m │ \u001b[1;36m4424\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mRDU \u001b[0m │ \u001b[1;36m63\u001b[0m │ \u001b[1;36m416\u001b[0m │ \u001b[32mEV \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m05:58:00\u001b[0m │ \u001b[1;36m6177\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAD \u001b[0m │ \u001b[1;36m45\u001b[0m │ \u001b[1;36m212\u001b[0m │ \u001b[32mEV \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:00:00\u001b[0m │ \u001b[1;36m731\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mDTW \u001b[0m │ \u001b[1;36m78\u001b[0m │ \u001b[1;36m502\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m684\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mLAX \u001b[0m │ \u001b[1;36m316\u001b[0m │ \u001b[1;36m2454\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m301\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mORD \u001b[0m │ \u001b[1;36m164\u001b[0m │ \u001b[1;36m733\u001b[0m │ \u001b[32mAA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m1\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m1837\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mMIA \u001b[0m │ \u001b[1;36m148\u001b[0m │ \u001b[1;36m1096\u001b[0m │ \u001b[32mAA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m1279\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mMEM \u001b[0m │ \u001b[1;36m128\u001b[0m │ \u001b[1;36m963\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:02:00\u001b[0m │ \u001b[1;36m1691\u001b[0m │ \u001b[32mJFK \u001b[0m │ \u001b[32mLAX \u001b[0m │ \u001b[1;36m309\u001b[0m │ \u001b[1;36m2475\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", + "│ \u001b[35m06:04:00\u001b[0m │ \u001b[1;36m1447\u001b[0m │ \u001b[32mJFK \u001b[0m │ \u001b[32mCLT \u001b[0m │ \u001b[1;36m75\u001b[0m │ \u001b[1;36m541\u001b[0m │ \u001b[32mUS \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │\n", "│ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │\n", "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┘" ] @@ -376,146 +376,29 @@ { "cell_type": "code", "execution_count": 8, - "id": "732624f4-a2af-4c6e-b29d-4fb7cb5fc99e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃ dep_time  flight  origin  dest    air_time  distance  carrier  date        arr_delay  time_hour            unique_key         ┃\n",
-       "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│ timeint64stringstringint64int64stringdateint64timestamp(6)string             │\n",
-       "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┼────────────────────┤\n",
-       "│ 05:57:00461LGA   ATL   100762DL     2013-06-2602013-06-26 10:00:00DL,461,2013-06-26  │\n",
-       "│ 05:58:004424EWR   RDU   63416EV     2013-06-2602013-06-26 10:00:00EV,4424,2013-06-26 │\n",
-       "│ 05:58:006177EWR   IAD   45212EV     2013-06-2602013-06-26 10:00:00EV,6177,2013-06-26 │\n",
-       "│ 06:00:00731LGA   DTW   78502DL     2013-06-2602013-06-26 10:00:00DL,731,2013-06-26  │\n",
-       "│ 06:01:00684EWR   LAX   3162454UA     2013-06-2602013-06-26 10:00:00UA,684,2013-06-26  │\n",
-       "│ 06:01:00301LGA   ORD   164733AA     2013-06-2612013-06-26 10:00:00AA,301,2013-06-26  │\n",
-       "│ 06:01:001837LGA   MIA   1481096AA     2013-06-2602013-06-26 10:00:00AA,1837,2013-06-26 │\n",
-       "│ 06:01:001279LGA   MEM   128963DL     2013-06-2602013-06-26 10:00:00DL,1279,2013-06-26 │\n",
-       "│ 06:02:001691JFK   LAX   3092475UA     2013-06-2602013-06-26 10:00:00UA,1691,2013-06-26 │\n",
-       "│ 06:04:001447JFK   CLT   75541US     2013-06-2602013-06-26 10:00:00US,1447,2013-06-26 │\n",
-       "│                   │\n",
-       "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┴────────────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mdep_time\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflight\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1morigin\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdest\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mair_time\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdistance\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcarrier\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdate\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1marr_delay\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtime_hour\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1munique_key\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩\n", - "│ \u001b[2mtime\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mdate\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mtimestamp(6)\u001b[0m │ \u001b[2mstring\u001b[0m │\n", - "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┼────────────────────┤\n", - "│ \u001b[35m05:57:00\u001b[0m │ \u001b[1;36m461\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mATL \u001b[0m │ \u001b[1;36m100\u001b[0m │ \u001b[1;36m762\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mDL,461,2013-06-26 \u001b[0m │\n", - "│ \u001b[35m05:58:00\u001b[0m │ \u001b[1;36m4424\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mRDU \u001b[0m │ \u001b[1;36m63\u001b[0m │ \u001b[1;36m416\u001b[0m │ \u001b[32mEV \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mEV,4424,2013-06-26\u001b[0m │\n", - "│ \u001b[35m05:58:00\u001b[0m │ \u001b[1;36m6177\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAD \u001b[0m │ \u001b[1;36m45\u001b[0m │ \u001b[1;36m212\u001b[0m │ \u001b[32mEV \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mEV,6177,2013-06-26\u001b[0m │\n", - "│ \u001b[35m06:00:00\u001b[0m │ \u001b[1;36m731\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mDTW \u001b[0m │ \u001b[1;36m78\u001b[0m │ \u001b[1;36m502\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mDL,731,2013-06-26 \u001b[0m │\n", - "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m684\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mLAX \u001b[0m │ \u001b[1;36m316\u001b[0m │ \u001b[1;36m2454\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mUA,684,2013-06-26 \u001b[0m │\n", - "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m301\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mORD \u001b[0m │ \u001b[1;36m164\u001b[0m │ \u001b[1;36m733\u001b[0m │ \u001b[32mAA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m1\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mAA,301,2013-06-26 \u001b[0m │\n", - "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m1837\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mMIA \u001b[0m │ \u001b[1;36m148\u001b[0m │ \u001b[1;36m1096\u001b[0m │ \u001b[32mAA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mAA,1837,2013-06-26\u001b[0m │\n", - "│ \u001b[35m06:01:00\u001b[0m │ \u001b[1;36m1279\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mMEM \u001b[0m │ \u001b[1;36m128\u001b[0m │ \u001b[1;36m963\u001b[0m │ \u001b[32mDL \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mDL,1279,2013-06-26\u001b[0m │\n", - "│ \u001b[35m06:02:00\u001b[0m │ \u001b[1;36m1691\u001b[0m │ \u001b[32mJFK \u001b[0m │ \u001b[32mLAX \u001b[0m │ \u001b[1;36m309\u001b[0m │ \u001b[1;36m2475\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mUA,1691,2013-06-26\u001b[0m │\n", - "│ \u001b[35m06:04:00\u001b[0m │ \u001b[1;36m1447\u001b[0m │ \u001b[32mJFK \u001b[0m │ \u001b[32mCLT \u001b[0m │ \u001b[1;36m75\u001b[0m │ \u001b[1;36m541\u001b[0m │ \u001b[32mUS \u001b[0m │ \u001b[35m2013-06-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-06-26 10:00:00\u001b[0m │ \u001b[32mUS,1447,2013-06-26\u001b[0m │\n", - "│ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │\n", - "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┴────────────────────┘" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flight_data_with_unique_key = flight_data.mutate(\n", - " unique_key=ibis.literal(\",\").join(\n", - " [flight_data.carrier, flight_data.flight.cast(str), flight_data.date.cast(str)]\n", - " )\n", - ")\n", - "flight_data_with_unique_key" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c9cd58ce-dc2d-4e4e-8b4a-51100fe1182c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓\n",
-       "┃ dep_time  flight  origin  dest    air_time  distance  carrier  date        arr_delay  time_hour            unique_key          cnt   ┃\n",
-       "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩\n",
-       "│ timeint64stringstringint64int64stringdateint64timestamp(6)stringint64 │\n",
-       "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┼────────────────────┼───────┤\n",
-       "│ 19:59:001022EWR   IAH   1671400UA     2013-09-1402013-09-14 23:00:00UA,1022,2013-09-142 │\n",
-       "│ 20:00:001022EWR   IAH   1861400UA     2013-09-1402013-09-14 00:00:00UA,1022,2013-09-142 │\n",
-       "│ 19:12:001023LGA   ORD   112733UA     2013-05-2902013-05-29 23:00:00UA,1023,2013-05-292 │\n",
-       "│ 21:16:001023EWR   IAH   1751400UA     2013-05-2902013-05-29 01:00:00UA,1023,2013-05-292 │\n",
-       "│ 15:18:001052EWR   IAH   1741400UA     2013-08-2702013-08-27 19:00:00UA,1052,2013-08-272 │\n",
-       "│ 21:22:001052EWR   IAH   1731400UA     2013-08-2702013-08-27 01:00:00UA,1052,2013-08-272 │\n",
-       "│ 18:39:001053EWR   CLE   72404UA     2013-12-2002013-12-20 23:00:00UA,1053,2013-12-202 │\n",
-       "│ 19:27:001053EWR   CLE   69404UA     2013-12-2002013-12-20 00:00:00UA,1053,2013-12-202 │\n",
-       "│ 20:16:001071EWR   BQN   1961585UA     2013-02-2602013-02-26 01:00:00UA,1071,2013-02-262 │\n",
-       "│ 17:20:001071EWR   PHX   2812133UA     2013-02-2602013-02-26 22:00:00UA,1071,2013-02-262 │\n",
-       "│  │\n",
-       "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┴────────────────────┴───────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mdep_time\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflight\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1morigin\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdest\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mair_time\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdistance\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcarrier\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mdate\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1marr_delay\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtime_hour\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1munique_key\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcnt\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩\n", - "│ \u001b[2mtime\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mdate\u001b[0m │ \u001b[2mint64\u001b[0m │ \u001b[2mtimestamp(6)\u001b[0m │ \u001b[2mstring\u001b[0m │ \u001b[2mint64\u001b[0m │\n", - "├──────────┼────────┼────────┼────────┼──────────┼──────────┼─────────┼────────────┼───────────┼─────────────────────┼────────────────────┼───────┤\n", - "│ \u001b[35m19:59:00\u001b[0m │ \u001b[1;36m1022\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m167\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-09-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-09-14 23:00:00\u001b[0m │ \u001b[32mUA,1022,2013-09-14\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m20:00:00\u001b[0m │ \u001b[1;36m1022\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m186\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-09-14\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-09-14 00:00:00\u001b[0m │ \u001b[32mUA,1022,2013-09-14\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m19:12:00\u001b[0m │ \u001b[1;36m1023\u001b[0m │ \u001b[32mLGA \u001b[0m │ \u001b[32mORD \u001b[0m │ \u001b[1;36m112\u001b[0m │ \u001b[1;36m733\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-05-29\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-05-29 23:00:00\u001b[0m │ \u001b[32mUA,1023,2013-05-29\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m21:16:00\u001b[0m │ \u001b[1;36m1023\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m175\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-05-29\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-05-29 01:00:00\u001b[0m │ \u001b[32mUA,1023,2013-05-29\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m15:18:00\u001b[0m │ \u001b[1;36m1052\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m174\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-08-27\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-08-27 19:00:00\u001b[0m │ \u001b[32mUA,1052,2013-08-27\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m21:22:00\u001b[0m │ \u001b[1;36m1052\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mIAH \u001b[0m │ \u001b[1;36m173\u001b[0m │ \u001b[1;36m1400\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-08-27\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-08-27 01:00:00\u001b[0m │ \u001b[32mUA,1052,2013-08-27\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m18:39:00\u001b[0m │ \u001b[1;36m1053\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mCLE \u001b[0m │ \u001b[1;36m72\u001b[0m │ \u001b[1;36m404\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-12-20\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-12-20 23:00:00\u001b[0m │ \u001b[32mUA,1053,2013-12-20\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m19:27:00\u001b[0m │ \u001b[1;36m1053\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mCLE \u001b[0m │ \u001b[1;36m69\u001b[0m │ \u001b[1;36m404\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-12-20\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-12-20 00:00:00\u001b[0m │ \u001b[32mUA,1053,2013-12-20\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m20:16:00\u001b[0m │ \u001b[1;36m1071\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mBQN \u001b[0m │ \u001b[1;36m196\u001b[0m │ \u001b[1;36m1585\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-02-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-26 01:00:00\u001b[0m │ \u001b[32mUA,1071,2013-02-26\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[35m17:20:00\u001b[0m │ \u001b[1;36m1071\u001b[0m │ \u001b[32mEWR \u001b[0m │ \u001b[32mPHX \u001b[0m │ \u001b[1;36m281\u001b[0m │ \u001b[1;36m2133\u001b[0m │ \u001b[32mUA \u001b[0m │ \u001b[35m2013-02-26\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[35m2013-02-26 22:00:00\u001b[0m │ \u001b[32mUA,1071,2013-02-26\u001b[0m │ \u001b[1;36m2\u001b[0m │\n", - "│ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │ \u001b[2m…\u001b[0m │\n", - "└──────────┴────────┴────────┴────────┴──────────┴──────────┴─────────┴────────────┴───────────┴─────────────────────┴────────────────────┴───────┘" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flight_data_with_unique_key.group_by(\"unique_key\").mutate(\n", - " count=flight_data_with_unique_key.count()\n", - ").filter(ibis._[\"count\"] > 1)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, "id": "6be459de-73cd-4d6e-a195-41b9e5c481a6", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [], "source": [ - "import random\n", - "\n", - "# Fix the random numbers by setting the seed\n", - "# This enables the analysis to be reproducible when random numbers are used\n", - "random.seed(222)\n", - "\n", - "# Put 3/4 of the data into the training set\n", - "random_key = str(random.getrandbits(256))\n", - "data_split = flight_data_with_unique_key.mutate(\n", - " train=(flight_data_with_unique_key.unique_key + random_key).hash().abs() % 4 < 3\n", - ")\n", + "import ibis_ml as ml\n", "\n", "# Create data frames for the two sets:\n", - "train_data = data_split[data_split.train].drop(\"unique_key\", \"train\")\n", - "test_data = data_split[~data_split.train].drop(\"unique_key\", \"train\")" + "train_data, test_data = ml.train_test_split(\n", + " flight_data,\n", + " unique_key=[\"carrier\", \"flight\", \"date\"],\n", + " # Put 3/4 of the data into the training set\n", + " test_size=0.25,\n", + " num_buckets=4,\n", + " # Fix the random numbers by setting the seed\n", + " # This enables the analysis to be reproducible when random numbers are used\n", + " random_seed=222,\n", + ")" ] }, { @@ -528,13 +411,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "a223b57d-31b7-4ad1-88fd-a216de7da01a", "metadata": {}, "outputs": [], "source": [ - "import ibis_ml as ml\n", - "\n", "flights_rec = ml.Recipe(\n", " ml.ExpandDate(\"date\", components=[\"dow\", \"month\"]),\n", " ml.Drop(\"date\"),\n", @@ -560,14 +441,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "id": "161b43a0-a3fc-4da3-a5ab-810b234bae32", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "042b3b33d16f421f9c6242642a03c55b", + "model_id": "a80c19c4c8664af0b2e916e2cf36158a", "version_major": 2, "version_minor": 0 }, @@ -575,7 +456,7 @@ "RadioButtons(description='Library:', index=2, options=('scikit-learn', 'XGBoost', 'skorch (PyTorch)'), value='…" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -609,7 +490,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "id": "dc04f24e-c8cb-4580-b502-a9410c64a126", "metadata": {}, "outputs": [], @@ -667,7 +548,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "id": "42ac1426-0561-4a8b-a949-127b2b0c4f01", "metadata": {}, "outputs": [ @@ -677,16 +558,16 @@ "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m4.4971\u001b[0m \u001b[32m0.8388\u001b[0m \u001b[35m2.5698\u001b[0m 1.0492\n", - " 2 \u001b[36m4.4671\u001b[0m 0.8388 2.5698 1.0529\n", - " 3 \u001b[36m4.4625\u001b[0m 0.8388 2.5698 1.0129\n", - " 4 \u001b[36m4.4451\u001b[0m 0.8388 2.5698 1.0246\n", - " 5 4.4488 0.8388 2.5698 1.0251\n", - " 6 4.4553 0.8388 2.5698 0.9891\n", - " 7 4.4630 0.8388 2.5698 1.0836\n", - " 8 4.4847 0.8388 2.5698 1.2162\n", - " 9 4.4798 0.8388 2.5698 1.2594\n", - " 10 4.4799 0.8388 2.5698 1.0920\n" + " 1 \u001b[36m2.4584\u001b[0m \u001b[32m0.8386\u001b[0m \u001b[35m2.5725\u001b[0m 0.9928\n", + " 2 \u001b[36m2.4424\u001b[0m 0.8386 2.5725 0.8958\n", + " 3 \u001b[36m2.4395\u001b[0m 0.8386 2.5725 0.9216\n", + " 4 2.4404 0.8386 2.5725 0.8905\n", + " 5 2.4411 0.8386 2.5725 0.8881\n", + " 6 2.4434 0.8386 2.5725 0.8884\n", + " 7 2.4442 0.8386 2.5725 0.9096\n", + " 8 \u001b[36m2.4391\u001b[0m 0.8386 2.5725 1.0850\n", + " 9 2.4432 0.8386 2.5725 0.9073\n", + " 10 \u001b[36m2.4354\u001b[0m 0.8386 2.5725 0.9601\n" ] }, { @@ -1142,7 +1023,7 @@ " DropZeroVariance(everything(), tolerance=0.0001),\n", " MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),\n", " MutateAt(timestamp(), _.epoch_seconds()),\n", - " Cast(numeric(), 'float32'))
ExpandDate(cols(('date',)), components=['dow', 'month'])
Drop(cols(('date',)))
TargetEncode(nominal(), smooth=0.0)
DropZeroVariance(everything(), tolerance=0.0001)
MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute()))
MutateAt(timestamp(), _.epoch_seconds())
Cast(numeric(), 'float32')
<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
+       "       Cast(numeric(), 'float32'))
ExpandDate(cols(('date',)), components=['dow', 'month'])
Drop(cols(('date',)))
TargetEncode(nominal(), smooth=0.0)
DropZeroVariance(everything(), tolerance=0.0001)
MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute()))
MutateAt(timestamp(), _.epoch_seconds())
Cast(numeric(), 'float32')
<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
        "  module_=MyModule(\n",
        "    (dense0): Linear(in_features=10, out_features=10, bias=True)\n",
        "    (nonlin): ReLU()\n",
@@ -1177,7 +1058,7 @@
        "))])"
       ]
      },
-     "execution_count": 14,
+     "execution_count": 12,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1200,17 +1081,17 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 13,
    "id": "be3ff129-d56f-4441-acbc-da7d6cd93d19",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "0.8385534190130481"
+       "0.8390849833968762"
       ]
      },
-     "execution_count": 15,
+     "execution_count": 13,
      "metadata": {},
      "output_type": "execute_result"
     }
diff --git a/ibis_ml/utils/_split.py b/ibis_ml/utils/_split.py
index 8b157f6..5290637 100644
--- a/ibis_ml/utils/_split.py
+++ b/ibis_ml/utils/_split.py
@@ -98,6 +98,7 @@ def train_test_split(
         }
     )
 
-    return table[table[train_flag]].drop([combined_key, train_flag]), table[
-        ~table[train_flag]
-    ].drop([combined_key, train_flag])
+    return (
+        table.filter(table[train_flag]).drop([combined_key, train_flag]),
+        table.filter(~table[train_flag]).drop([combined_key, train_flag]),
+    )