Reference
In this reference, you will find a detailed overview of the package API.
Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented.
β DiΓ‘taxis
In other words, you come here because you want to take a very close look at the code π§.
Content
Exported functions
CounterfactualExplanations.RawOutputArrayType
β TypeRawOutputArrayType
A type union for the allowed type for the output array y
.
CounterfactualExplanations.RawTargetType
β TypeRawTargetType
A type union for the allowed types for the target
variable.
CounterfactualExplanations.flux_training_params
β Constantflux_training_params
The default training parameter for FluxModels
etc.
CounterfactualExplanations.AbstractConvergence
β TypeAn abstract type that serves as the base type for convergence objects.
CounterfactualExplanations.AbstractCounterfactualExplanation
β TypeBase type for counterfactual explanations.
CounterfactualExplanations.AbstractFittedModel
β TypeBase type for fitted models.
CounterfactualExplanations.AbstractGenerator
β TypeAn abstract type that serves as the base type for counterfactual generators.
CounterfactualExplanations.CounterfactualExplanation
β TypeA struct that collects all information relevant to a specific counterfactual explanation for a single individual.
CounterfactualExplanations.CounterfactualExplanation
β Methodfunction CounterfactualExplanation(;
x::AbstractArray,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.AbstractGenerator,
num_counterfactuals::Int = 1,
initialization::Symbol = :add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)
Outer method to construct a CounterfactualExplanation
structure.
CounterfactualExplanations.EncodedOutputArrayType
β TypeEncodedOutputArrayType
Type of encoded output array.
CounterfactualExplanations.EncodedTargetType
β TypeEncodedTargetType
Type of encoded target variable.
CounterfactualExplanations.OutputEncoder
β TypeOutputEncoder
The OutputEncoder
takes a raw output array (y
) and encodes it.
CounterfactualExplanations.OutputEncoder
β Method(encoder::OutputEncoder)(ynew::RawTargetType)
When called on a new value ynew
, the OutputEncoder
encodes it based on the initial encoding.
CounterfactualExplanations.OutputEncoder
β Method(encoder::OutputEncoder)()
On call, the OutputEncoder
returns the encoded output array.
CounterfactualExplanations.EvoTreeModel
β FunctionEvoTreeModel
Exposes the EvoTreeModel
from the EvoTreesExt
extension.
CounterfactualExplanations.LaplaceReduxModel
β FunctionLaplaceReduxModel
Exposes the LaplaceReduxModel
from the LaplaceReduxExt
extension.
CounterfactualExplanations.NeuroTreeModel
β FunctionNeuroTreeModel
Exposes the NeuroTreeModel
from the NeuroTreeExt
extension.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Base.Iterators.Zip,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::AbstractGenerator;
kwargs...,
)
Overloads the generate_counterfactual
method to accept a zip of factuals x
and return a vector of counterfactuals.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Matrix,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
)
The core function that is used to run counterfactual search for a given factual x
, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.
Arguments
x::Matrix
: Factual data point.target::RawTargetType
: Target class.data::CounterfactualData
: Counterfactual data.M::Models.AbstractFittedModel
: Fitted model.generator::AbstractGenerator
: Generator.num_counterfactuals::Int=1
: Number of counterfactuals to generate for factual.initialization::Symbol=:add_perturbation
: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.convergence::Union{AbstractConvergence,Symbol}=:decision_threshold
: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are:decision_threshold
,:max_iter
,:generator_conditions
or a conrete convergence object (e.g.DecisionThresholdConvergence
).timeout::Union{Nothing,Int}=nothing
: Timeout in seconds.
Examples
Generic generator
julia> using CounterfactualExplanations
julia> using TaijaData
# Counteractual data and model:
julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(counterfactual_data, :Linear);
julia> target = 2;
julia> factual = 1;
julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual));
julia> x = select_factual(counterfactual_data, chosen);
# Search:
julia> generator = Generators.GenericGenerator();
julia> ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
CounterfactualExplanation
Convergence: β
after 7 steps.
Broadcasting
The generate_counterfactual
method can also be broadcasted over a tuple containing an array. This allows for generating multiple counterfactuals in parallel.
julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 5);
julia> xs = select_factual(counterfactual_data, chosen);
julia> ces = generate_counterfactual.(xs, target, counterfactual_data, M, generator)
5-element Vector{CounterfactualExplanation}:
CounterfactualExplanation
Convergence: β
after 7 steps.
CounterfactualExplanation
Convergence: β
after 7 steps.
CounterfactualExplanation
Convergence: β
after 8 steps.
CounterfactualExplanation
Convergence: β
after 6 steps.
CounterfactualExplanation
Convergence: β
after 7 steps.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Matrix,
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.GrowingSpheresGenerator;
num_counterfactuals::Int=1,
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels)), max_iter=1000
),
kwrgs...,
)
Overloads the generate_counterfactual
for the GrowingSpheresGenerator
generator.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(x::Tuple{<:AbstractArray}, args...; kwargs...)
Overloads the generate_counterfactual
method to accept a tuple containing and array. This allows for broadcasting over Zip
iterators.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Vector{<:Matrix},
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::AbstractGenerator;
kwargs...,
)
Overloads the generate_counterfactual
method to accept a vector of factuals x
and return a vector of counterfactuals.
CounterfactualExplanations.get_target_index
β Methodget_target_index(y_levels, target)
Utility that returns the index of target
in y_levels
.
CounterfactualExplanations.path
β Methodpath(ce::CounterfactualExplanation)
A convenience method that returns the entire counterfactual path.
CounterfactualExplanations.target_probs
β Functiontarget_probs(
ce::CounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Returns the predicted probability of the target class for x
. If x
is nothing
, the predicted probability corresponding to the counterfactual value is returned.
CounterfactualExplanations.terminated
β Methodterminated(ce::CounterfactualExplanation)
A convenience method that checks if the counterfactual search has terminated.
CounterfactualExplanations.total_steps
β Methodtotal_steps(ce::CounterfactualExplanation)
A convenience method that returns the total number of steps of the counterfactual search.
CounterfactualExplanations.update!
β Methodupdate!(ce::CounterfactualExplanation)
An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.
CounterfactualExplanations.Convergence.convergence_catalogue
β Constantconvergence_catalogue
A dictionary containing all convergence criteria.
CounterfactualExplanations.Convergence.DecisionThresholdConvergence
β TypeDecisionThresholdConvergence
Convergence criterion based on the target class probability threshold. The search stops when the target class probability exceeds the predefined threshold.
Fields
decision_threshold::AbstractFloat
: The predefined threshold for the target class probability.max_iter::Int
: The maximum number of iterations.min_success_rate::AbstractFloat
: The minimum success rate for the target class probability.
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence
β TypeGeneratorConditionsConvergence
Convergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.
Fields
decision_threshold::AbstractFloat
: The threshold for the decision probability.gradient_tol::AbstractFloat
: The tolerance for the gradients of the search objective.max_iter::Int
: The maximum number of iterations.min_success_rate::AbstractFloat
: The minimum success rate for the generator conditions (across counterfactuals).
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence
β MethodGeneratorConditionsConvergence(; decision_threshold=0.5, gradient_tol=1e-2, max_iter=100, min_success_rate=0.75, y_levels=nothing)
Outer constructor for GeneratorConditionsConvergence
.
CounterfactualExplanations.Convergence.MaxIterConvergence
β TypeMaxIterConvergence
Convergence criterion based on the maximum number of iterations.
Fields
max_iter::Int
: The maximum number of iterations.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::InvalidationRateConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::MaxIterConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::GeneratorConditionsConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::DecisionThresholdConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.
CounterfactualExplanations.Convergence.get_convergence_type
β Methodget_convergence_type(convergence::AbstractConvergence)
Returns the convergence object.
CounterfactualExplanations.Convergence.get_convergence_type
β Methodget_convergence_type(convergence::Symbol)
Returns the convergence object from the dictionary of default convergence types.
CounterfactualExplanations.Convergence.hinge_loss
β Methodhinge_loss(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)
Calculates the hinge loss of a counterfactual explanation.
Arguments
convergence::InvalidationRateConvergence
: The convergence criterion to use.ce::AbstractCounterfactualExplanation
: The counterfactual explanation to calculate the hinge loss for.
Returns
The hinge loss of the counterfactual explanation.
CounterfactualExplanations.Convergence.invalidation_rate
β Methodinvalidation_rate(ce::AbstractCounterfactualExplanation)
Calculates the invalidation rate of a counterfactual explanation.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation to calculate the invalidation rate for.kwargs
: Additional keyword arguments to pass to the function.
Returns
The invalidation rate of the counterfactual explanation.
CounterfactualExplanations.Convergence.threshold_reached
β Functionthreshold_reached(ce::AbstractCounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
Determines if the predefined threshold for the target class probability has been reached.
CounterfactualExplanations.Evaluation.default_measures
β ConstantThe default evaluation measures.
CounterfactualExplanations.Evaluation.Benchmark
β TypeA container for benchmarks of counterfactual explanations. Instead of subtyping DataFrame
, it contains a DataFrame
of evaluation measures (see this discussion for why we don't subtype DataFrame
directly).
CounterfactualExplanations.Evaluation.Benchmark
β Method(bmk::Benchmark)(; agg=mean)
Returns a DataFrame
containing evaluation measures aggregated by num_counterfactual
.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
data::CounterfactualData;
models::Dict{<:Any,<:Any}=standard_models_catalogue,
generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
n_individuals::Int=5,
suppress_training::Bool=false,
factual::Union{Nothing,RawTargetType}=nothing,
target::Union{Nothing,RawTargetType}=nothing,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
kwrgs...,
)
Runs the benchmarking exercise as follows:
- Randomly choose a
factual
andtarget
label unless specified. - If no pretrained
models
are provided, it is assumed that a dictionary of callable model objects is provided (by default using thestandard_models_catalogue
). - Each of these models is then trained on the data.
- For each model separately choose
n_individuals
randomly from the non-target (factual
) class. For each generator create a benchmark as inbenchmark(xs::Union{AbstractArray,Base.Iterators.Zip})
. - Finally, concatenate the results.
If vertical_splits
is specified to an integer, the computations are split vertically into vertical_splits
chunks. In this case, the results are stored in a temporary directory and concatenated afterwards.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
x::Union{AbstractArray,Base.Iterators.Zip},
target::RawTargetType,
data::CounterfactualData;
models::Dict{<:Any,<:AbstractFittedModel},
generators::Dict{<:Any,<:AbstractGenerator},
measure::Union{Function,Vector{Function}}=default_measures,
xids::Union{Nothing,AbstractArray}=nothing,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
kwrgs...,
)
First generates counterfactual explanations for factual x
, the target
and data
using each of the provided models
and generators
. Then generates a Benchmark
for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation})
.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
counterfactual_explanations::Vector{CounterfactualExplanation};
meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
store_ce::Bool=false,
)
Generates a Benchmark
for a vector of counterfactual explanations. Optionally meta_data
describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data
is supplied, it will be automatically inferred. All measure
functions are applied to each counterfactual explanation. If store_ce=true
, the counterfactual explanations are stored in the benchmark.
CounterfactualExplanations.Evaluation.evaluate
β Functionevaluate(
ce::CounterfactualExplanation;
measure::Union{Function,Vector{Function}}=default_measures,
agg::Function=mean,
report_each::Bool=false,
output_format::Symbol=:Vector,
pivot_longer::Bool=true
)
Just computes evaluation measures
for the counterfactual explanation. By default, no meta data is reported. For report_meta=true
, meta data is automatically inferred, unless this overwritten by meta_data
. The optional meta_data
argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.
CounterfactualExplanations.Evaluation.redundancy
β Methodredundancy(ce::CounterfactualExplanation)
Computes the feature redundancy: that is, the number of features that remain unchanged from their original, factual values.
CounterfactualExplanations.Evaluation.validity
β Methodvalidity(ce::CounterfactualExplanation; Ξ³=0.5)
Checks of the counterfactual search has been successful with respect to the probability threshold Ξ³
. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.
CounterfactualExplanations.DataPreprocessing.CounterfactualData
β MethodCounterfactualData(
X::AbstractMatrix,
y::RawOutputArrayType;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)
This outer constructor method prepares features X
and labels y
to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.
Examples
using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
CounterfactualExplanations.DataPreprocessing.CounterfactualData
β Methodfunction CounterfactualData(
X::Tables.MatrixTable,
y::RawOutputArrayType;
kwrgs...
)
Outer constructor method that accepts a Tables.MatrixTable
. By default, the indices of categorical and continuous features are automatically inferred the features' scitype
.
CounterfactualExplanations.DataPreprocessing.apply_domain_constraints
β Methodapply_domain_constraints(counterfactual_data::CounterfactualData, x::AbstractArray)
A subroutine that is used to apply the predetermined domain constraints.
CounterfactualExplanations.DataPreprocessing.select_factual
β Methodselect_factual(counterfactual_data::CounterfactualData, index::Int)
A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.select_factual
β Methodselect_factual(counterfactual_data::CounterfactualData, index::Union{Vector{Int},UnitRange{Int}})
A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)
By default, all continuous features are transformable. This function returns the indices of all continuous features.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)
Returns the indices of all continuous features that can be transformed. For constant features ZScoreTransform
returns NaN
.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(counterfactual_data::CounterfactualData)
Dispatches the transformable_features
function to the appropriate method based on the type of the dt
field.
CounterfactualExplanations.Models.all_models_catalogue
β Constantall_models_catalogue
A dictionary containing both differentiable and non-differentiable machine learning models.
CounterfactualExplanations.Models.mlj_models_catalogue
β Constantmlj_models_catalogue
A dictionary containing all machine learning models from the MLJ model registry that the package supports.
CounterfactualExplanations.Models.standard_models_catalogue
β Constantstandard_models_catalogue
A dictionary containing all differentiable machine learning models.
CounterfactualExplanations.Models.AbstractDifferentiableModel
β TypeBase type for differentiable models.
CounterfactualExplanations.Models.FluxEnsemble
β TypeFluxEnsemble <: AbstractFluxModel
Constructor for deep ensembles trained in Flux.jl
.
CounterfactualExplanations.Models.FluxModel
β TypeFluxModel <: AbstractFluxModel
Constructor for models trained in Flux.jl
.
CounterfactualExplanations.Models.FluxModel
β MethodFluxModel(data::CounterfactualData; kwargs...)
Constructs a multi-layer perceptron (MLP).
CounterfactualExplanations.Models.DecisionTreeModel
β MethodDecisionTreeModel(data::CounterfactualData; kwargs...)
Constructs a new TreeModel object wrapped around a decision tree from the data in a CounterfactualData
object. Not called by the user directly.
Arguments
data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
model::TreeModel
: A TreeModel object.
CounterfactualExplanations.Models.Linear
β MethodLinear(data::CounterfactualData; kwargs...)
Constructs a model with one linear layer. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.
CounterfactualExplanations.Models.RandomForestModel
β MethodRandomForestModel(data::CounterfactualData; kwargs...)
Constructs a new TreeModel object wrapped around a random forest from the data in a CounterfactualData
object. Not called by the user directly.
Arguments
data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
model::TreeModel
: A TreeModel object.
CounterfactualExplanations.Models.fit_model
β Functionfit_model(
counterfactual_data::CounterfactualData, model::Symbol=:MLP;
kwrgs...
)
Fits one of the available default models to the counterfactual_data
. The model
argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue
dictionary.
CounterfactualExplanations.Models.logits
β Methodlogits(M::AbstractFittedModel, X::AbstractArray)
Generic method that is compulsory for all models. It returns the raw model predictions. In classification this is sometimes referred to as logits: the non-normalized predictions that are fed into a link function to produce predicted probabilities. In regression (not currently implemented) raw outputs typically correspond to final outputs. In other words, there is typically no normalization involved.
CounterfactualExplanations.Models.logits
β Methodlogits(M::TreeModel, X::AbstractArray)
Calculates the logit scores output by the model M for the input data X.
Arguments
M::TreeModel
: The model selected by the user.X::AbstractArray
: The feature vector for which the logit scores are calculated.
Returns
logits::Matrix
: A matrix of logits for each output class for each data point in X.
Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x
CounterfactualExplanations.Models.model_evaluation
β Methodmodel_evaluation(M::AbstractFittedModel, test_data::CounterfactualData)
Helper function to compute F-Score for AbstractFittedModel
on a (test) data set. By default, it computes the accuracy. Any other measure, e.g. from the StatisticalMeasures package, can be passed as an argument. Currently, only measures applicable to classification tasks are supported.
CounterfactualExplanations.Models.predict_label
β Methodpredict_label(M::AbstractFittedModel, counterfactual_data::CounterfactualData, X::AbstractArray)
Returns the predicted output label for a given model M
, data set counterfactual_data
and input data X
.
CounterfactualExplanations.Models.predict_label
β Methodpredict_label(M::AbstractFittedModel, counterfactual_data::CounterfactualData)
Returns the predicted output labels for all data points of data set counterfactual_data
for a given model M
.
CounterfactualExplanations.Models.predict_label
β Methodpredict_label(M::TreeModel, X::AbstractArray)
Returns the predicted label for X
.
Arguments
M::TreeModel
: The model selected by the user.X::AbstractArray
: The input array for which the label is predicted.
Returns
labels::AbstractArray
: The predicted label for each data point inX
.
Example
label = Models.predict_label(M, x) # returns the predicted label for each data point in x
CounterfactualExplanations.Models.predict_proba
β Methodpredict_proba(M::AbstractFittedModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})
Returns the predicted output probabilities for a given model M
, data set counterfactual_data
and input data X
.
CounterfactualExplanations.Models.probs
β Methodprobs(M::AbstractFittedModel, X::AbstractArray)
Generic method that is compulsory for all models. It returns the normalized model predictions, so the predicted probabilities in the case of classification. In regression (not currently implemented) this method is redundant.
CounterfactualExplanations.Models.probs
β Methodprobs(M::TreeModel, X::AbstractArray{<:Number, 2})
Calculates the probability scores for each output class for the two-dimensional input data matrix X.
Arguments
M::TreeModel
: The TreeModel.X::AbstractArray
: The feature vector for which the predictions are made.
Returns
p::Matrix
: A matrix of probability scores for each output class for each data point in X.
Example
probabilities = Models.probs(M, X) # calculates the probability scores for each output class for each data point in X.
CounterfactualExplanations.Models.probs
β Methodprobs(M::TreeModel, X::AbstractArray{<:Number, 1})
Works the same way as the probs(M::TreeModel, X::AbstractArray{<:Number, 2}) method above, but handles 1-dimensional rather than 2-dimensional input data.
CounterfactualExplanations.Generators.generator_catalogue
β ConstantA dictionary containing the constructors of all available counterfactual generators.
CounterfactualExplanations.Generators.AbstractGradientBasedGenerator
β TypeAbstractGradientBasedGenerator
An abstract type that serves as the base type for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.AbstractNonGradientBasedGenerator
β TypeAbstractNonGradientBasedGenerator
An abstract type that serves as the base type for non gradient-based counterfactual generators.
CounterfactualExplanations.Generators.FeatureTweakGenerator
β TypeFeature Tweak counterfactual generator class.
CounterfactualExplanations.Generators.FeatureTweakGenerator
β MethodFeatureTweakGenerator(; penalty::Union{Nothing,Function,Vector{Function}}=Objectives.distance_l2, Ο΅::AbstractFloat=0.1)
Constructs a new Feature Tweak Generator object.
Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei et al., another recommended choice for the penalty in addition to the L2-norm is the L0-norm. The L0-norm simply minimizes the number of features that are changed through the tweak.
Arguments
penalty::Union{Nothing,Function,Vector{Function}}
: The penalty function to use for the generator. Defaults todistance_l2
.Ο΅::AbstractFloat
: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf). Defaults to 0.1.
Returns
generator::FeatureTweakGenerator
: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
CounterfactualExplanations.Generators.GradientBasedGenerator
β TypeBase class for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.GradientBasedGenerator
β MethodGradientBasedGenerator(;
loss::Union{Nothing,Function}=nothing,
penalty::Penalty=nothing,
Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
latent_space::Bool::false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
generative_model_params::NamedTuple=(;),
)
Default outer constructor for GradientBasedGenerator
.
Arguments
loss::Union{Nothing,Function}=nothing
: The loss function used by the model.penalty::Penalty=nothing
: A penalty function for the generator to penalize counterfactuals too far from the original point.Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing
: The weight of the penalty function.latent_space::Bool=false
: Whether to use the latent space of a generative model to generate counterfactuals.opt::Flux.Optimise.AbstractOptimiser=Flux.Descent()
: The optimizer to use for the generator.generative_model_params::NamedTuple
: The parameters of the generative model associated with the generator.
Returns
generator::GradientBasedGenerator
: A gradient-based counterfactual generator.
CounterfactualExplanations.Generators.GrowingSpheresGenerator
β TypeGrowing Spheres counterfactual generator class.
CounterfactualExplanations.Generators.GrowingSpheresGenerator
β MethodGrowingSpheresGenerator(; n::Int=100, Ξ·::Float64=0.1, kwargs...)
Constructs a new Growing Spheres Generator object.
CounterfactualExplanations.Generators.JSMADescent
β TypeAn optimisation rule that can be used to implement a Jacobian-based Saliency Map Attack.
CounterfactualExplanations.Generators.JSMADescent
β MethodOuter constructor for the JSMADescent
rule.
CounterfactualExplanations.Generators.CLUEGenerator
β MethodConstructor for CLUEGenerator
.
CounterfactualExplanations.Generators.ClaPROARGenerator
β MethodConstructor for ClaPGenerator
.
CounterfactualExplanations.Generators.DiCEGenerator
β MethodConstructor for DiCEGenerator
.
CounterfactualExplanations.Generators.GenericGenerator
β MethodConstructor for GenericGenerator
.
CounterfactualExplanations.Generators.GravitationalGenerator
β MethodConstructor for GravitationalGenerator
.
CounterfactualExplanations.Generators.GreedyGenerator
β MethodConstructor for GreedyGenerator
.
CounterfactualExplanations.Generators.ProbeGenerator
β MethodConstructor for ProbeGenerator
.
CounterfactualExplanations.Generators.REVISEGenerator
β MethodConstructor for REVISEGenerator
.
CounterfactualExplanations.Generators.WachterGenerator
β MethodConstructor for WachterGenerator
.
CounterfactualExplanations.Generators.conditions_satisfied
β Methodconditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.
CounterfactualExplanations.Generators.feature_tweaking!
β Methodfeature_tweaking!(ce::AbstractCounterfactualExplanation)
Returns a counterfactual instance of ce.x
based on the ensemble of classifiers provided.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation object.
Returns
ce::AbstractCounterfactualExplanation
: The counterfactual explanation object.
Example
ce = feature_tweaking!(ce) # returns a counterfactual inside the ce.sβ² field based on the ensemble of classifiers provided
CounterfactualExplanations.Generators.generate_perturbations
β Methodgenerate_perturbations(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to generate feature perturbations for gradient-based generators through simple gradient descent.
CounterfactualExplanations.Generators.generate_perturbations
β Methodgenerate_perturbations(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to generate feature perturbations for gradient-based generators through simple gradient descent.
CounterfactualExplanations.Generators.hinge_loss
β Methodhinge_loss(convergence::AbstractConvergence, ce::AbstractCounterfactualExplanation)
The default hinge loss for any convergence criterion. Can be overridden inside the Convergence
module as part of the definition of specific convergence criteria.
CounterfactualExplanations.Generators.@objective
β Macroobjective(generator, ex)
A macro that can be used to define the counterfactual search objective.
CounterfactualExplanations.Generators.@search_feature_space
β Macrosearch_feature_space(generator)
A simple macro that can be used to specify feature space search.
CounterfactualExplanations.Generators.@search_latent_space
β Macrosearch_latent_space(generator)
A simple macro that can be used to specify latent space search.
CounterfactualExplanations.Generators.@with_optimiser
β Macrowith_optimiser(generator, optimiser)
A simple macro that can be used to specify the optimiser to be used.
CounterfactualExplanations.Objectives.ddp_diversity
β Methodddp_diversity(
ce::AbstractCounterfactualExplanation;
perturbation_size=1e-5
)
Evaluates how diverse the counterfactuals are using a Determinantal Point Process (DDP).
CounterfactualExplanations.Objectives.distance
β Methoddistance(ce::AbstractCounterfactualExplanation, p::Real=2)
Computes the distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l0
β Methoddistance_l0(ce::AbstractCounterfactualExplanation)
Computes the L0 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l1
β Methoddistance_l1(ce::AbstractCounterfactualExplanation)
Computes the L1 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l2
β Methoddistance_l2(ce::AbstractCounterfactualExplanation)
Computes the L2 (Euclidean) distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_linf
β Methoddistance_linf(ce::AbstractCounterfactualExplanation)
Computes the L-inf distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_mad
β Methoddistance_mad(ce::AbstractCounterfactualExplanation; agg=mean)
This is the distance measure proposed by Wachter et al. (2017).
CounterfactualExplanations.Objectives.predictive_entropy
β Methodpredictive_entropy(ce::AbstractCounterfactualExplanation; agg=Statistics.mean)
Computes the predictive entropy of the counterfactuals. Explained in https://arxiv.org/abs/1406.2541.
Flux.Losses.logitbinarycrossentropy
β MethodFlux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the logitbinarycrossentropy
method to work with objects of type AbstractCounterfactualExplanation
.
Flux.Losses.logitcrossentropy
β MethodFlux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the logitcrossentropy
method to work with objects of type AbstractCounterfactualExplanation
.
Flux.Losses.mse
β MethodFlux.Losses.mse(ce::AbstractCounterfactualExplanation)
Simply extends the mse
method to work with objects of type AbstractCounterfactualExplanation
.
Internal functions
CounterfactualExplanations.FluxModelParams
β TypeFluxModelParams
Default MLP training parameters.
Base.Broadcast.broadcastable
β MethodTreat AbstractFittedModel
as scalar when broadcasting.
Base.Broadcast.broadcastable
β MethodTreat AbstractGenerator
as scalar when broadcasting.
CounterfactualExplanations.adjust_shape!
β Methodadjust_shape!(ce::CounterfactualExplanation)
A convenience method that adjusts the dimensions of the counterfactual state and related fields.
CounterfactualExplanations.adjust_shape
β Methodadjust_shape(
ce::CounterfactualExplanation,
x::AbstractArray
)
A convenience method that adjusts the dimensions of x
.
CounterfactualExplanations.already_in_target_class
β Methodalready_in_target_class(ce::CounterfactualExplanation)
Check if the factual is already in the target class.
CounterfactualExplanations.apply_domain_constraints!
β Methodapply_domain_constraints!(ce::CounterfactualExplanation)
Wrapper function that applies underlying domain constraints.
CounterfactualExplanations.apply_mutability
β Methodapply_mutability(
ce::CounterfactualExplanation,
Ξsβ²::AbstractArray,
)
A subroutine that applies mutability constraints to the proposed vector of feature perturbations.
CounterfactualExplanations.counterfactual
β Methodcounterfactual(ce::CounterfactualExplanation)
A convenience method that returns the counterfactual.
CounterfactualExplanations.counterfactual_label
β Methodcounterfactual_label(ce::CounterfactualExplanation)
A convenience method that returns the predicted label of the counterfactual.
CounterfactualExplanations.counterfactual_label_path
β Methodcounterfactual_label_path(ce::CounterfactualExplanation)
Returns the counterfactual labels for each step of the search.
CounterfactualExplanations.counterfactual_probability
β Functioncounterfactual_probability(ce::CounterfactualExplanation)
A convenience method that computes the class probabilities of the counterfactual.
CounterfactualExplanations.counterfactual_probability_path
β Methodcounterfactual_probability_path(ce::CounterfactualExplanation)
Returns the counterfactual probabilities for each step of the search.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::GenerativeModels.AbstractGenerativeModel
.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::MultivariateStats.AbstractDimensionalityReduction
.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::Nothing, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::Nothing
. This is a no-op.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::StatsBase.AbstractDataTransform
.
CounterfactualExplanations.decode_state
β Functionfunction decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )
Applies all the applicable decoding functions:
- If applicable, map the state variable back from the latent space to the feature space.
- If and where applicable, inverse-transform features.
- Reconstruct all categorical encodings.
Finally, the decoded counterfactual is returned.
CounterfactualExplanations.decode_state!
β Functiondecode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of decode_state
.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::GenerativeModels.AbstractGenerativeModel
.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::MultivariateStats.AbstractDimensionalityReduction
.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::Nothing, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::Nothing
. This is a no-op.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::StatsBase.AbstractDataTransform
.
CounterfactualExplanations.encode_state
β Functionfunction encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )
Applies all required encodings to x
:
- If applicable, it maps
x
to the latent space learned by the generative model. - If and where applicable, it rescales features.
Finally, it returns the encoded state variable.
CounterfactualExplanations.encode_state!
β Functionencode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of encode_state
.
CounterfactualExplanations.factual
β Methodfactual(ce::CounterfactualExplanation)
A convenience method to retrieve the factual x
.
CounterfactualExplanations.factual_label
β Methodfactual_label(ce::CounterfactualExplanation)
A convenience method to get the predicted label associated with the factual.
CounterfactualExplanations.factual_probability
β Methodfactual_probability(ce::CounterfactualExplanation)
A convenience method to compute the class probabilities of the factual.
CounterfactualExplanations.find_potential_neighbours
β Methodfind_potential_neighbors(ce::AbstractCounterfactualExplanation)
Finds potential neighbors for the selected factual data point.
CounterfactualExplanations.get_meta
β Methodget_meta(ce::CounterfactualExplanation)
Returns meta data for a counterfactual explanation.
CounterfactualExplanations.guess_likelihood
β Methodguess_likelihood(y::RawOutputArrayType)
Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.
CounterfactualExplanations.guess_loss
β Methodguess_loss(ce::CounterfactualExplanation)
Guesses the loss function to be used for the counterfactual search in case likelihood
field is specified for the AbstractFittedModel
instance and no loss function was explicitly declared for AbstractGenerator
instance.
CounterfactualExplanations.initialize!
β Methodinitialize!(ce::CounterfactualExplanation)
Initializes the counterfactual explanation. This method is called by the constructor. It does the following:
- Creates a dictionary to store information about the search.
- Initializes the counterfactual state.
- Initializes the search path.
- Initializes the loss.
CounterfactualExplanations.initialize_state!
β Methodinitialize_state!(ce::CounterfactualExplanation)
Initializes the starting point for the factual(s) in-place.
CounterfactualExplanations.initialize_state
β Methodinitialize_state(ce::CounterfactualExplanation)
Initializes the starting point for the factual(s):
- If
ce.initialization
is set to:identity
or counterfactuals are searched in a latent space, then nothing is done. - If
ce.initialization
is set to:add_perturbation
, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
CounterfactualExplanations.output_dim
β Methodoutput_dim(ce::CounterfactualExplanation)
A convenience method that returns the output dimension of the predictive model.
CounterfactualExplanations.reset!
β Methodreset!(flux_training_params::FluxModelParams)
Restores the default parameter values.
CounterfactualExplanations.steps_exhausted
β Methodsteps_exhausted(ce::CounterfactualExplanation)
A convenience method that checks if the number of maximum iterations has been exhausted.
CounterfactualExplanations.target_probs_path
β Methodtarget_probs_path(ce::CounterfactualExplanation)
Returns the target probabilities for each step of the search.
CounterfactualExplanations.Evaluation.distance_measures
β ConstantAll distance measures.
Base.vcat
β MethodBase.vcat(bmk1::Benchmark, bmk2::Benchmark)
Vertically concatenates two Benchmark
objects.
CounterfactualExplanations.Evaluation.compute_measure
β Methodcompute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)
Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce
and aggregated using the aggregation function agg
.
CounterfactualExplanations.Evaluation.to_dataframe
β Methodevaluate_dataframe(
ce::CounterfactualExplanation,
measure::Vector{Function},
agg::Function,
report_each::Bool,
pivot_longer::Bool,
store_ce::Bool,
)
Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.
CounterfactualExplanations.Evaluation.validity_strict
β Methodvalidity_strict(ce::CounterfactualExplanation)
Checks if the counterfactual search has been strictly valid in the sense that it has converged with respect to the pre-specified target probability Ξ³
.
CounterfactualExplanations.DataPreprocessing.InputTransformer
β TypeInputTransformer
Abstract type for data transformers. This can be any of the following:
StatsBase.AbstractDataTransform
: A data transformation object from theStatsBase
package.MultivariateStats.AbstractDimensionalityReduction
: A dimensionality reduction object from theMultivariateStats
package.GenerativeModels.AbstractGenerativeModel
: A generative model object from theGenerativeModels
module.
CounterfactualExplanations.DataPreprocessing.TypedInputTransformer
β TypeTypedInputTransformer
Abstract type for data transformers.
Base.Broadcast.broadcastable
β MethodTreat CounterfactualData
as scalar when broadcasting.
CounterfactualExplanations.DataPreprocessing._subset
β Method_subset(data::CounterfactualData, idx::Vector{Int})
Creates a subset of the data
.
CounterfactualExplanations.DataPreprocessing.convert_to_1d
β Methodconvert_to_1d(y::Matrix, y_levels::AbstractArray)
Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.
Arguments
y::Matrix
: The one-hot encoded matrix.y_levels::AbstractArray
: The levels of the categorical variable.
Returns
labels
: A vector of labels.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(data::CounterfactualData, input_encoder::Nothing; kwargs...)
Fit a transformer to the data. This is a no-op if input_encoder
is Nothing
.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{GenerativeModels.AbstractGenerativeModel};
kwargs...,
)
Fit a transformer to the data for a GenerativeModels.AbstractGenerativeModel
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{MultivariateStats.AbstractDimensionalityReduction};
kwargs...,
)
Fit a transformer to the data for a MultivariateStats.AbstractDimensionalityReduction
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{StatsBase.AbstractDataTransform};
kwargs...,
)
Fit a transformer to the data for a StatsBase.AbstractDataTransform
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(data::CounterfactualData, input_encoder::InputTransformer; kwargs...)
Fit a transformer to the data for an InputTransformer
object. This is a no-op.
CounterfactualExplanations.DataPreprocessing.input_dim
β Methodinput_dim(counterfactual_data::CounterfactualData)
Helper function that returns the input dimension (number of features) of the data.
CounterfactualExplanations.DataPreprocessing.mutability_constraints
β Methodmutability_constraints(counterfactual_data::CounterfactualData)
A convenience function that returns the mutability constraints. If none were specified, it is assumed that all features are mutable in :both
directions.
CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj
β Methodpreprocess_data_for_mlj(data::CounterfactualData)
Helper function to preprocess data::CounterfactualData
for MLJ models.
Arguments
data::CounterfactualData
: The data to be preprocessed.
Returns
- (
df_x
,y
): A tuple containing the preprocessed data, withdf_x
being a DataFrame object andy
being a categorical vector.
Example
X, y = preprocessdatafor_mlj(data)
CounterfactualExplanations.DataPreprocessing.reconstruct_cat_encoding
β Methodreconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::Vector)
Reconstruct the categorical encoding for a single instance.
CounterfactualExplanations.DataPreprocessing.subsample
β Methodsubsample(data::CounterfactualData, n::Int)
Helper function to randomly subsample data::CounterfactualData
.
CounterfactualExplanations.DataPreprocessing.train_test_split
β Methodtrain_test_split(data::CounterfactualData;test_size=0.2,keep_class_ratio=false)
Splits data into train and test split.
Arguments
data::CounterfactualData
: The data to be preprocessed.test_size=0.2
: Proportion of the data to be used for testing.keep_class_ratio=false
: Decides whether to sample equally from each class, or keep their relative size.
Returns
- (
train_data::CounterfactualData
,test_data::CounterfactualData
): A tuple containing the train and test splits.
Example
train, test = traintestsplit(data, testsize=0.1, keepclass_ratio=true)
CounterfactualExplanations.DataPreprocessing.unpack_data
β Methodunpack_data(data::CounterfactualData)
Helper function that unpacks data.
CounterfactualExplanations.Models.AbstractCustomDifferentiableModel
β TypeBase type for custom differentiable models.
CounterfactualExplanations.Models.AbstractFluxModel
β TypeBase type for differentiable models written in Flux.
CounterfactualExplanations.Models.AbstractMLJModel
β TypeBase type for differentiable models from the MLJ library.
CounterfactualExplanations.Models.AbstractNonDifferentiableJuliaModel
β TypeBase type for non-differentiable models written in pure Julia.
CounterfactualExplanations.Models.AbstractNonDifferentiableModel
β TypeBase type for non-differentiable models.
CounterfactualExplanations.Models.FluxEnsembleParams
β TypeFluxModelParams
Default Deep Ensemble training parameters.
CounterfactualExplanations.Models.TreeModel
β TypeTreeModel <: AbstractNonDifferentiableJuliaModel
Constructor for tree-based models from the MLJ library.
Arguments
model::Any
: The model selected by the user. Must be a model from the MLJ library.likelihood::Symbol
: The likelihood of the model. Must be one of[:classification_binary, :classification_multi]
.
Returns
TreeModel
: A tree-based model from the MLJ library wrapped inside the TreeModel class.
CounterfactualExplanations.Models.TreeModel
β MethodOuter constructor method for TreeModel.
CounterfactualExplanations.Models.binary_to_onehot
β Methodbinary_to_onehot(p)
Helper function to turn dummy-encoded variable into onehot-encoded variable.
CounterfactualExplanations.Models.build_ensemble
β Methodbuild_ensemble(K::Int;kw=(input_dim=2,n_hidden=32,output_dim=1))
Helper function that builds an ensemble of K
models.
CounterfactualExplanations.Models.build_mlp
β Methodbuild_mlp()
Helper function to build simple MLP.
Examples
nn = build_mlp()
CounterfactualExplanations.Models.data_loader
β Methoddata_loader(data::CounterfactualData)
Prepares counterfactual data for training in Flux.
CounterfactualExplanations.Models.get_individual_classifiers
β Methodget_individual_classifiers(M::TreeModel)
Returns the individual classifiers in the forest. If the input is a decision tree, the method returns the decision tree itself inside an array.
Arguments
M::TreeModel
: The model selected by the user.
Returns
classifiers::AbstractArray
: An array of individual classifiers in the forest.
Example
classifiers = Models.getindividualclassifiers(M) # returns the individual classifiers in the forest
CounterfactualExplanations.Models.train
β Methodtrain(M::TreeModel, data::CounterfactualData; kwargs...)
Fits the model M
to the data in the CounterfactualData
object. This method is not called by the user directly.
Arguments
M::TreeModel
: The wrapper for a TreeModel.data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
M::TreeModel
: The fitted TreeModel.
CounterfactualExplanations.Models.train
β Methodtrain(M::FluxEnsemble, data::CounterfactualData; kwargs...)
Wrapper function to retrain.
CounterfactualExplanations.Models.train
β Methodtrain(M::FluxModel, data::CounterfactualData; kwargs...)
Wrapper function to retrain FluxModel
.
CounterfactualExplanations.GenerativeModels.AbstractGMParams
β TypeBase type of generative model hyperparameter container.
CounterfactualExplanations.GenerativeModels.AbstractGenerativeModel
β TypeBase type for generative model.
CounterfactualExplanations.GenerativeModels.Encoder
β TypeEncoder
Constructs encoder part of VAE: a simple Flux neural network with one hidden layer and two linear output layers for the first two moments of the latent distribution.
CounterfactualExplanations.GenerativeModels.VAE
β TypeVAE <: AbstractGenerativeModel
Constructs the Variational Autoencoder. The VAE is a subtype of AbstractGenerativeModel
. Any (sub-)type of AbstractGenerativeModel
is accepted by latent space generators.
CounterfactualExplanations.GenerativeModels.VAE
β MethodVAE(input_dim;kws...)
Outer method for instantiating a VAE.
CounterfactualExplanations.GenerativeModels.VAEParams
β TypeVAEParams <: AbstractGMParams
The default VAE parameters describing both the encoder/decoder architecture and the training process.
Base.rand
β FunctionRandom.rand(encoder::Encoder, x, device=cpu)
Draws random samples from the latent distribution.
CounterfactualExplanations.GenerativeModels.Decoder
β MethodDecoder(input_dim::Int, latent_dim::Int, hidden_dim::Int; activation=relu)
The default decoder architecture is just a Flux Chain with one hidden layer and a linear output layer.
CounterfactualExplanations.GenerativeModels.decode
β Methoddecode(generative_model::VAE, x::AbstractArray)
Decodes an array x
using the VAE decoder.
CounterfactualExplanations.GenerativeModels.encode
β Methodencode(generative_model::VAE, x::AbstractArray)
Encodes an array x
using the VAE encoder. Specifically, it samples from the latent distribution. It does so by first passing x
through the encoder to obtain the mean and log-variance of the latent distribution. Then, it samples from the latent distribution using the reparameterization trick. See Random.rand(encoder::Encoder, x, device=cpu)
for more details.
CounterfactualExplanations.GenerativeModels.get_data
β Methodget_data(X::AbstractArray, batch_size)
Preparing data for mini-batch training .
CounterfactualExplanations.GenerativeModels.reconstruct
β Functionreconstruct(generative_model::VAE, x, device=cpu)
Implements a full pass of some input x
through the VAE: x β¦ xΜ
.
CounterfactualExplanations.GenerativeModels.reparameterization_trick
β Functionreparameterization_trick(ΞΌ,logΟ,device=cpu)
Helper function that implements the reparameterization trick: z βΌ π©(ΞΌ,ΟΒ²) β z=ΞΌ + Ο β Ξ΅, Ξ΅ βΌ π©(0,I).
CounterfactualExplanations.Generators.Penalty
β TypeType union for acceptable argument types for the penalty
field of GradientBasedGenerator
.
CounterfactualExplanations.Generators._replace_nans
β Function_replace_nans(Ξsβ²::AbstractArray, old_new::Pair=(NaN => 0))
Helper function to deal with exploding gradients. This is only a temporary fix and will be improved.
CounterfactualExplanations.Generators.calculate_delta
β Methodcalculate_delta(ce::AbstractCounterfactualExplanation, penalty::Vector{Function})
Calculates the penalty for the proposed feature tweak.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation object.
Returns
delta::Float64
: The calculated penalty for the proposed feature tweak.
CounterfactualExplanations.Generators.esatisfactory_instance
β Methodesatisfactory_instance(generator::FeatureTweakGenerator, x::AbstractArray, paths::Dict{String, Dict{String, Any}})
Returns an epsilon-satisfactory counterfactual for x
based on the paths provided.
Arguments
generator::FeatureTweakGenerator
: The feature tweak generator.x::AbstractArray
: The factual instance.paths::Dict{String, Dict{String, Any}}
: A list of paths to the leaves of the tree to be used for tweaking the feature.
Returns
esatisfactory::AbstractArray
: The epsilon-satisfactory instance.
Example
esatisfactory = esatisfactory_instance(generator, x, paths) # returns an epsilon-satisfactory counterfactual for x
based on the paths provided
CounterfactualExplanations.Generators.feature_selection!
β Methodfeature_selection!(ce::AbstractCounterfactualExplanation)
Perform feature selection to find the dimension with the closest (but not equal) values between the ce.x
(factual) and ce.sβ²
(counterfactual) arrays.
Arguments
ce::AbstractCounterfactualExplanation
: An instance of theAbstractCounterfactualExplanation
type representing the counterfactual explanation.
Returns
nothing
The function iteratively modifies the ce.sβ²
counterfactual array by updating its elements to match the corresponding elements in the ce.x
factual array, one dimension at a time, until the predicted label of the modified ce.sβ²
matches the predicted label of the ce.x
array.
CounterfactualExplanations.Generators.find_closest_dimension
β Methodfind_closest_dimension(factual, counterfactual)
Find the dimension with the closest (but not equal) values between the factual
and counterfactual
arrays.
Arguments
factual
: The factual array.counterfactual
: The counterfactual array.
Returns
closest_dimension
: The index of the dimension with the closest values.
The function iterates over the indices of the factual
array and calculates the absolute difference between the corresponding elements in the factual
and counterfactual
arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in factual
and counterfactual
are equal.
CounterfactualExplanations.Generators.find_counterfactual
β Methodfind_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)
Find the first counterfactual index by predicting labels.
Arguments
model
: The fitted model used for prediction.target_class
: Expected target class.counterfactual_data
: Data required for counterfactual generation.counterfactual_candidates
: The array of counterfactual candidates.
Returns
counterfactual
: The index of the first counterfactual found.
CounterfactualExplanations.Generators.growing_spheres_generation!
β Methodgrowing_spheres_generation(ce::AbstractCounterfactualExplanation)
Generate counterfactual candidates using the growing spheres generation algorithm.
Arguments
ce::AbstractCounterfactualExplanation
: An instance of theAbstractCounterfactualExplanation
type representing the counterfactual explanation.
Returns
nothing
This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.
The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in ce.M
. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
Dispatches to the appropriate complexity function for any generator.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Nothing, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where no penalty is provided.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.hyper_sphere_coordinates
β Methodhyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)
Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.
The implementation follows the Random Point Picking over a sphere algorithm described in the paper: "Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020), presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html
The growing spheres method is originally proposed in the paper: "Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018), presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).
Arguments
n_search_samples::Int
: The number of search samples (int > 0).instance::AbstractArray
: The input point array.low::AbstractFloat
: The lower bound (float >= 0, l < h).high::AbstractFloat
: The upper bound (float >= 0, h > l).p_norm::Integer
: The norm parameter (int >= 1).
Returns
candidate_counterfactuals::Array
: An array of candidate counterfactuals.
CounterfactualExplanations.Generators.propose_state
β Methodpropose_state(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
Proposes new state based on backpropagation.
CounterfactualExplanations.Generators.search_path
β Functionsearch_path(tree::Union{DecisionTree.Leaf, DecisionTree.Node}, target::RawTargetType, path::AbstractArray)
Return a path index list with the inequality symbols, thresholds and feature indices.
Arguments
tree::Union{DecisionTree.Leaf, DecisionTree.Node}
: The root node of a decision tree.target::RawTargetType
: The target class.path::AbstractArray
: A list containing the paths found thus far.
Returns
paths::AbstractArray
: A list of paths to the leaves of the tree to be used for tweaking the feature.
Example
paths = search_path(tree, target) # returns a list of paths to the leaves of the tree to be used for tweaking the feature
CounterfactualExplanations.Generators.total_loss
β Methodtotal_loss(ce::AbstractCounterfactualExplanation)
Computes the total loss of a counterfactual explanation with respect to the search objective.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
Dispatches to the appropriate loss function for any generator.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation)
Overloads the β
function for the case where a single loss function is provided.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation)
Overloads the β
function for the case where no loss function is provided.
CounterfactualExplanations.Generators.βh
β Methodβh(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl
has gradient access.
If the penalty is not provided, it returns 0.0. By default, Zygote never works out the gradient for constants and instead returns 'nothing', so we need to add a manual step to override this behaviour. See here: https://discourse.julialang.org/t/zygote-gradient/26715.
CounterfactualExplanations.Generators.ββ
β Methodββ(generator::AbstractGradientBasedGenerator, M::Union{Models.LogisticModel, Models.BayesianLogisticModel}, ce::AbstractCounterfactualExplanation)
The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl
has gradient access.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGradientBasedGenerator, M::Models.AbstractDifferentiableModel, ce::AbstractCounterfactualExplanation)
The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivates. It assumes that Zygote.jl
has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.
CounterfactualExplanations.Objectives.NeedsNeighbours
β TypePenalties that need access to neighbors in the target class.
CounterfactualExplanations.Objectives.NoPenaltyRequirements
β TypeBy default, penalties have no extra requirements.
CounterfactualExplanations.Objectives.PenaltyRequirements
β TypeA base type for a style of process.
CounterfactualExplanations.Objectives.PenaltyRequirements
β MethodThe distance_from_target
method needs neighbors in the target class.
CounterfactualExplanations.Objectives.distance_from_target
β Methoddistance_from_target(
ce::AbstractCounterfactualExplanation;
K::Int=50
)
Computes the distance of the counterfactual from a point in the target main.
CounterfactualExplanations.Objectives.model_loss_penalty
β Methodfunction model_loss_penalty(
ce::AbstractCounterfactualExplanation;
agg=mean
)
Additional penalty for ClaPROARGenerator.
CounterfactualExplanations.Objectives.needs_neighbours
β Methodneeds_neighbours(ce::AbstractCounterfactualExplanation)
Check if a counterfactual explanation needs access to neighbors in the target class.
CounterfactualExplanations.Objectives.needs_neighbours
β Methodneeds_neighbours(gen::AbstractGenerator)
Check if a generator needs access to neighbors in the target class.