diff --git a/examples/tutorials/AATestTutorial.ipynb b/examples/tutorials/AATestTutorial.ipynb index 654c3049..12cf6add 100644 --- a/examples/tutorials/AATestTutorial.ipynb +++ b/examples/tutorials/AATestTutorial.ipynb @@ -428,6 +428,27 @@ "res.resume" ] }, + { + "cell_type": "markdown", + "id": "55c32466", + "metadata": {}, + "source": [ + "**Interpretation of AA test results**\n", + "\n", + "Each row in the table corresponds to a target feature being tested for equality between the control and test groups. Two statistical tests are used:\n", + "\n", + "- **TTest**: tests if means are statistically different.\n", + "- **KSTest**: tests if distributions differ.\n", + "\n", + "The `OK` / `NOT OK` labels show whether the difference is statistically significant. A `NOT OK` result indicates a possible imbalance.\n", + "\n", + "Typical threshold:\n", + "- If p-value < 0.05 → `NOT OK` (statistically significant difference)\n", + "- If p-value ≥ 0.05 → `OK` (no significant difference)\n", + "\n", + "If any metric has a `NOT OK` status in the `AA test` column, it means at least one iteration showed significant difference.\n" + ] + }, { "cell_type": "code", "execution_count": 5, @@ -506,6 +527,21 @@ "res.aa_score" ] }, + { + "cell_type": "markdown", + "id": "eb0ce07b", + "metadata": {}, + "source": [ + "**Interpreting `aa_score`**\n", + "\n", + "This output shows p-values and the overall pass/fail status for each test type and feature. A high p-value (close to 1.0) means the test passed — the groups are similar.\n", + "\n", + "- `score`: p-value of the statistical test.\n", + "- `pass`: True if no iterations showed significant differences.\n", + "\n", + "Note: Even if the average p-value is high, the `pass` might still be False if at least one of the iterations had a p-value < 0.05.\n" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -726,6 +762,18 @@ "res.best_split" ] }, + { + "cell_type": "markdown", + "id": "a225e982", + "metadata": {}, + "source": [ + "**About `best_split`**\n", + "\n", + "This shows the best found split of the dataset, where control and test groups are as similar as possible in terms of target metrics.\n", + "\n", + "You can use this split for future modeling or as a validation check before proceeding to actual experiments.\n" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -824,6 +872,22 @@ "res.best_split_statistic" ] }, + { + "cell_type": "markdown", + "id": "ef1986ae", + "metadata": {}, + "source": [ + "**Understanding `best_split_statistic`**\n", + "\n", + "This table contains detailed statistics for the best (most balanced) split found across all iterations. You can compare:\n", + "\n", + "- Mean values in control vs test group.\n", + "- Absolute and relative differences.\n", + "- p-values for both tests.\n", + "\n", + "Ideally, all rows should have `OK` in both TTest and KSTest columns, and small difference values (<1%)." + ] + }, { "cell_type": "code", "execution_count": 8, @@ -2085,12 +2149,16 @@ "source": [ "# AA Test with stratification\n", "\n", - "Depending on your requirements it is possible to stratify the data. You can set `stratification=True` and `StratificationRole` in `Dataset` to run it with stratification. " + "Depending on your requirements it is possible to stratify the data. You can set `stratification=True` and `StratificationRole` in `Dataset` to run it with stratification.\n", + "\n", + "Stratified AA tests ensure that both groups (control/test) have the same proportions of categories (e.g. same % of genders or regions). This prevents imbalances in categorical features that can distort results.\n", + "\n", + "Make sure to assign `StratificationRole` to relevant columns in your dataset before enabling stratification." ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "da9ab2f374ce1273", "metadata": { "ExecuteTime": { @@ -5337,6 +5405,20 @@ "source": [ "res.best_split_statistic" ] + }, + { + "cell_type": "markdown", + "id": "d3dd84bc", + "metadata": {}, + "source": [ + "## Common issues and tips\n", + "\n", + "- **Missing roles**: Make sure all target variables are assigned `TargetRole`. Columns without roles may cause silent failure.\n", + "- **Stratification**: If your dataset contains categorical features (e.g. `gender`, `region`) that may affect the outcome, use `StratificationRole` and enable `stratification=True` in `AATest(...)`.\n", + "- **Imbalanced categories**: If some categories have too few samples, stratified splits may become unstable. Consider filtering or merging rare categories.\n", + "- **Random fluctuations**: On small datasets, it's normal to see occasional `NOT OK` results. Use more iterations (e.g. `n_iterations=50`) for stability.\n", + "- **Missing values**: NaNs in stratification columns may be treated as separate categories. Clean or fill missing values before stratified AA tests." + ] } ], "metadata": { diff --git a/examples/tutorials/ABTestTutorial.ipynb b/examples/tutorials/ABTestTutorial.ipynb index 497a8b56..3864aed3 100644 --- a/examples/tutorials/ABTestTutorial.ipynb +++ b/examples/tutorials/ABTestTutorial.ipynb @@ -59,6 +59,19 @@ { "cell_type": "code", "execution_count": 2, + "id": "547f5448", + "metadata": {}, + "outputs": [], + "source": [ + "from hypex.utils.tutorial_data_creation import DataGenerator\n", + "import numpy as np\n", + "import pandas as pd\n", + "from scipy import stats" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "904175ab484d1690", "metadata": { "ExecuteTime": { @@ -67,420 +80,30 @@ }, "collapsed": false }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idsignup_monthtreatpre_spendspost_spendsagegenderindustry
0000488.0414.444444NaNME-commerce
1181512.5462.22222226.0NaNE-commerce
2271483.0479.44444425.0MLogistics
3300501.5424.33333339.0ME-commerce
4411543.0514.55555618.0FE-commerce
...........................
99959995101538.5450.44444442.0MLogistics
9996999600500.5430.88888926.0FLogistics
9997999731473.0534.11111122.0FE-commerce
9998999821495.0523.22222267.0FE-commerce
9999999971508.0475.88888938.0FE-commerce
\n", - "

10000 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " user_id signup_month treat pre_spends post_spends age gender \\\n", - "0 0 0 0 488.0 414.444444 NaN M \n", - "1 1 8 1 512.5 462.222222 26.0 NaN \n", - "2 2 7 1 483.0 479.444444 25.0 M \n", - "3 3 0 0 501.5 424.333333 39.0 M \n", - "4 4 1 1 543.0 514.555556 18.0 F \n", - "... ... ... ... ... ... ... ... \n", - "9995 9995 10 1 538.5 450.444444 42.0 M \n", - "9996 9996 0 0 500.5 430.888889 26.0 F \n", - "9997 9997 3 1 473.0 534.111111 22.0 F \n", - "9998 9998 2 1 495.0 523.222222 67.0 F \n", - "9999 9999 7 1 508.0 475.888889 38.0 F \n", - "\n", - " industry \n", - "0 E-commerce \n", - "1 E-commerce \n", - "2 Logistics \n", - "3 E-commerce \n", - "4 E-commerce \n", - "... ... \n", - "9995 Logistics \n", - "9996 Logistics \n", - "9997 E-commerce \n", - "9998 E-commerce \n", - "9999 E-commerce \n", - "\n", - "[10000 rows x 8 columns]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ + "gen1 = DataGenerator(\n", + " n_samples=2000,\n", + " distributions={\n", + " \"X1\": {\"type\": \"normal\", \"mean\": 0, \"std\": 1},\n", + " \"X2\": {\"type\": \"bernoulli\", \"p\": 0.5},\n", + " \"y0\": {\"type\": \"normal\", \"mean\": 5, \"std\": 1},\n", + " },\n", + " time_correlations={\"X1\": 0.2, \"X2\": 0.1, \"y0\": 0.6},\n", + " effect_size=2.0,\n", + " seed=7\n", + ")\n", + "df = gen1.generate()\n", + "df = df.drop(columns=['y0', 'z', 'U', 'D', 'y1', 'y0_lag_2'])\n", + "\n", "data = Dataset(\n", " roles={\n", - " \"user_id\": InfoRole(int),\n", - " \"treat\": TreatmentRole(),\n", - " \"pre_spends\": TargetRole(),\n", - " \"post_spends\": TargetRole(),\n", - " \"gender\": TargetRole()\n", - " }, data=\"data.csv\",\n", - ")\n", - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ec0659f2c8de40d9", - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T13:14:12.745242Z", - "start_time": "2024-08-26T13:14:12.713074Z" - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idsignup_monthtreatpre_spendspost_spendsagegenderindustry
0001488.0414.444444NaNME-commerce
1181512.5462.22222226.0NaNE-commerce
2271483.0479.44444425.0MLogistics
3301501.5424.33333339.0ME-commerce
4410543.0514.55555618.0FE-commerce
...........................
99959995101538.5450.44444442.0MLogistics
9996999601500.5430.88888926.0FLogistics
9997999731473.0534.11111122.0FE-commerce
9998999821495.0523.22222267.0FE-commerce
9999999972508.0475.88888938.0FE-commerce
\n", - "

10000 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " user_id signup_month treat pre_spends post_spends age gender \\\n", - "0 0 0 1 488.0 414.444444 NaN M \n", - "1 1 8 1 512.5 462.222222 26.0 NaN \n", - "2 2 7 1 483.0 479.444444 25.0 M \n", - "3 3 0 1 501.5 424.333333 39.0 M \n", - "4 4 1 0 543.0 514.555556 18.0 F \n", - "... ... ... ... ... ... ... ... \n", - "9995 9995 10 1 538.5 450.444444 42.0 M \n", - "9996 9996 0 1 500.5 430.888889 26.0 F \n", - "9997 9997 3 1 473.0 534.111111 22.0 F \n", - "9998 9998 2 1 495.0 523.222222 67.0 F \n", - "9999 9999 7 2 508.0 475.888889 38.0 F \n", - "\n", - " industry \n", - "0 E-commerce \n", - "1 E-commerce \n", - "2 Logistics \n", - "3 E-commerce \n", - "4 E-commerce \n", - "... ... \n", - "9995 Logistics \n", - "9996 Logistics \n", - "9997 E-commerce \n", - "9998 E-commerce \n", - "9999 E-commerce \n", - "\n", - "[10000 rows x 8 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data[\"treat\"] = [random.choice([0, 1, 2]) for _ in range(len(data))]\n", - "data" + " \"d\": TreatmentRole(),\n", + " \"y\": TargetRole(),\n", + " },\n", + " data=df,\n", + " default_role=InfoRole()\n", + " )" ] }, { @@ -506,14 +129,13 @@ { "data": { "text/plain": [ - "{'user_id': Info(),\n", - " 'treat': Treatment(),\n", - " 'pre_spends': Target(),\n", - " 'post_spends': Target(),\n", - " 'gender': Target(),\n", - " 'signup_month': Default(),\n", - " 'age': Default(),\n", - " 'industry': Default()}" + "{'d': Treatment(),\n", + " 'y': Target(),\n", + " 'X1': Info(),\n", + " 'X1_lag': Info(),\n", + " 'X2': Info(),\n", + " 'X2_lag': Info(),\n", + " 'y0_lag_1': Info()}" ] }, "execution_count": 4, @@ -615,64 +237,25 @@ " \n", " \n", " 0\n", - " pre_spends\n", + " y\n", " 1\n", - " 487.071536\n", - " 487.020348\n", - " -0.051188\n", - " -0.010509\n", - " NOT OK\n", - " 0.911224\n", - " \n", - " \n", - " 1\n", - " pre_spends\n", - " 2\n", - " 487.071536\n", - " 487.191596\n", - " 0.120060\n", - " 0.024649\n", - " NOT OK\n", - " 0.795599\n", - " \n", - " \n", - " 2\n", - " post_spends\n", - " 1\n", - " 451.697086\n", - " 452.914905\n", - " 1.217820\n", - " 0.269610\n", - " NOT OK\n", - " 0.207300\n", - " \n", - " \n", - " 3\n", - " post_spends\n", - " 2\n", - " 451.697086\n", - " 451.862460\n", - " 0.165374\n", - " 0.036612\n", - " NOT OK\n", - " 0.863482\n", + " 4.815482\n", + " 7.827936\n", + " 3.012454\n", + " 62.557684\n", + " OK\n", + " 1.895971e-157\n", " \n", " \n", "\n", "" ], "text/plain": [ - " feature group control mean test mean difference difference % \\\n", - "0 pre_spends 1 487.071536 487.020348 -0.051188 -0.010509 \n", - "1 pre_spends 2 487.071536 487.191596 0.120060 0.024649 \n", - "2 post_spends 1 451.697086 452.914905 1.217820 0.269610 \n", - "3 post_spends 2 451.697086 451.862460 0.165374 0.036612 \n", + " feature group control mean test mean difference difference % TTest pass \\\n", + "0 y 1 4.815482 7.827936 3.012454 62.557684 OK \n", "\n", - " TTest pass TTest p-value \n", - "0 NOT OK 0.911224 \n", - "1 NOT OK 0.795599 \n", - "2 NOT OK 0.207300 \n", - "3 NOT OK 0.863482 " + " TTest p-value \n", + "0 1.895971e-157 " ] }, "execution_count": 6, @@ -684,21 +267,6 @@ "result.resume" ] }, - { - "cell_type": "markdown", - "id": "2e226d84456a869b", - "metadata": {}, - "source": [ - "The method sizes shows the statistics on the groups of the data.\n", - "\n", - "The columns are:\n", - "- `control size`: the size of the control group.\n", - "- `test size`: the size of the test group.\n", - "- `control size %`: the share of the control group in the whole dataset.\n", - "- `test size %`: the share of the test group in the whole dataset.\n", - "- `group`: name of the test group." - ] - }, { "cell_type": "code", "execution_count": 7, @@ -741,28 +309,19 @@ " \n", " \n", " 1\n", - " 3313\n", - " 3391\n", - " 49\n", - " 50\n", + " 1352\n", + " 648\n", + " 67.6\n", + " 32.4\n", " 1\n", " \n", - " \n", - " 2\n", - " 3313\n", - " 3296\n", - " 50\n", - " 49\n", - " 2\n", - " \n", " \n", "\n", "" ], "text/plain": [ " control size test size control size % test size % group\n", - "1 3313 3391 49 50 1\n", - "2 3313 3296 50 49 2" + "1 1352 648 67.6 32.4 1" ] }, "execution_count": 7, @@ -784,6 +343,52 @@ "start_time": "2024-08-26T13:14:13.018409Z" } }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"There was less than three groups or multitest method wasn't provided\"" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result.multitest" + ] + }, + { + "cell_type": "markdown", + "id": "c6161398", + "metadata": {}, + "source": [ + "## CUPED: Classic Covariate Adjustment for Variance Reduction\n", + "\n", + "CUPED (Controlled Experiments Using Pre-Experiment Data) — это классический метод снижения дисперсии в A/B тестах. Он использует исторические или вспомогательные признаки (например, лаги таргета), чтобы скорректировать целевую переменную и повысить статистическую мощность теста.\n", + "\n", + "В HypEx для применения CUPED достаточно указать соответствующие признаки через параметр `cuped_features` в `ABTest` или напрямую в трансформере CUPEDTransformer. В результате создаётся новая колонка (например, `y_cuped`), которая автоматически используется для анализа.\n", + "\n", + "Пример запуска ABTest с CUPED-коррекцией:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fbc1569d", + "metadata": {}, + "outputs": [], + "source": [ + "test = ABTest(cuped_features={'y': 'y0_lag_1'})\n", + "result = test.execute(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9ef9b808", + "metadata": {}, "outputs": [ { "data": { @@ -806,75 +411,160 @@ " \n", " \n", " \n", - " field\n", - " test\n", - " old p-value\n", - " new p-value\n", - " correction\n", - " rejected\n", + " feature\n", " group\n", + " control mean\n", + " test mean\n", + " difference\n", + " difference %\n", + " TTest pass\n", + " TTest p-value\n", " \n", " \n", " \n", " \n", " 0\n", - " pre_spends\n", - " TTest\n", - " 0.911224\n", - " 1.000000\n", - " 0.911224\n", - " False\n", + " y\n", " 1\n", + " 4.924818\n", + " 7.599815\n", + " 2.674998\n", + " 54.316686\n", + " OK\n", + " 6.901190e-188\n", " \n", - " \n", - " 1\n", - " post_spends\n", - " TTest\n", - " 0.795599\n", - " 1.000000\n", - " 0.795599\n", - " False\n", - " 1\n", + " \n", + "\n", + "" + ], + "text/plain": [ + " feature group control mean test mean difference difference % TTest pass \\\n", + "0 y 1 4.924818 7.599815 2.674998 54.316686 OK \n", + "\n", + " TTest p-value \n", + "0 6.901190e-188 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result.resume" + ] + }, + { + "cell_type": "markdown", + "id": "e23e18fc", + "metadata": {}, + "source": [ + "## CUPAC: Advanced Covariate Adjustment\n", + "\n", + "CUPAC (Covariate-Updated Pre-Analysis Correction) is an advanced method for variance reduction in A/B testing. It extends the CUPED approach by allowing flexible model selection (linear, ridge, lasso, or CatBoost regression) to adjust the target variable using historical or auxiliary features. This can lead to more accurate and powerful statistical tests.\n", + "\n", + "To use CUPAC in HypEx, specify the `cupac_params` argument in `ABTest`, including the target and covariate columns, and optionally the model type. The result is a new target column (e.g., `y_cupac`) automatically added to your dataset and used in the analysis.\n", + "\n", + "Below is an example of running ABTest with CUPAC adjustment:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8406f80d", + "metadata": {}, + "outputs": [], + "source": [ + "# Run ABTest with CUPAC adjustment\n", + "test = ABTest(cupac_params={\n", + " 'y': ['y0_lag_1', 'X1_lag', 'X2_lag'],\n", + " 'model': 'linear' # Options: 'linear', 'ridge', 'lasso', 'catboost', None\n", + " # If 'model' is None or not specified, CUPAC will try all and select the best by variance reduction.\n", + "})\n", + "result = test.execute(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "978bc0bf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
featuregroupcontrol meantest meandifferencedifference %TTest passTTest p-value
2pre_spendsTTest0.2073000.8292010.250000False20y14.8154827.8279363.01245462.557684OK1.895971e-157
3post_spendsTTest0.8634821.0000000.863482False21y_cupac15.0497237.3392102.28948645.338850OK1.176433e-160
\n", "
" ], "text/plain": [ - " field test old p-value new p-value correction rejected group\n", - "0 pre_spends TTest 0.911224 1.000000 0.911224 False 1\n", - "1 post_spends TTest 0.795599 1.000000 0.795599 False 1\n", - "2 pre_spends TTest 0.207300 0.829201 0.250000 False 2\n", - "3 post_spends TTest 0.863482 1.000000 0.863482 False 2" + " feature group control mean test mean difference difference % \\\n", + "0 y 1 4.815482 7.827936 3.012454 62.557684 \n", + "1 y_cupac 1 5.049723 7.339210 2.289486 45.338850 \n", + "\n", + " TTest pass TTest p-value \n", + "0 OK 1.895971e-157 \n", + "1 OK 1.176433e-160 " ] }, - "execution_count": 8, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "result.multitest" + "result.resume" ] }, { @@ -889,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "id": "a40f5762f0b37a0a", "metadata": { "ExecuteTime": { @@ -918,7 +608,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "89a8898c35681e97", "metadata": { "ExecuteTime": { @@ -959,132 +649,35 @@ " TTest p-value\n", " UTest pass\n", " UTest p-value\n", - " Chi2Test pass\n", - " Chi2Test p-value\n", " \n", " \n", " \n", " \n", " 0\n", - " pre_spends\n", - " 1\n", - " 487.071536\n", - " 487.020348\n", - " -0.051188\n", - " -0.010509\n", - " NOT OK\n", - " 0.911224\n", - " NOT OK\n", - " 0.764231\n", - " NaN\n", - " NaN\n", - " \n", - " \n", - " 1\n", - " pre_spends\n", - " 2\n", - " 487.071536\n", - " 487.191596\n", - " 0.120060\n", - " 0.024649\n", - " NOT OK\n", - " 0.795599\n", - " NOT OK\n", - " 0.752229\n", - " NaN\n", - " NaN\n", - " \n", - " \n", - " 2\n", - " post_spends\n", + " y\n", " 1\n", - " 451.697086\n", - " 452.914905\n", - " 1.217820\n", - " 0.269610\n", - " NOT OK\n", - " 0.207300\n", - " NOT OK\n", - " 0.457447\n", - " NaN\n", - " NaN\n", - " \n", - " \n", - " 3\n", - " post_spends\n", - " 2\n", - " 451.697086\n", - " 451.862460\n", - " 0.165374\n", - " 0.036612\n", - " NOT OK\n", - " 0.863482\n", - " NOT OK\n", - " 0.572854\n", - " NaN\n", - " NaN\n", - " \n", - " \n", - " 4\n", - " gender\n", - " 1\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NOT OK\n", - " 0.945581\n", - " \n", - " \n", - " 5\n", - " gender\n", - " 2\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NOT OK\n", - " 0.858201\n", + " 4.815482\n", + " 7.827936\n", + " 3.012454\n", + " 62.557684\n", + " OK\n", + " 1.895971e-157\n", + " OK\n", + " 1.114725e-110\n", " \n", " \n", "\n", "" ], "text/plain": [ - " feature group control mean test mean difference difference % \\\n", - "0 pre_spends 1 487.071536 487.020348 -0.051188 -0.010509 \n", - "1 pre_spends 2 487.071536 487.191596 0.120060 0.024649 \n", - "2 post_spends 1 451.697086 452.914905 1.217820 0.269610 \n", - "3 post_spends 2 451.697086 451.862460 0.165374 0.036612 \n", - "4 gender 1 NaN NaN NaN NaN \n", - "5 gender 2 NaN NaN NaN NaN \n", + " feature group control mean test mean difference difference % TTest pass \\\n", + "0 y 1 4.815482 7.827936 3.012454 62.557684 OK \n", "\n", - " TTest pass TTest p-value UTest pass UTest p-value Chi2Test pass \\\n", - "0 NOT OK 0.911224 NOT OK 0.764231 NaN \n", - "1 NOT OK 0.795599 NOT OK 0.752229 NaN \n", - "2 NOT OK 0.207300 NOT OK 0.457447 NaN \n", - "3 NOT OK 0.863482 NOT OK 0.572854 NaN \n", - "4 NaN NaN NaN NaN NOT OK \n", - "5 NaN NaN NaN NaN NOT OK \n", - "\n", - " Chi2Test p-value \n", - "0 NaN \n", - "1 NaN \n", - "2 NaN \n", - "3 NaN \n", - "4 0.945581 \n", - "5 0.858201 " + " TTest p-value UTest pass UTest p-value \n", + "0 1.895971e-157 OK 1.114725e-110 " ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -1095,7 +688,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "1da993761313d8d8", "metadata": { "ExecuteTime": { @@ -1107,132 +700,11 @@ "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
fieldtestold p-valuenew p-valuecorrectionrejectedgroup
0pre_spendsTTest0.9112241.00.911224False1
1post_spendsTTest0.7955991.00.795599False1
2pre_spendsTTest0.2073001.00.207300False2
3post_spendsTTest0.8634821.00.863482False2
4pre_spendsUTest0.7642311.00.764231False1
5post_spendsUTest0.7522291.00.752229False1
6pre_spendsUTest0.4574471.00.457447False2
7post_spendsUTest0.5728541.00.572854False2
\n", - "
" - ], "text/plain": [ - " field test old p-value new p-value correction rejected group\n", - "0 pre_spends TTest 0.911224 1.0 0.911224 False 1\n", - "1 post_spends TTest 0.795599 1.0 0.795599 False 1\n", - "2 pre_spends TTest 0.207300 1.0 0.207300 False 2\n", - "3 post_spends TTest 0.863482 1.0 0.863482 False 2\n", - "4 pre_spends UTest 0.764231 1.0 0.764231 False 1\n", - "5 post_spends UTest 0.752229 1.0 0.752229 False 1\n", - "6 pre_spends UTest 0.457447 1.0 0.457447 False 2\n", - "7 post_spends UTest 0.572854 1.0 0.572854 False 2" + "\"There was less than three groups or multitest method wasn't provided\"" ] }, - "execution_count": 11, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1243,7 +715,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "c11137e6c10eb0dc", "metadata": { "ExecuteTime": { @@ -1284,31 +756,22 @@ " \n", " \n", " 1\n", - " 3313\n", - " 3391\n", - " 49\n", - " 50\n", + " 1352\n", + " 648\n", + " 67.6\n", + " 32.4\n", " 1\n", " \n", - " \n", - " 2\n", - " 3313\n", - " 3296\n", - " 50\n", - " 49\n", - " 2\n", - " \n", " \n", "\n", "" ], "text/plain": [ " control size test size control size % test size % group\n", - "1 3313 3391 49 50 1\n", - "2 3313 3296 50 49 2" + "1 1352 648 67.6 32.4 1" ] }, - "execution_count": 12, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1329,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "5921c9e2", "metadata": { "ExecuteTime": { @@ -1345,7 +808,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "id": "952d21c6", "metadata": { "ExecuteTime": { @@ -1388,67 +851,28 @@ " \n", " \n", " 0\n", - " pre_spends\n", + " y\n", " 1\n", - " 487.071536\n", - " 487.020348\n", - " -0.051188\n", - " -0.010509\n", - " NOT OK\n", - " 0.911224\n", - " \n", - " \n", - " 1\n", - " pre_spends\n", - " 2\n", - " 487.071536\n", - " 487.191596\n", - " 0.120060\n", - " 0.024649\n", - " NOT OK\n", - " 0.795599\n", - " \n", - " \n", - " 2\n", - " post_spends\n", - " 1\n", - " 451.697086\n", - " 452.914905\n", - " 1.217820\n", - " 0.269610\n", - " NOT OK\n", - " 0.207300\n", - " \n", - " \n", - " 3\n", - " post_spends\n", - " 2\n", - " 451.697086\n", - " 451.862460\n", - " 0.165374\n", - " 0.036612\n", - " NOT OK\n", - " 0.863482\n", + " 4.815482\n", + " 7.827936\n", + " 3.012454\n", + " 62.557684\n", + " OK\n", + " 1.895971e-157\n", " \n", " \n", "\n", "" ], "text/plain": [ - " feature group control mean test mean difference difference % \\\n", - "0 pre_spends 1 487.071536 487.020348 -0.051188 -0.010509 \n", - "1 pre_spends 2 487.071536 487.191596 0.120060 0.024649 \n", - "2 post_spends 1 451.697086 452.914905 1.217820 0.269610 \n", - "3 post_spends 2 451.697086 451.862460 0.165374 0.036612 \n", + " feature group control mean test mean difference difference % TTest pass \\\n", + "0 y 1 4.815482 7.827936 3.012454 62.557684 OK \n", "\n", - " TTest pass TTest p-value \n", - "0 NOT OK 0.911224 \n", - "1 NOT OK 0.795599 \n", - "2 NOT OK 0.207300 \n", - "3 NOT OK 0.863482 " + " TTest p-value \n", + "0 1.895971e-157 " ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1459,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "id": "ad59dec9", "metadata": { "ExecuteTime": { @@ -1499,31 +923,22 @@ " \n", " \n", " 1\n", - " 3313\n", - " 3391\n", - " 49\n", - " 50\n", + " 1352\n", + " 648\n", + " 67.6\n", + " 32.4\n", " 1\n", " \n", - " \n", - " 2\n", - " 3313\n", - " 3296\n", - " 50\n", - " 49\n", - " 2\n", - " \n", " \n", "\n", "" ], "text/plain": [ " control size test size control size % test size % group\n", - "1 3313 3391 49 50 1\n", - "2 3313 3296 50 49 2" + "1 1352 648 67.6 32.4 1" ] }, - "execution_count": 15, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -1534,7 +949,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 20, "id": "7849230a", "metadata": { "ExecuteTime": { @@ -1545,88 +960,11 @@ "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
fieldtestold p-valuenew p-valuecorrectionrejectedgroup
0pre_spendsTTest0.9112241.0000000.911224False1
1post_spendsTTest0.7955991.0000000.795599False1
2pre_spendsTTest0.2073000.8292010.250000False2
3post_spendsTTest0.8634821.0000000.863482False2
\n", - "
" - ], "text/plain": [ - " field test old p-value new p-value correction rejected group\n", - "0 pre_spends TTest 0.911224 1.000000 0.911224 False 1\n", - "1 post_spends TTest 0.795599 1.000000 0.795599 False 1\n", - "2 pre_spends TTest 0.207300 0.829201 0.250000 False 2\n", - "3 post_spends TTest 0.863482 1.000000 0.863482 False 2" + "\"There was less than three groups or multitest method wasn't provided\"" ] }, - "execution_count": 16, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/tutorials/MatchingTutorial.ipynb b/examples/tutorials/MatchingTutorial.ipynb index f40aef23..9f0008d9 100644 --- a/examples/tutorials/MatchingTutorial.ipynb +++ b/examples/tutorials/MatchingTutorial.ipynb @@ -48,10 +48,10 @@ "\n", "It is important to mark the data fields by assigning the appropriate roles:\n", "\n", - "* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.\n", - "* TreatmentRole: a role for columns that show the treatment or intervention.\n", - "* TargetRole: a role for columns that show the target or outcome variable.\n", - "* InfoRole: a role for columns that contain information about the data, such as user IDs." + "* **FeatureRole**: columns with features or predictor variables. Matching is based on these. Applied by default if the role is not specified for the column.\n", + "* **TreatmentRole**: column indicating the treatment or intervention (should be binary: 0/1 or True/False).\n", + "* **TargetRole**: column with the target or outcome variable (numeric, e.g., spend, conversion).\n", + "* **InfoRole**: columns with information about the data, such as user IDs (should be unique identifiers).\n" ] }, { @@ -314,11 +314,15 @@ }, "source": [ "## Simple Matching \n", - "Now matching has 4 steps: \n", - "1. Dummy Encoder \n", - "2. Process Mahalanobis distance \n", - "3. Two sides pairs searching by faiss \n", - "4. Metrics (ATT, ATC, ATE) estimation depends on your data " + "Matching consists of 4 main steps: \n", + "1. **Dummy Encoder**: Converts categorical features to numeric (one-hot encoding).\n", + "2. **Process Mahalanobis distance**: Calculates distances between units using all features (default is Mahalanobis, can be changed).\n", + "3. **Two sides pairs searching by faiss**: Finds the best matches between treated and control units using fast nearest neighbor search.\n", + "4. **Metrics (ATT, ATC, ATE) estimation**: Calculates the effect based on matched pairs.\n", + "\n", + "> **Common issues:**\n", + "- If you get errors about categorical features, check that all non-numeric columns are intended as features and will be encoded.\n", + "- If matching quality is poor, try changing the distance metric or reviewing your feature selection." ] }, { @@ -353,17 +357,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n" ] } @@ -380,9 +384,11 @@ "collapsed": false }, "source": [ - "**ATT** shows the difference in treated group. \n", - "**ATC** shows the difference in untreated group. \n", - "**ATE** shows the weighted average difference between ATT and ATC. " + "**ATT** (Average Treatment effect on the Treated): the estimated effect of the treatment for those who actually received it.\n", + "\n", + "**ATC** (Average Treatment effect on the Controls): the estimated effect if the control group had received the treatment.\n", + "\n", + "**ATE** (Average Treatment Effect): the overall average effect, combining ATT and ATC, weighted by group sizes.\n" ] }, { @@ -969,17 +975,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n" ] } @@ -1082,7 +1088,11 @@ "id": "3ad7a444", "metadata": {}, "source": [ - "We can change **metric** and do estimation again." + "We can change the **metric** parameter to estimate different effects:\n", + "- `'att'`: effect for treated group (default)\n", + "- `'atc'`: effect for control group\n", + "- `'ate'`: average effect for all\n", + "- `'auto'`: automatically selects based on data" ] }, { @@ -1095,17 +1105,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n" ] } @@ -1304,17 +1314,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n" ] } @@ -1825,7 +1835,12 @@ "id": "a60205ca", "metadata": {}, "source": [ - "Finally, we may search pairs in L2 distance. " + "Finally, you can change the distance metric used for matching. By default, Mahalanobis distance is used, but you can also use L2 (Euclidean) distance.\n", + "\n", + "- **Mahalanobis**: Takes into account correlations between features; recommended for most cases.\n", + "- **L2 (Euclidean)**: Simpler, may work well if features are uncorrelated and similarly scaled.\n", + "\n", + "> **Tip:** If matching quality is poor or you get warnings about singular matrices, try switching the distance metric." ] }, { @@ -1838,15 +1853,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n", - "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:344: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", + "/home/anathema/HypEx/hypex/dataset/backends/pandas_backend.py:337: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.\n", " return list(groups)\n" ] } diff --git a/hypex/ab.py b/hypex/ab.py index 6aa81b85..d8be8401 100644 --- a/hypex/ab.py +++ b/hypex/ab.py @@ -9,6 +9,7 @@ from .ui.ab import ABOutput from .ui.base import ExperimentShell from .utils import ABNTestMethodsEnum +from .transformers import CUPEDTransformer class ABTest(ExperimentShell): @@ -39,29 +40,14 @@ class ABTest(ExperimentShell): # A/B test with multiple statistical tests ab_test = ABTest( additional_tests=["t-test", "chi2-test"], - multitest_method="bonferroni" + multitest_method="bonferroni", + cuped_feature = "feature_name" ) results = ab_test.execute(data) """ @staticmethod - def _make_experiment(additional_tests, multitest_method): - """Creates an experiment configuration with specified statistical tests. - - Args: - Args: - additional_tests (Union[str, List[str], None], optional): Statistical test(s) to run in addition to - the default group difference calculation. Valid options are "t-test", "u-test", and "chi2-test". - Can be a single test name or list of test names. Defaults to ["t-test"]. - multitest_method (str, optional): Method to use for multiple testing correction. Valid options are: - "bonferroni", "sidak", "holm-sidak", "holm", "simes-hochberg", "hommel", "fdr_bh", "fdr_by", - "fdr_tsbh", "fdr_tsbhy", "quantile". Defaults to "holm". - For more information refer to the statsmodels documentation: - - - Returns: - Experiment: Configured experiment object with specified tests and correction method. - """ + def _make_experiment(additional_tests, multitest_method, cuped_features=None, cupac_params=None): test_mapping = { "t-test": TTest(compare_by="groups", grouping_role=TreatmentRole()), "u-test": UTest(compare_by="groups", grouping_role=TreatmentRole()), @@ -76,22 +62,34 @@ def _make_experiment(additional_tests, multitest_method): ) for i in additional_tests: on_role_executors += [test_mapping[i]] - return Experiment( - executors=[ - GroupSizes(grouping_role=TreatmentRole()), - OnRoleExperiment( - executors=on_role_executors, - role=TargetRole(), - ), - ABAnalyzer( - multitest_method=( - ABNTestMethodsEnum(multitest_method) - if multitest_method - else None - ) - ), - ] - ) + + + + if cuped_features and cupac_params: + raise ValueError("You can use only one transformer: either CUPED or CUPACExecutor, not both.") + transformers = [] + if cuped_features: + transformers.append(CUPEDTransformer(cuped_features=cuped_features)) + elif cupac_params: + from .ml import CUPACExecutor + transformers.append(CUPACExecutor(cupac_features=cupac_params)) + + executors = [ + GroupSizes(grouping_role=TreatmentRole()), + OnRoleExperiment( + executors=on_role_executors, + role=TargetRole(), + ), + ABAnalyzer( + multitest_method=( + ABNTestMethodsEnum(multitest_method) + if multitest_method + else None + ) + ), + ] + + return Experiment(transformer=transformers, executors=executors) def __init__( self, @@ -117,9 +115,28 @@ def __init__( | None ) = "holm", t_test_equal_var: bool | None = None, + cuped_features: dict[str, str] | None = None, + cupac_params: dict | None = None, ): + """ + Args: + additional_tests: Statistical test(s) to run in addition to the default group difference calculation. Valid options are "t-test", "u-test", and "chi2-test". Can be a single test name or list of test names. Defaults to ["t-test"]. + multitest_method: Method to use for multiple testing correction. Valid options are: "bonferroni", "sidak", "holm-sidak", "holm", "simes-hochberg", "hommel", "fdr_bh", "fdr_by", "fdr_tsbh", "fdr_tsbhy", "quantile". Defaults to "holm". + t_test_equal_var: Whether to use equal variance in t-test (optional). + cuped_features: dict[str, str] — Dictionary {target_feature: pre_target_feature} for CUPED. Only dict is allowed. + cupac_params: dict — Parameters for CUPACML, e.g. {"target1": ["cov1", "cov2"], ...}. + You can also specify a model for adjustment: + Supported models for 'model' parameter: + 'linear' - LinearRegression (sklearn) + 'ridge' - Ridge regression + 'lasso' - Lasso regression + 'catboost' - CatBoostRegressor (if installed) + If 'model' is None or not specified, CUPAC will try all and select the best by variance reduction. + Raises: + ValueError: If both cuped_features and cupac_params are specified. + """ super().__init__( - experiment=self._make_experiment(additional_tests, multitest_method), + experiment=self._make_experiment(additional_tests, multitest_method, cuped_features, cupac_params), output=ABOutput(), ) if t_test_equal_var is not None: diff --git a/hypex/experiments/base.py b/hypex/experiments/base.py index 8c8dffa7..39dbcf9f 100644 --- a/hypex/experiments/base.py +++ b/hypex/experiments/base.py @@ -59,7 +59,13 @@ def _set_value(self, data: ExperimentData, value, key=None) -> ExperimentData: return data.set_value(ExperimentDataEnum.analysis_tables, self.id, value) def execute(self, data: ExperimentData) -> ExperimentData: - experiment_data = deepcopy(data) if self.transformer else data + if self.transformer: + experiment_data = deepcopy(data) + for transformer in self.transformer: + experiment_data = transformer.execute(experiment_data) + else: + experiment_data = data + for executor in self.executors: executor.key = self.key experiment_data = executor.execute(experiment_data) diff --git a/hypex/ml/__init__.py b/hypex/ml/__init__.py index 218c62ca..fcf72485 100644 --- a/hypex/ml/__init__.py +++ b/hypex/ml/__init__.py @@ -1,3 +1,4 @@ from .faiss import FaissNearestNeighbors +from .cupac import CUPACExecutor -__all__ = ["FaissNearestNeighbors"] +__all__ = ["FaissNearestNeighbors", "CUPACExecutor"] diff --git a/hypex/ml/cupac.py b/hypex/ml/cupac.py new file mode 100644 index 00000000..ff1a2bcf --- /dev/null +++ b/hypex/ml/cupac.py @@ -0,0 +1,165 @@ +from typing import Any, Dict, Optional +import numpy as np +from ..dataset.dataset import Dataset, ExperimentData +from ..dataset.roles import TargetRole +from ..executor import MLExecutor +from ..utils import ExperimentDataEnum + + +class CUPACExecutor(MLExecutor): + def __init__( + self, + cupac_features: Dict[str, list], + key: Any = "", + models: Optional[Dict[str, Any]] = None, + n_folds: int = 5, + random_state: Optional[int] = None, + ): + super().__init__(target_role=TargetRole(), key=key) + self.cupac_features = cupac_features + self.models = models + self.n_folds = n_folds + self.random_state = random_state + self.best_model = None + self.best_model_name = None + self.best_score = None + self.variance_reduction = None + self.feature_importances_ = None + self.is_fitted = False + self.model_results_ = {} + + @classmethod + def _inner_function( + cls, + data: Dataset, + cupac_features: Dict[str, list], + models: Optional[Dict[str, Any]] = None, + n_folds: int = 5, + random_state: Optional[int] = None, + **kwargs, + ) -> Dict[str, np.ndarray]: + instance = cls( + cupac_features=cupac_features, + models=models, + n_folds=n_folds, + random_state=random_state, + ) + instance.fit(data) + return instance.predict(data) + + def fit(self, X: Dataset) -> "CUPACExecutor": + import pandas as pd + from sklearn.linear_model import LinearRegression, Ridge, Lasso + try: + from catboost import CatBoostRegressor + except ImportError: + CatBoostRegressor = None + + # Supported models + all_models = { + "linear": LinearRegression(), + "ridge": Ridge(alpha=0.5), + "lasso": Lasso(alpha=0.01, max_iter=10000), + } + if CatBoostRegressor: + all_models["catboost"] = CatBoostRegressor( + iterations=100, + depth=4, + learning_rate=0.1, + silent=True, + random_state=self.random_state, + allow_writing_files=False, + ) + + # Check for explicit model selection + explicit_model = None + if "model" in self.cupac_features: + model_name = self.cupac_features["model"].lower() + if model_name not in all_models: + raise ValueError(f"Unknown model '{model_name}'. Supported: {list(all_models.keys())}") + explicit_model = all_models[model_name] + + df = X.data.copy() + self.fitted_models = {} + self.best_model_names = {} + for target_col, covariates in self.cupac_features.items(): + if target_col == "model": + continue + X_cov = df[covariates] + y = df[target_col] + from sklearn.model_selection import KFold + kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.random_state) + if explicit_model is not None: + # Use only the specified model + model = explicit_model.__class__(**explicit_model.get_params()) + model.fit(X_cov, y) + self.fitted_models[target_col] = model + self.best_model_names[target_col] = model_name + else: + # Auto-select best model + best_score = -np.inf + best_model = None + best_model_name = None + for name, model in all_models.items(): + fold_var_reductions = [] + for train_idx, val_idx in kf.split(X_cov): + X_train, X_val = X_cov.iloc[train_idx], X_cov.iloc[val_idx] + y_train, y_val = y.iloc[train_idx], y.iloc[val_idx] + m = model.__class__(**model.get_params()) + m.fit(X_train, y_train) + pred = m.predict(X_val) + fold_var_reductions.append(self._calculate_variance_reduction(y_val, pred)) + score = np.nanmean(fold_var_reductions) + if score > best_score: + best_score = score + best_model = model.__class__(**model.get_params()) + best_model_name = name + best_model.fit(X_cov, y) + self.fitted_models[target_col] = best_model + self.best_model_names[target_col] = best_model_name + self.is_fitted = True + return self + + def predict(self, X: Dataset) -> Dict[str, np.ndarray]: + df = X.data.copy() + result = {} + for target_col, covariates in self.cupac_features.items(): + if target_col == "model": + continue + model = self.fitted_models.get(target_col) + if model is None: + raise RuntimeError(f"Model for {target_col} not fitted. Call fit() first.") + X_cov = df[covariates] + y = df[target_col] + pred = model.predict(X_cov) + y_adj = y - pred + np.mean(y) + result[f"{target_col}_cupac"] = y_adj + return result + + @staticmethod + def _calculate_variance_reduction(y, pred): + pred_centered = pred - np.mean(pred) + if np.var(pred_centered) < 1e-10: + return 0.0 + theta = np.cov(y, pred_centered)[0, 1] / np.var(pred_centered) + y_adj = y - theta * pred_centered + return max(0, (1 - np.var(y_adj) / np.var(y)) * 100) + + def execute(self, data: ExperimentData) -> ExperimentData: + self.fit(data.ds) + cupac_result = self.predict(data.ds) + for col, values in cupac_result.items(): + ds_ml = Dataset.from_dict( + {col: values}, + roles={col: TargetRole()}, + index=data.ds.index, + ) + data.set_value( + ExperimentDataEnum.ml, + executor_id=col, + value=ds_ml, + role=TargetRole(), + ) + # Добавить колонку в основной Dataset и назначить ей роль TargetRole + data.ds.add_column(values, {col: TargetRole()}) + return data diff --git a/hypex/transformers/__init__.py b/hypex/transformers/__init__.py index 03167a90..0baa5cac 100644 --- a/hypex/transformers/__init__.py +++ b/hypex/transformers/__init__.py @@ -3,6 +3,7 @@ from .filters import ConstFilter, CorrFilter, CVFilter, NanFilter, OutliersFilter from .na_filler import NaFiller from .shuffle import Shuffle +from .cuped import CUPEDTransformer __all__ = [ "CVFilter", diff --git a/hypex/transformers/cuped.py b/hypex/transformers/cuped.py new file mode 100644 index 00000000..274ea806 --- /dev/null +++ b/hypex/transformers/cuped.py @@ -0,0 +1,48 @@ +from typing import Any, Sequence +from ..dataset.dataset import Dataset, ExperimentData +from ..dataset.roles import TargetRole, PreTargetRole +from ..utils.adapter import Adapter +from .abstract import Transformer + + +class CUPEDTransformer(Transformer): + def __init__( + self, + cuped_features: dict[str, str], + key: Any = "", + ): + """ + Transformer для применения метода CUPED. + + Args: + cuped_features (dict[str, str]): Словарь {target_feature: pre_target_feature}. + """ + super().__init__(key=key) + self.cuped_features = cuped_features + + @staticmethod + def _inner_function( + data: Dataset, + cuped_features: dict[str, str], + ) -> Dataset: + # cuped_features: {target_col: covariate_col} + for target_col, covariate_col in cuped_features.items(): + # Используем Series для вычислений + target_series = data.data[target_col] + covariate_series = data.data[covariate_col] + cov_xy = data.data[[target_col, covariate_col]].cov().loc[target_col, covariate_col] + std_y = target_series.std() + std_x = covariate_series.std() + theta = cov_xy / (std_y * std_x) + data[target_col] = target_series - theta * (covariate_series - covariate_series.mean()) + data = data.astype({target_col: data.roles[target_col].data_type or float}) + return data + + def execute(self, data: ExperimentData) -> ExperimentData: + result = data.copy( + data=self.calc( + data=data.ds, + cuped_features=self.cuped_features, + ) + ) + return result \ No newline at end of file diff --git a/hypex/utils/enums.py b/hypex/utils/enums.py index 7989ac0a..9410816a 100644 --- a/hypex/utils/enums.py +++ b/hypex/utils/enums.py @@ -7,6 +7,7 @@ class ExperimentDataEnum(enum.Enum): additional_fields = "additional_fields" analysis_tables = "analysis_tables" groups = "groups" + ml = "ml" @enum.unique diff --git a/hypex/utils/tutorial_data_creation.py b/hypex/utils/tutorial_data_creation.py index d821bfe4..1f64d556 100644 --- a/hypex/utils/tutorial_data_creation.py +++ b/hypex/utils/tutorial_data_creation.py @@ -3,13 +3,120 @@ import sys from pathlib import Path from typing import Sequence - import numpy as np import pandas as pd +from scipy import stats ROOT = Path("").absolute().parents[0] sys.path.append(str(ROOT)) +class DataGenerator: + """ + Advanced synthetic data generator with support for two lags for Y + and control of correlation structure. + """ + def __init__( + self, + n_samples=2000, + distributions=None, + time_correlations=None, + effect_size=5.0, + seed=None, + ): + self.n_samples = n_samples + self.distributions = distributions or { + "X1": {"type": "normal", "mean": 1, "std": 2}, + "X2": {"type": "bernoulli", "p": 0.4}, + "y0": {"type": "normal", "mean": 10, "std": 3}, + } + self.time_correlations = time_correlations or {"X1": 0.7, "X2": 0.6, "y0": 0.8} + self.effect_size = effect_size + self.seed = seed + np.random.seed(seed) + + def _generate_bernoulli_pair(self, p, rho): + rho_max = min(p / (1 - p), (1 - p) / p) + if abs(rho) > rho_max: + raise ValueError(f"Impossible correlation {rho} for p={p}") + p11 = p * p + rho * p * (1 - p) + p10 = p * (1 - p) - rho * p * (1 - p) + p01 = (1 - p) * p - rho * p * (1 - p) + p00 = (1 - p) * (1 - p) + rho * p * (1 - p) + states = np.random.choice(4, size=self.n_samples, p=[p00, p01, p10, p11]) + lag = (states == 1) | (states == 3) + current = (states == 2) | (states == 3) + return current.astype(int), lag.astype(int) + + def _generate_correlated_pair(self, dist_type, params, rho, U_vector=0): + if dist_type == "normal": + cov = [ + [params["std"] ** 2, rho * params["std"] ** 2], + [rho * params["std"] ** 2, params["std"] ** 2], + ] + return np.random.multivariate_normal( + [params["mean"], params["mean"]], cov, self.n_samples + ).T + U_vector + elif dist_type == "bernoulli": + return self._generate_bernoulli_pair(params["p"], rho) + elif dist_type == "gamma": + Z = np.random.multivariate_normal( + [0, 0], [[1, rho], [rho, 1]], self.n_samples + ) + U = stats.norm.cdf(Z) + current = stats.gamma.ppf(U[:, 0], a=params["shape"], scale=params["scale"]) + lag = stats.gamma.ppf(U[:, 1], a=params["shape"], scale=params["scale"]) + return current, lag + else: + raise ValueError(f"Unsupported distribution: {dist_type}") + + def _generate_correlated_chain(self, params, rho, n_points, U=0): + mean = params["mean"] + std = params["std"] + cov = np.zeros((n_points, n_points)) + for i in range(n_points): + for j in range(n_points): + cov[i, j] = (std**2) * (rho ** abs(i - j)) + return np.random.multivariate_normal([mean] * n_points, cov, self.n_samples).T + U + + def generate(self): + data = {} + data["z"] = np.random.binomial(1, 0.5, self.n_samples) + data["U"] = np.random.normal(0, 1, self.n_samples) + D_propensity = 0.3 + 0.4 * data["z"] + 0.3 * data["U"] + data["D"] = np.random.binomial(1, np.clip(D_propensity, 0, 1)) + data["d"] = data["D"] * data["z"] + for var in ["X1", "X2"]: + current, lag = self._generate_correlated_pair( + self.distributions[var]["type"], + self.distributions[var], + self.time_correlations[var], + data['U'] + ) + data[var] = current + data[f"{var}_lag"] = lag + y_params = self.distributions["y0"] + y_rho = self.time_correlations["y0"] + if y_params["type"] == "normal": + y_chain = self._generate_correlated_chain(y_params, y_rho, 3, data["U"]) + data["y0"] = y_chain[2] + data["y0_lag_1"] = y_chain[1] + data["y0_lag_2"] = y_chain[0] + else: + current, lag1 = self._generate_correlated_pair( + y_params["type"], y_params, y_rho + ) + lag2, _ = self._generate_correlated_pair(y_params["type"], y_params, y_rho) + data["y0"] = current + data["y0_lag_1"] = lag1 + data["y0_lag_2"] = lag2 + data["y1"] = ( + data["y0"] + + self.effect_size + + 1.5 * data["U"] + + np.random.normal(0, 2, self.n_samples) + ) + data["y"] = np.where(data["d"] == 1, data["y1"], data["y0"]) + return pd.DataFrame(data) def set_nans( data: pd.DataFrame, @@ -347,4 +454,4 @@ def gen_control_variates_df( Target=target_factual, ) ) - return df + return df \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..e27913cd --- /dev/null +++ b/test.py @@ -0,0 +1,33 @@ +import pandas as pd +from hypex.dataset.dataset import ExperimentData, ExperimentDataEnum +from hypex.dataset import Dataset, InfoRole, TargetRole, TreatmentRole + +initial_data = pd.DataFrame({ + "target_col": [1.0, 2.0, 3.0], + "covariate_col": [0.5, 1.5, 2.5] +}) + +data = Dataset( + roles={ + "covariate_col": TargetRole(), + "target_col": TargetRole(), + }, + data=initial_data, + default_role=InfoRole() +) + +target_col = 'target_col' +covariate_col = 'covariate_col' + +data = ExperimentData(data=data) + +cov_xy = data.ds[[target_col, covariate_col]].cov().loc[target_col, covariate_col] +std_y = data.ds[target_col].std() +std_x = data.ds[covariate_col].std() + +# Вычисляем theta +theta = cov_xy / (std_y * std_x) + +# Применяем CUPED data.ds[target_col] - +adjusted_values = theta.get_values(0,0) * (data.ds[covariate_col] - data.ds[covariate_col].mean()) +print(adjusted_values) \ No newline at end of file