DeBERTa 论文+代码笔记

Paper:[2006.03654] DeBERTa: Decoding-enhanced BERT with Disentangled Attention

Code:microsoft/DeBERTa: The implementation of DeBERTa

核心思想:增加位置-内容与内容-位置的自注意力增强位置和内容之间的依赖,用 EMD 缓解 BERT 预训练和精调因为 MASK 造成的不匹配问题。

What

动机和核心问题

  • 一组词的 Attention 不光取决于内容,还和它们的相对位置有关。比如 deep learning 挨在一起时的依赖关系比不在一起时要强(其实中文就更加如此了,所以感觉对中文可能更加有效)。
  • 解决预训练和精调的不匹配问题(精调时没有 MASK)。

针对以上问题有针对性地提出解决方案:

  • Disentangled Attention:增加计算 “位置-内容” 和 “内容-位置” 注意力。
  • Enhanced Mask Decoder:用 EMD 来代替原 BERT 的 SoftMax 层预测遮盖的 Token。因为我们在精调时一般会在 BERT 的输出后接一个特定任务的 Decoder,但是在预训练时却并没有这个 Decoder;所以本文在预训练时用一个两层的 Transformer decoder 和一个 SoftMax 作为 Decoder。

模型和算法

Disentangled Attention

模型架构和 BERT 类似,主要区别就是 Attention 分数的计算额外增加了位置信息。

假设 k 是最大相对距离,δ(i,j) 是 token i 到 j 的相对位置,定义如下:

δ(i,j)={0 for ijk2k1 for ijkij+k others \delta(i, j)=\left\{\begin{array}{rcc} 0 & \text { for } & i-j \leqslant-k \\ 2 k-1 & \text { for } & i-j \geqslant k \\ i-j+k & \text { others } \end{array}\right.

默认是 512,也就是说相对距离的范围从 -512 到 512。

Qc=HWq,c,Kc=HWk,c,Vc=HWv,c,Qr=PWq,r,Kr=PWk,rQ_{c}=H W_{q, c}, K_{c}=H W_{k, c}, V_{c}=H W_{v, c}, Q_{r}=P W_{q, r}, K_{r}=P W_{k, r}

除了正常的 QKV 外,另外还定义了 Qr 和 Kr,P 表示所有层之间共享的相对位置 Embedding。

A~i,j=QicKjc(a) content-to-content +QicKδ(i,j)r(b) content-to-position +KjcQδ(j,i)r(c) position-to-content \tilde{A}_{i, j}= \underbrace{Q_{i}^{c} K_{j}^{c \top}}_{(\mathrm{a}) \text { content-to-content }} + \underbrace{Q_{i}^{c} K_{\delta(i, j)}^{r}{\top}}_{(\mathrm{b}) \text { content-to-position }} + \underbrace{K_{j}^{c} Q_{\delta(j, i)}^{r}{\top}}_{(\mathrm{c}) \text { position-to-content }}

Aij 表示 Token i 到 j 的 Attention Score 。最后的 Attention 就是上面的三项,进一步计算 H:

Ho=softmax(A~3d)Vc\boldsymbol{H}_{o}=\operatorname{softmax}\left(\frac{\tilde{\boldsymbol{A}}}{\sqrt{3 d}}\right) \boldsymbol{V}_{c}

在具体实现时做了一些优化,也就是重复使用相对位置 Embedding,因为每次使用的其实都是整体的一个子集。

代码和详细的注释如下(做了一些简化):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# From https://github.com/microsoft/DeBERTa
class DisentangledSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size*3, bias=False)
self.q_bias = torch.nn.Parameter(
torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = torch.nn.Parameter(
torch.zeros((self.all_head_size), dtype=torch.float))
self.pos_att_type = ["p2c", "c2p"]
self.max_relative_positions = config.max_relative_positions
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False)
self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask,
return_att=False, query_states=None,
relative_pos=None, rel_embeddings=None):
""" Call the module
Args:
hidden_states (:obj:`torch.FloatTensor`):
Input states to the module usally the output from previous layer,
it will be the Q,K and V in `Attention(Q,K,V)`

attention_mask (:obj:`torch.ByteTensor`):
An attention mask matrix of shape [`B`, `N`, `N`]
where `B` is the batch size,
`N` is the maxium sequence length in which element [i,j] = `1` means
the `i` th token in the input can attend to the `j` th token.

return_att (:obj:`bool`, optional):
Whether return the attention maxitrix.

query_states (:obj:`torch.FloatTensor`, optional):
The `Q` state in `Attention(Q,K,V)`.

relative_pos (:obj:`torch.LongTensor`):
The relative position encoding between the tokens in the sequence.
It's of shape [`B`, `N`, `N`] with values ranging in
[`-max_relative_positions`, `max_relative_positions`].

rel_embeddings (:obj:`torch.FloatTensor`):
The embedding of relative distances.
It's a tensor of shape
[:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`].
"""
# (batch_size, seq_len, hidden_size * 3)
qp = self.in_proj(hidden_states)
# (batch_size, num_attention_heads, seq_len, 3 * attention_head_size).chunk(3, dim=-1) =>
# (batch_size, num_attention_heads, seq_len, attention_head_size)
query_layer,key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)

query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0))
value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0))

rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1
if 'c2p' in self.pos_att_type:
scale_factor += 1
if 'p2c' in self.pos_att_type:
scale_factor += 1
if 'p2p' in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1)*scale_factor)
query_layer = query_layer/scale
# (batch_size, num_attention_heads, query_size, key_size)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

# 本文定义的额外计算 Attention 分数
rel_embeddings = self.pos_dropout(rel_embeddings)
# (batch_size, num_attention_heads, query_size, key_size)
rel_att = self.disentangled_att_bias(query_layer, key_layer,
relative_pos, rel_embeddings,
scale_factor)

attention_scores = attention_scores + rel_att

attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
attention_probs = self.dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)

return (context_layer, attention_probs)

def disentangled_att_bias(self,
query_layer,
key_layer,
relative_pos,
rel_embeddings,
scale_factor):
# query_layer: (batch_size, num_attention_heads, query_seq_len, attention_head_size)
# key_layer: like query_layer
# relative_pos: (1, query_size, key_size)
# rel_embeddings: (max_relative_positions*2, hidden_size)
# scale_factor: 3

# (1, query_size, key_size) => (1, 1, query_size, key_size)
relative_pos = relative_pos.unsqueeze(1)

# an int number
att_span = min(max(query_layer.size(-2), key_layer.size(-2)),
self.max_relative_positions)
relative_pos = relative_pos.long().to(query_layer.device)
# (1, att_span*2, hidden_size)
# 层间共享的 P
rel_embeddings = rel_embeddings[
self.max_relative_positions - att_span:
self.max_relative_positions + att_span, :].unsqueeze(0)

# 位置 Kr
if 'c2p' in self.pos_att_type:
# without bias
# (1, att_span*2, hidden_size)
pos_key_layer = self.pos_proj(rel_embeddings)
# (1, num_attention_heads, att_span*2, attention_head_size)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
# 位置 Qr
if 'p2c' in self.pos_att_type:
# with bias
# (1, att_span*2, hidden_size)
pos_query_layer = self.pos_q_proj(rel_embeddings)
# (1, num_attention_heads, att_span*2, attention_head_size)
pos_query_layer = self.transpose_for_scores(pos_query_layer)

score = 0
# content->position
if 'c2p' in self.pos_att_type:
# (batch_size, num_attention_heads, query_size, att_span * 2)
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
# (1, 1, query_size, key_size) # i-j+k, [0, 2*k)
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
# (batch_size, num_attention_heads, query_size, key_size)
# expand(batch_size, num_attention_heads, query_size, key_size)
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.expand(
[query_layer.size(0),
query_layer.size(1),
query_layer.size(2),
relative_pos.size(-1)]))
score += c2p_att

# position->content
if 'p2c' in self.pos_att_type:
pos_query_layer /= math.sqrt(pos_query_layer.size(-1)*scale_factor)
# j-i+k, [0, 2*k), δ(j,i)
p2c_pos = torch.clamp(-relative_pos + att_span, 0, att_span*2-1)
# (batch_size, num_attention_heads, key_size, att_span * 2)
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
# expand(batch_size, num_attention_heads, key_size, query_size)
# transpose to => (batch_size, num_attention_heads, query_size, key_size)
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.expand(
[key_layer.size(0),
key_layer.size(1),
key_layer.size(2),
relative_pos.size(-2)])).transpose(-1, -2)
# expand 里面稍微改了一下,以前是这样的:
# [query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]
score += p2c_att
return score

发现里面并没有对 c2p 做 scale,不知道是不是忘记了还是其他原因,有知道的小伙伴还望指教一下。

这里最重要的就是这个 disentangled_att_bias 函数,其他的都和 BERT 基本一致。有几个特别需要注意的点(论文中也提到了):

  • p2c 的时候,是 δ(j,i) 而不是 δ(i,j),所以需要稍微做个变换,原因论文也解释了,因为 i 是 position,j 是 content,所以要反过来。也就是说 p2c 实际计算的是 j 位置的 key content 与 query position i 的 Attention。
  • 相对位置的 Embedding 是层间共享的,每次取了其中一个子集。

其实由于是 SelfAttention,query_size, key_size, seq_len 都是相等的,因为 qkv 其实都是它自己。所以代码其实可以更加精炼。另外值得注意的是本文实现的几个辅助工具,比如 StableDropout, MaskedLayerNorm, XSoftmax 等,主要是对自带的模块做了一些优化。

Enhanced Mask Decoder

前面提过这个主要针对的是预训练和精调阶段的不匹配,所以把 BERT 的 SoftMax 替换为 EMD,其实也就是包含一个或多个 Transformer 层(完全成标配的感觉)再接 SoftMax。

另外本文还做了一个改动,就是在将 Encoder 的输出喂进 Decoder 时,将 MASK Token 中 10% 不改变的 Token 编码换成了他们的绝对位置 Embedding,然后再用 MLM 预测。因为这些 Token 虽然不会造成预训练和精调阶段的不匹配,但是却导致了信息泄露——Token 以本身为条件进行预测。

特点和创新

  • 关注了位置和内容之间的注意力。
  • 通过在预训练中加入 Decoder 避免了预训练和精调 MASK 的不匹配问题。

How

如何训练使用

数据方面和 BERT 一致,训练代码官方并未提供,不过使用倒是提供了开箱即用的方法,可以用 Docker,甚至 pip 安装。具体可以参考官方 GitHub。不过并没有提供中文的模型,要想训练只能自己辛苦一下了。由于使用了 EMD,还是有点好奇下游任务中的使用细节,就简单看了下分类任务,发现和正常的 BERT 并无二样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# From https://github.com/microsoft/DeBERTa
class SequenceClassificationModel(NNModule):
def __init__(self, config, num_labels=2, drop_out=None, pre_trained=None):
super().__init__(config)
self.num_labels = num_labels
self.bert = DeBERTa(config, pre_trained=pre_trained)
self.config = config
pool_config = PoolConfig(self.config)
output_dim = self.bert.config.hidden_size
self.pooler = ContextPooler(pool_config)
output_dim = self.pooler.output_dim()
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
self.bert.apply_state()

def forward(self,
input_ids, type_ids=None, input_mask=None,
labels=None, position_ids=None, **kwargs):
encoder_layers = self.bert(
input_ids, type_ids, input_mask,
position_ids=position_ids, output_all_encoded_layers=True)
pooled_output = self.pooler(encoder_layers[-1])
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

因为 self.bert 出来的其实就是 Encoder 的所有输出和 Attention(设置为 True 时),和 EMD 没有关系,这里池化用的是 Encoder Layer 最后一层的输出,size 为 (batch_size, seq_len, hidden_size)

数据和实验

Large

Base

Ablation

结果表明每个点(EMD,C2P 和 P2C)都挺重要的。

Attention

  • DeBERTa 并没有观察到明显的对角线效应(Token 注意到自己),这归功于 EMD。
  • RoBERTa 出现的竖直条纹主要由高频虚词引起,DeBERTa 的主要出现在第一列,表示 [CLS]。因此对于一个好的预训练模型,强调 [CLS] 是可取的,因为它的向量通常用作下游任务中整个输入序列的上下文表示。

Discussion

从 2018 年开始的基于 Transformer 架构的大型预训练模型:GPT, BERT, RoBERTa, XLNet, UniLM, ELECTRA, T5, ALUM, StructBERT, ERINE。

两篇已有的根据相对位置计算注意力权重的文献:

  • Cheng-Zhi Anna Huang, Ashish Vaswani, Jakob Uszkoreit, Ian Simon, Curtis Hawthorne, Noam Shazeer, Andrew M Dai, Matthew D Hoffman, Monica Dinculescu, and Douglas Eck. Music transformer: Generating music with long-term structure. 2018.
  • Peter Shaw, Jakob Uszkoreit, and Ashish Vaswani. Self-attention with relative position representations. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), pages 464–468, 2018.

本文两个主要创新点虽说都是对现有的一点补充,但感觉还是挺有意义的,尤其是对类似中文、日文这种没分词使用字符作为 Token 的语言,唯一的遗憾是都增加了额外的复杂度(个人对这点很是看重)。另外,Self-Attention 本来就有 c2c,再同时增加 p2c 和 c2p 究竟会不会太过,本文(当然)没有探讨。所以,可能又是一次锦上添花罢了。