How to add Custom Generators
As we will see in this short tutorial, building custom counterfactual generators is straightforward. We hope that this will facilitate contributions through the community.
Generic generator with dropout
To illustrate how custom generators can be implemented we will consider a simple example of a generator that extends the functionality of our GenericGenerator
. We have noted elsewhere that the effectiveness of counterfactual explanations depends to some degree on the quality of the fitted model. Another, perhaps trivial, thing to note is that counterfactual explanations are not unique: there are potentially many valid counterfactual paths. One interesting (or silly) idea following these two observations might be to introduce some form of regularization in the counterfactual search. For example, we could use dropout to randomly switch features on and off in each iteration. Without dwelling further on the usefulness of this idea, let us see how it can be implemented.
The first code chunk below implements two important steps: 1) create an abstract subtype of the AbstractGradientBasedGenerator
and 2) create a constructor similar to the GenericConstructor
, but with one additional field for the probability of dropout.
# Abstract suptype:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
# Constructor:
struct DropoutGenerator <: AbstractDropoutGenerator
loss::Function # loss function
penalty::Function
λ::AbstractFloat # strength of penalty
latent_space::Bool
opt::Any # optimizer
generative_model_params::NamedTuple
p_dropout::AbstractFloat # dropout rate
end
# Instantiate:
generator = DropoutGenerator(
Flux.logitbinarycrossentropy,
CounterfactualExplanations.Objectives.distance_l1,
0.1,
false,
Flux.Optimise.Descent(0.1),
(;),
0.5,
)
Next, we define how feature perturbations are generated for our dropout generator: in particular, we extend the relevant function through a method that implemented the dropout logic.
using CounterfactualExplanations.Generators
using StatsBase
function Generators.generate_perturbations(
generator::AbstractDropoutGenerator,
ce::CounterfactualExplanation
)
s′ = deepcopy(ce.s′)
new_s′ = Generators.propose_state(generator, ce)
Δs′ = new_s′ - s′ # gradient step
# Dropout:
set_to_zero = sample(
1:length(Δs′),
Int(round(generator.p_dropout*length(Δs′))),
replace=false
)
Δs′[set_to_zero] .= 0
return Δs′
end
Finally, we proceed to generate counterfactuals in the same way we always do:
# Data and Classifier:
M = fit_model(counterfactual_data, :DeepEnsemble)
# Factual and Target:
yhat = predict_label(M, counterfactual_data)
target = 2 # target label
candidates = findall(vec(yhat) .!= target)
chosen = rand(candidates)
x = select_factual(counterfactual_data, chosen)
# Counterfactual search:
ce = generate_counterfactual(
x, target, counterfactual_data, M, generator;
num_counterfactuals=5)
plot(ce)