This repository contains the code for the paper: Start Smart: Leveraging Gradients For Enhancing Mask-based XAI Methods
Authors: Buelent Uendes, Shujian Yu, and Mark Hoogendoorn
Mask-based explanation methods offer a powerful framework for interpreting deep learning model predictions across diverse data modalities, such as images and time series, in which the central idea is to identify an instance-dependent mask that minimizes the performance drop from the resulting masked input. Different objectives for learning such masks have been proposed, all of which, in our view, can be unified under an information-theoretic framework that balances performance degradation of the masked input with the complexity of the resulting masked representation. Typically, these methods initialize the masks either uniformly or as all-ones. In this paper, we argue that an effective mask initialization strategy is as important as the development of novel learning objectives, particularly in light of the significant computational costs associated with existing mask-based explanation methods. To this end, we introduce a new gradient-based initialization technique called StartGrad, which is the first initialization method specifically designed for mask-based post-hoc explainability methods. Compared to commonly used strategies, StartGrad is provably superior at initialization in striking the aforementioned trade-off. Despite its simplicity, our experiments demonstrate that StartGrad enhances the optimization process of various state-of-the-art mask-explanation methods by reaching target metrics faster and, in some cases, boosting their overall performance.
- Please first create and activate a virtualenv via:
virtualenv <NAME_OF_YOUR_VIRTUALENV>
source <NAME_OF_YOUR_VIRTUALENV>/bin/activate- Install the pytorch wavelets
cd ./src
git clone https://github.com/fbcotter/pytorch_wavelets
cd pytorch_wavelets
pip install .- Now you can run and install the required packages via:
cd ../../
pip3 install -r requirements.txt
pip3 install -e .- Lastly, you need to install pyShearLab via:
pip3 install https://github.com/stefanloock/pyshearlab/archive/master.zipThe project was run and tested on both MacOS Sonoma and Linux Ubuntu with Python 3.9 installed.
The repository is structured as follows.
data/contains the source code as well as the datasets used in this study.utils/contains helper functions for the project.src/contains the main source code for the project for both vision and time-series experiments.config_environment/contains configuration files for experiments.experiments/contains experiment scripts and bash scripts.pics/contains images used in documentation.
This project uses W&B for experiment tracking and model management.
Experiment configurations are in the config_environment folder. The config files are configured to correpond to the default settings used in the paper. Please read the paper for more information on the hyperparameters.
- For ImageNet experiments: Download the validation set and place in
data/ImageNet/validation_set/ - For quick testing: Use provided custom images in the repository
All implementations are in src/vision/:
- PixelMask (
pixel_RDE.py) - WaveletX (
waveletX.py) - ShearletX (
shearletX.py) - Gradient-based methods (
saliency_methods.py): Integrated Gradients, SmoothGrad, GradCam
-
Configuration files are in
config_environment/vision/<METHOD>/- Note: StartGrad initialization is labeled as 'saliency' in configs
- Example: ShearletX with StartGrad uses
hparams_shearletX_saliency.yaml
-
Example command:
cd ./experiments/vision
python3 main_vision.py --method shearletX,shearletX_saliency,shearletX_uniform --folder Custom --input kobe.jpg --iterations 5 --pretrained_model resnet18 --seed 123The command above stores a visual comparison between the original image and the masked image to: figures/shearletX/<seed>/<model>/Comparison_explainers_<method>_<seed>.png
For detailed parameter descriptions, see main_vision.py.
- Results & Figures
The results folder contains important results files stored as csv files which can be used to generate plots of the paper. Important: Some of the csv files in the corresponding resnet18 folders are in a zip format, so one needs to unzip these first.
To generate the main figures, please have a look at the bash scripts of the experiments/vision folder.
There are two datasets implemented, state and switch-feature dataset. The corresponding scripts to generate the
datasets can be found in utils/time_series/.
In this repo, we use the ExtremalMask method introduced in this paper
as a time-series mask-based explainer. The implementation can be found in src/time_series/timeseries_mask_explainer.py.
There are two classifiers implemented, a LSTM and a GRU deep learning architecture. The corresponding code can be found
src/time_series/XAI_classifier.py.
-
Configuration files are in
config_environment/time_series/<OBJECTIVE_FORMULATION>/- Note: StartGrad initialization is labeled as 'gradient' in configs
- Example:
hparams_extrema_gradient.yaml
-
Example command:
cd ./experiments/time_series
python3 main_time_series.py --iterations 500 --epochs 50 --mode preservation_game --model_type GRU --dataset state --plot_averageIf the dataset has not yet been generated for the seed, it will first create it before training the time-series classifier and fitting the mask-based XAI method.
- Results
- Performance results are then saved to
results/time_series/extremal/<seed>/<model_type>/<mode>/<fold_number> - Average performance plots are then generated and saved to
figures/time_series/<dataset>/extremal/<seed_number>/<model_type>/<mode>
For detailed parameter descriptions, see main_time_series.py.
If you found this work useful in your research, please consider citing:
@article{buendes2025startgrad,
title={Start Smart: Leveraging Gradients For Enhancing Mask-based XAI Methods},
author={Uendes, Buelent and Yu, Shujian, and Hoogendoorn, Mark},
journal={Proceedings of the 13th International Conference on Learning Representations},
year={2025}
}This work is funded by Stress in Action. The research project Stress in Action is financially supported by the Dutch Research Council and the Dutch Ministry of Education, Culture and Science (NWO gravitation grant number 024.005.010).
Part of the code relies on repository ShearletX for the associated paper Explaining Image Classifiers with Multiscale Directional Image Representation for the vision experiments. For the time series experiments, we also relied on the repository time interpret for the implementation of the ExtremalMask model and generation of the synthetic datasets.
