Performance Benchmarks
In the previous tutorial, we have seen how counterfactual explanations can be evaluated. An important follow-up task is to compare the performance of different counterfactual generators is an important task. Researchers can use benchmarks to test new ideas they want to implement. Practitioners can find the right counterfactual generator for their specific use case through benchmarks. In this tutorial, we will see how to run benchmarks for counterfactual generators.
Post Hoc Benchmarking
We begin by continuing the discussion from the previous tutorial: suppose you have generated multiple counterfactual explanations for multiple individuals, like below:
# Factual and target:
n_individuals = 5
ids = rand(findall(predict_label(M, counterfactual_data) .== factual), n_individuals)
xs = select_factual(counterfactual_data, ids)
ces = generate_counterfactual(xs, target, counterfactual_data, M, generator; num_counterfactuals=5)
You may be interested in comparing the outcomes across individuals. To benchmark the various counterfactual explanations using default evaluation measures, you can simply proceed as follows:
bmk = benchmark(ces)
Under the hood, the benchmark(counterfactual_explanations::Vector{CounterfactualExplanation})
uses evaluate(counterfactual_explanations::Vector{CounterfactualExplanation})
to generate a Benchmark
object, which contains the evaluation in its most granular form as a DataFrame
.
Working with Benchmark
s
For convenience, the DataFrame
containing the evaluation can be returned by simply calling the Benchmark
object. By default, the aggregated evaluation measures across id
(in line with the default behaviour of evaluate
).
bmk()
15ร7 DataFrame
Row โ sample variable value generator model โฏ
โ Int64 String Float64 Symbol Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 distance 3.17243 GradientBasedGenerator(nothing, โฆ FluxMod โฏ
2 โ 1 redundancy 0.0 GradientBasedGenerator(nothing, โฆ FluxMod
3 โ 1 validity 1.0 GradientBasedGenerator(nothing, โฆ FluxMod
4 โ 2 distance 3.07148 GradientBasedGenerator(nothing, โฆ FluxMod
5 โ 2 redundancy 0.0 GradientBasedGenerator(nothing, โฆ FluxMod โฏ
6 โ 2 validity 1.0 GradientBasedGenerator(nothing, โฆ FluxMod
7 โ 3 distance 3.62159 GradientBasedGenerator(nothing, โฆ FluxMod
8 โ 3 redundancy 0.0 GradientBasedGenerator(nothing, โฆ FluxMod
9 โ 3 validity 1.0 GradientBasedGenerator(nothing, โฆ FluxMod โฏ
10 โ 4 distance 2.62783 GradientBasedGenerator(nothing, โฆ FluxMod
11 โ 4 redundancy 0.0 GradientBasedGenerator(nothing, โฆ FluxMod
12 โ 4 validity 1.0 GradientBasedGenerator(nothing, โฆ FluxMod
13 โ 5 distance 2.91985 GradientBasedGenerator(nothing, โฆ FluxMod โฏ
14 โ 5 redundancy 0.0 GradientBasedGenerator(nothing, โฆ FluxMod
15 โ 5 validity 1.0 GradientBasedGenerator(nothing, โฆ FluxMod
3 columns omitted
To retrieve the granular dataset, simply do:
bmk(agg=nothing)
75ร8 DataFrame
Row โ sample num_counterfactual variable value generator โฏ
โ Int64 Int64 String Float64 Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 1 distance 3.15903 GradientBasedGenerator โฏ
2 โ 1 2 distance 3.16773 GradientBasedGenerator
3 โ 1 3 distance 3.17011 GradientBasedGenerator
4 โ 1 4 distance 3.20239 GradientBasedGenerator
5 โ 1 5 distance 3.16291 GradientBasedGenerator โฏ
6 โ 1 1 redundancy 0.0 GradientBasedGenerator
7 โ 1 2 redundancy 0.0 GradientBasedGenerator
8 โ 1 3 redundancy 0.0 GradientBasedGenerator
9 โ 1 4 redundancy 0.0 GradientBasedGenerator โฏ
10 โ 1 5 redundancy 0.0 GradientBasedGenerator
11 โ 1 1 validity 1.0 GradientBasedGenerator
โฎ โ โฎ โฎ โฎ โฎ โฎ โฑ
66 โ 5 1 redundancy 0.0 GradientBasedGenerator
67 โ 5 2 redundancy 0.0 GradientBasedGenerator โฏ
68 โ 5 3 redundancy 0.0 GradientBasedGenerator
69 โ 5 4 redundancy 0.0 GradientBasedGenerator
70 โ 5 5 redundancy 0.0 GradientBasedGenerator
71 โ 5 1 validity 1.0 GradientBasedGenerator โฏ
72 โ 5 2 validity 1.0 GradientBasedGenerator
73 โ 5 3 validity 1.0 GradientBasedGenerator
74 โ 5 4 validity 1.0 GradientBasedGenerator
75 โ 5 5 validity 1.0 GradientBasedGenerator โฏ
4 columns and 54 rows omitted
Since benchmarks return a DataFrame
object on call, post-processing is straightforward. For example, we could use Tidier.jl
:
using Tidier
@chain bmk() begin
@filter(variable == "distance")
@select(sample, variable, value)
end
5ร3 DataFrame
Row โ sample variable value
โ Int64 String Float64
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 distance 3.17243
2 โ 2 distance 3.07148
3 โ 3 distance 3.62159
4 โ 4 distance 2.62783
5 โ 5 distance 2.91985
Metadata for Counterfactual Explanations
Benchmarks always report metadata for each counterfactual explanation, which is automatically inferred by default. The default metadata concerns the explained model
and the employed generator
. In the current example, we used the same model and generator for each individual:
@chain bmk() begin
@group_by(sample)
@select(sample, model, generator)
@summarize(model=unique(model),generator=unique(generator))
@ungroup
end
5ร3 DataFrame
Row โ sample model generator โฏ
โ Int64 Symbol Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 FluxModel(Chain(Dense(2 => 2)), โฆ GradientBasedGenerator(nothi โฏ
2 โ 2 FluxModel(Chain(Dense(2 => 2)), โฆ GradientBasedGenerator(nothi
3 โ 3 FluxModel(Chain(Dense(2 => 2)), โฆ GradientBasedGenerator(nothi
4 โ 4 FluxModel(Chain(Dense(2 => 2)), โฆ GradientBasedGenerator(nothi
5 โ 5 FluxModel(Chain(Dense(2 => 2)), โฆ GradientBasedGenerator(nothi โฏ
1 column omitted
Metadata can also be provided as an optional key argument.
meta_data = Dict(
:generator => "Generic",
:model => "MLP",
)
meta_data = [meta_data for i in 1:length(ces)]
bmk = benchmark(ces; meta_data=meta_data)
@chain bmk() begin
@group_by(sample)
@select(sample, model, generator)
@summarize(model=unique(model),generator=unique(generator))
@ungroup
end
5ร3 DataFrame
Row โ sample model generator
โ Int64 String String
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 MLP Generic
2 โ 2 MLP Generic
3 โ 3 MLP Generic
4 โ 4 MLP Generic
5 โ 5 MLP Generic
Ad Hoc Benchmarking
So far we have assumed the following workflow:
- Fit some machine learning model.
- Generate counterfactual explanations for some individual(s) (
generate_counterfactual
). - Evaluate and benchmark them (
benchmark(ces::Vector{CounterfactualExplanation})
).
In many cases, it may be preferable to combine these steps. To this end, we have added support for two scenarios of Ad Hoc Benchmarking.
Pre-trained Models
In the first scenario, it is assumed that the machine learning models have been pre-trained and so the workflow can be summarized as follows:
- Fit some machine learning model(s).
- Generate counterfactual explanations and benchmark them.
We suspect that this is the most common workflow for practitioners who are interested in benchmarking counterfactual explanations for the pre-trained machine learning models. Letโs go through this workflow using a simple example. We first train some models and store them in a dictionary:
models = Dict(
:MLP => fit_model(counterfactual_data, :MLP),
:Linear => fit_model(counterfactual_data, :Linear),
)
Next, we store the counterfactual generators of interest in a dictionary as well:
generators = Dict(
:Generic => GenericGenerator(),
:Gravitational => GravitationalGenerator(),
:Wachter => WachterGenerator(),
:ClaPROAR => ClaPROARGenerator(),
)
Then we can run a benchmark for individual(s) x
, a pre-specified target
and counterfactual_data
as follows:
bmk = benchmark(x, target, counterfactual_data; models=models, generators=generators)
In this case, metadata is automatically inferred from the dictionaries:
@chain bmk() begin
@filter(variable == "distance")
@select(sample, variable, value, model, generator)
end
8ร5 DataFrame
Row โ sample variable value model generator
โ Int64 String Float64 Symbol Symbol
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 distance 3.23559 Linear Gravitational
2 โ 1 distance 3.40924 Linear ClaPROAR
3 โ 1 distance 3.08311 Linear Generic
4 โ 1 distance 3.1338 Linear Wachter
5 โ 1 distance 4.44266 MLP Gravitational
6 โ 1 distance 4.67161 MLP ClaPROAR
7 โ 1 distance 4.98131 MLP Generic
8 โ 1 distance 4.32344 MLP Wachter
Everything at once
Researchers, in particular, may be interested in combining all steps into one. This is the second scenario of Ad Hoc Benchmarking:
- Fit some machine learning model(s), generate counterfactual explanations and benchmark them.
It involves calling benchmark
directly on counterfactual data (the only positional argument):
bmk = benchmark(counterfactual_data)
This will use the default models from standard_models_catalogue
and train them on the data. All available generators from generator_catalogue
will also be used:
@chain bmk() begin
@filter(variable == "validity")
@select(sample, variable, value, model, generator)
end
165ร5 DataFrame
Row โ sample variable value model generator
โ Int64 String Float64 Symbol Symbol
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 validity 1.0 Linear gravitational
2 โ 2 validity 1.0 Linear gravitational
3 โ 3 validity 1.0 Linear gravitational
4 โ 4 validity 1.0 Linear gravitational
5 โ 5 validity 1.0 Linear gravitational
6 โ 1 validity 1.0 Linear growing_spheres
7 โ 2 validity 1.0 Linear growing_spheres
8 โ 3 validity 1.0 Linear growing_spheres
9 โ 4 validity 1.0 Linear growing_spheres
10 โ 5 validity 1.0 Linear growing_spheres
11 โ 1 validity 1.0 Linear revise
โฎ โ โฎ โฎ โฎ โฎ โฎ
156 โ 11 validity 1.0 MLP generic
157 โ 12 validity 1.0 MLP generic
158 โ 13 validity 1.0 MLP generic
159 โ 14 validity 1.0 MLP generic
160 โ 15 validity 1.0 MLP generic
161 โ 11 validity 1.0 MLP greedy
162 โ 12 validity 1.0 MLP greedy
163 โ 13 validity 1.0 MLP greedy
164 โ 14 validity 1.0 MLP greedy
165 โ 15 validity 1.0 MLP greedy
144 rows omitted
Optionally, you can instead provide a dictionary of models
and generators
as before. Each value in the models
dictionary should be one of two things:
- Either be an object
M
of typeAbstractFittedModel
that implements theModels.train
method. - Or a
DataType
that can be called onCounterfactualData
to create an objectM
as in (a).
Multiple Datasets
Benchmarks are run on single instances of type CounterfactualData
. This is our design choice for two reasons:
- We want to avoid the loops inside the
benchmark
method(s) from getting too nested and convoluted. - While it is straightforward to infer metadata for models and generators, this is not the case for datasets.
Fortunately, it is very easy to run benchmarks for multiple datasets anyway, since Benchmark
instances can be concatenated. To see how, letโs consider an example involving multiple datasets, models and generators:
# Data:
datasets = Dict(
:moons => load_moons(),
:circles => load_circles(),
)
# Models:
models = Dict(
:MLP => FluxModel,
:Linear => Linear,
)
# Generators:
generators = Dict(
:Generic => GenericGenerator(),
:Greedy => GreedyGenerator(),
)
Then we can simply loop over the datasets and eventually concatenate the results like so:
using CounterfactualExplanations.Evaluation: distance_measures
bmks = []
for (dataname, dataset) in datasets
bmk = benchmark(dataset; models=models, generators=generators, measure=distance_measures, verbose=true)
push!(bmks, bmk)
end
bmk = vcat(bmks[1], bmks[2]; ids=collect(keys(datasets)))
When ids
are supplied, then a new id column is added to the evaluation data frame that contains unique identifiers for the different benchmarks. The optional idcol_name
argument can be used to specify the name for that indicator column (defaults to "dataset"
):
@chain bmk() begin
@group_by(dataset, generator)
@filter(model == :MLP)
@filter(variable == "distance_l1")
@summarize(L1_norm=mean(value))
@ungroup
end
4ร3 DataFrame
Row โ dataset generator L1_norm
โ Symbol Symbol Float32
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ circles Generic 2.71561
2 โ circles Greedy 0.596901
3 โ moons Generic 1.30436
4 โ moons Greedy 0.742734