问题 (Problem)

传统的大型语言模型(LLMs)在少样本学习(few-shot learning)上表现出色,但这通常依赖于巨大的参数量来存储世界知识。这引发了一个核心问题:强大的少样本学习能力是否必须与庞大的模型参数(即内置记忆)绑定?

这篇论文旨在探讨是否可以将模型的“记忆”(知识存储)与“推理”(泛化能力)解耦。作者假设,通过将知识存储外包给一个外部的、可检索的知识库,模型可以将更多参数用于学习推理和泛化能力,从而在拥有较少参数的情况下,在知识密集型任务(如问答、事实核查)上实现卓越的少样本学习性能。

本文的目标是设计并训练一个精心构建的检索增强语言模型——ATLAS,验证其在知识密集型任务上,仅用少量样本就能超越巨大参数量模型的潜力。


方法 (Method)

ATLAS 遵循一个统一的“文本到文本”(text-to-text)框架,其中所有任务都被建模为:输入一个文本查询(query),生成一个文本输出(output)。其核心是一个由**检索器(Retriever)语言模型(Language Model)**组成的双模块架构。

模型架构 (Architecture)

当处理一个任务时,模型首先使用检索器从大规模语料库中召回 top-k 个最相关的文档,然后将这些文档连同原始查询一起输入语言模型,最终生成答案。

  • 检索器 (Retriever):基于 Contriever 模型,这是一个使用双编码器(dual-encoder)架构的稠密检索器。

    • 它使用一个 Transformer 编码器独立地将查询(query)和文档(document)编码成向量(通过对最后一层输出进行平均池化)。
    • 查询和文档之间的相关性分数通过计算它们向量表示的点积得到。
    • Contriever 的一个优点是它可以通过无监督的对比学习进行预训练。
  • 语言模型 (Language Model):基于 T5 (Seq2Seq) 架构,并采用了 Fusion-in-Decoder (FiD) 的思想。

    • 处理方式:为了高效处理多个文档,模型将原始查询与每个检索到的文档拼接,然后让编码器(Encoder)独立地处理每一个 [query, document] 对。
    • 融合方式:所有文档经过编码器后的表征(representations)被拼接(concatenate)成一个长序列。解码器(Decoder)在这个长序列上执行交叉注意力(cross-attention),从而在生成答案时融合所有文档的信息。
    • 优势:这种方式避免了将所有文档拼接成一个超长序列输入编码器,从而绕开了 Transformer 自注意力机制的二次方复杂度问题,扩展性更好。

下图清晰地展示了 ATLAS 的工作流程:
Figure1
图1 解读:该图展示了 ATLAS 在预训练和少样本微调两个阶段的工作模式。

  • 预训练 (Pretraining):以“遮盖语言建模 (Masked Language Modelling)”为例,模型输入一个带 <MASK> 的句子,然后从知识库中检索相关文档(例如关于“百慕大三角”的传说),最终目标是预测出被遮盖的词(“western part”)。
  • 少样本学习 (Few-shot):以下游任务为例,如“事实核查”或“问答”。模型接收一个输入(如一个待核查的陈述或一个问题),检索相关文档,并基于输入和文档内容生成最终输出(“False” 或 “Western part of the North Atlantic Ocean”)。

检索器的训练目标 (Training Objectives for the Retriever)

本文的核心创新之一在于如何联合训练检索器和语言模型,特别是如何让语言模型为检索器提供监督信号,而无需人工标注的“查询-相关文档”对。论文探讨了四种损失函数:

  1. 注意力蒸馏 (Attention Distillation, ADist)

    • 思想:语言模型解码器在生成答案时,对不同文档的交叉注意力分数可以被视为该文档重要性的代理指标。
    • 改进:传统的注意力蒸馏只考虑注意力权重 αn\alpha_n。本文提出,一个词元(token)的实际贡献度应同时考虑其注意力权重和其值向量(value vector)的范数,即 αnvn2\alpha_n |v_n|_2
    • 损失函数:将上述贡献度在所有头、层和词元上聚合,为每个文档计算一个总分,并通过 Softmax 得到目标概率分布 pATTNp_{ATTN}。然后,最小化 pATTNp_{ATTN} 与检索器输出的概率分布 pRETRp_{RETR} 之间的 KL 散度。
    • 公式
      • 检索器分布:pRETR(dq)=exp(s(d,q)/θ)k=1Kexp(s(dk,q)/θ)p_{RETR}(d|q) = \frac{\exp(s(d,q)/\theta)}{\sum_{k=1}^{K}\exp(s(d_{k},q)/\theta)}
      • 损失函数:L=KL(pATTNpRETR)=k=1KpATTN(dk)log(pATTN(dk)pRETR(dk))\mathcal{L} = KL(p_{ATTN} || p_{RETR}) = \sum_{k=1}^{K} p_{ATTN}(d_k) \log \left( \frac{p_{ATTN}(d_k)}{p_{RETR}(d_k)} \right)
    • 注意:在计算时,会对 pATTNp_{ATTN} 施加 STOP_GRADIENT,确保损失只用于更新检索器参数。
  2. 端到端训练 (EMDR²)

    • 思想:受期望最大化(EM)算法启发,将被检索的文档视为隐变量(latent variables)。
    • 损失函数:最大化生成正确答案 aa 的边际对数似然。
    • 公式L=log[k=1KpLM(aq,dk)pRETR(dkq)]\mathcal{L} = \log\left[\sum_{k=1}^{K}p_{LM}(a|q,d_{k})p_{RETR}(d_{k}|q)\right]
    • 注意:同样对语言模型概率 pLMp_{LM} 施加 STOP_GRADIENT
  3. 困惑度蒸馏 (Perplexity Distillation, PDist)

    • 思想:训练检索器去预测“哪个文档能最大程度地降低语言模型生成正确答案时的困惑度(Perplexity)”。
    • 损失函数:为每个文档 dkd_k 计算语言模型生成答案的对数概率 logpLM(adk,q)\log p_{LM}(a|d_k, q),将其通过 Softmax 归一化后作为目标分布 yky_k。然后最小化 yky_kpRETRp_{RETR} 的 KL 散度。
    • 公式
      • 目标分布:yk=exp(logpLM(adk,q))i=1Kexp(logpLM(adi,q))y_{k} = \frac{\exp(\log p_{LM}(a|d_{k},q))}{\sum_{i=1}^{K}\exp(\log p_{LM}(a|d_{i},q))}
      • 损失函数:L=KL(ypRETR)\mathcal{L} = KL(y || p_{RETR})
    • 结论:实验表明 PDist 效果好、计算稳定,被选为最终模型使用的目标。
  4. 留一法困惑度蒸馏 (LOOP)

    • 思想:一个文档的重要性体现在“当从上下文中移除该文档时,模型预测的性能会下降多少”。
    • 损失函数:对于每个文档 dkd_k,计算在缺少它的情况下(即使用其他 K1K-1 个文档)生成答案的负对数概率 logpLM(aDKdk,q)-\log p_{LM}(a|\mathcal{D}_K \setminus {d_k}, q) 作为其重要性分数。将这些分数通过 Softmax 归一化得到目标分布 pLOOPp_{LOOP},并最小化其与 pRETRp_{RETR} 的 KL 散度。
    • 公式
      • 目标分布:pLOOP(dk)=exp(logpLM(aDK\dk,q))i=1Kexp(logpLM(aDK\di,q))p_{LOOP}(d_{k}) = \frac{\exp(-\log p_{LM}(a|\mathcal{D}_{K}\backslash{d_{k}},q))}{\sum_{i=1}^{K}\exp(-\log p_{LM}(a|\mathcal{D}_{K}\backslash{d_{i}},q))}
    • 特点:计算成本更高,但其评估方式(使用 K-1 个文档)更接近语言模型的实际工作状态。

预训练任务 (Pretext Tasks)

为了让模型学习到如何利用检索到的知识,研究者们设计了三种无监督的联合预训练任务:

  1. 前缀语言建模 (Prefix LM):将一段文本一分为二,用前半段作为查询,预测后半段。
  2. 遮盖语言建模 (Masked LM):T5 风格的任务,随机遮盖文本中的一些片段(span),用带遮盖的文本作为查询,预测被遮盖的内容。实验表明这是效果最好的任务
  3. 标题到章节生成 (Title to section generation):用维基百科文章的标题和章节标题作为查询,生成该章节的内容。

高效的检索器微调 (Efficient Retriever Fine-tuning)

在下游任务上微调时,如果检索器的文档编码器(document encoder)更新了,整个知识库的索引(数百万甚至上亿个文档的向量)都需要重新计算,这非常耗时。为此,论文提出了三种策略:

  1. 全量索引更新 (Full index update):每隔 R 个训练步就重新计算一次整个索引。开销大。
  2. 重排序 (Re-ranking):用旧的(stale)索引先检索出 L (L > K) 个候选文档,然后用更新后的文档编码器只对这 L 个文档重新编码和排序,取 top-K 送给语言模型。开销显著降低
  3. 仅查询侧微调 (Query-side fine-tuning):在微调时,冻结文档编码器的参数,只更新查询编码器(query encoder)。这样文档索引就无需更新。实验证明,这种方法在少样本场景下效果非常好,甚至能防止过拟合,是少样本微调的首选策略。

Baseline (对比模型)

  • 闭卷模型 (Closed-book T5):使用相同规模的 T5 模型,但不进行检索。这是证明检索增强有效性的关键对照组。
  • 其他大型语言模型:在少样本评测中,与 GPT-3 (175B)、Gopher (280B)、Chinchilla (70B)、PaLM (540B) 等参数量远超 ATLAS (11B) 的模型进行比较。
  • 其他检索增强模型:在全数据集(full-dataset)设定下,与 FiD、SEAL 等先进的检索模型进行比较。

数据集 (Datasets)

  • 知识库/预训练语料:2021年12月的英文维基百科快照(3700万个段落)和 Common Crawl (CCNet) 网络文本(3.5亿个段落)。
  • 评测基准
    • MMLU:包含57个领域的人类考试多项选择题,用于评估模型的广博知识和推理能力。
    • KILT:包含11个知识密集型任务的数据集集合,如 Natural Questions (NQ), TriviaQA, FEVER, Wizard of Wikipedia (WoW) 等。
    • 原始开放域问答集:原始版本的 NQ 和 TriviaQA。
    • TempLAMA:一个专门构建的、用于测试模型对时间敏感信息处理能力的数据集,其中问题的答案会随时间变化(例如,2017年和2020年的答案不同)。

可复现性 (Reproducibility)

  • 代码与模型:论文明确指出,代码、预训练的 ATLAS 模型检查点和相关数据都在 GitHub 上开源,可复现性非常高。
    • GitHub: facebookresearch/atlas
  • 算力与配置
    • 模型规模:从 770M 到 11B 参数不等,主要报告的是 11B 模型的结果。
    • 硬件需求:训练和运行 ATLAS 需要相当大的计算资源。特别是其知识库索引,即使在半精度(fp16)下,维基百科索引也需要 49GB 的 GPU 显存,而混合索引则需要 587GB。
    • 解决方案:论文展示了通过乘积量化 (Product Quantization) 技术可以大幅压缩索引。例如,混合索引可以从 587GB 压缩到 50GB,而下游任务性能下降很小,这使得在单张 80GB 的 GPU 上部署成为可能。

图4 解读:该图展示了在 64-shot NQ 任务上,索引大小与模型性能的关系。
Figure4

  • 上排:维基百科 + CC 混合索引。下排:仅维基百科索引。
  • 左列:检索召回率 (Recall@50)。右列:问答准确率 (Exact Match)。
  • 结论:可以将索引压缩一个数量级(例如从 ~500GB 压缩到 ~50GB,或从 ~50GB 压缩到 ~4GB),而召回率和最终准确率几乎没有损失。只有在极度压缩时,性能才会显著下降。

可改进的几个点 (Potential Improvements)

  1. 更复杂的推理模式:当前模型对于一次检索就能解决的问题表现优异,但对于需要**多步推理(multi-hop reasoning)**的复杂问题,一次性检索可能不足。未来的工作可以探索迭代式检索和推理的框架。
  2. 检索与语言模型的融合方式:Fusion-in-Decoder 是一种有效的融合策略,但信息主要在解码器层面融合。探索更深层次的融合机制,让检索到的信息在模型的每一层都与原始查询进行交互,可能会带来性能提升。
  3. 索引内容与结构:目前的索引是扁平的段落集合。引入结构化知识(如知识图谱)或不同粒度的文本(句子、篇章)可能会让检索更精准。
  4. 训练信号的探索:尽管论文提出了四种为检索器提供监督信号的方法,但它们之间的性能差异不大。这表明可能还存在更优、更直接的训练范式有待发掘。

可以被引用的一些结论 (Key Takeaways / Citable Conclusions)

  1. 少样本学习的核心结论:通过检索增强,模型可以在参数量远小于传统 LLMs 的情况下,在知识密集型任务上取得SOTA(state-of-the-art)的少样本学习效果。记忆可以被有效外包
  2. 惊人的性能数据
    • Natural Questions (64-shot) 任务上,11B 参数的 ATLAS 准确率超过 42%,比 540B 参数的 PaLM 高出 3%,而参数量仅为其 1/50
    • TriviaQA (64-shot) 任务上,ATLAS 取得了 84.7% 的高准确率。
  3. 联合预训练的重要性检索器和语言模型的联合预训练是实现强大少样本能力的关键。只在微调阶段引入检索的模型性能远不如联合预训练过的模型。
  4. 模型的可更新性 (Updatability):ATLAS 展示了卓越的可更新能力。通过简单地更换知识库索引(例如,从2017年维基百科换成2020年),模型无需重新训练就能正确回答与时间相关的问题,这是纯参数模型无法做到的。
  5. 高效微调策略:在少样本微调时,仅更新查询编码器 (Query-side fine-tuning) 是一种计算高效且性能优异的策略,它避免了昂贵的索引重建。
  6. 可解释性与实用性
    • 通过分析检索到的文档,可以理解模型的决策依据。例如,在 MMLU 任务上,当正确答案在检索文档中出现次数越多,模型准确率越高。
    • 通过索引压缩技术,可以显著降低模型的部署成本,使其更具实用价值。