Skip to content

luo3300612/Visualizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Jun 9, 2022
d3358cf · Jun 9, 2022

History

15 Commits
Aug 9, 2021
Aug 9, 2021
Aug 9, 2021
Aug 9, 2021
Jun 9, 2022
Jun 9, 2022
Aug 9, 2021
Aug 9, 2021
Aug 9, 2021

Repository files navigation

Visualizer

Visualizer是一个辅助深度学习模型中Attention模块可视化的小工具,主要功能是帮助取出嵌套在模型深处的Attention Map

为什么需要Visualizer?

为了可视化Attention Map,你是否有以下苦恼

  • Return大法好:通过return将嵌套在模型深处的Attention Map一层层地返回回来,然后训练模型的时候又不得不还原
  • 全局大法好:使用全局变量在Attention函数中直接记录Attention Map,结果训练的时候忘改回来导致OOM

不管你有没有,反正我有

咨询了专业人士的意见后,发现pytorch有个hook可以取出中间结果,大概查了一下,发现确实可以取出中间变量,但需要进行如下类似的hook注册

handle = net.conv2.register_forward_hook(hook)

进行这样操作的前提是我们知道要取出来的模块名,但是Transformer类模型一般是这样定义的(以Vit为例)

class VisionTransformer(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
        ...

然后每个Block中都有一个Attention

class Block(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.attn = Attention(...)
        ...

如果要使用hooks其中的问题就是

  1. 嵌套太深,模块名不清晰,我们根本不知道我们要取的attention map怎么以model.bla.bla.bla这样一直点出来!
  2. 一般来说,Transformer中attention map每层都有一个,一个个注册实在太麻烦了

所以我就思考并查找能否通过更简洁的方法来得到Attention Map(尤其是Transformer的),而visualizer就是其中的一种,它具有以下特点

  • 精准直接,你可以取出任何变量名的模型中间结果
  • 快捷方便,同时取出Transformer类模型中的所有attention map
  • 非侵入式,你无须修改函数内的任何一行代码
  • 训练-测试一致,可视化完成后,你无须在训练时再将代码改回来

用法

安装

pip install bytecode
python setup.py install

安装完成后,只需要用get_local装饰一下Attention的函数,forward之后就可以拿到函数内与装饰器参数同名的局部变量啦~

Usage1

比如说,我想要函数里的attention_map变量: 在模型文件里,我们这么写

from visualizer import get_local
@get_local('attention_map')
def your_attention_function(*args, **kwargs):
    ...
    attention_map = ... 
    ...
    return ...

然后在可视化代码里,我们这么写

from visualizer import get_local
get_local.activate() # 激活装饰器
from ... import model # 被装饰的模型一定要在装饰器激活之后导入!!

# load model and data
...
out = model(data)

cache = get_local.cache # ->  {'your_attention_function': [attention_map]}

最终就会以字典形式存在get_local.cache里,其中key是你的函数名,value就是一个存储attention_map的列表

Usage2

使用Pytorch时我们往往会将模块定义成一个类,此时也是一样只要装饰类内计算出attention_map的函数即可

from visualizer import get_local

class Attention(nn.Module):
    def __init__(self):
        ...
    
    @get_local('attn_map')
    def forward(self, x):
        ...
        attn_map = ...
        ...
        return ...

其他细节请参考demo.ipynb文件

可视化结果

这里是部分可视化vit_small的结果,全部内容在demo.ipynb文件里

因为普通Vit所有Attention map都是在Attention.forward中计算出来的,所以只要简单地装饰一下这个函数,我们就可以同时取出vit中12层Transformer的所有Attention Map!

一个Head的结果

a head

一层所有Heads的结果

heads

某个grid的Attention Map

grid2grid

注意

  • 想要可视化的变量在函数内部不能被后续的同名变量覆盖了,因为get_local取的是对应名称变量在函数中的最终值
  • 进行可视化时,get_local.activate()一定要在导入模型完成,因为python装饰器是在导入时执行的
  • 训练时你不需要修改/删除任何代码,即不用删掉装饰函数的代码,因为在get_local.activate()没有执行的情况下,attention函数不会被装饰,故没有任何性能损失(同上一点,因为python装饰器是在导入时执行的)

其他

当然,其实get_local本身可以取出任何一个函数中某个局部变量的最终值,所以它应该还有其他更有趣的用途

references

About

assistant tools for attention visualization in deep learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published