CTRL 论文+代码+实践笔记

paper: [1909.05858] CTRL: A Conditional Transformer Language Model for Controllable Generation

code: salesforce/ctrl: Conditional Transformer Language Model for Controllable Generation

核心思想:借鉴多任务,将文本标签作为输入的一部分(放在开头)控制文本生成。

Abstract

文本生成最大的问题是难以对其进行控制,本文发布了一个 1.6 billion 参数的条件 transformer language model,训练能够 govern 风格、内容、特定任务行为等的控制代码。控制代码来自与原始文本共现的结构,保留了无监督学习的优点,同时提供对文本生成更明确的控制。这些控制代码还允许 CTRL 预测训练数据的哪些部分最有可能给出序列。

Introduction

在足够大的模型和足够多的数据下,生成模型能够学到足够强大的分布生成高质量的 sample。在图像领域,2014 年 Goodfellow 的 Gan 大放异彩。在自然语言领域,语言模型经常被训练成特定任务的条件模型;也被用来学习词向量、文档向量、上下文向量等进而在迁移学习中使用;语言模型本身也会迁移到新任务上精调。但生成并没有在这些任务中被限制,典型的文本生成过程也仅仅被粗糙地引导或者直接指定开头。这自然就有个问题:怎样才能对生成过程进行更明确的控制?

本文受图像生成和近期的文本生成及多任务学习启发,训练了一个多控制代码的条件模型。该模型能够给在控制代码(指定 domain, style, topics, dates, entities, relationships between entities, plot points, task-related behavior)条件下生成文本。控制代码来自与原始文本共现的结构,比如 Wiki、评论这种可以指定一个 domain-related 控制代码,小语料还会有一个 subdomain 的信息。文本能够通过控制内容或改变 domain 进而以更加可预测的方式生成。

因为所有控制代码都可以追溯到训练数据的特定子集,所以还可以用来预测最有可能给出序列的训练数据的子集。这些控制代码还允许直接包含特定任务的数据,进而在生成中使特定任务的行为控制代码能够与内容相关的控制代码结合。

Language Model

p(x)=i=1np(xix<i)p(x)=\prod_{i=1}^{n} p\left(x_{i} | x_{<i}\right)

L(D)=k=1Dlogpθ(xikx<ik)\mathcal{L}(D)=-\sum_{k=1}^{|D|} \log p_{\theta}\left(x_{i}^{k} | x_{<i}^{k}\right)

x 是一个序列,训练后的模型自然可以生成特定长度的序列。

Language Modeling With CTRL

本文的模型在控制代码下学习分布。

p(xc)=i=1np(xix<i,c)L(D)=k=1Dlogpθ(xikx<ik,ck)p(x | c)=\prod_{i=1}^{n} p\left(x_{i} | x_{<i}, c\right) \quad \mathcal{L}(D)=-\sum_{k=1}^{|D|} \log p_{\theta}\left(x_{i}^{k} | x_{<i}^{k}, c^{k}\right)

其实就是在预先考虑控制代码的基础上进行训练。

一个单独的序列包括 n 个 token(n 个 d 维向量),每个 token 的向量是 embedding 和 sinusoidal positional embedding 二者之和。序列的向量矩阵为 n×d,可以进一步连 attention。

每一层包括两个 block,第一个 block 是一个 k heads 的 multi-head attention,使用 mask:

 Attention (X,Y,Z)=softmax(mask(XY)d)Z MultiHead (X,k)=[h1;;hk]Wo where hj= Attention (XWj1,XWj2,XWj3)\begin{aligned} \text { Attention }(X, Y, Z) &=\operatorname{softmax}\left(\frac{\operatorname{mask}\left(X Y^{\top}\right)}{\sqrt{d}}\right) Z \\ \text { MultiHead }(X, k) &=\left[h_{1} ; \cdots ; h_{k}\right] W_{o} \\ \text { where } h_{j} &=\text { Attention }\left(X W_{j}^{1}, X W_{j}^{2}, X W_{j}^{3}\right) \end{aligned}

第二个 block 是一个 ReLU 激活的前馈网络:

FF(X)=max(0,XU)VF F(X)=\max (0, X U) V

每个 block 执行 normalization 然后是一个 residual 连接:

Block1:

Xi= LayerNorm (Xi)Hi= MultiHead (Xi)+Xi\begin{aligned} \overline{X}_{i} &=\text { LayerNorm }\left(X_{i}\right) \\ H_{i} &=\text { MultiHead }\left(\overline{X}_{i}\right)+\overline{X}_{i} \end{aligned}

Block2:

Hi= LayerNorm (Hi)Xi+1=FF(Hi)+Hi\begin{aligned} \overline{H}_{i} &=\text { LayerNorm }\left(H_{i}\right) \\ X_{i+1} &=\operatorname{FF}\left(\overline{H}_{i}\right)+\overline{H}_{i} \end{aligned}

每个 token 的 score:

 Scores (X0)= LayerNorm (Xl)Wvocab\text { Scores }\left(X_{0}\right)=\text { LayerNorm }\left(X_{l}\right) W_{\text {vocab}}

训练时,scores 作为 cross-entropy loss function 的输入;生成时,scores 与最终的 token(softmax 后的)相关,然后产生用于采用新 token 的分布。

Data

140G

Experimental Settings

  • Use fastBPE tokenize
  • Vocabulary: 250K tokens, includes sub-word tokens
  • the first token of each sequence is the domain
  • model dimension d = 1280
  • inner dimension f = 8192, 48 layers and 16 heads per layer
  • dropout: 0.1
  • token embeddings tied with the final output embedding layer
  • batch size: 1024
  • 800k iterations
  • Adagrad with a linear warmup from 0 to 0.05 over 25k steps
  • norm of gradients clipped to 0.25

Controllable Generation

Sampling

从语言模型生成文本时一般会用到 temperature-controlled stochastic sampling 方法,同时,每次生成 token 时在 top-k(而不是所有词表)中取。

pi=exp(xi/T)jexp(xj/T)p_{i}=\frac{\exp \left(x_{i} / T\right)}{\sum_{j} \exp \left(x_{j} / T\right)}

  • T -> 0 近似贪婪分布,放大了峰值

  • T -> ∞ 使得分布更加平坦

k 是启发式的(自适应),xi 是每个 token 的 score;如果下个词的 confidence 比较高,k 就小一些。

在有多个非零的高概率候选 token 时,不采用模型,而是 “贪婪” 地选择下一个 token。对可能会产生的重复 token,文章提出一种新的 sample 方法,既能够近似贪婪 sampling,又能够对重复进行惩罚。惩罚的方法是对已产生的 tokens 进行打折(不在训练中使用),给定一列生成的 tokens g:

pi=exp(xi/(TI(ig))jexp(xj/(TI(jg))I(c)=θ if c is True else 1p_{i}=\frac{\exp \left(x_{i} /(T \cdot I(i \in g))\right.}{\sum_{j} \exp \left(x_{j} /(T \cdot I(j \in g))\right.} \quad I(c)=\theta \text { if } c \text { is True else } 1

θ≈1.2 能够取得不错的平衡。

Control Codes

  • Style by domain: Wiki,Books,Reviews,Horror,Relationships,Legal
  • More complex control codes:
    • Science Title, Politics Title, Running Text, Horror Text, Reviews Rating
    • 不同的 Link 代表不同的特征(domain, subdomain, entities, entity relations, and even dates)
  • Triggering specific tasks:问答、翻译
  • Zero-shot code-mixing

详细可参考文中的 sample。

Source Attribution

根据之前的定义,给定 domain control code 的先验 p©,有:

pθ(cx)pθ(xc)p(c)p_{\theta}(c | x) \propto p_{\theta}(x | c) p(c)

为了避免带来的不良影响,文中采用统一的先验。

模型固有地依赖于原始的关联进行预测, 它并不关心关联是否正确或好坏(事实表明,相互矛盾的主张经常出现在相同的上下文中)。CTRL 证明了特定的领域更可能包含与给定陈述相似的语言。

Others

  • Related Work
    • 语言模型:词向量、上下文词向量、注意力机制
    • 多任务学习
    • Sampling 方法和覆盖机制:聚焦在减少重复(替换为连贯的文本)
  • Future directions
    • 更多和更细粒度的控制
    • 扩展到 NLP 的其他领域
    • 分析语言模型和语料的关系
    • 使人与语言模型的接口更加明确和直观

Practice

Usage

  • 安装依赖:Tensorflow1.14 或 PyTorch,glample/fastBPE: Fast BPE

  • 使用 Tensorflow 需要修复 keras.pypatch -b <path_to_tensorflow_estimator_package>/python/estimator/keras.py estimator.patch

  • 获取模型:gs://sf-ctrl/seqlen256_v1.ckpt/ or gs://sf-ctrl/seqlen512_v1.ckpt/ or gs://sf-ctrl/seqlen256_36layers_v0.ckpt/.

    • gsutilgsutil -m cp -r gs://sf-ctrl/seqlen256_v1.ckpt/ .
    • 没有:根据这个链接,要翻墙,且记得创建一个文件夹,比如:mkdir -p seqlen256_v1.ckpt。用 wget 貌似不用翻墙也可以直接下载。注意:模型有十几个 G。
  • 运行:generation.pysource_attribution.py

    • generation.py 提示用户输入文本,然后生成接下来的部分
    • source_attribution.py 提示用户输入文本,然后列出排序的 domain 和文本基于 domain 的 ppl
  • 精调:

    • 修复 keras.py
    • 获取数据并转为 TFRecords:python make_tf_records.py --text_file YOUR_FILE --control_code YOUR_CTRL_CODE --sequence_len 256CTRL_CODE 是一个词表中的 token,会被添加到每条数据中;sequence_len 需要和训练时的一样。
    • 训练:python training.py --model_dir <path_to_model>.ckpt/ --iterations <number_of_iterations>,如果数据很少,可以调低 iterations
    • 生成:python training.py --model_dir seqlen256_v1.ckpt/ --iterations 250

Code

source_attribution.py 比较简单,导进来模型后,根据输入的文本循环计算不同 domain 的 ppl,然后从大到小排序输出。代码是这样子:

1
2
3
4
5
6
7
8
9
10
11
12
13
ppls = {}
xent = 0
# 计算其中一个 domain 的 ppl
# text[1:] 把 domain 去掉
for sequence_idx, token_idx in enumerate(text[1:]):
token = idx2word[token_idx]

# compute the probability of this token
Z = np.exp(token_scores[sequence_idx]).sum()
token_prob = np.exp(token_scores[sequence_idx, token_idx]) / Z
xent -= np.log(token_prob) / len(text[1:])

ppls[domain] = round(np.exp(xent), 6)

ppl 用了下面这个公式(这里底数用了自然对数):

PP(S)=21Nlog(P(wi))P P(S)=2^{-\frac{1}{N} \sum \log \left(P\left(w_{i}\right)\right)}

generation.py 是生成的代码,有几个需要注意的:

  • 如果输入的 token 超过了 seq_length(如上面的 256),则划过前面若干个 token 让剩下的在范围内:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # text 为输入的文本 ids
    for token in range(len(text)-1, args.generate_num-1):
    # get the logits from the prediction function
    # the logic here is a bit convoluted because we are allowing generation past 512 tokens
    # this is done by sliding the window over (past 512 tokens) and continuing prediction
    if token <= seq_length:
    prompt_logits = predict_fn({'input_1':tokens_generated[:, :seq_length]})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.)
    _token = token if token < seq_length else -1
    else:
    _token = -1
    # slide 在这里
    end = token + 1
    start = token - seq_length + 2
    prompt_logits = predict_fn({'input_1':np.hstack((tokens_generated[:,0:1], tokens_generated[:,start:end]))})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.)
  • 计算 logits 使用了 temperature-controlled stochastic sampling 方法

  • 使用了惩罚:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    # if penalty (for repetition) is non-zero,
    # discount the logits from already generated tokens
    if penalty>0:
    penalized_so_far = set()
    for _ in range(token+1):
    generated_token = tokens_generated[0][_]
    # don't penalize newlines
    # you could also choose not to penalize frequent words
    # (which incidentally are sorted in the vocab file)
    # but I don't do that
    # if it prints too many new lines instead of continuing generating text,
    # you might want to comment this out
    if idx2word[generated_token] == '\n':
    continue
    if generated_token in penalized_so_far:
    continue
    penalized_so_far.add(generated_token)
    prompt_logits[_token][generated_token] /= penalty
  • 选择 next token 时有三种方法:top-k,启发式 k,所有的 token

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    # compute probabilities from logits
    prompt_probs = np.exp(prompt_logits[_token])
    prompt_probs = prompt_probs / sum(prompt_probs)
    # 从小到大排列再反转
    pruned_list = np.argsort(prompt_probs)[::-1]
    # if you are using nucleus prob, then compute the nucleus probability size
    if nucleusprob > 0.:
    minimum_topk = 1
    nucleus = max(np.where(np.cumsum(np.sort(prompt_probs)[::-1])>nucleusprob)[0][0], minimum_topk)
    elif topk > 0:
    # we are over-loading notation here
    # if you choose to specify a topk instead of a nucleus,
    # we will hardcode the nucleus to be just that
    nucleus = topk
    else:
    # if you specify neither nucleus or topk,
    # then we will use the whole list
    nucleus = len(pruned_list)

    pruned_list = pruned_list[:nucleus]

最后也是最重要的——模型: transformer.py

  • Encoder: lookup -> sqrt -> add position -> Dropout -> EncoderLayer -> LayerNormalization

  • EncoderLayer:

    • Block1: LayerNormalization -> MultiHeadAttention -> Dropout -> +x = out1
    • Block2: out1 -> LayerNormalization -> FFN -> Dropout -> +out1 = out2
  • MultiHeadAttention: v,k,q = LayerNormalization(x) -> scaled_dot_product_attention

精调的 train.py 是在已有的模型基础上 train 的,不是特别复杂,在 training_utils 下面,这里略过了。

Summary

相信通过以上内容能够对论文和代码有个基本的了解,论文其实还好,代码的细节还是不少的(比如上面提到的),相比而言 model 本身倒是没有什么特别的。如果要迁移到中文,Input 需要做一些调整,再就是需要从头到尾训模型。作者已经发布的 256 seqlen 的有十几个 G,有 GPU 或 TPU 的可以尝试跑一个。

本文最大的特色体现在 “CTRL”,使用 control code 控制文本生成,控制代码可以是主题、实体、关系、特定任务等等。其实它的本质与之前的 Bert 类似:多任务 + 语言模型;这里的多任务可以看作是一个多分类任务。不过本文的切入角度是 “控制文本生成”,虽然是以类别标签的方式,但不得不说这是一个不错的创新点。