作者:周子宜,李姝
编辑:李慧
在「Meet AI4S」系列直播第三期中,周子宜博士以「蛋白质语言模型的小样本学习方法」为题,分享了团队的最新研究成果,以下为演讲精华实录。
在「Meet AI4S」系列直播第三期中,我们有幸邀请到了上海交通大学自然科学研究院 & 上海国家应用数学中心博士后周子宜,他所在的上海交通大学洪亮课题组研究方向主要为 AI 蛋白和药物设计、分子生物物理。该课题组研究成果颇丰,截止目前共发表研究论文 77 篇,其中多篇登顶 Nature 期刊。
本次分享,周子宜博士以「蛋白质语言模型的小样本学习方法」为题,分享了团队的最新研究成果,并探讨了 AI 辅助定向进化的新思路。
HyperAI超神经在不违原意的前提下,对周子宜博士的本次深度分享进行了整理汇总。点击查看完整直播回放 ⬇️
大家好,我是来自上海交通大学自然科学研究院、洪亮教授课题组的博士后周子宜,今天为大家分享我们课题组最近发在 Nature Communications 上的一项工作,即运用小样本学习方法来提升蛋白质语言模型的性能。
今天我分享的内容主要分为 4 部分:蛋白质语言模型的研究背景、研究成果 FSFP 方法介绍、FSFP 方法的评估以及总结与未来研究展望。
蛋白质是生物功能的主要载体,也是生命活动的执行者。天然氨基酸经过脱水缩合反应 (Dehydration Condensation) 形成蛋白质的残基序列 (Residue Sequence),随后折叠成为三级结构。改变蛋白质的氨基酸种类会影响其结构和功能。
由于天然蛋白质往往难以满足工业或医疗的需求,因此蛋白质工程希望通过对蛋白质进行突变,从而提升蛋白质的功能属性,如催化活性、稳定性、结合能力等。
我们通常将蛋白质功能属性的量化称为 Fitness。定向进化是现在主流的蛋白质工程方法。它依赖随机突变和高通量实验,寻找高 Fitness 的突变体,但实验成本较高。针对于此,我今天分享的主题是如何用 AI 方法来预测 Fitness,从而降低实验成本。
PLM 的架构
我们知道,以 ChatGPT 为代表的语言模型非常强大,能够进行高质量的文本理解和生成。这些语言模型通过在海量文本上预训练,能够学习文本的统计规律,掌握基本的语法和上下文中单词的语义。那么,是否可以在海量蛋白质序列上类似地训练蛋白质语言模型呢?答案是肯定的。
蛋白质语言模型 PLM 主要有 3 类作用。首先,PLM 能建模蛋白质序列的共进化信息,学习残基之间的相互依赖关系和进化约束,就好比自然语言LM能够学习文本的语法一样。PLM 可以用这种能力去估计哪些突变是有害的或者有利的,从而能预测突变的 Fitness。
其次,除了 Fitness 预测之外,PLM 还可以计算蛋白质的向量表征,这些表征可用于结构预测或蛋白质挖掘,经过微调后还能进行功能预测。
最后,PLM 可以像 ChatGPT 一样进行有条件的蛋白质生成,实现从头设计 (de novo) 蛋白质。
PLM 的架构与自然语言 LM 相似,分为自回归 (Autoregressive) 模型和遮蔽 (Masked) 模型。这两种模型的网络结构都采用 Transformer,由自注意力机制 (Self-attention) 和全连接层组成,主要区别在于预训练目标。
自回归模型的预训练目标是按照顺序从左到右生成下一个氨基酸,而遮蔽模型的目标是还原被随机遮蔽的氨基酸,类似于完形填空。由于自回归模型在预测每个氨基酸时只能依赖左侧已生成的序列,因此其注意力是单向的。而遮蔽模型在预测时可以看到被遮蔽位置两侧的氨基酸,因此其注意力是双向的。
目前,PLM 的研究热点主要分为 2 个方向。首先是检索增强型 (Retrieval-Augmented) PLM,这类模型在训练或预测时,将当前蛋白质的多序列比对 (MSA) 作为额外输入,通过检索得到的信息提升预测性能。比如,MSA Transformer 和 Tranception 就是典型的此类模型。
其次是多模态 (Multi-Modal) PLM,这类模型除了蛋白质序列外,还将蛋白质的结构或其他信息作为额外输入,以增强模型的表征能力。例如,我们课题组今年投稿的 ProSST 模型,将蛋白质结构量化为结构 token 序列,并与氨基酸序列一起输入 Transformer 模型,通过分离注意力机制来融合这两类信息。另一个例子是同期的模型 ESM-3,它考虑的信息更丰富,包括氨基酸类型、完整三级结构、三级结构 token、二级结构、溶剂可及表面积 (SASA),以及蛋白质和残基的功能描述共 7 种输入。
接下来继续讨论 Fitness 预测问题。由于 PLM 可以建模蛋白质序列的概率分布,因此无需标注数据,即可直接用于突变的 Fitness 预测,这种方法被称为零样本预测 (zero-shot prediction) 或无监督预测。
具体而言,PLM 通过计算突变体与野生型之间的对数似然比 (log-likelihood ratio) 来为突变打分。对于自回归模型而言,序列的概率 P 就是生成每个氨基酸的概率连乘。突变的打分可以通过突变体的 logP 减去野生型的 logP 来获得。直观来讲,就是比较突变出现的概率相对于野生型高多少,进而评估突变的影响,这是一种经验性的评估方法。
对于遮蔽模型来说,无法直接计算整条序列的概率,但是它可以先把某个点位遮蔽掉,然后去估计这个点位上氨基酸的概率分布。因此对于每个突变位置,可以用遮蔽之后预测的突变氨基酸的 logP 减掉野生型氨基酸的 logP,再将所有位置的差值相加,得到突变体的打分。
此外,由于 PLM 提供了蛋白质序列的向量表征,当有足够的实验数据时,还可以对它们进行微调,实现有监督的 Fitness 预测。
具体做法是在 PLM 的最后一层特征后加上一个用于预测 Fitness 的输出层(如注意力机制或多层感知机 MLP),并使用 Fitness Label 进行全量或者部分训练。例如,ECNet 在大模型特征基础上加入了 MSA 特征,通过 LSTM 进行融合,进行有监督训练。我们课题组去年开发的 SESNet 模型则融合了 ESM-1b 的序列特征、ESM-IF的结构特征,以及 MSA 特征,进行有监督的 Fitness 预测。
在介绍 FSFP 方法之前,需要先明确小样本学习在 Fitness 预测中的重要性。尽管无监督方法不需要用标注数据来训练,但其 zero-shot 打分的准确性较低。此外,因为基于对数似然比的打分只能反映蛋白质的某些自然规律,它们也难以有效预测蛋白质的非天然属性。
另一方面,虽然监督学习方法很准确,但由于 PLM 参数量巨大,它们需要大规模的实验数据来训练才能显著提升性能。监督学习模型的评估一般是把已有的高通量数据集 8:2 切分,而 80% 的训练集可能已经包含了上万的数据了,这在实践当中获取成本很高。
针对这一问题,我们提出了 FSFP 方法,这是一种适用于 PLM 的小样本学习方法。该方法能够利用少量训练样本(几十个),显著提升 PLM 的 Fitness 预测性能。同时,FSFP 方法具有较强的灵活性,可应用于不同的 PLM。
以往的监督学习方法都是将 Fitness 预测视为回归问题,即通过计算模型输出与 Fitness Label 之间的均方误差 (MSE) 来优化模型。然而,在小样本条件下,回归模型非常容易过拟合,训练损失下降极快。因此,我们转变了思路,不去做回归,而是去做排序学习,只要求排序准确,不强求拟合精确数值。
这种方法有两大优势。首先,排序本身符合蛋白质工程的基本需求,只需衡量突变的相对有效性即可。其次,相比预测绝对数值,排序任务更加简单。
在训练迭代中,我们将采样到的一组突变体根据标签倒序排列,然后根据模型对这些突变体的预测值计算排序损失——ListMLE。模型预测值的排序越接近真实排序,损失就越小。其中,我们使用基于对数似然比的 zero-shot 打分函数作为模型对突变的打分函数 f。这样做的目的是以 zero-shot 打分作为起点,用训练数据逐渐去修正它来提升性能,而不需要重新初始化一个模块,从而降低训练难度。
由于 PLM 的参数量通常高达数亿个,用极少的数据进行全量微调必然会导致过拟合。因此,我们引入了第二项技术 LoRA,来限制模型的可训练参数数量。
LoRA 在 Transformer 每个块的全连接层插入一对可训练的秩分解矩阵,保持预训练参数不动。因为秩分解矩阵很小,可训练参数量能够降低到原来的 1.84%。尽管可训练参数量减少了,但由于对 Transformer 每层都进行了微调,模型的学习能力仍能得到保证。
为了避免过拟合,我们不仅使用了更好的损失函数,还通过 LoRA 技术限制了可训练参数量。然而,即便如此,若在小样本训练数据上的训练迭代步数过多,还是存在过拟合的风险。因此,我们希望通过较少的训练迭代步数快速提升模型性能。基于这一需求,我们采用了第三项技术——元学习。元学习的基本思想是,首先让模型在某些辅助任务上积累经验,获得一个初始模型,然后利用该初始模型快速适应新任务。
如下图所示,这是一个基于元学习的图像分类示例。假设目标任务是训练一个模型去做马分类,但是马的标注数据比较少。因此,我们可以先找一些数据量多的辅助任务,比如猫分类、狗分类等等,在这些辅助任务上用元学习算法进行训练,学习如何去学习新的任务,得到一个元学习器。然后以这个元学习器作为初始模型,用少量马的标注数据训练若干步,就可以快速获得一个马分类器。显然,元学习能奏效的前提是采用的辅助任务要跟目标任务足够接近。
如何将元学习应用到 Fitness 预测的场景?首先我们的目标任务是对目标蛋白质的突变做 Fitness 排序,而要训练的模型是采用了 LoRA 技术的 PLM。
我们采用了两种策略来构造辅助任务。第一种是根据与目标蛋白的相似度,去已有的 DMS 数据库里找相似蛋白的突变实验数据集,选出来前两个数据集分别作为 2 个辅助任务。这样做的出发点是考虑到相似蛋白的 Fitness Landscape 也是接近的。
第二种策略是使用 MSA 模型对目标蛋白的候选突变进行评分,形成伪标签数据集,并将其作为第 3 个辅助任务。之所以选择 MSA 模型,是因为 MSA 模型的突变预测效果通常不逊于 PLM,我们希望通过 MSA 进行数据增强,充分发挥 PLM 的表征能力。
我们采用的元学习算法是 MAML,它的训练目标是使得元学习器用某个辅助任务的训练数据微调 k 步以后,测试损失尽可能小,这样在目标任务上微调 k 步以后也能大致收敛。
我们的 Benchmark 数据来源于 ProteinGym,最初包含 87 个 DMS 数据集,现已更新至 217 个。其中 87 个 DMS 对应的蛋白质大致分为四类:真核生物、原核生物、人类和病毒,总共涵盖了约 1,500 万条突变和对应的 Fitness。
对于每个数据集,我们随机选取 20、40、60、80 和 100 个单点突变作为小样本训练集,其余突变则作为测试集。需要说明的是,我们没有使用额外的验证集来做 early stop,而是在训练集上通过交叉验证来估计训练步数。
此前提到元学习需要 3 个辅助任务,其中 2 个辅助任务是根据和目标蛋白的相似度从 DMS 数据库里检索出来的。在某一数据集上进行训练时,我们从 ProteinGym 的其余数据集中进行检索,假设它就是数据库。
如下面右图所示,将 ProteinGym 中每个蛋白质分别作为查询 (query),找出来的最相似蛋白质的相似度分布,分别通过 MMseqs2 和 FoldSeek检索得到。可以看到最相似蛋白的序列或者结构相似度平均在 0.5 左右。第 3 个辅助任务涉及使用 MSA 模型对突变进行打分。我们选择了 GEMME 模型,该模型基于 MSA 构建进化树,在进化树上计算各个点位的保守性来给突变打分。
评估指标使用了 Spearman/Pearson 系数以及 NDCG,这些是 Fitness 预测任务中常见的评估标准。最终的评估得分是在 87 个数据集上的平均得分。
如下图所示,左图中 x 轴代表训练集的大小,y 轴代表 Spearman 系数,每条线对应不同的模型配置。最上方的线代表完整版 FSFP 模型;第二条线表示将元学习的第 3 个辅助任务替换为相似蛋白的 DMS 数据,而不使用 MSA,可以看出移除 MSA 信息后模型性能有所下降;第三条线表示不使用元学习,仅依靠排序学习和 LoRA,Spearman 系数进一步下降。
绿色线条代表此前发表在 NBT 上的岭回归模型,它是目前少有的适用于小样本场景的基线模型;灰色虚线表示 ESM-2 的 zero-shot 得分;最底部两条线则表示使用传统回归方法训练 ESM-2 的结果。
整体来看,在仅有 20 个训练样本时,我们的方法相较于 zero-shot 提升了 10 个点的 Spearman,且各个模块均对模型性能起到了积极作用。右图展示了在 87 个数据集上,相较于 zero-shot 的性能提升分布,训练集大小为 40 个样本。可以看到我们的方法在大多数数据集上都能提升模型性能,部分数据集的提升甚至超过 40 个点,表现得比基线更加稳定。
元学习的目的是使 PLM 能够在目标任务上通过少量迭代快速收敛。下面通过几个示例进行说明。
以下 3 张图表展示了在 3 个数据集上使用 40 个训练样本微调的训练曲线。x 轴表示训练步数,y 轴表示测试集上的 Spearman 系数。顶端橙色和红色的线都是用元学习训练过的模型,前者用了 MSA 构建辅助任务,后者则没用。黄色的线表示仅使用排序学习和 LoRA 而不使用元学习的模型。
可以看到,经过元学习训练的模型在目标蛋白质上能够更快速地提升性能,且在 20 步以内即可达到较高的分数,有时甚至不加微调的初始模型表现就已经较好。这表明元学习得到了有效的初始模型。而下方基于 MSE 的模型表现较差,且过拟合较快,难以超越 zero-shot 方法。
我们选择了 3 个典型的 PLM ,分别是 ESM-1v、ESM-2 和 SaProt。前 2 个模型仅使用蛋白质的序列信息,而 SaProt 则结合了蛋白质的三级结构 token。
左侧折线图展示了不同训练集大小下,预测单点突变效果的 Spearman得分,同一种颜色代表同一个模型,点的不同形状代表不同的训练方法。上方的圆点表示 FSFP 方法,下方倒三角表示岭回归,虚线表示模型的zero-shot性能。紫色线则代表 GEMME 模型,它不是 PLM,但是岭回归方法用可以和它结合。可以看出,FSFP 方法可以稳定地提升各个 PLM 的性能,而且远好于岭回归和对应模型的 zero-shot。
第二张柱状图展示了在不同数据集上使用 3 种策略 (zero-shot、岭回归和FSFP) 所获得的最高分的数量。FSFP 在大多数数据集中表现最佳。右侧两张图展示了预测多点突变的性能,涉及的多点突变数据集共有 11 个,得到的结论与单点突变类似。然而,岭回归模型此处的方差较大,表明它对数据切分比较敏感。
随后,我们评估了 FSFP 的外推性能,即专门评估在训练集里没见过的突变点位上的预测性能。在这种情况下,测试集会比原来小很多,而且测试集会随着训练集变大发生明显变化,所以表里面 zero-shot 性能不再是一条直线了。这种设置比较有挑战性,可以看到左侧单点突变上岭回归的性能几乎没法超过 zero-shot,但是 FSFP 仍然能稳定地提升性能。右侧多点突变的测试结果也同样表明我们的训练方法具有较好的泛化能力。
此外,我们还用 FSFP 做了一个蛋白质改造的案例。目标蛋白质为 Phi29,这是一种 DNA 聚合酶,我们希望通过单点突变来改善它的 Tm。
实验流程如下:首先使用 ESM-1v 对饱和单点突变进行 zero-shot 打分,选择得分排名前 20 的突变并进行湿实验测量 Tm;然后将这 20 条实验数据作为训练集,利用 FSFP 对 ESM-1v 进行训练,用训练后的模型对饱和单点突变再次打分,重新选择前 20 条突变进行测试。
右图展示了前后两轮实验的 Tm 分布对比。第一轮的 20 个突变中有 7 个为阳性,第二轮增加至 12 个,且平均 Tm 提升了 1 度。其中,第二轮找出来的阳性突变中有 9 个是新的。虽然阳性率和平均 Tm 有所提高,但可惜最高 Tm 并没有提升,因为第二轮得到的 Tm 最高的突变仍然是第一轮结果中已存在的。不过,由于获得了更多阳性单点突变,接下来可以尝试组合这些点位进行高点突变实验,进一步提升 Tm。
FSFP 是一种针对 PLM 的小样本学习策略,能够利用少量(几十个)有标注训练样本显著提升 PLM 在突变效果预测中的表现,并能灵活地应用于多种不同的 PLM 上。实验表明,FSFP 的设计具有合理性:
* 排序学习满足了蛋白质工程中对突变排序的基本需求,降低了训练难度;
* LoRA 通过控制 PLM 的可训练参数量,降低了过拟合风险;
* 元学习可以为模型提供良好的初始参数,使模型能够快速迁移至目标任务。
最后,我们来讨论 AI 辅助定向进化的未来方向。AI 辅助定向进化的一般流程是从一组初始突变开始,通过湿实验获得它们的 Fitness Label,并利用实验反馈的标注数据训练机器学习模型,随后根据模型的预测选择下一轮要测试的突变,反复迭代。
FSFP 主要解决了模型在每一轮实验迭代中的小样本训练问题,提高模型的预测准确性。然而,我们尚未讨论如何有效选择下一轮要测试的突变,亦即下一轮要新加入的训练样本。在之前 Phi29 蛋白质改造的例子中,我们直接选择了模型打分最高的前 20 个突变,然而在多轮迭代的场景中,贪心选择策略不一定是最好的方法,它容易陷入局部最优。因此,必须在探索与利用之间找到平衡。
事实上,迭代选择测试样本来标注、并逐步扩充训练数据的过程是一个主动学习问题,这在蛋白质工程领域已有一定研究进展。例如,定向进化领域的权威科学家 Frances H. Arnold 在她的文章「Active Learning-Assisted Directed Evolution」中探讨了相关问题。
论文地址:
https://www.biorxiv.org/content/10.1101/2024.07.27.605457v1.full.pdf
我们可以通过不确定性量化技术,来评估模型对每个突变体打分的不确定性。基于这些不确定性,测试样本的选择策略会更加多样化。常用的一种策略是 UCB 方法,它通过挑选模型预测不确定性最高的突变样本进行下一轮标注,即优先选择预测方差最大的样本。这与人类的学习过程类似:如果我们对某些知识点掌握不足或存在不确定,就会重点加强学习。