V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
Richard14
V2EX  ›  程序员

深度学习中位置编码的本质是不是就是一层 nn.Parameter()而已?

  •  
  •   Richard14 · 2022-07-25 19:44:01 +08:00 · 1377 次点击
    这是一个创建于 854 天前的主题,其中的信息可能已经有所发展或是发生改变。

    近日阅读 Bert 源码,读到其中所谓“可训练式”的位置编码的部分,似乎具体实现就是初始化一个可训练的,然后把它加到输入上,这就算是可训练的位置编码了。

    def __init__(self):
        super()
        positional_encoding = nn.Parameter()
    def forward(self, x):
        x += positional_encoding
    

    我不太能确定是否就是这么简单,还是我理解错了。

    另外我注意到按照经典 Bert 的实现:

    BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(...)
          ...
      (pooler): BertPooler(...)
    

    实际上它只进行了一次 positional_embedding ,也就是在输入的时候,这是否意味着实际上所谓的能把握序列信息的也只有第一层,后面的 bert 层是无法识别位置信息的?

    但是这么说的话,按照我上一个帖子的内容 https://v2ex.com/t/868398#reply5 ,bert 的输出根据设置一般有两部分,一部分是类似原始输入序列长度对应长度的数组,表示每个字经过特征提取后的消息,另一部分是一个表示全句子整体信息的层。理论上前者应该与输入是一一对应的,但是它都不能捕捉位置,这不是很矛盾吗

    7 条回复    2022-07-27 07:32:13 +08:00
    ox180
        1
    ox180  
       2022-07-26 08:59:35 +08:00
    1 、position_embeddings 是把句子的位置信息加进去,self-attention 是为了学习语义信息。如果不管句子顺序如何,后面 self-attention 输出都一样。另外 position_embeddings 是 bert 的做法,transformer 是用的正余弦编码,关于位置编码,你还可以了解下绝对位置编码和相对位置编码。

    2 、bert 输出有两个,一个是[CLS]位置所代表的句子信息,其实你看了源码就知道,这个无非是多增加其他层进行训练了罢了,你可以不用,直接用 seq_output 。捕捉位置在最开始 position_embeddings 已经学到了。
    ecwu
        2
    ecwu  
       2022-07-26 11:41:45 +08:00   ❤️ 1
    - 位置编码在输入时加在了词嵌入中,模型里的 Transformer Block 都有残差链接,这样位置的信息也可以传递到后面的层,被后面的层“把握”。

    - 输出的“整体信息”和每个输入 token 的 embedding ( embedding 也就是你说的特征提取后的信息)都在一个输出层上。一般认为插入在句子输入最前面的 [CLS] token 对应的 embedding 包含了后面输入句子的全部信息,这里的原因是在 BERT 的 NSP 预训练任务时,会拿 [CLS] 位置的 embedding 来预测输入的两句话的先后关系,这样 Self-Attention 的过程就会把后面的句子的信息集中到 [CLS] 的位置的 embedding 中。所以加入的 CLS token 并不是说人为加入了一个全局信息。

    - 如果你要把 BERT 用在自己的回归任务上,可以只将预训练的 BERT 当作一个获取词嵌入的工具。也就是在 BERT layer 的输出给到回归任务的输入。但具体用 BERT layer 的全局 embedding ([CLS] 位置输出),还是取输入 token embedding 的平均,都可以尝试。
    Richard14
        3
    Richard14  
    OP
       2022-07-27 00:22:57 +08:00
    @ecwu 好的,我原先不是很确定 bert 的输出含义,比如 bert layer*n 结束后它输出是一个 cls+n 个 token 的特征信息,有一种模糊的感觉是这样但是不能确定,网上信息里对大体原理讲述的比较多,涉及具体行为的比较少,尤其涉及具体预训练细节的几乎没有。

    按照你的说法我的一个想法是,如果 bert 的输出可以认为是一个高级版 word2vec 的话,所有 token embedding 取平均感觉逻辑上不太能说得通,也许我应该测试在输出结束后再接一层 rnn 之类的。。如果不接 rnn 的话,是不是应该尝试将结果再进行位置编码再进入 mlp ,因为 bert 输出的 token 有应该不是前后顺序完全不影响的吧。。是不是还是应该有位置因素
    ecwu
        4
    ecwu  
       2022-07-27 00:59:17 +08:00 via Android   ❤️ 1
    @Richard14 你可以理解 BERT 给出的 embedding 是高级版 w2v (严谨点是叫 contextual word embedding ,也就是同一个词,在不同的上下文里,embedding 是不同的,不同于 w2v 或者 GloVe 学习完就是固定的)

    取平均来获得输入的全局的表示确实会损失隐式信息,但是 CLS 位置 embedding 是通过 self-attention 获得的,本质上就是对 token embedding 的加权平均。所以用 CLS 还是取平均,需要看具体的任务是干什么。

    如果你是对输入句子做分类或输出浮点数,你可以考虑直接拿 CLS 位置的 embedding 给到 MLP 。如果是继续生成内容,可以去了解下 Seq2seq 架构。

    最后你提到的 RNN 或者 MLP + 位置编码的想法。我个人认为 RNN 可以尝试。而 MLP 方案,你的输入会过于巨大( 768 * token 长度),不太可行。
    Richard14
        5
    Richard14  
    OP
       2022-07-27 06:26:30 +08:00
    @ecwu 谢谢,很有帮助,确实是太大了
    Richard14
        6
    Richard14  
    OP
       2022-07-27 06:32:09 +08:00
    @ecwu 对了大佬我还想问一下关于预训练,因为我的文本是脱敏的没法直接用成品我需要自己训练,我没太搞懂多任务训练训练是实践上怎么结合起来的。它原论文有 mlm 和 nsp 两种方式,正确的做法是比如我先构建模型和对应 mlm 的输出,把它训练到类似收敛,然后再把输出层换成 nsp 的再重新训练到收敛,这样先后训练算是经过两个预训练吗,还是说它有什么交替训练的办法。如果有顺序的话会不会导致结果差异
    ecwu
        7
    ecwu  
       2022-07-27 07:32:13 +08:00 via Android   ❤️ 1
    @Richard14 不同预训练任务是替换不同的输出层,这里你可以参考下原论文。预训练任务的顺序会导致模型效果的差异。

    使用 HuggingFace 来训练自己的模型可以参考 https://stackoverflow.com/questions/65646925/how-to-train-bert-from-scratch-on-a-new-domain-for-both-mlm-and-nsp
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   3096 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 31ms · UTC 14:11 · PVG 22:11 · LAX 06:11 · JFK 09:11
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.