Title: | Interpretable Matching for Causal Inference |
---|---|
Description: | Efficient implementations of the algorithms in the Almost-Matching-Exactly framework for interpretable matching in causal inference. These algorithms match units via a learned, weighted Hamming distance that determines which covariates are more important to match on. For more information and examples, see the Almost-Matching-Exactly website. |
Authors: | Vittorio Orlandi [aut, cre], Sudeepa Roy [aut], Cynthia Rudin [aut], Alexander Volfovsky [aut] |
Maintainer: | Vittorio Orlandi <[email protected]> |
License: | MIT + file LICENSE |
Version: | 2.1.1.9000 |
Built: | 2024-11-09 04:08:12 UTC |
Source: | https://github.com/vittorioorlandi/flame |
Almost Matching Exactly (AME) Algorithms for Discrete, Observational Data
FLAME( data, holdout = 0.1, C = 0.1, treated_column_name = "treated", outcome_column_name = "outcome", weights = NULL, PE_method = "ridge", user_PE_fit = NULL, user_PE_fit_params = NULL, user_PE_predict = NULL, user_PE_predict_params = NULL, replace = FALSE, estimate_CATEs = FALSE, verbose = 2, return_pe = FALSE, return_bf = FALSE, early_stop_iterations = Inf, early_stop_epsilon = 0.25, early_stop_control = 0, early_stop_treated = 0, early_stop_pe = Inf, early_stop_bf = 0, missing_data = c("none", "drop", "keep", "impute"), missing_holdout = c("none", "drop", "impute"), missing_data_imputations = 1, missing_holdout_imputations = 5, impute_with_treatment = TRUE, impute_with_outcome = FALSE ) DAME( data, holdout = 0.1, treated_column_name = "treated", outcome_column_name = "outcome", weights = NULL, PE_method = "ridge", n_flame_iters = 0, user_PE_fit = NULL, user_PE_fit_params = NULL, user_PE_predict = NULL, user_PE_predict_params = NULL, replace = FALSE, estimate_CATEs = FALSE, verbose = 2, return_pe = FALSE, return_bf = FALSE, early_stop_iterations = Inf, early_stop_epsilon = 0.25, early_stop_control = 0, early_stop_treated = 0, early_stop_pe = Inf, early_stop_bf = 0, missing_data = c("none", "drop", "keep", "impute"), missing_holdout = c("none", "drop", "impute"), missing_data_imputations = 1, missing_holdout_imputations = 5, impute_with_treatment = TRUE, impute_with_outcome = FALSE ) ## S3 method for class 'ame' print(x, digits = getOption("digits"), linewidth = 80, ...)
FLAME( data, holdout = 0.1, C = 0.1, treated_column_name = "treated", outcome_column_name = "outcome", weights = NULL, PE_method = "ridge", user_PE_fit = NULL, user_PE_fit_params = NULL, user_PE_predict = NULL, user_PE_predict_params = NULL, replace = FALSE, estimate_CATEs = FALSE, verbose = 2, return_pe = FALSE, return_bf = FALSE, early_stop_iterations = Inf, early_stop_epsilon = 0.25, early_stop_control = 0, early_stop_treated = 0, early_stop_pe = Inf, early_stop_bf = 0, missing_data = c("none", "drop", "keep", "impute"), missing_holdout = c("none", "drop", "impute"), missing_data_imputations = 1, missing_holdout_imputations = 5, impute_with_treatment = TRUE, impute_with_outcome = FALSE ) DAME( data, holdout = 0.1, treated_column_name = "treated", outcome_column_name = "outcome", weights = NULL, PE_method = "ridge", n_flame_iters = 0, user_PE_fit = NULL, user_PE_fit_params = NULL, user_PE_predict = NULL, user_PE_predict_params = NULL, replace = FALSE, estimate_CATEs = FALSE, verbose = 2, return_pe = FALSE, return_bf = FALSE, early_stop_iterations = Inf, early_stop_epsilon = 0.25, early_stop_control = 0, early_stop_treated = 0, early_stop_pe = Inf, early_stop_bf = 0, missing_data = c("none", "drop", "keep", "impute"), missing_holdout = c("none", "drop", "impute"), missing_data_imputations = 1, missing_holdout_imputations = 5, impute_with_treatment = TRUE, impute_with_outcome = FALSE ) ## S3 method for class 'ame' print(x, digits = getOption("digits"), linewidth = 80, ...)
data |
Data to be matched. Either a data frame or a path to a .csv file
to be read into a data frame. Treatment must be described by a logical or
binary numeric column with name |
holdout |
Holdout data to be used to compute predictive error, if
|
C |
A finite, positive scalar denoting the tradeoff between BF and PE in the FLAME algorithm. Higher C prioritizes more matches and lower C prioritizes not dropping important covariates. Defaults to 0.1. |
treated_column_name |
Name of the treatment column in |
outcome_column_name |
Name of the outcome column in |
weights |
A positive numeric vector representing covariate importances. Supplying this argument prevents PE from being computed as it determines dropping order by forcing covariate subsets with lower weights to be dropped first. The weight of a covariate subset is defined to be the sum of the weights of the constituent covariates. Ties are broken at random. |
PE_method |
Denotes how predictive error (PE) is to be computed. Either
a string – one of "ridge" (default) or "xgb" – or a function. If "ridge",
ridge regression is used to fit a an outcome regression model via
|
user_PE_fit |
Deprecated; use argument 'PE_method' instead. An optional
function supplied by the user that can be used instead of those allowed for
by |
user_PE_fit_params |
Deprecated; use argument 'PE_method' instead. A
named list of optional parameters to be used by |
user_PE_predict |
Deprecated; use argument 'PE_method' instead. An
optional function supplied by the user that can be used to generate
predictions from the output of |
user_PE_predict_params |
Deprecated; use argument 'PE_method' instead. A
named list of optional parameters to be used by |
replace |
A logical scalar. If |
estimate_CATEs |
A logical scalar. If |
verbose |
Controls how FLAME displays progress while running. If 0, no output. If 1, only outputs the stopping condition. If 2, outputs the iteration and number of unmatched units every 5 iterations, and the stopping condition. If 3, outputs the iteration and number of unmatched units every iteration, and the stopping condition. Defaults to 2. |
return_pe |
A logical scalar. If |
return_bf |
A logical scalar. If |
early_stop_iterations |
A positive integer, denoting an upper bound
on the number of matching rounds to be performed. If 1, one round of
exact matching is performed before stopping. Defaults to |
early_stop_epsilon |
A nonnegative numeric. If fixed covariate weights
are passed via |
early_stop_control , early_stop_treated
|
If the proportion of control, treated units, respectively, that are unmatched falls below this value, the matching algorithm will stop. Default to 0. |
early_stop_pe |
Deprecated. A positive numeric. If FLAME attempts to
drop a covariate that would lead to a PE above this value, FLAME stops.
Defaults to |
early_stop_bf |
Deprecated. A numeric value between 0 and 2. If FLAME attempts to drop a covariate that would lead to a BF below this value, FLAME stops. Defaults to 0. |
missing_data |
Specifies how to handle missingness in |
missing_holdout |
Specifies how to handle missingness in |
missing_data_imputations |
Defunct. If |
missing_holdout_imputations |
If |
impute_with_treatment , impute_with_outcome
|
If |
n_flame_iters |
Specifies that this many iterations of FLAME should be run before switching to DAME. This can be used to speed up the matching procedure as FLAME rapidly eliminates irrelevant covariates, after which DAME will make higher quality matches on the remaining variables. |
x |
An object of class |
digits |
Number of significant digits for printing the average treatment effect. |
linewidth |
Maximum number of characters on line; output will be wrapped accordingly. |
... |
Additional arguments to be passed to other methods. |
An object of type ame
, which by default is a list of 4
entries:
The original data frame with several modifications:
An extra logical column, data$matched
, that indicates
whether or not a unit was matched.
An extra numeric column, data$weight
, that denotes on how
many different sets of covariates a unit was matched. This will only be
greater than 1 when replace = TRUE
.
The columns denoting treatment and outcome will be moved after all covariate columns.
If replace
is FALSE
, a column containing a matched
group identifier for each unit.
If, estimate_CATEs = TRUE
, a column containing the CATE
estimate for each unit.
A list whose 'th entry contains the indices of units in
the main matched group of the
'th unit.
A list whose 'th entry contains the covariates set
not matched on in the
'th iteration.
A list containing miscellaneous information about the data and
matching specifications. Primarily for use by *.ame
methods.
FLAME and DAME are matching algorithms for
observational causal inference on data with discrete (categorical)
covariates. They match units that share identical values of certain
covariates, as follows. The algorithms first make any possible exact
matches; that is, they match units that share identical values of all
covariates (this is possible because covariates are discrete). They then
iteratively drop a set of covariates and make any possible matches on the
remaining covariates, until stopping. For each unit, DAME solves an
optimization problem that finds the highest quality set of covariates the
unit can be matched to others on, where quality is determined by how well
that set of covariates predicts the outcome. FLAME approximates the
solution to the problem solved by DAME; at each step, it drops the
covariate leading to the smallest drop in match quality , defined
as
. Here,
denotes the predictive error,
which measures how important the dropped covariate is for predicting the
outcome. The balancing factor
measures the number of matches
formed by dropping that covariate. In this way, FLAME encourages matching
on covariates more important to the outcome and also making many matches.
The hyperparameter
controls the balance between these two
objectives. In both cases, a machine learning algorithm trained on a
holdout dataset is responsible for learning the quality / importance of
covariates. For more details on the algorithms, please see the vignette,
the FLAME paper here and/or
the DAME paper here.
By default, both FLAME
and DAME
stop
when 1. all covariates have been dropped or 2. all treatment or control
units have been matched. This behavior can be modified by the arguments
whose prefix is "early_stop". With the exception of
early_stop_iterations
, all the rules come into play before
the offending covariate set is dropped. For example, if
early_stop_control = 0.2
and at the current iteration, dropping the
covariate leading to highest match quality is associated with a unmatched
control proportion of 0.1, FLAME will stop without dropping this
covariate.
FLAME
and DAME
offer functionality for
handling missing data in the covariates, for both the data
and
holdout
sets. This functionality can be specified via the arguments
whose prefix is "missing" or "impute". It allows for ignoring missing data,
imputing it, or (for data
) not matching on missing values. If
data
is imputed, imputation will be done once and the matching
algorithm will be run on the imputed dataset. If holdout
is imputed,
the predictive error at an iteration will be the average of predictive
errors across all imputed holdout
datasets. Units with missingness
in the treatment or outcome will be dropped.
## Not run: data <- gen_data() holdout <- gen_data() # FLAME with replacement, stopping after dropping a single covariate FLAME_out <- FLAME(data = data, holdout = holdout, replace = TRUE, early_stop_iterations = 2) # Use a linear model to compute predictive error. Call DAME without # replacement, returning predictive error at each iteration. my_PE <- function(X, Y) { return(lm(Y ~ ., as.data.frame(cbind(X, Y = Y)))$fitted.values) } DAME_out <- DAME(data = data, holdout = holdout, PE_method = my_PE, return_PE = TRUE) ## End(Not run)
## Not run: data <- gen_data() holdout <- gen_data() # FLAME with replacement, stopping after dropping a single covariate FLAME_out <- FLAME(data = data, holdout = holdout, replace = TRUE, early_stop_iterations = 2) # Use a linear model to compute predictive error. Call DAME without # replacement, returning predictive error at each iteration. my_PE <- function(X, Y) { return(lm(Y ~ ., as.data.frame(cbind(X, Y = Y)))$fitted.values) } DAME_out <- DAME(data = data, holdout = holdout, PE_method = my_PE, return_PE = TRUE) ## End(Not run)
These functions are deprecated and will be made defunct at a later release.
See summary.ame
for average treatment effects estimates and their
variance.
ATE(ame_out) ATT(ame_out) ATC(ame_out)
ATE(ame_out) ATT(ame_out) ATC(ame_out)
ame_out |
An object of class |
ATE
, ATT
, and ATC
estimate the average treatment effect
(ATE), average treatment effect on the treated (ATT), and average treatment
effect on the controls (ATC), respectively, of a matched dataset.
The ATE is estimated as the average CATE estimate across all matched units in the data, while the ATT and ATC average only across matched treated or matched control units, respectively.
CATE
returns an estimate of the conditional average treatment effect
for the subgroup defined by units
.
CATE(units, ame_out)
CATE(units, ame_out)
units |
A vector of units whose CATE estimates are desired. |
ame_out |
An object of class |
This function returns CATE estimates and the estimated variances of such
estimates for units
of interest. The CATE of a given unit is estimated
by the difference between the average treated and the average control outcome
in that unit's main matched group. As shown by Morucci 2021, under standard
regularity conditions, such an estimator is asymptotically normal and unbiased
for the true CATE, with a variance that can be estimated by the sum of the
variance of treated and control outcomes in the matched group, each normalized
by the number of treated and control units in the matched group, respectively.
Note that CATEs cannot be estimated for unmatched units and estimator
variances cannot be computed for units whose main matched group only contains
a single treated or control unit. Note also that these CATE estimates differ
from those that are used to compute average treatment effects in
print.ame
and summary.ame
and from those that will be returned
in ame_out$data$CATE
if estimate_CATEs = TRUE
. For a treated
(control) unit , the latter estimate the treated (control) outcome
conditioned on
simply as
, and do not average across
other treated (control) units in the matched group as is done here. This
averaging is necessary in order to compute variance estimates. The different
estimates can always be manually compared, though they are the same in
expectation (assuming mean 0 noise) and we expect them to be similar in
practice, in the absence of large noise.
Lastly, note that the units
argument refers to units with respect to
rownames(ame_out$data)
. Typically, this will also correspond to the
indexing of the data (i.e. passing units = 3
will return the matched
group of the 3rd unit in the matching data). However, if a separate holdout
set was not passed to the matching algorithm or if the original matching data
had rownames other than 1:nrow(data)
, then this is not the case.
A matrix whose columns correspond to CATE estimates and their
variances and whose rows correspond to queried units. NA
's therein
correspond to inestimable quantities.
## Not run: data <- gen_data() holdout <- gen_data() FLAME_out <- FLAME(data = data, holdout = holdout) CATE(1:5, FLAME_out) ## End(Not run)
## Not run: data <- gen_data() holdout <- gen_data() FLAME_out <- FLAME(data = data, holdout = holdout) CATE(1:5, FLAME_out) ## End(Not run)
gen_data
generates toy data that can be used to explore FLAME and DAME
functionality.
gen_data(n = 250, p = 5, write = FALSE, path = getwd(), filename = "AME.csv")
gen_data(n = 250, p = 5, write = FALSE, path = getwd(), filename = "AME.csv")
n |
Number of units desired in the data set. Defaults to 250. |
p |
Number of covariates in the data set. Must be greater than 2. Defaults to 5. |
write |
A logical scalar. If |
path |
The path to the location where the data should be written if
|
filename |
The name of the file to which the data should be written if
|
gen_data
simulates data in the format accepted by FLAME
and link{DAME}
. Covariates and treatment
are both
independently generated according to a Bernoulli(0.5) distribution. The
outcome
is generated according to
, where
. Thus, the value of
p
must be at least 3 and any additional covariates beyond
are irrelevant.
A data frame that may be passed to FLAME
or
DAME
. Covariates are numeric, treatment is binary numeric and
outcome is numeric.
MG
returns the matched groups of the supplied units.
MG(units, ame_out, multiple = FALSE, id_only = FALSE, index_only)
MG(units, ame_out, multiple = FALSE, id_only = FALSE, index_only)
units |
A vector of units whose matched groups are desired. |
ame_out |
An object of class |
multiple |
A logical scalar. If |
id_only |
A logical scalar. If |
index_only |
Defunct. Use 'id_only' instead. |
The units
argument refers to units with respect to
rownames(ame_out$data)
. Typically, this will also correspond to the
indexing of the data (i.e. passing units = 3
will return the matched
group of the 3rd unit in the matching data). However, if a separate holdout
set was not passed to the matching algorithm or if the original matching data
had rownames other than 1:nrow(data)
, then this is not the case.
The multiple
argument toggles whether only a unit's main matched group
(MMG) or all matched groups a unit is part of should be returned. A unit's
MMG contains its highest quality matches (that is, the units with which it
first matched in the sequence of considered covariate sets). If the original
call that generated ame_out
specified replace = FALSE
then
units only are part of one matched group (which is also their MMG) and
multiple
must be set to FALSE
.
A list of length length(units)
, each entry of which corresponds to a
different unit in units
. For matched units, if multiple =
FALSE
, each entry is 1. a data frame containing the treatment and outcome
information of members of the matched group, along with covariates they
were matched on if id_only = FALSE
or 2. a vector of the IDs of
matched units if id_only = TRUE
. If multiple = TRUE
, each
entry of the returned list is a list containing the previously described
information, but with each entry corresponding to a different matched
group. In either case, entries corresponding to unmatched units are
NULL
.
## Not run: data <- gen_data() holdout <- gen_data() FLAME_out <- FLAME(data = data, holdout = holdout, replace = TRUE) # Only the main matched group of unit 1 MG(1, FLAME_out, multiple = F) # All matched groups of unit 1 MG(1, FLAME_out, multiple = T) ## End(Not run)
## Not run: data <- gen_data() holdout <- gen_data() FLAME_out <- FLAME(data = data, holdout = holdout, replace = TRUE) # Only the main matched group of unit 1 MG(1, FLAME_out, multiple = F) # All matched groups of unit 1 MG(1, FLAME_out, multiple = T) ## End(Not run)
Plot information about numbers of covariates matched on, CATE estimates, and
covariate set dropping order after a call to FLAME
or DAME
.
## S3 method for class 'ame' plot(x, which_plots = c(1, 2, 3, 4), ...)
## S3 method for class 'ame' plot(x, which_plots = c(1, 2, 3, 4), ...)
x |
An object of class |
which_plots |
A vector describing which plots should be displayed. See details. |
... |
Additional arguments to passed on to other methods. |
plot.ame
displays four plots by default. The first contains
information on the number of covariates that matched groups were formed on,
and thereby gives some indication of the quality of matched groups across the
matched data. The second plots matched group sizes against CATEs, which can
be useful for determining whether higher quality matched groups yield
different treatment effect estimates than lower quality ones. The third plots
a density estimate of the estimated CATE distribution. The fourth displays a
heatmap showing which covariates were dropped (shown in black) throughout the
matching procedure.
These methods create and print objects of class summary.ame
containing
information on the numbers of units matched by the AME algorithm, matched
groups formed, and, if applicable, average treatment effects.
## S3 method for class 'ame' summary(object, ...) ## S3 method for class 'summary.ame' print(x, digits = 3, ...)
## S3 method for class 'ame' summary(object, ...) ## S3 method for class 'summary.ame' print(x, digits = 3, ...)
object |
|
... |
Additional arguments to be passed on to other methods. |
x |
An object of class |
digits |
Number of significant digits for printing the average treatment effect estimates and their variances. |
The average treatment effect (ATE) is estimated as the average CATE estimate across all matched units in the data, while the average treatment effect on the treated (ATT) and average treatment effect on controls (ATC) average only across matched treated or matched control units, respectively. Variances of these estimates are computed as in Abadie, Drukker, Herr, and Imbens (The Stata Journal, 2004) assuming constant treatment effect and homoscedasticity. Note that the implemented estimator is not =asymptotically normal and so in particular, asymptotically valid confidence intervals or hypothesis tests cannot be conducted on its basis. In the future, the estimation procedure will be changed to employ the nonparametric regression bias adjustment estimator of Abadie and Imbens 2011 which is asymptotically normal.
A list of type summary.ame
with the following entries:
A list with the number and median size of matched groups formed. Additionally, two of the highest quality matched groups formed. Quality is determined first by number of covariates matched on and second by matched group size.
A matrix detailing the number of treated and control units matched.
If the matching data had a continuous outcome, estimates of the ATE, ATT, and ATC and the corresponding variance of the estimates.