LLMs之FlashAttention-2:《FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning更快的注意力與更好的并行性和工作分區(qū)》翻譯與解讀
導(dǎo)讀:FlashAttention-2通過算法、并行計(jì)算和工作分配的優(yōu)化,實(shí)現(xiàn)了原FlashAttention注意力計(jì)算的顯著加速,有助于推動(dòng)更長序列模型和應(yīng)用的發(fā)展。
長文本序列模型的 Attention 計(jì)算開銷是序列長度的二次方,限制了模型輸入序列長度的擴(kuò)展。原來的FlashAttention算法已經(jīng)實(shí)現(xiàn)了2-4倍速度提升,但效率仍有提升空間,無法達(dá)到矩陣乘法性能極限。新的FlashAttention-2算法從根本上重寫,利用Nvidia CUTLASS庫實(shí)現(xiàn),利用并行計(jì)算更全面,工作分配更優(yōu)化。
>> FlashAttention-2減少非矩陣乘法運(yùn)算,提升矩陣乘法利用率,GPU資源利用率更高。
>>FlashAttention-2在并行度方面增加了序列長度維度的并行計(jì)算。
>>FlashAttention-2在線程塊內(nèi)部線程群的工作分配上進(jìn)行了優(yōu)化,減少線程同步帶來的緩存讀取開銷。
>>FlashAttention-2支持頭數(shù)量最大256,支持多查詢注意力機(jī)制(如MQA/GQA),應(yīng)用范圍更廣。
>>通過Benchmark證實(shí),FlashAttention-2在A100上實(shí)現(xiàn)2倍速度提升,最高能達(dá)到335TFLOPs/s效率。
>>在訓(xùn)練GPT風(fēng)格模型上,FlashAttention-2可以實(shí)現(xiàn)1.3倍訓(xùn)練速度提升。
未來工作將繼續(xù)在新硬件和數(shù)據(jù)類型上優(yōu)化FlashAttention-2算法。
《FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning》翻譯與解讀
地址
博客文章:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning | Princeton NLP Group
時(shí)間
2023年7月17日
作者
谷歌學(xué)者Tri Dao,坦福大學(xué)計(jì)算機(jī)科學(xué)博士,Together.AI的首席科學(xué)家
擴(kuò)展Transformer的上下文長度是一個(gè)挑戰(zhàn)—需要更長上下文的語言模型:GPT-4(32k)、MPT(65k)、Claude(100k)
Just within the last year, there have been several language models with much longer context than before: GPT-4 with context length 32k, MosaicML’s MPT with context length 65k, and Anthropic’s Claude with context length 100k. Emerging use cases such as long document querying and story writing have demonstrated a need for models with such long context. Scaling up the context length of Transformers is a challenge, since the attention layer at their heart has runtime and memory requirements that are quadratic in the input sequence length.
僅在過去的一年里,出現(xiàn)了幾種比以前上下文更長的語言模型:GPT-4的上下文長度為32k, MosaicML的MPT的上下文長度為65k, Anthropic的Claude的上下文長度為100k。新興的用例,如長文檔查詢和故事創(chuàng)作,已經(jīng)證明了需要具有如此長上下文的模型。擴(kuò)展Transformer的上下文長度是一個(gè)挑戰(zhàn),因?yàn)樗鼈兒诵牡淖⒁饬泳哂信c輸入序列長度成二次方關(guān)系的運(yùn)行時(shí)和內(nèi)存要求。
2022年發(fā)布FlashAttention(速注意力并減少其內(nèi)存占用):比幾線快2~4倍
A year ago, we released FlashAttention, a new algorithm to speed up attention and reduce its memory footprint—without any approximation. We’ve been very happy to see FlashAttention being adopted by many organizations and research labs to speed up their training & inference (see this page for a partial list). Even though FlashAttention was already 2-4x faster than optimized baselines at the time of its release, it still has quite a bit of headroom. FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s (e.g. up to 124 TFLOPs/s on A100 GPU).
一年前,我們發(fā)布了FlashAttention,這是一種新的算法,可以加速注意力并減少其內(nèi)存占用,而無需任何近似。我們很高興看到FlashAttention被許多組織和研究實(shí)驗(yàn)室采用,以加快他們的訓(xùn)練和推理(請查看此頁面以獲取部分列表)。盡管FlashAttention在發(fā)布時(shí)已經(jīng)比優(yōu)化后的基線快了2-4倍,但它仍然有相當(dāng)大的提升空間。FlashAttention仍然不如優(yōu)化的矩陣乘法(GEMM)操作快,僅達(dá)到理論最大FLOPs/s的25-40%(例如在A100 GPU上高達(dá)124 TFLOPs/s)。
今天,正式發(fā)布FlashAttention-2
In the past few months, we’ve been working on the next version, FlashAttention-2, that makes FlashAttention even better. Rewritten completely from scratch to use the primitives from Nvidia’s ?CUTLASS 3.x and its core library ?CuTe, FlashAttention-2 is about 2x faster than its previous version, reaching up to 230 TFLOPs/s on A100 GPUs. When used end-to-end to train GPT-style language models, we reach a training speed of up to 225 TFLOPs/s (72% model FLOP utilization). In this blogpost, we describe some of the bottlenecks of FlashAttention, and how we use better parallelism and work partitioning to get significant speedup.
FlashAttention-2 is available at: ?flash-attention
在過去的幾個(gè)月里,我們一直在研發(fā)下一個(gè)版本,即FlashAttention-2,使FlashAttention變得更加出色。FlashAttention-2完全重寫,使用Nvidia的CUTLASS 3.x和其核心庫CuTe的原語,比其前一個(gè)版本快大約2倍,在A100 GPU上最高可達(dá)230 TFLOPs/s。當(dāng)用于端到端訓(xùn)練類似GPT的語言模型時(shí),我們達(dá)到了高達(dá)225 TFLOPs/s(模型FLOP利用率為72%)的訓(xùn)練速度。在本文中,我們將描述FlashAttention的一些瓶頸,以及我們?nèi)绾问褂酶玫牟⑿行院凸ぷ鞣謪^(qū)來獲得顯著的加速。FlashAttention-2可在:flash-attention
1、FlashAttention回顧
FlashAttention的優(yōu)點(diǎn):是一種重新排序注意力計(jì)算的算法,利經(jīng)典技術(shù)(tiling、重新計(jì)算,提速2~4倍)加速+從與序列長度成二次關(guān)系降低到與線性關(guān)系
FlashAttention is an algorithm that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. Tiling means that we load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.
FlashAttention是一種重新排序注意力計(jì)算的算法,并利用經(jīng)典技術(shù)(tiling、重新計(jì)算)來顯著加速它并將內(nèi)存使用量從與序列長度成二次關(guān)系降低到與線性成正比的算法。tiling意味著我們從HBM(GPU內(nèi)存)加載輸入塊到SRAM(快速緩存),針對該塊執(zhí)行注意力計(jì)算,并在HBM中更新輸出。通過不將大型中間關(guān)注矩陣寫入HBM,我們減少了內(nèi)存讀寫的數(shù)量,從而帶來2-4倍的時(shí)間加速。
FlashAttention的缺點(diǎn):低占用率(GPU工作分區(qū)不夠優(yōu)化)、不必要的共享內(nèi)存讀寫
However, FlashAttention still has some inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.
然而,FlashAttention仍然存在一些低效性,這是因?yàn)樵贕PU上不同線程塊和線程束之間的工作分區(qū)不夠優(yōu)化,導(dǎo)致低占用率或不必要的共享內(nèi)存讀寫。
Diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.
FlashAttention向前傳遞圖:使用tiling和softmax重新縮放,我們按塊操作,避免必須從HBM讀/寫,同時(shí)獲得正確的輸出,沒有近似。
Diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.
2、FlashAttention-2:更好的算法、并行性和工作分區(qū)
(1)、Fewer non-matmul FLOPs減少非矩陣乘法FLOPs:基于使矩陣乘法更快的現(xiàn)代GPU+盡可能多地使用矩陣乘法FLOP
We tweak the algorithm from FlashAttention to reduce the number of non-matmul FLOPs. This is important because modern GPUs have specialized compute units (e.g., Tensor Cores on Nvidia GPUs) that makes matmul much faster. As an example, the A100 GPU has a max theoretical throughput of 312 TFLOPs/s of FP16/BF16 matmul, but only 19.5 TFLOPs/s of non-matmul FP32. Another way to think about this is that each non-matmul FLOP is 16x more expensive than a matmul FLOP. To maintain high throughput, we want to spend as much time on matmul FLOPs as possible. We rewrite the online softmax trick used in FlashAttention to reduce the number of rescaling ops, as well as bound-checking and causal masking operations, without changing the output.
我們對FlashAttention算法進(jìn)行了微調(diào),以減少非矩陣乘法FLOPs的數(shù)量。這很重要,因?yàn)楝F(xiàn)代GPU擁有專門的計(jì)算單元(例如Nvidia GPU上的Tensor Cores),使矩陣乘法計(jì)算速度更快。例如,以A100 GPU為例,FP16/BF16的最大理論吞吐量為312 TFLOPs/s,而非矩陣FP32的最大理論吞吐量為19.5 TFLOPs/s。
另一種思考這個(gè)問題的方式是,每個(gè)非矩陣FLOP的成本是矩陣乘法FLOP的16倍。為了保持高吞吐量,我們希望盡可能多地使用矩陣乘法FLOP的時(shí)間。我們重新編寫了FlashAttention中使用的在線softmax技巧,以減少重新縮放操作的數(shù)量,以及邊界檢查和因果掩碼操作,而不改變輸出。
(2)、Better Parallelism更好的并行性:FlashAttention基于額外添加了序列長度維度的并行→加速為長序列場景
第一版FlashAttention的并行計(jì)算主要基于批量大小和頭數(shù)量,但對于長序列而言(即小批量或小頭數(shù)量情況),會導(dǎo)致GPU多處理器利用率低。FlashAttention-2在此基礎(chǔ)上額外添加了序列長度維度的并行,更好地利用GPU多處理器,從而為長序列場景帶來明顯的速度提升。
The first version of FlashAttention parallelizes over batch size and number of heads. We use 1 thread block to process one attention head, and there are overall (batch_size * number of heads) thread blocks. Each thread block is scheduled to run on a streaming multiprocessor (SM), and there are 108 of these SMs on an A100 GPU for example. This scheduling is efficient when this number is large (say >= 80), since we can effectively use almost all of the compute resources on the GPU.
In the case of long sequences (which usually means small batch sizes or small number of heads), to make better use of the multiprocessors on the GPU, we now additionally parallelize over the sequence length dimension. This results in significant speedup for this regime.
FlashAttention的第一個(gè)版本在批量大小和頭數(shù)上進(jìn)行了并行化。我們使用1個(gè)線程塊來處理一個(gè)關(guān)注頭,總共有(批量大小*頭數(shù))個(gè)線程塊。每個(gè)線程塊被安排在一個(gè)流多處理器(SM)上運(yùn)行,例如A100 GPU上有108個(gè)這樣的SM。當(dāng)這個(gè)數(shù)字很大時(shí)(例如>= 80),這種調(diào)度是有效的,因?yàn)槲覀兛梢杂行У厥褂肎PU上的幾乎所有計(jì)算資源。
對于長序列的情況(通常意味著較小的批量大小或較少的頭數(shù)),為了更好地利用GPU上的多處理器,我們現(xiàn)在還在序列長度維度上進(jìn)行了額外的并行化。這將為這種情況帶來顯著的加速。
(3)、Better Work Partitioning更好的工作分區(qū):
FlashAttention-2改進(jìn)了線程塊內(nèi)線程群之間的工作分配方案,原版本將K和V劃分給不同線程群需要線程同步寫出中間結(jié)果,而新版本將Q分片給不同線程群計(jì)算后直接與共享的K、V相乘得到輸出,消除了線程同步帶來的內(nèi)存讀取開銷,從而提升速度。
Even within each thread block, we also have to decide how to partition the work between different warps (a group of 32 threads working together). We typically use 4 or 8 warps per thread block, and the partitioning scheme is described below. We improve this partitioning in FlashAttention-2 to reduce the amount of synchronization and communication between different warps, resulting in less shared memory reads/writes.
即使在每個(gè)線程塊中,我們也必須決定如何在不同的warp(一組32個(gè)線程一起工作)之間劃分工作。
即使在每個(gè)線程塊內(nèi)部,我們還必須決定如何在不同的線程束和線程塊之間分配工作。通常情況下,我們每個(gè)線程塊使用4或8個(gè)線程束,并且下面描述了分區(qū)方案。在FlashAttention-2中,我們改進(jìn)了這種分區(qū),以減少不同線程束之間的同步和通信量,從而減少了共享內(nèi)存的讀寫。
對比:FlashAttention(原版本將K和V劃分給不同線程群需要線程同步寫出中間結(jié)果)、FlashAttention-2(新版本將Q分片給不同線程群計(jì)算后直接與共享的K、V相乘得到輸出)
For each block, FlashAttention splits K and V across 4 warps while keeping Q accessible by all warps. This is referred to as the “sliced-K” scheme. However, this is inefficient since all warps need to write their intermediate results out to shared memory, synchronize, then add up the intermediate results. These shared memory reads/writes slow down the forward pass in FlashAttention.
In FlashAttention-2, we instead split Q across 4 warps while keeping K and V accessible by all warps. After each warp performs matrix multiply to get a slice of Q K^T, they just need to multiply with the shared slice of V to get their corresponding slice of the output. There is no need for communication between warps. The reduction in shared memory reads/writes yields speedup.
對于每個(gè)塊,FlashAttention將K和V分配給4個(gè)線程束,同時(shí)讓Q對所有線程束可訪問。這被稱為“切片-K”方案。然而,這是低效的,因?yàn)樗芯€程束都需要將其中間結(jié)果寫入共享內(nèi)存,進(jìn)行同步,然后將中間結(jié)果相加。這些共享內(nèi)存的讀寫減慢了FlashAttention的向前傳遞。
在FlashAttention-2中,我們改為將Q分配給4個(gè)線程束,同時(shí)讓K和V對所有線程束可訪問。每個(gè)線程束執(zhí)行矩陣乘法以獲取Q K^T的一個(gè)切片,然后它們只需將其與共享切片的V相乘,即可獲得相應(yīng)的輸出切片。不需要在線程束之間進(jìn)行通信。減少共享內(nèi)存的讀寫帶來了加速。
3、新特性(頭維度高達(dá)256+支持多查詢注意力):兼容更多模型+同時(shí)支持多查詢注意力【如MQA/GQA】→進(jìn)一步減小鍵值緩存大小+提速
New features: head dimensions up to 256, multi-query attention
FlashAttention-2支持的頭數(shù)量上限提高到256,兼容更多模型,同時(shí)支持多查詢注意力和分組查詢注意力,這些變體可以進(jìn)一步減小鍵值緩存大小,明顯提升推理吞吐量。
FlashAttention only supported head dimensions up to 128, which works for most models but a few were left out. FlashAttention-2 now supports head dimension up to 256, which means that models such as GPT-J, CodeGen and CodeGen2, and StableDiffusion 1.x can use FlashAttention-2 to get speedup and memory saving.
This new version also supports multi-query attention (MQA) as well as grouped-query attention (GQA). These are variants of attention where multiple heads of query attend to the same head of key and value, in order to reduce the size of KV cache during inference and can lead to significantly higher inference throughput.
FlashAttention僅支持頭維度高達(dá)128,這適用于大多數(shù)模型,但有些模型不能使用。FlashAttention-2現(xiàn)在支持頭維度高達(dá)256,這意味著模型如GPT-J、CodeGen和CodeGen2以及StableDiffusion 1.x可以使用FlashAttention-2來獲得加速和節(jié)省內(nèi)存。
這個(gè)新版本還支持多查詢注意力(MQA)以及分組查詢注意力(GQA)。這些是注意力的變種,其中多個(gè)查詢頭注意力相同的鍵頭和值頭,以減小推斷過程中的KV緩存大小,可以顯著提高推斷吞吐量。
4、Attention Benchmark—Attention的基準(zhǔn):FlashAttention-2可達(dá)2倍FlashAttention、9倍PyTorch標(biāo)準(zhǔn)實(shí)現(xiàn)、訓(xùn)練GPT風(fēng)格模型時(shí)的1.3倍實(shí)現(xiàn)
通過Benchmark測試,FlashAttention-2在不同設(shè)置下的注意力前向和反向傳播速度可達(dá)2倍FlashAttention和9倍PyTorch標(biāo)準(zhǔn)實(shí)現(xiàn),在A100GPU上最大運(yùn)行速度達(dá)335TFLOPs/s,在end-to-end訓(xùn)練GPT模型時(shí)實(shí)現(xiàn)1.3倍速度提升。
We measure the runtime of different attention methods on an A100 80GB SXM4 GPU for different settings (without / with causal mask, head dimension 64 or 128). We see that FlashAttention-2 is around 2x faster than FlashAttention (as well as its other implementations in the xformers library and in Triton). Compared to a standard attention implementation in PyTorch, FlashAttention-2 can be up to 9x faster.
我們在A100 80GB SXM4 GPU上對不同設(shè)置(無/有因果掩碼,頭維度64或128)上的不同關(guān)注方法的運(yùn)行時(shí)進(jìn)行了測量。我們發(fā)現(xiàn)FlashAttention-2比FlashAttention快大約2倍(以及xformers庫和Triton中的其他實(shí)現(xiàn))。與PyTorch中的標(biāo)準(zhǔn)關(guān)注實(shí)現(xiàn)相比,FlashAttention-2的速度可以提高多達(dá)9倍。
Attention forward + backward speed on A100 GPU:高達(dá)225 TFLOPs/s
Attention forward + backward speed on A100 GPU
Just running the same implementation on H100 GPUs (using no special instructions to make use of new hardware features such as TMA and 4th-gen Tensor Cores), we obtain up to 335 TFLOPs/s.
Attention forward + backward speed on H100 GPU
只要在H100 GPU上運(yùn)行相同的實(shí)現(xiàn)(不使用特殊指令來利用新的硬件特性,如TMA和第四代張量核心),我們可以獲得多達(dá)335 TFLOPs/s的吞吐量。
Attention forward + backward speed on H100 GPU:多達(dá)335 TFLOPs/s
When used to train a GPT-style model end-to-end, FlashAttention-2 helps achieve up to 225 TFLOPs/s on A100 GPU (72% model FLOPs utilization). This is a 1.3x end-to-end speedup over an already very optimized model with FlashAttention.
當(dāng)用于端到端訓(xùn)練GPT風(fēng)格的模型時(shí),FlashAttention-2可以在A100 GPU上實(shí)現(xiàn)高達(dá)225 TFLOPs/s的吞吐量(模型FLOP利用率為72%)。這相當(dāng)于在已經(jīng)經(jīng)過優(yōu)化的模型中使用FlashAttention獲得了1.3倍的端到端加速。
Baseline: Megatron-LM without FlashAttention. Megatron-LM now has an option to use FlashAttention.
*Baseline: Megatron-LM without FlashAttention. Megatron-LM now has an option to use FlashAttention. We plan to integrate FlashAttention-2 to Megatron-LM in the near future.
*基線:Megatron-LM沒有使用FlashAttention。Megatron-LM現(xiàn)在可以選擇使用FlashAttention。我們計(jì)劃在不久的將來將FlashAttention-2集成到Megatron-LM中。
5、Discussion and Future Work討論與未來工作:FlashAttention-2使得同樣成本訓(xùn)練更長文本模型【8k→16k】→未來將其應(yīng)用于更多設(shè)備和數(shù)據(jù)類型
FlashAttention-2運(yùn)行速度提升2倍,可以用相同成本訓(xùn)練更長文本模型,未來工作計(jì)劃將其應(yīng)用于更多設(shè)備和數(shù)據(jù)類型,同時(shí)通過算法和低級優(yōu)化相結(jié)合可能支持遠(yuǎn)超以往的長序列訓(xùn)練。
FlashAttention-2 is 2x faster than FlashAttention, which means that e.g. we can train models with 16k longer context for the same price as previously training a 8k context model. We’re excited about how this can be used to understand long books and reports, high resolution images, audio and video. FlashAttention-2 will also speed up training, finetuning, and inference of existing models.
In the near future, we plan to collaborate with folks to make FlashAttention widely applicable in different kinds of devices (e.g. H100 GPUs, AMD GPUs), as well as new data types such as fp8. As an immediate next step, we plan to optimize FlashAttention-2 for H100 GPUs to use new hardware features (TMA, 4th-gen Tensor Cores, fp8). Combining the low-level optimizations in FlashAttention-2 with high-level algorithmic changes (e.g. local, dilated, block-sparse attention) could allow us to train AI models with much longer context. We’re also excited to work with compiler researchers to make these optimization techniques easily programmable.
FlashAttention-2比FlashAttention快2倍,這意味著,例如,我們可以用與以前訓(xùn)練8k上下文模型相同的價(jià)格來訓(xùn)練具有16k更長上下文的模型。我們很高興看到它可以用來理解長篇書籍和報(bào)告、高分辨率圖像、音頻和視頻。FlashAttention-2還將加速現(xiàn)有模型的訓(xùn)練、微調(diào)和推理。
在不久的將來,我們計(jì)劃與其他人合作,使FlashAttention在不同類型的設(shè)備(例如H100 GPU、AMD GPU)上廣泛適用,以及新的數(shù)據(jù)類型,如fp8。下一步,,我們計(jì)劃優(yōu)化FlashAttention-2以適應(yīng)H100 GPU,以利用新的硬件特性(TMA、第四代張量內(nèi)核、fp8)。將FlashAttention-2中的低級優(yōu)化與高級算法更改(例如局部、擴(kuò)張、塊稀疏關(guān)注)結(jié)合起來,可以使我們能夠訓(xùn)練具有更長上下文的AI模型。我們也很高興與編譯器研究人員合作,使這些優(yōu)化技術(shù)易于編程。
6、Acknowledgement致謝
We thank Phil Tillet and Daniel Haziza, who have implemented versions of FlashAttention in Triton and the xformers library. FlashAttention-2 was motivated by exchange of ideas between different ways that attention could be implemented. We are grateful to the Nvidia CUTLASS team (especially Vijay Thakkar, Haicheng Wu, and Andrew Kerr) for their CUTLASS library, in particular the CUTLASS 3.x release, which provides clean abstractions and powerful building blocks for the implementation of FlashAttention-2. We thank Driss Guessous for integrating FlashAttention to PyTorch. FlashAttention-2 has benefited from helpful discussions with Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay, Daniel Hesslow, Micha?l Benesty, Horace He, Ashish Vaswani, and Erich Elsen. Thanks to Stanford CRFM and Stanford NLP for the compute support. We thank Dan Fu and Christopher Ré for their collaboration, constructive feedback, and constant encouragement on this line of work of designing hardware-efficient algorithms. We thank Albert Gu and Beidi Chen for their helpful suggestions on early drafts of the FlashAttention-2 technical report.
我們要感謝Phil Tillet和Daniel Haziza,他們在Triton和xformers庫中實(shí)現(xiàn)了FlashAttention的版本。FlashAttention-2的靈感來自于不同注意力執(zhí)行方式之間的思想交流。我們非常感謝Nvidia CUTLASS團(tuán)隊(duì)(特別是Vijay Thakkar, Haicheng Wu和Andrew Kerr)的CUTLASS庫,特別是CUTLASS 3。它為FlashAttention-2的實(shí)現(xiàn)提供了清晰的抽象和強(qiáng)大的構(gòu)建塊。我們感謝Driss Guessous將FlashAttention集成到PyTorch。flashatten2得益于與Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay, Daniel Hesslow, Micha?l Benesty, Horace He, Ashish Vaswani和Erich Elsen的有益討論。感謝斯坦福CRFM和斯坦福NLP提供的計(jì)算支持。我們感謝Dan Fu和Christopher r<s:1>在設(shè)計(jì)硬件高效算法方面的合作、建設(shè)性的反饋和不斷的鼓勵(lì)。我們感謝Albert Gu和Beidi Chen在FlashAttention-2技術(shù)報(bào)告的早期草稿中提出的有益建議。