# LSTM+CRF序列标注

B

I

I

I

O

O

O

O

O

B

I

## 条件随机场(Conditional Random Field, CRF)

B

I

I

I

O

I

I

I

×

$$x=\{x_0, ..., x_n\}$$为输入序列，$$y=\{y_0, ..., y_n\}，y \in Y$$为输出的标注序列，其中$$n$$为序列的最大长度，$$Y$$表示$$x$$对应的所有可能的输出序列集合。则输出序列$$y$$的概率为：

\begin{align}P(y|x) = \frac{\exp{(\text{Score}(x, y)})}{\sum_{y' \in Y} \exp{(\text{Score}(x, y')})} \qquad (1)\end{align}

$$x_i$$, $$y_i$$为序列的第$$i$$个Token和对应的标签，则$$\text{Score}$$需要能够在计算$$x_i$$$$y_i$$的映射的同时，捕获相邻标签$$y_{i-1}$$$$y_{i}$$之间的关系，因此我们定义两个概率函数：

1. 发射概率函数$$\psi_\text{EMIT}$$：表示$$x_i \rightarrow y_i$$的概率。

2. 转移概率函数$$\psi_\text{TRANS}$$：表示$$y_{i-1} \rightarrow y_i$$的概率。

\begin{align}\text{Score}(x,y) = \sum_i \log \psi_\text{EMIT}(x_i \rightarrow y_i) + \log \psi_\text{TRANS}(y_{i-1} \rightarrow y_i) \qquad (2)\end{align}

\begin{align}\text{Score}(x,y) = \sum_i h_i[y_i] + \textbf{P}_{y_{i-1}, y_{i}} \qquad (3)\end{align}

\begin{align}\text{Loss} = -log(P(y|x)) \qquad (4)\end{align}

\begin{align}\text{Loss} = -log(\frac{\exp{(\text{Score}(x, y)})}{\sum_{y' \in Y} \exp{(\text{Score}(x, y')})}) \qquad (5)\end{align}
\begin{align}= log(\sum_{y' \in Y} \exp{(\text{Score}(x, y')}) - \text{Score}(x, y) \end{align}

### Score计算

[1]:

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# tags: (seq_length, batch_size)
# mask: (seq_length, batch_size)

seq_length, batch_size = tags.shape

# 将score设置为初始转移概率
# shape: (batch_size,)
score = start_trans[tags[0]]
# score += 第一次发射概率
# shape: (batch_size,)
score += emissions[0, mnp.arange(batch_size), tags[0]]

for i in range(1, seq_length):
# 标签由i-1转移至i的转移概率（当mask == 1时有效）
# shape: (batch_size,)
score += trans[tags[i - 1], tags[i]] * mask[i]

# 预测tags[i]的发射概率（当mask == 1时有效）
# shape: (batch_size,)
score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]

# 结束转移
# shape: (batch_size,)
last_tags = tags[seq_ends, mnp.arange(batch_size)]
# score += 结束转移概率
# shape: (batch_size,)
score += end_trans[last_tags]

return score


### Normalizer计算

$log(\sum_{y'_{0,i} \in Y} \exp{(\text{Score}_i})) = log(\sum_{y'_{0,i-1} \in Y} \exp{(\text{Score}_{i-1} + h_{i} + \textbf{P}})) \qquad (6)$

$log(\sum_{y'_{0,i} \in Y} \exp{(\text{Score}_i})) = log(\sum_{y'_{0,i-1} \in Y} \exp{(\text{Score}_{i-1}})) + h_{i} + \textbf{P} \qquad (7)$

[2]:

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)

seq_length = emissions.shape[0]

# 将score设置为初始转移概率，并加上第一次发射概率
# shape: (batch_size, num_tags)
score = start_trans + emissions[0]

for i in range(1, seq_length):
# 扩展score的维度用于总score的计算
# shape: (batch_size, num_tags, 1)

# 扩展emission的维度用于总score的计算
# shape: (batch_size, 1, num_tags)

# 根据公式(7)，计算score_i
# 对应score的log_sum_exp
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + trans + broadcast_emissions

# 对score_i做log_sum_exp运算，用于下一个Token的score计算
# shape: (batch_size, num_tags)
next_score = ops.logsumexp(next_score, axis=1)

# 当mask == 1时，score才会变化
# shape: (batch_size, num_tags)
score = mnp.where(mask[i].expand_dims(1), next_score, score)

# 最后加结束转移概率
# shape: (batch_size, num_tags)
score += end_trans
# 对所有可能的路径得分求log_sum_exp
# shape: (batch_size,)
return ops.logsumexp(score, axis=1)


### Viterbi算法

$P_{0,i} = max(P_{0, i-1}) + P_{i-1, i}$

[3]:

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)

score = start_trans + emissions[0]
history = ()

for i in range(1, seq_length):
next_score = broadcast_score + trans + broadcast_emission

# 求当前Token对应score取值最大的标签，并保存
indices = next_score.argmax(axis=1)
history += (indices,)

next_score = next_score.max(axis=1)
score = mnp.where(mask[i].expand_dims(1), next_score, score)

score += end_trans

return score, history

def post_decode(score, history, seq_length):
# 使用Score和History计算最佳预测序列
batch_size = seq_length.shape[0]
seq_ends = seq_length - 1
# shape: (batch_size,)
best_tags_list = []

# 依次对一个Batch中每个样例进行解码
for idx in range(batch_size):
# 查找使最后一个Token对应的预测概率最大的标签，
# 并将其添加至最佳预测序列存储的列表中
best_last_tag = score[idx].argmax(axis=0)
best_tags = [int(best_last_tag.asnumpy())]

# 重复查找每个Token对应的预测概率最大的标签，加入列表
for hist in reversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(int(best_last_tag.asnumpy()))

# 将逆序求解的序列标签重置为正序
best_tags.reverse()
best_tags_list.append(best_tags)

return best_tags_list


### CRF层

[4]:

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform

def sequence_mask(seq_length, max_length, batch_first=False):
range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
result = range_vector < seq_length.view(seq_length.shape + (1,))
if batch_first:
return result.astype(ms.int64)
return result.astype(ms.int64).swapaxes(0, 1)

class CRF(nn.Cell):
def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
self.num_tags = num_tags
self.batch_first = batch_first
self.reduction = reduction
self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

def construct(self, emissions, tags=None, seq_length=None):
if tags is None:
return self._decode(emissions, seq_length)
return self._forward(emissions, tags, seq_length)

def _forward(self, emissions, tags=None, seq_length=None):
if self.batch_first:
batch_size, max_length = tags.shape
emissions = emissions.swapaxes(0, 1)
tags = tags.swapaxes(0, 1)
else:
max_length, batch_size = tags.shape

if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)

# shape: (batch_size,)
numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
# shape: (batch_size,)
denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
# shape: (batch_size,)
llh = denominator - numerator

if self.reduction == 'none':
return llh
if self.reduction == 'sum':
return llh.sum()
if self.reduction == 'mean':
return llh.mean()
return llh.sum() / mask.astype(emissions.dtype).sum()

def _decode(self, emissions, seq_length=None):
if self.batch_first:
batch_size, max_length = emissions.shape[:2]
emissions = emissions.swapaxes(0, 1)
else:
batch_size, max_length = emissions.shape[:2]

if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)

return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)


## BiLSTM+CRF模型

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF


[5]:

class BiLSTM_CRF(nn.Cell):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
super().__init__()
self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
self.crf = CRF(num_tags, batch_first=True)

def construct(self, inputs, seq_length, tags=None):
embeds = self.embedding(inputs)
outputs, _ = self.lstm(embeds, seq_length=seq_length)
feats = self.hidden2tag(outputs)

crf_outs = self.crf(feats, tags, seq_length)
return crf_outs


[6]:

embedding_dim = 16
hidden_dim = 32

training_data = [(
"清 华 大 学 坐 落 于 首 都 北 京".split(),
"B I I I O O O O O B I".split()
), (
"重 庆 是 一 个 魔 幻 城 市".split(),
"B I O O O O O O O".split()
)]

word_to_idx = {}
for sentence, tags in training_data:
for word in sentence:
if word not in word_to_idx:
word_to_idx[word] = len(word_to_idx)

tag_to_idx = {"B": 0, "I": 1, "O": 2}

[7]:

len(word_to_idx)

[7]:

21


[8]:

model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

[9]:

grad_fn = ms.value_and_grad(model, None, optimizer.parameters)

def train_step(data, seq_length, label):
return loss


[10]:

def prepare_sequence(seqs, word_to_idx, tag_to_idx):
seq_outputs, label_outputs, seq_length = [], [], []
max_len = max([len(i[0]) for i in seqs])

for seq, tag in seqs:
seq_length.append(len(seq))
idxs = [word_to_idx[w] for w in seq]
labels = [tag_to_idx[t] for t in tag]
idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
seq_outputs.append(idxs)
label_outputs.append(labels)

return ms.Tensor(seq_outputs, ms.int64), \
ms.Tensor(label_outputs, ms.int64), \
ms.Tensor(seq_length, ms.int64)

[11]:

data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape

[11]:

((2, 11), (2, 11), (2,))


[12]:

from tqdm import tqdm

steps = 500
with tqdm(total=steps) as t:
for i in range(steps):
loss = train_step(data, seq_length, label)
t.set_postfix(loss=loss)
t.update(1)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:23<00:00, 21.13it/s, loss=0.3487625]


[13]:

score, history = model(data, seq_length)
score

[13]:

Tensor(shape=[2, 3], dtype=Float32, value=
[[ 3.15928860e+01,  3.63119812e+01,  3.17248516e+01],
[ 2.81416149e+01,  2.61749763e+01,  3.24760780e+01]])


[14]:

predict = post_decode(score, history, seq_length)
predict

[14]:

[[0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 1], [0, 1, 2, 2, 2, 2, 2, 2, 2]]


[15]:

idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}

def sequence_to_tag(sequences, idx_to_tag):
outputs = []
for seq in sequences:
outputs.append([idx_to_tag[i] for i in seq])
return outputs

[16]:

sequence_to_tag(predict, idx_to_tag)

[16]:

[['B', 'I', 'I', 'I', 'O', 'O', 'O', 'O', 'O', 'B', 'I'],
['B', 'I', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]