GlobalPointer
计算公式:
$$
p(h_s, h_e, t) = q_{s,t}^Tk_{e,t}
$$ {d}$$
q_{s,t}=W_{s,t}h_s+b_{s,t}
$$$$
k_{e,t}=W_{e,t}h_s+b_{e,t}
$$核心思想
类似于attention的打分机制,将多个实体类型的识别视为Muti-head机制,没一个head视为一种实体识别任务,最后利用attention的score(QK)作为打分
考虑到start和end之间距离的关键信息,作者加入了旋转式位置编码(RoPE)。
核心代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class GlobalPointer(Module):
"""全局指针模块
将序列的每个(start, end)作为整体来进行判断
"""
def __init__(self, heads, head_size,hidden_size,RoPE=True):
super(GlobalPointer, self).__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
# 每个head代表一个类别,head_size*2是应为需要表征Q和K
self.dense = nn.Linear(hidden_size,self.head_size * self.heads * 2)
def forward(self, inputs, mask=None):
inputs = self.dense(inputs)
inputs = torch.split(inputs, self.head_size * 2 , dim=-1)
inputs = torch.stack(inputs, dim=-2)
qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:]
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
cos_pos = pos[..., None, 1::2].repeat(1,1,1,2)
sin_pos = pos[..., None, ::2].repeat(1,1,1,2)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4)
qw2 = torch.reshape(qw2, qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 4)
kw2 = torch.reshape(kw2, kw.shape)
kw = kw * cos_pos + kw2 * sin_pos
# 计算内积
logits = torch.einsum('bmhd , bnhd -> bhmn', qw, kw)
# 排除padding,排除下三角
logits = add_mask_tril(logits,mask)
return logits / self.head_size ** 0.5
TPLinker(Muti_head)
计算公式
$$
p(h_s, h_e, t) = W_t \cdot h_{s, e}+b_t
$$$$
h_{s,e} = tanh(W_h\cdot[h_s;h_e]+b_h)
$$核心思想
与globalPointer相比,muti-head是加性的,而GlobalPointer是乘性的。
核心代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32class MutiHeadSelection(Module):
def __init__(self,hidden_size,c_size,abPosition = False,rePosition=False, maxlen=None,max_relative=None):
super(MutiHeadSelection, self).__init__()
self.hidden_size = hidden_size
self.c_size = c_size
self.abPosition = abPosition
self.rePosition = rePosition
self.Wh = nn.Linear(hidden_size * 2,self.hidden_size)
self.Wo = nn.Linear(self.hidden_size,self.c_size)
if self.rePosition:
self.relative_positions_encoding = relative_position_encoding(max_length=maxlen,
depth= 2 * hidden_size,max_relative_position=max_relative)
def forward(self, inputs, mask=None):
input_length = inputs.shape[1]
batch_size = inputs.shape[0]
if self.abPosition:
# 由于为加性拼接,我们无法使用RoPE,因此这里直接使用绝对位置编码
inputs = SinusoidalPositionEmbedding(self.hidden_size, 'add')(inputs)
x1 = torch.unsqueeze(inputs, 1)
x2 = torch.unsqueeze(inputs, 2)
x1 = x1.repeat(1, input_length, 1, 1)
x2 = x2.repeat(1, 1, input_length, 1)
concat_x = torch.cat([x2, x1], dim=-1)
# 与TPLinker原论文中不同的是,通过重复+拼接的方法构建的矩阵能满足并行计算的要求。
if self.rePosition:
# 如果使用相对位置编码,我们则直接在矩阵上实现相加
relations_keys = self.relative_positions_encoding[:input_length, :input_length, :].to(inputs.device)
concat_x += relations_keys
hij = torch.tanh(self.Wh(concat_x))
logits = self.Wo(hij)
Tencent Muti-head
计算公式
$$
p(h_s, h_e, t) = U\cdot tanh(Vs_{s,e})
$$$$
s_{s,e}=[h_s;h_e;h_s-h_e;h_s\cdot h_e]
$$与TPLinker相比,加入了更多的交互元素:$h_s-h_e$ ,$h_s\cdot h_e$ (作差与点积)
核心思想
提出了基于片段标注解决实体数据标注缺失的训练方法——负采样。
核心代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20def generate_whole_label(self, positions, length):
"""
负采样方式:直接随机采样k(文本长度*采样率)个负样本
"""
neg_positions = []
neg_num = int(length * self.neg_rate) + 1
# 候选样本,即排除正样本之后的全体负样本
candies = flat_list([[(i, j) for j in range(i, length) if (i, j) not in positions] for i in range(length)])
if len(candies) > 0:
sample_num = min(neg_num, len(candies))
assert sample_num > 0
# 随机采样若干个负样本
np.random.shuffle(candies)
for i, j in candies[:sample_num]:
neg_positions.append((i, j))
return neg_positions1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class TxMutihead(Module):
def __init__(self,hidden_size,c_size,abPosition = False,rePosition=False, maxlen=None,max_relative=None):
super(TxMutihead, self).__init__()
self.hidden_size = hidden_size
self.c_size = c_size
self.abPosition = abPosition
self.rePosition = rePosition
self.Wh = nn.Linear(hidden_size * 4, self.hidden_size)
self.Wo = nn.Linear(self.hidden_size,self.c_size)
if self.rePosition:
self.relative_positions_encoding = relative_position_encoding(max_length=maxlen,
depth= 4 * hidden_size,max_relative_position=max_relative)
def forward(self, inputs, mask=None):
input_length = inputs.shape[1]
batch_size = inputs.shape[0]
if self.abPosition:
# 由于为加性拼接,我们无法使用RoPE,因此这里直接使用绝对位置编码
inputs = SinusoidalPositionEmbedding(self.hidden_size, 'add')(inputs)
x1 = torch.unsqueeze(inputs, 1)
x2 = torch.unsqueeze(inputs, 2)
x1 = x1.repeat(1, input_length, 1, 1)
x2 = x2.repeat(1, 1, input_length, 1)
concat_x = torch.cat([x2, x1,x2-x1,x2.mul(x1)], dim=-1)
if self.rePosition:
relations_keys = self.relative_positions_encoding[:input_length, :input_length, :].to(inputs.device)
concat_x += relations_keys
hij = torch.tanh(self.Wh(concat_x))
logits = self.Wo(hij)
logits = logits.permute(0,3,1,2)
logits = add_mask_tril(logits, mask)
return logits
Deep Biaffine
计算公式
$$
p(h_s, h_e, t) = h_s^TU_th_e+W_t[h_s;h_e]+b
$$核心思想
deep biaffine是加性和乘性的结合
核心代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class Biaffine(Module):
def __init__(self, in_size, out_size, Position = False):
super(Biaffine, self).__init__()
self.out_size = out_size
self.weight1 = Parameter(torch.Tensor(in_size, out_size, in_size))
self.weight2 = Parameter(torch.Tensor(2 * in_size + 1, out_size))
self.Position = Position
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.weight1,a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.weight2,a=math.sqrt(5))
def forward(self, inputs, mask = None):
input_length = inputs.shape[1]
hidden_size = inputs.shape[-1]
if self.Position:
#引入绝对位置编码,在矩阵乘法时可以转化为相对位置信息
inputs = SinusoidalPositionEmbedding(hidden_size, 'add')(inputs)
x1 = torch.unsqueeze(inputs, 1)
x2 = torch.unsqueeze(inputs, 2)
x1 = x1.repeat(1, input_length, 1, 1)
x2 = x2.repeat(1, 1, input_length, 1)
concat_x = torch.cat([x2, x1], dim=-1)
concat_x = torch.cat([concat_x, torch.ones_like(concat_x[..., :1])],dim=-1)
# bxi,oij,byj->boxy
logits_1 = torch.einsum('bxi,ioj,byj -> bxyo', inputs, self.weight1, inputs)
logits_2 = torch.einsum('bijy,yo -> bijo', concat_x, self.weight2)
logits = logits_1 + logits_2
logits = logits.permute(0,3,1,2)
logits = add_mask_tril(logits, mask)
return logits