REVISEGenerator

REVISE is a Latent Space generator introduced by Joshi et al. (2019).

Description

The current consensus in the literature is that Counterfactual Explanations should be realistic: the generated counterfactuals should look like they were generated by the data-generating process (DGP) that governs the problem at hand. With respect to Algorithmic Recourse, it is certainly true that counterfactuals should be realistic in order to be actionable for individuals.[1] To address this need, researchers have come up with various approaches in recent years. Among the most popular approaches is Latent Space Search, which was first proposed in Joshi et al. (2019): instead of traversing the feature space directly, this approach relies on a separate generative model that learns a latent space representation of the DGP. Assuming the generative model is well-specified, access to the learned latent embeddings of the data comes with two advantages:

  1. Since the learned DGP is encoded in the latent space, the generated counterfactuals will respect the learned representation of the data. In practice, this means that counterfactuals will be realistic.
  2. The latent space is typically a compressed (i.e. lower dimensional) version of the feature space. This makes the counterfactual search less costly.

There are also certain disadvantages though:

  1. Learning generative models is (typically) an expensive task, which may well outweigh the benefits associated with utlimately traversing a lower dimensional space.
  2. If the generative model is poorly specified, this will affect the quality of the counterfactuals.[2]

Anyway, traversing latent embeddings is a powerful idea that may be very useful depending on the specific context. This tutorial introduces the concept and how it is implemented in this package.

Usage

The approach can be used in our package as follows:

generator = REVISEGenerator()
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
plot(ce)

Worked 2D Examples

Below we load 2D data and train a VAE on it and plot the original samples against their reconstructions.

# output: true

counterfactual_data = load_overlapping()
X = counterfactual_data.X
y = counterfactual_data.y
input_dim = size(X, 1)
using CounterfactualExplanations.GenerativeModels: VAE, train!, reconstruct
vae = VAE(input_dim; nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=2, hidden_dim=32)
flux_training_params.verbose = true
train!(vae, X, y)
X̂ = reconstruct(vae, X)[1]
p0 = scatter(X[1, :], X[2, :], color=:blue, label="Original", xlab="x₁", ylab="x₂")
scatter!(X̂[1, :], X̂[2, :], color=:orange, label="Reconstructed", xlab="x₁", ylab="x₂")
p1 = scatter(X[1, :], X̂[1, :], color=:purple, label="", xlab="x₁", ylab="x̂₁")
p2 = scatter(X[2, :], X̂[2, :], color=:purple, label="", xlab="x₂", ylab="x̂₂")
plt2 = plot(p1,p2, layout=(1,2), size=(800, 400))
plot(p0, plt2, layout=(2,1), size=(800, 600))

Next, we train a simple MLP for the classification task. Then we determine a target and factual class for our counterfactual search and select a random factual instance to explain.

M = fit_model(counterfactual_data, :MLP)
target = 2
factual = 1
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)

Finally, we generate and visualize the generated counterfactual:

# Search:
generator = REVISEGenerator()
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
plot(ce)

3D Example

To illustrate the notion of Latent Space search, let’s look at an example involving 3-dimensional input data, which we can still visualize. The code chunk below loads the data and implements the counterfactual search.

# Data and Classifier:
counterfactual_data = load_blobs(k=3)
X = counterfactual_data.X
ys = counterfactual_data.output_encoder.labels.refs
M = fit_model(counterfactual_data, :MLP)

# Randomly selected factual:
x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
y = predict_label(M, counterfactual_data, x)[1]
target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1]

# Generate recourse:
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)

The figure below demonstrates the idea of searching counterfactuals in a lower-dimensional latent space: on the left, we can see the counterfactual search in the 3-dimensional feature space, while on the right we can see the corresponding search in the latent space.

MNIST data

Let’s carry the ideas introduced above over to a more complex example. The code below loads MNIST data as well as a pre-trained classifier and generative model for the data.

using CounterfactualExplanations.Models: load_mnist_mlp, load_mnist_ensemble, load_mnist_vae
counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()
vae = load_mnist_vae()

The F1-score of our pre-trained image classifier on test data is: 0.94

Before continuing, we supply the pre-trained generative model to our data container:

counterfactual_data.generative_model = vae # assign generative model

Now let’s define a factual and target label:

# Randomly selected factual:
Random.seed!(2023)
factual_label = 8
x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]

Using REVISE, we are going to turn a randomly drawn 8 into a 3.

The API call is the same as always:

γ = 0.95
# Define generator:
generator = REVISEGenerator(opt=Flux.Adam(0.5))
# Generate recourse:
ce = generate_counterfactual(x, target, counterfactual_data, M, generator; decision_threshold=γ)

The chart below shows the results:

References

Joshi, Shalmali, Oluwasanmi Koyejo, Warut Vijitbenjaronk, Been Kim, and Joydeep Ghosh. 2019. “Towards Realistic Individual Recourse and Actionable Explanations in Black-Box Decision Making Systems.” https://arxiv.org/abs/1907.09615.

[1] In general, we believe that there may be a trade-off between creating counterfactuals that respect the DGP vs. counterfactuals reflect the behaviour of the black-model in question - both accurately and complete.

[2] We believe that there is another potentially crucial disadvantage of relying on a separate generative model: it reallocates the task of learning realistic explanations for the data from the black-box model to the generative model.