Skip to content

Commit

Permalink
visualization update
Browse files Browse the repository at this point in the history
  • Loading branch information
jeonsworld committed Nov 10, 2020
1 parent 4ca4bb6 commit 1b1eecc
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 19 deletions.
9 changes: 9 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ To verify that the converted model weight is correct, we simply compare it with
| imagenet21k | ViT-B_16 | CIFAR-100 | 1000/100 | 0.9115 | 0.9216 |


## Visualization
The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module.
The attention map for the input image can be visualized through the attention score of self-attention.

Please refer to [visualize_attention_map.ipynb](./visualize_attention_map)

![fig3](./img/figure3.png)


## Reference
* [Google ViT](https://github.com/google-research/vision_transformer)
* [Pytorch Image Models(timm)](https://github.com/rwightman/pytorch-image-models)
Expand Down
Binary file added img/figure3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 25 additions & 19 deletions models/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def swish(x):


class Attention(nn.Module):
def __init__(self, config):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
Expand Down Expand Up @@ -81,6 +82,7 @@ def forward(self, hidden_states):
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
Expand All @@ -89,7 +91,7 @@ def forward(self, hidden_states):
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output
return attention_output, weights


class Mlp(nn.Module):
Expand Down Expand Up @@ -150,25 +152,25 @@ def forward(self, x):


class Block(nn.Module):
def __init__(self, config):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config)
self.attn = Attention(config, vis)

def forward(self, x):
h = x
x = self.attention_norm(x)
x = self.attn(x)
x, weights = self.attn(x)
x = x + h

h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x
return x, weights

def load_from(self, weights, n_block):
ROOT = f"Transformer/encoderblock_{n_block}"
Expand Down Expand Up @@ -209,53 +211,57 @@ def load_from(self, weights, n_block):


class Encoder(nn.Module):
def __init__(self, config):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config)
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))

def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states = layer_block(hidden_states)
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded
return encoded, attn_weights


class Transformer(nn.Module):
def __init__(self, config, img_size):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config)
self.encoder = Encoder(config, vis)

def forward(self, input_ids):
embedding_output = self.embeddings(input_ids)
encoded = self.encoder(embedding_output)
return encoded
encoded, attn_weights = self.encoder(embedding_output)
return encoded, attn_weights


class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier

self.transformer = Transformer(config, img_size)
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)

def forward(self, x, labels=None):
x = self.transformer(x)[:, 0]
logits = self.head(x)
x, attn_weights = self.transformer(x)
logits = self.head(x[:, 0])

if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits
return logits, attn_weights

def load_from(self, weights):
with torch.no_grad():
Expand Down
376 changes: 376 additions & 0 deletions visualize_attention_map.ipynb

Large diffs are not rendered by default.

0 comments on commit 1b1eecc

Please sign in to comment.