{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "ename": "ImportError", "evalue": "cannot import name 'Benchmark' from 'ferret' (/home/cass/anaconda3/envs/eva/lib/python3.11/site-packages/ferret.py)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m AutoModelForSequenceClassification, AutoTokenizer\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mferret\u001b[39;00m \u001b[39mimport\u001b[39;00m Benchmark\n\u001b[1;32m 4\u001b[0m name \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcardiffnlp/twitter-xlm-roberta-base-sentiment\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 5\u001b[0m model \u001b[39m=\u001b[39m AutoModelForSequenceClassification\u001b[39m.\u001b[39mfrom_pretrained(name)\n", "\u001b[0;31mImportError\u001b[0m: cannot import name 'Benchmark' from 'ferret' (/home/cass/anaconda3/envs/eva/lib/python3.11/site-packages/ferret.py)" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "from ferret import Benchmark\n", "\n", "name = \"cardiffnlp/twitter-xlm-roberta-base-sentiment\"\n", "model = AutoModelForSequenceClassification.from_pretrained(name)\n", "tokenizer = AutoTokenizer.from_pretrained(name)\n", "\n", "bench = Benchmark(model, tokenizer)\n", "explanations = bench.explain(\"You look stunning!\", target=1)\n", "evaluations = bench.evaluate_explanations(explanations, target=1)\n", "\n", "bench.show_evaluation_table(evaluations)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/cass/anaconda3/envs/eva/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2023-07-04 09:52:18.157356: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2023-07-04 09:52:18.188381: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-07-04 09:52:18.794393: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import inseq" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", "Unused arguments during attribution: {'n_steps': 500, 'internal_batch_size': 50}\n", "Attributing with saliency...: 100%|██████████| 45/45 [00:00<00:00, 50.01it/s]\n" ] }, { "data": { "text/html": [ "<br/><b>0th instance:</b><br/>\n", "<html>\n", "<div id=\"usieyiccwpusnzjovqwp_viz_container\">\n", " <div id=\"usieyiccwpusnzjovqwp_content\" style=\"padding:15px;border-style:solid;margin:5px;\">\n", " <div id = \"usieyiccwpusnzjovqwp_saliency_plot_container\" class=\"usieyiccwpusnzjovqwp_viz_container\" style=\"display:block\">\n", " \n", "<div id=\"qahftrxckafnafqofpva_saliency_plot\" class=\"qahftrxckafnafqofpva_viz_content\">\n", " <div style=\"margin:5px;font-family:sans-serif;font-weight:bold;\">\n", " <span style=\"font-size: 20px;\">Target Saliency Heatmap</span>\n", " <br>\n", " x: Generated tokens, y: Attributed tokens\n", " </div>\n", " \n", "<table border=\"1\" cellpadding=\"5\" cellspacing=\"5\"\n", " style=\"overflow-x:scroll;display:block;\">\n", " <tr><th></th>\n", "<th>ĠTokyo</th></tr><tr><th>When</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.1959595959595959)\">0.036</th></tr><tr><th>Ġmy</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10136660724896006)\">0.02</th></tr><tr><th>Ġflight</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.17231134878193693)\">0.031</th></tr><tr><th>Ġlanded</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10136660724896006)\">0.02</th></tr><tr><th>Ġin</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.15654585066349747)\">0.029</th></tr><tr><th>ĠJapan</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.44032481679540497)\">0.081</th></tr><tr><th>,</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.14866310160427795)\">0.028</th></tr><tr><th>ĠI</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>Ġconverted</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06983561101208159)\">0.014</th></tr><tr><th>Ġmy</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>Ġcurrency</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.23537334125569415)\">0.043</th></tr><tr><th>Ġand</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06983561101208159)\">0.013</th></tr><tr><th>Ġslowly</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)\">0.012</th></tr><tr><th>Ġfell</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>Ġasleep</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.09348385818974037)\">0.017</th></tr><tr><th>.</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.08560110913052081)\">0.016</th></tr><tr><th>Ġ(</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06983561101208159)\">0.014</th></tr><tr><th>I</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.006</th></tr><tr><th>Ġhad</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.006</th></tr><tr><th>Ġa</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.04618736383442265)\">0.009</th></tr><tr><th>Ġterrifying</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.08560110913052081)\">0.016</th></tr><tr><th>Ġdream</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)\">0.012</th></tr><tr><th>Ġabout</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>Ġmy</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.007</th></tr><tr><th>Ġgrandmother</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)\">0.023</th></tr><tr><th>,</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.04618736383442265)\">0.01</th></tr><tr><th>Ġbut</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.006</th></tr><tr><th>Ġthat</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>âĢ</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10924935630817992)\">0.02</th></tr><tr><th>Ļ</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.04618736383442265)\">0.009</th></tr><tr><th>s</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.006</th></tr><tr><th>Ġa</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.05407011289364243)\">0.01</th></tr><tr><th>Ġstory</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10136660724896006)\">0.019</th></tr><tr><th>Ġfor</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>Ġanother</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.030421865715983164)\">0.006</th></tr><tr><th>Ġtime</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)\">0.008</th></tr><tr><th>).</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10924935630817992)\">0.02</th></tr><tr><th>ĠI</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.07771836007130124)\">0.015</th></tr><tr><th>Ġwas</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.13289760348583876)\">0.025</th></tr><tr><th>Ġstaying</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.18019409784115661)\">0.033</th></tr><tr><th>Ġin</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.14866310160427795)\">0.028</th></tr><tr><th>Ġthe</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.2117250940780353)\">0.038</th></tr><tr><th>Ġcapital</th><th style=\"background:rgba(255.0, 13.0, 87.0, 1.0)\">0.182</th></tr><tr><th>Ġof</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.3536145771439889)\">0.065</th></tr><tr><th>ĠTokyo</th><th style=\"background:rgba(0.0, 0.0, 0.0, 0.0)\"></th></tr></table>\n", "</div>\n", "\n", " </div>\n", " </div>\n", "</div>\n", "</html>\n" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = inseq.load_model(\"gpt2\", \"saliency\") # saliency integrated_gradients\n", "\n", "country = 'Japan'\n", "text = f\"When my flight landed in {country}, I converted my currency and slowly fell asleep. (I had a terrifying dream about my grandmother, but that’s a story for another time). I was staying in the capital of\"\n", "model.attribute(text, generation_args={\"max_new_tokens\": 1}, n_steps=500, internal_batch_size=50).show()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", "Unused arguments during attribution: {'n_steps': 500, 'internal_batch_size': 50}\n", "Attributing with saliency...: 100%|██████████| 24/24 [00:00<00:00, 61.98it/s]\n" ] }, { "data": { "text/html": [ "<br/><b>0th instance:</b><br/>\n", "<html>\n", "<div id=\"usieyiccwpusnzjovqwp_viz_container\">\n", " <div id=\"usieyiccwpusnzjovqwp_content\" style=\"padding:15px;border-style:solid;margin:5px;\">\n", " <div id = \"usieyiccwpusnzjovqwp_saliency_plot_container\" class=\"usieyiccwpusnzjovqwp_viz_container\" style=\"display:block\">\n", " \n", "<div id=\"qahftrxckafnafqofpva_saliency_plot\" class=\"qahftrxckafnafqofpva_viz_content\">\n", " <div style=\"margin:5px;font-family:sans-serif;font-weight:bold;\">\n", " <span style=\"font-size: 20px;\">Target Saliency Heatmap</span>\n", " <br>\n", " x: Generated tokens, y: Attributed tokens\n", " </div>\n", " \n", "<table border=\"1\" cellpadding=\"5\" cellspacing=\"5\"\n", " style=\"overflow-x:scroll;display:block;\">\n", " <tr><th></th>\n", "<th>ĠBeijing</th></tr><tr><th>When</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.306318082788671)\">0.057</th></tr><tr><th>Ġmy</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.16442859972271742)\">0.031</th></tr><tr><th>Ġflight</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.2747870865517925)\">0.051</th></tr><tr><th>Ġlanded</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.15654585066349747)\">0.03</th></tr><tr><th>Ġin</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.2511388393741335)\">0.047</th></tr><tr><th>ĠChina</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.5979797979797981)\">0.111</th></tr><tr><th>,</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.22749059219647458)\">0.043</th></tr><tr><th>ĠI</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.06983561101208159)\">0.014</th></tr><tr><th>Ġconverted</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10136660724896006)\">0.02</th></tr><tr><th>Ġmy</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.05407011289364243)\">0.011</th></tr><tr><th>Ġcurrency</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.2747870865517925)\">0.052</th></tr><tr><th>Ġand</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)\">0.023</th></tr><tr><th>Ġslowly</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.10136660724896006)\">0.02</th></tr><tr><th>Ġfell</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.08560110913052081)\">0.016</th></tr><tr><th>Ġasleep</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.14866310160427795)\">0.029</th></tr><tr><th>.</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.18019409784115661)\">0.033</th></tr><tr><th>ĠI</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.07771836007130124)\">0.015</th></tr><tr><th>Ġwas</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)\">0.024</th></tr><tr><th>Ġstaying</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.22749059219647458)\">0.043</th></tr><tr><th>Ġin</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.24325609031491383)\">0.046</th></tr><tr><th>Ġthe</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.219607843137255)\">0.041</th></tr><tr><th>Ġcapital</th><th style=\"background:rgba(255.0, 13.0, 87.0, 1.0)\">0.185</th></tr><tr><th>Ġof</th><th style=\"background:rgba(255.0, 13.0, 87.0, 0.32208358090711037)\">0.06</th></tr><tr><th>ĠBeijing</th><th style=\"background:rgba(0.0, 0.0, 0.0, 0.0)\"></th></tr></table>\n", "</div>\n", "\n", " </div>\n", " </div>\n", "</div>\n", "</html>\n" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "country = 'Japan'\n", "model.attribute(text, generation_args={\"max_new_tokens\": 1}, n_steps=500, internal_batch_size=50).show()" ] } ], "metadata": { "kernelspec": { "display_name": "eva", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }