Receptance Weighted Key Value(RWKV)是pengbo提出的一个新的语言模型架构,它使用了线性的注意力机制,把Transformer的高效并行训练与RNN的高效推理相结合,使得模型在训练期间可以并行,并在推理的时候保持恒定的计算和内存复杂度。目前RWKV的社区已经非常火了,我们从huggingface上可以看到RWKV已经训练了多个百亿参数的模型,特别是RWKV World模型支持世界所有语言的生成+对话+任务+代码,功能十分全面。此外还有很多开发者基于RWKV的微调模型。
RWKV论文原文:https://arxiv.org/abs/2305.13048
这里列出一些我看到的学习资料,方便感兴趣的读者学习:
- [双字] 在{Transformer}时代, {RWKV}是RNN的[文艺复兴]–论文详解(https://www.bilibili.com/video/BV11N411C76z/?spm_id_from=333.337.search-card.all.click&vd_source=4dffb0fbabed4311f4318e8c6d253a10)
- 野心勃勃的RNN——RWKV语言模型及其100行代码极简实现(https://zhuanlan.zhihu.com/p/620469303)
- github仓库(https://github.com/BlinkDL/RWKV-LM)
- rwkv论文原理解读(https://www.zhihu.com/question/602564718)
- RWKV的微调教学,以及RWKV World:支持世界所有语言的生成+对话+任务+代码(https://zhuanlan.zhihu.com/p/638326262)
- RWKV:用RNN达到Transformer性能,且支持并行模式和长程记忆,既快又省显存,已在14B参数规模检验(https://zhuanlan.zhihu.com/p/599150009)
- 谈谈 RWKV 系列的 prompt 设计,模型选择,解码参数设置(https://zhuanlan.zhihu.com/p/639629050)
- RWKV进展:一键生成论文,纯CPU高速INT4,纯CUDA脱离pytorch,ctx8192不耗显存不变慢(https://zhuanlan.zhihu.com/p/626083366)
- 开源1.5/3/7B中文小说模型:显存3G就能跑7B模型,几行代码即可调用(https://zhuanlan.zhihu.com/p/609154637)
- 发布几个RWKV的Chat模型(包括英文和中文)7B/14B欢迎大家玩(https://zhuanlan.zhihu.com/p/618011122)
- 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍(某特殊算子)(https://zhuanlan.zhihu.com/p/476297195)
- BlinkDL/RWKV-World-7B gradio demo(https://huggingface.co/spaces/BlinkDL/RWKV-World-7B/tree/main)
- ChatRWKV(有可用猫娘模型!)微调/部署/使用/训练资源合集(https://zhuanlan.zhihu.com/p/616351661)
- pengbo的专栏(https://www.zhihu.com/people/bopengbopeng/posts)
原理推荐看第一个和第二个链接,其它的有选择观看,我这里就以ChatRWKV项目的解析为例来入门RWKV。
下面这个文件 https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py 以150行代码实现了RWKV-4-Pile-430M这个模型,是学习RWKV的最佳代码,所以让这一节就是来逐步解析一下这个代码。分析代码之前先对RWKV这个名字的含义和组成RWKV模型2个关键的元素Time Mixing和Channel Mixing简单描述一下,详细的原理还是请参考原始论文和第一节学习资料的第一个视频链接和第四个原理和公式详解的文字链接。
论文中对名字有说明:
- R: Receptance vector acting as the acceptance of past information. 类似于LSTM的“门控单元”
- W: Weight is the positional weight decay vector. A trainable model parameter. 可学习的位置权重衰减向量,什么叫“位置权重衰减”看下面的公式(14)
- K: Key is a vector analogous to K in traditional attention. 与传统自注意力机制
- V : Value is a vector analogous to V in traditional attention. 与传统自注意力机制相同
RWKV Block又主要由Time Mixing和Channel Mixing组成。
这里的表示当前时刻,看成当前的token,而看成前一个token,、、的计算与传统Attention机制类似,通过将当前输入token与前一时刻输入token做线性插值,体现了recurrence的特性。然后的计算则是对应注意力机制的实现,这个实现也是一个过去时刻信息与当前时刻信息的线性插值,注意到这里是指数形式并且当前token和之前的所有token都有一个指数衰减求和的关系,也正是因为这样让拥有了线性attention的特性。
然后RWKV模型里面除了使用Time Mixing建模这种Token间的关系之外,在Token内对应的隐藏层维度上RWKV也进行了建模,即通过Channel Mixing模块。
Channel Mixing的意思就是在特征维度上做融合。假设特征向量维度是d,那么每一个维度的元素都要接收其他维度的信息,来更新它自己。特征向量的每个维度就是一个“channel”(通道)。
下图展示了RWKV模型整体的结构:
这里提到的token shift就是上面对r, k, v计算的时候类似于卷积滑窗的过程。然后我们可以看到当前的token不仅仅剋呀通过Time Mixing的token shit和隐藏状态States(即)和之前的token建立联系,也可以通过Channel Mixing的token shift和之前的token建立联系,类似于拥有了全局感受野。
初始化部分
首先来看RWKV模型初始化部分以及最开始的一些准备工作:
这部分除了准备一些模型执行需要的超参数之外,还对RWKV模型进行了初始化,值得注意的是在初始化过程中会加载RWKV模型的权重到w这个字典里面,然后遍历字典w的所有键。注释中的例子"blocks.0.att.time_first" => self.w.blocks[0].att.time_first" 说明了代码的目标:将点分隔的键转换为嵌套的属性访问。后面推理的时候将会直接访问self.w这个处理之后的权重对象。
RWKV 模型通道融合函数(Channel mixing)
参考Channel Mixing的公式来看:
在channel_mixing函数里面,对应当前token的词嵌入向量,表示前一个token的词嵌入向量。剩下的变量都是RWKV的可学习参数。然后代码里面会动态更新state,让总是当前token的前一个token的词嵌入。
RWKV Time mixing函数
仍然是要对照公式来看:
关于RWKV 的attention部分()计算如果你有细节不清楚,建议观看一下这个视频:解密RWKV线性注意力的进化过程(https://www.bilibili.com/video/BV1zW4y1D7Qg/?spm_id_from=333.337.search-card.all.click&vd_source=4dffb0fbabed4311f4318e8c6d253a10) 。
RWKV model forward函数
从这里可以看到RWKV的state只和层数和嵌入层维度有关系,和序列长度是无关的,这是推理时相比于RNN的核心优点。
采样函数
这个函数的作用是根据给定的概率分布 out 生成一个随机样本,通过调整温度 temperature 和顶部概率 top_p 来控制生成的样本的多样性和稳定性。
生成文本的流程
这段代码的目的是使用 RWKV_RNN 模型生成文本,模型通过 sample_logits 函数生成随机 token,然后将其传递给模型进行预测,并根据预测结果生成下一个 token,不断重复这个过程直到生成指定数量的 tokens。生成的文本会被打印出来。
ChatRWKV的README中提到v2版本实现了一些新功能,建议我们使用,v2版本的代码在 https://github.com/BlinkDL/ChatRWKV/tree/main/v2 。我们这一节就对这部分的代码做一个解读。
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py是ChatRWKV v2的核心实现,我们直接来看这个文件。
这段代码设置了必要的依赖项,配置了GPU使用环境,并提供了有关RWKV语言模型及其精度选项的信息。
这段代码相当于配置文件,来设置RWKV聊天系统的运行参数和模型信息。根据选择的语言不同,会加载相应的模型,并设置相应的参数。
这段代码的主要功能是加载和配置RWKV模型,并准备生成文本所需的参数和工具。继续解析:
这段代码定义了以RNN模式运行RWKV的函数以及保存和恢复模型状态的函数,最后还定义了一个修复tokens的工具函数用于对world模型和其它模型分别进行处理,继续解析:
总的来看,这段代码是一个循环,用于ChatRWKV系统生成回复消息和更新状态信息。它根据输入消息和模型的状态信息进行一个RNN模式的推理,生成一个回复的token和一个新的状态,然后将生成的回复显示出来,生成的状态则可以在下一次生成中继续使用。代码还包括处理特殊命令以及加载和保存状态信息的逻辑。
- 直接说话:聊天,用 + 换个回答
- reset: 通过发送+reset消息,您可以重置聊天。
- 加载新的提示: 使用+prompt {path}可以加载一个新的提示,其中{path}是提示文件的路径。
- 消息类型:
- +gen {text}: 基于{text}生成新的内容。
- +i {instruction}: 根据给定的指令{instruction}产生一个响应。
- +qq {question}: 对于给定的问题{question},生成一个答案。
- +qa {text}: 将{text}作为一个问题,并生成一个答案。
- +++: 继续写下去。
- ++: 换个写法。
除了这些指令之外,还可以调整生成参数:
- -temp=: 调整生成的温度。温度影响生成文本的随机性。较高的温度将导致更多的随机输出,而较低的温度将导致更确定的输出。
- -top_p=: 调整Top-P采样。Top-P采样是选择词汇中概率最高的一部分词进行采样的方法。