上下文窗口(context window)是指语言模型在进行预测或生成文本时,所考虑的前一个词元(token)或文本片段的大小范围。
在语言模型中,上下文窗口对于理解和生成与特定上下文相关的文本至关重要。较大的上下文窗口可以提供更丰富的语义信息、消除歧义、处理上下文依赖性,并帮助模型生成连贯、准确的文本,还能更好地捕捉语言的上下文相关性,使得模型能够根据前文来做出更准确的预测或生成。
最新发布的语言大模型的上下文窗口越来越大。本文详细探讨了大型上下文窗口的技术可能性,尤其分析了将上下文长度增加到100K背后的六大优化技巧。本文作者Galina Alperovich是英国数据安全服务商Soveren的机器学习负责人。
(以下内容由OneFlow编译,转载请联系OneFlow获得授权。来源:https://medium.com/gopenai/how-to-speed-up-llms-and-use-100k-context-window-all-tricks-in-one-place-ffd40577b4c )
最近有几个新的语言大模型(LLM)发布,这些模型可以使用非常大的上下文窗口,例如65K词元(MosaicML的MPT-7B-StoryWriter-65k+)和100K词元的上下文窗口(Antropic)。在Palm-2技术报告中,谷歌并没有透露具体上下文大小,但表示他们“显著增加了模型的上下文长度”。
相比之下,当前GPT-4模型可以使用32K输入词元的上下文长度,而大多数开源LLM的上下文长度为2K词元。
如此大的上下文长度意味着提示(prompt)可以达到一本书的大小。《了不起的盖茨比》有72K词元,210页,按1.7分钟/页的阅读速度计算,需要6小时的阅读时间。因此,模型可以扫描并保留此数量的“自定义”信息来处理查询!
我想要弄清楚大型上下文窗口的技术可能性。本文搜集了一些零散信息,内容如下:
-
为何上下文长度如此重要,且能在LLM中起到举足轻重的作用? -
处理大型上下文长度时,原始Transformer架构的主要局限性是什么? -
Transformer架构的计算复杂度 -
目前有哪些可以加速Transformer并将上下文长度增加到100K的优化技术?
重点概览
-
第一个问题是注意力层(attention layer)计算的二次方时间(Quadratic time)和空间复杂度,即输入词元数量n。
-
当嵌入大小d>n时,第二个问题是嵌入大小d的线性层的二次方时间复杂度。
-
第三个问题是原始架构中使用的位置正弦嵌入(Positional Sinusoidal Embedding )。
-
在Transformer架构中,可学习(learnable)矩阵权重的形状与输入词元n的数量无关。
-
因此,在2K上下文长度中训练的Transformer可以使用任意长度的词元,甚至是100K词元。但如果不是在100K词元上训练出来的,那么该模型在100K词元的推理过程中不会产生有意义的推理结果。
-
由于n、d相关的二次复杂度,在巨型语料库上训练Vanilla Transformer,并且只在较大的上下文长度上训练是不可行的。据估计,在2K上下文长度上训练LLaMA的费用约为300万美元,因此,100K的花费约为1.5亿美元。
-
一种选择是,可以在2K词元上下文中训练模型,然后在更长的上下文词元(例如65K)中微调。但由于位置正弦编码(Positional Sinusoidal Encoding)的存在,这不适用于原始Transformer模型。
-
[技巧1] 为解决此问题,可删除位置正弦编码并使用ALiBi,这一简单位置嵌入不会影响准确性。然后可以在2K词元上训练,在100K词元上微调。
-
[技巧2] 无需计算所有词元间的注意力分数(attention scores)。某些词元比其他词元更重要,因此可使用稀疏注意力。这将提升训练和推理速度。
-
[技巧3] Flash Attention有效地实现了GPU的注意力层。它使用切片(tiling)技术,避免生成不适合GPU SRAM容量的大型中间矩阵(n,n)。这将提升训练和推理速度。
-
[技巧4] 选择多查询注意力(Multi-Query attention),而非多头注意力。这意味着线性投影K和V时,可在跨所有注意力头(head)中共享权重。这极大地加快了增量(incremental)推理速度。
-
[技巧5] 条件计算(Conditional computation)避免将所有模型参数应用于输入序列中的所有词元。CoLT5仅对最重要的词元应用重量级计算,并使用较轻量级的层处理其余词元。这将加速训练和推理。
-
[技巧6] 为适应大型上下文,需要GPU中有大量RAM,因此人们使用80GB的A100 GPU。
为何上下文长度如此重要?
-
尝试总结技巧和复杂的链式提示。
-
维护向量数据库以保留自定义文档的嵌入,然后通过相似性指标在它们之间展开“搜索”。
-
尽可能使用自定义数据微调LLM(并非所有商业LLM都允许自定义微调,对开源LLM进行自定义微调并不常见)。
-
为特定数据开发定制小型LLM(同样,这并非常规任务)
原始Transformer和上下文长度
多头注意力回顾
多头注意力(Multi-Head Attention)
1. 我们有一个查找嵌入层,用于接收词元作为输入,并返回大小为(1,d)的向量。因此,对于一个由n个词元组成的序列,我们得到大小为(n,d)的文本嵌入矩阵X,然后将其与位置正弦嵌入相加。
2. 多头注意力层旨在为词元序列计算新的嵌入表示,该词元序列可以被视为对原始文本编码X,但需要,(1)根据词元间相对于上下文的重要性进行加权,(2)根据词元的相对位置进行加权。
3. 我们使用h个注意力头对嵌入矩阵X(n×d)进行并行处理。为了使所有的注意力头都得到Q、K和V,我们需要对X进行线性投影,将其分别投影到k、k和v维度。为此,可以通过将X分别与形状为(d,k)、(d,k)和(d,v)的h个矩阵相乘来实现。你可将其理解为用(n,d)乘以(h,d,k)、(h,d,k)和(h,d,v)。
4. 注意力头返回大小为(n,v)的h个注意力分数矩阵。然后,我们将来自所有注意力头(n,h*v)的片段进行连接,并对其进行线性投影,为后续步骤做准备。
缩放点积注意力(Scaled Dot-Product Attention)
Transformer的复杂度和上下文长度
1. 线性投影得到Q,K,V:大小为(n,d)的嵌入矩阵乘以h个可学习矩阵(d,k),(d,k)和(d,v)。因此,复杂度约为O(nd²)
2. 将Q与变换后的K相乘,然后再乘以V:(n,k)*(k,n)=(n,n),以及(n,n)*(n,v)=(n,v)。复杂度约为O(n²d)。
假设token数量为n时,注意力的复杂度为O(n²d + nd²),需要进行M次迭代来进行训练。如果我们将上下文长度从n增加到p*n,由于上下文长度变大,所需的迭代次数将变为M/p(这里简单假设它是线性的,实际情况可能会高点或低点,具体取决于任务)。现在我们有两个方程式: (1)n的复杂度为M * (n²d + nd²) (2)pn的复杂度为M/p * ((pn)²d + (pn)d²) 经过一系列简化和除法,得到比值(2)/(1)的近似为 (d + p*n)/(d + n)。 如果 d << n,将n增加p倍将导致迭代次数增加约p倍。 如果 d ~ n,将n增加p倍将导致迭代次数增加约p/2倍。
Transformer训练阶段和推理阶段的区别
增加上下文长度的优化技术
[技巧1] 更好的位置编码——ALiBi
[技巧2] 稀疏注意力机制
[技巧3] FlashAttention——用于GPU的注意力层高效实现
1. S = Q*K
2. P = softmax(S)
3. O = P*V
[技巧4] 多查询注意力(Multi-Query Attention,MQA)
[技巧5] 条件计算
[技巧6] 大型内存GPU
结论
-
[1] Introducing 100K Context Windows by Antropic (https://www.anthropic.com/index/100k-context-windows ) -
[2] MPT-7B by MosaicML (https://www.mosaicml.com/blog/mpt-7b ) -
[3] Palm-2 Technical report by Google (https://ai.google/static/documents/palm2techreport.pdf ) -
[4] ALiBI: Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (https://arxiv.org/abs/2108.12409 ) -
[5] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (https://arxiv.org/abs/2205.14135 ) -
[6] Multi-Query attention: Fast Transformer Decoding: One Write-Head is All You Need (https://arxiv.org/pdf/1911.02150.pdf ) -
[8] Attention is All You Need (https://arxiv.org/abs/1706.03762 ) -
[9] Video on Positional Sinusoidal Embedding (https://www.youtube.com/watch?v=dichIcUZfOw&ab_channel=HeduAI ) -
[10] Overview of the FlashAttention paper (https://shreyansh26.github.io/post/2023-03-26_flash-attention/ ) -
[11] Sliding Window Attention (https://paperswithcode.com/method/sliding-window-attention ) -
[12] Constructing Transformers For Longer Sequences with Sparse Attention Methods (https://shreyansh26.github.io/post/2023-03-26_flash-attention/ ) -
[13] FlashAttention implementation in Triton language(file:///C:/Users/Administrator/Desktop/%E4%B8%8B%E7%8F%AD%E4%BA%A4.docx#L584 ) -
[14] How to Accelerate HuggingFace Throughput by 193% with Triton and ClearML (https://clear.ml/blog/increase-huggingface-triton-throughput-by-193/ ) -
[15] ClearML Serving (https://github.com/allegroai/clearml-serving ) -
[16] Analyzing the Pros and Cons of NVIDIA Triton Inference Server vs. Other Inference Engines (https://ts2.space/en/nvidia-triton-inference-server-vs-other-inference-engines-which-is-best-for-your-project/ ) -
[17] COLT5: Faster Long-Range Transformers with Conditional Computation (https://arxiv.org/pdf/2303.09752.pdf ) -
[18] LongT5: Efficient Text-To-Text Transformer for Long Sequences(https://arxiv.org/abs/2112.07916 ) -
[19] PaLM(https://arxiv.org/pdf/2204.02311.pdf ) -
[20] BigBird attention mechanism (https://arxiv.org/abs/2007.14062