Skip to content

Commit 649b3c3

Browse files
Docstring for datasets.linear_dataset fn(#774)
* follow up on open issue on the same topic * removed two files from this PR Signed-off-by: Amit Sharma <[email protected]> * added docstring to datasets * updated docstring to avoid format error Signed-off-by: Amit Sharma <[email protected]> * updated black error Signed-off-by: Amit Sharma <[email protected]> Signed-off-by: Amit Sharma <[email protected]> Co-authored-by: Amit Sharma <[email protected]>
1 parent ec04c5e commit 649b3c3

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

dowhy/datasets.py

+151
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,157 @@ def linear_dataset(
9292
stddev_outcome_noise=0.01,
9393
one_hot_encode=False,
9494
):
95+
"""
96+
Generate a synthetic dataset with a known effect size.
97+
98+
This function generates a pandas dataFrame with num_samples records. The variables follow a naming convention where the first letter indicates its role in the causality graph and then a sequence number.
99+
100+
:param beta: coefficient of the treatment(s) ('v?') in the generating equation of the outcome ('y').
101+
:type beta: int or list/ndarray of length num_treatments of type int
102+
:param num_common_causes: Number of variables affecting both the treatment and the outcome [w -> v; w -> y]
103+
:type num_common_causes: int
104+
:param num_samples: Number of records to generate
105+
:type num_samples: int
106+
:param num_instruments: Number of instrumental variables [z -> v], defaults to 0
107+
:type num_instruments: int
108+
:param num_effect_modifiers: Number of effect modifiers, variables affecting only the outcome [x -> y], defaults to 0
109+
:type num_effect_modifiers: int
110+
:param num_treatments: Number of treatment variables [v], defaults to 1
111+
:type num_treatments : int
112+
:param num_frontdoor_variables : Number of frontdoor mediating variables [v -> FD -> y], defaults to 0
113+
:type num_frontdoor_variables: int
114+
:param treatment_is_binary: Cannot be True if treatment_is_category is True, defaults to True
115+
:type treatment_is_binary: bool
116+
:param treatment_is_category: Cannot be True if treatment_is_binary is True, defaults to False
117+
:type treatment_is_category: bool
118+
:param outcome_is_binary: defaults to False,
119+
:type outcome_is_binary: bool
120+
:param stochastic_discretization: if False, quartiles are used when discretised variables are specified. They can be hot encoded, defaults True
121+
:type stochastic_discretization: bool
122+
:param num_discrete_common_causes: Number of discrete common causes of the total num_common_causes, defaults to 0
123+
:type num_discrete_common_causes: int
124+
:param num_discrete_instruments: Number of discrete instrumental variables of the total num_instruments, defaults to 0
125+
:type num_discrete_instruments : int
126+
:param num_discrete_effect_modifiers : Number of discrete effect modifiers of the total effect_modifiers, defaults to 0
127+
:type num_discrete_effect_modifiers: int
128+
:param stddev_treatment_noise : defaults to 1
129+
:type stddev_treatment_noise : float
130+
:param stddev_outcome_noise: defaults to 0.01
131+
:type stddev_outcome_noise: float
132+
:param one_hot_encode: defaults to False
133+
:type one_hot_encode: bool
134+
135+
:returns: Dictionary with pandas dataFrame and few other metadata variables.
136+
"df": pd.dataFrame
137+
with num_samples records. The variables follow a naming convention were the first letter indicates its role in the causality graph and then a sequence number.
138+
139+
v variables - are the treatments. They can be binary or continuous. In the case of continuous abs(*beta*) defines thier magnitude;
140+
141+
y - is the outcome variable. The generating equation is,
142+
y = normal(0, stddev_outcome_noise) + t @ beta [where @ is a numpy matrix multiplication allowing for beta be a vector]
143+
144+
W variables - commonly cause both the treatment and the outcome and are iid. if continuous, they are Norm(mu = Unif(-1,1), sigma = 1)
145+
146+
Z variables - Instrument variables. Each one affects all treatments. i.e. if there is one instrument and two treatments then z0->v0, z0->v1
147+
148+
X variables - effect modifiers. If continuous, they are Norm(mu = Unif(-1,1), sigma = 1)
149+
150+
FD variables - Front door variables, v0->FD0->y
151+
152+
"treatment_name": str/list(str)
153+
"outcome_name": str
154+
"common_causes_names": str/list(str)
155+
"instrument_names": str/list(str)
156+
"effect_modifier_names": str/list(str)
157+
"frontdoor_variables_names": str/list(str)
158+
"dot_graph": dot_graph,
159+
"gml_graph": gml_graph,
160+
"ate": float, the true ate in the dataset
161+
:rtype: dict
162+
163+
Examples
164+
********
165+
.. code-block:: python
166+
import networkx as nx
167+
import matplotlib.pyplot as plt
168+
import pandas as pd
169+
import numpy as np
170+
import dowhy.datasets
171+
172+
def plot_gml(gml_graph):
173+
G = nx.parse_gml(gml_graph)
174+
pos=nx.spring_layout(G)
175+
nx.draw_networkx(G, pos, with_labels=True, node_size=1000, node_color="darkorange")
176+
return(plt.show())
177+
178+
def describe_synthetic_data(synthetic_data):
179+
if (synthetic_data['gml_graph'] != None) :
180+
plot_gml(synthetic_data["gml_graph"])
181+
synthetic_data_df=synthetic_data["df"]
182+
print('------- Variables --------')
183+
print('Treatment vars:' , synthetic_data['treatment_name'])
184+
print('Outcome vars:' , synthetic_data['outcome_name'])
185+
print('Common causes vars:' , synthetic_data['common_causes_names'])
186+
print('Instrument vars:' , synthetic_data['instrument_names'])
187+
print('Effect Modifier vars:', synthetic_data['effect_modifier_names'])
188+
print('Frontdoor vars:' , synthetic_data['frontdoor_variables_names'])
189+
print('Treatment vars:', synthetic_data['outcome_name'])
190+
print('-------- Corr -------')
191+
print(synthetic_data_df.corr())
192+
print('------- Head --------')
193+
return(synthetic_data_df)
194+
195+
# create a dataset with 10 observations one binary treatment and a continuous outcome affected by one common cause
196+
synthetic_data = dowhy.datasets.linear_dataset(beta = 100,
197+
num_common_causes = 1,
198+
num_samples =10
199+
)
200+
describe_synthetic_data(synthetic_data).head()
201+
202+
# Two continuous treatments, no common cause, an instrumental variable and two effect modifiers - linearly added appropriately
203+
synthetic_data = dowhy.datasets.linear_dataset(
204+
beta = 100,
205+
num_common_causes = 0,
206+
num_samples = 20,
207+
num_instruments = 1,
208+
num_effect_modifiers = 2,
209+
num_treatments = 2,
210+
num_frontdoor_variables = 0,
211+
treatment_is_binary = False,
212+
treatment_is_category = False,
213+
outcome_is_binary = False,
214+
stochastic_discretization = True,
215+
num_discrete_common_causes = 0,
216+
num_discrete_instruments = 0,
217+
num_discrete_effect_modifiers = 0,
218+
stddev_treatment_noise = 1,
219+
stddev_outcome_noise = 0.01,
220+
one_hot_encode = False
221+
)
222+
describe_synthetic_data(synthetic_data).head()
223+
224+
# One Hot Encoding
225+
synthetic_data = dowhy.datasets.linear_dataset(
226+
beta = 100,
227+
num_common_causes = 2,
228+
num_samples = 20,
229+
num_instruments = 1,
230+
num_effect_modifiers = 1,
231+
num_treatments = 1,
232+
num_frontdoor_variables = 1,
233+
treatment_is_binary = False,
234+
treatment_is_category = False,
235+
outcome_is_binary = False,
236+
stochastic_discretization = True,
237+
num_discrete_common_causes = 1, #of the total num_common_causes
238+
num_discrete_instruments = 1,
239+
num_discrete_effect_modifiers = 1,
240+
stddev_treatment_noise = 1,
241+
stddev_outcome_noise = 0.01,
242+
one_hot_encode = True
243+
)
244+
describe_synthetic_data(synthetic_data).head()
245+
"""
95246
assert not (treatment_is_binary and treatment_is_category)
96247
W, X, Z, FD, c1, c2, ce, cz, cfd1, cfd2 = [None] * 10
97248
W_with_dummy, X_with_categorical = (None, None)

0 commit comments

Comments
 (0)