Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
51 changes: 51 additions & 0 deletions tests/causal_prediction/test_algorithm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch

from dowhy.causal_prediction.algorithms.utils import gaussian_kernel, mmd_compute, my_cdist


class TestAlgorithmUtils:
def test_my_cdist(self):
# Squared Euclidean distances between x1 and x2
x1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
x2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
distances = my_cdist(x1, x2)
expected = torch.tensor([[1.0, 1.0], [13.0, 5.0]])
assert torch.allclose(distances, expected, rtol=1e-5)

# Single vector case
x1 = torch.tensor([[1.0, 2.0]])
x2 = torch.tensor([[1.0, 1.0]])
distances = my_cdist(x1, x2)
expected = torch.tensor([[1.0]])
assert torch.allclose(distances, expected, rtol=1e-5)

def test_gaussian_kernel(self):
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
gamma = 1.0
kernel = gaussian_kernel(x, y, gamma)

# Kernel values are exp(-gamma * squared distance)
assert kernel.shape == (2, 2)
assert torch.all(kernel >= 0) and torch.all(kernel <= 1)

# Symmetry for same input
kernel_xx = gaussian_kernel(x, x, gamma)
assert torch.allclose(kernel_xx, kernel_xx.t(), rtol=1e-5)

def test_mmd_compute(self):
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
gamma = 1.0

# MMD^2 = mean(K(x, x)) + mean(K(y, y)) - 2 * mean(K(x, y))
mmd_gaussian = mmd_compute(x, y, "gaussian", gamma)
assert mmd_gaussian >= 0

# MMD for identical distributions should be zero
mmd_same = mmd_compute(x, x, "gaussian", gamma)
assert torch.allclose(mmd_same, torch.tensor(0.0), rtol=1e-5)

# 'other' kernel: sum of mean squared difference of means and covariances
mmd_other = mmd_compute(x, y, "other", gamma)
assert mmd_other >= 0
86 changes: 86 additions & 0 deletions tests/causal_prediction/test_causal_prediction_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import TensorDataset

from dowhy.causal_prediction.algorithms.cacm import CACM
from dowhy.causal_prediction.algorithms.erm import ERM
from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders
from dowhy.causal_prediction.models.networks import MLP, Classifier
from dowhy.datasets import linear_dataset


class LinearTensorDataset:
N_WORKERS = 0

def __init__(self, n_envs, n_samples, input_shape, num_classes):
self.input_shape = input_shape
self.num_classes = num_classes
self.datasets = []

for env in range(n_envs):
data = linear_dataset(
beta=10,
num_common_causes=2,
num_instruments=0,
num_samples=n_samples,
treatment_is_binary=True,
outcome_is_binary=True,
)
df = data["df"]

# Use treatment as input, outcome as label
x = torch.tensor(df[data["treatment_name"]].values, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(df[data["outcome_name"]].values, dtype=torch.long)

# Use common causes as attributes
cc_names = data["common_causes_names"]
a = torch.tensor(df[cc_names].values, dtype=torch.float32)
self.datasets.append(TensorDataset(x, y, a))

def __getitem__(self, index):
return self.datasets[index]

def __len__(self):
return len(self.datasets)


@pytest.mark.usefixtures("fixed_seed")
@pytest.mark.parametrize(
"algorithm_cls, algorithm_kwargs",
[
(ERM, {}),
(CACM, {"gamma": 1e-2, "attr_types": ["causal"], "lambda_causal": 10.0}),
],
)
def test_causal_prediction_training_and_eval(algorithm_cls, algorithm_kwargs, fixed_seed):
# Use the new linear dataset-based class
dataset = LinearTensorDataset(n_envs=4, n_samples=1000, input_shape=(1,), num_classes=2)
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64, val_envs=[2], test_envs=[3])

# Model
n_inputs = dataset.input_shape[0]
mlp_width = 128
mlp_depth = 4
mlp_dropout = 0.1
n_outputs = mlp_width
featurizer = MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)
classifier = Classifier(featurizer.n_outputs, dataset.num_classes)
model = torch.nn.Sequential(featurizer, classifier)

# Train
algorithm = algorithm_cls(model, lr=1e-3, **algorithm_kwargs)
trainer = pl.Trainer(devices=1, max_epochs=5, accelerator="cpu", logger=False, enable_checkpointing=False)

# Fit
trainer.fit(algorithm, loaders["train_loaders"], loaders["val_loaders"])

# Check results
results = trainer.test(algorithm, dataloaders=loaders["test_loaders"])
assert isinstance(results, list)
assert len(results) > 0
for r in results:
if "test_acc" in r:
assert r["test_acc"] > 0.7, f"Test accuracy too low: {r['test_acc']}"
if "test_loss" in r:
assert r["test_loss"] < 1.0, f"Test loss too high: {r['test_loss']}"
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import numpy
import pytest
import torch


@pytest.fixture
def fixed_seed():
rand.seed(0)
numpy.random.seed(0)
torch.manual_seed(0)
if hasattr(torch, "cuda"):
torch.cuda.manual_seed_all(0)