.. -*- mode: rst -*-
|pypi_version|_ |pypi_downloads|_
.. |pypi_version| image:: https://img.shields.io/pypi/v/explainable-transformers.svg
.. _pypi_version: https://pypi.python.org/pypi/explainable-transformers/
.. |pypi_downloads| image:: https://pepy.tech/badge/explainable-transformers/month
.. _pypi_downloads: https://pepy.tech/project/explainable-transformers
.. image:: artwork/cover.png
:alt: Vision Transformers explanation
=====
explainable-transformers
=====
Explanation and interpretation techniques for Transformer-based architectures.
-----------
Installation
-----------
Requirements:
* opencv-python
* numpy
* torch
* tqdm
.. code:: bash
pip install explainable-transformers
--------------
Usage examples
--------------
Please, see **notebook/** for complete examples on how to create representations for the explanations.
For Vision Transformers, use the *VisionTransformerWrapper* passing a Pytorch model.
.. code:: python
from transformers import ViTModel
# import explanator module
from explainable_transformers.image_explainer import VisionTransformerWrapper
# define the last layer for classification
class PreTrainedViT(nn.Module):
def __init__(self, vit_model, d_model, classes):
...
def forward(self, x):
...
# load the pre-trained model
pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k',
add_pooling_layer=False, output_attentions=True)
model = PreTrainedViT(pretrained_vit_model, hidden_size=768, output_dim=10)
# create the ViT wrapper and register the layers
vit_wrapper = VisionTransformerWrapper(model, device, num_attn_layers=12)
vit_wrapper.register_hook()
# explain a prediction using .generate_visualization(img)
image = Image.open('images/dogbird.png')
processed_image = transform(image)
cat_exp, _ = vit_wrapper.generate_visualization(processed_image)
For Text Transformers, right now we need to know how the attention component is organized.
.. code:: python
# first the imports
from transformers import BertTokenizer, BertForSequenceClassification
from explainable_transformers.utils import *
from explainable_transformers import NLPTransformerWrapper
# for text, we provide the NLP wrapper
"""
We access the attention component like following:
- BERT or RoBERTa: '.encoder.layer.#.attention.self.dropout'
- XLNet: '.layer.#.rel_attn.dropout'
"""
nlp_wrapper = NLPTransformerWrapper(model, device, 12, 'bert', 'classifier', '.encoder.layer.#.attention.self.dropout')
nlp_wrapper.register_hook()
explanation = nlp_wrapper.generate_explanation(input_ids, attention_mask, class_index=true_class, start_layer=NUM_LAYERS-1)
explanation = explanation.detach().cpu().numpy()
--------
Citation
--------
Please, use the respective authors if you use any of the techniques.
Currently, we have the **Pytorch** implementation of the following approaches:
*Transformer Interpretability Beyond Attention Visualization* (`paper <https://arxiv.org/abs/2012.09838>`_):
1. Transformers: BERT, RoBERTa, and XLNet
2. Vision Transformers
.. code:: bibtex
@InProceedings{Chefer_2021_CVPR,
author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
title = {Transformer Interpretability Beyond Attention Visualization},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {782-791}
}
-------
License
-------
*explainable-transformers* follows the 3-clause BSD license and it is based on other open-source implementations: `Chefer's <https://github.com/hila-chefer/Transformer-Explainability>`_.
We also use `nlp_understanding <https://github.com/ENSAE-CKW/nlp_understanding>`_ for generating the heatmap.
E-mail me (wilson_jr at outlook dot com) if you like to contribute.
......