Skip to content

SDiT: Spiking Diffusion Model with Transformer

Updated: at 14:57

摘要:脉冲神经网络(Spiking Neural Networks, SNNs)由于其低功耗和生物可解释性等特性,被认为在高能效计算方面具有巨大潜力。然而,在图像生成任务中对 SNN 的探索仍然非常有限,目前尚未有一种统一且有效的 SNN 生成模型结构被提出。为此,本文提出了一种新的基于脉冲神经网络的扩散模型体系结构。我们使用 Transformer 替代了主流扩散模型中常用的 U-net 结构,从而在较低的计算成本和更短的采样时间下生成质量更高的图像,旨在为基于 SNN 的生成模型研究提供一个实证性的基准。基于 MNIST、Fashion-MNIST 和 CIFAR-10 数据集的实验表明,与现有的 SNN 生成模型相比,我们的工作在性能上具有很强的竞争力。

1. Intro

之前的SNN图像生成,如Spiking-GAN、FSVAE、SDDPM之类的工作普遍没有达到ANN的水平。

这篇工作提出基于RWKV注意力的Diffusion Transformer模型,并且引入了重构模块,旨在补充Neuron触发后丢失的信息。

2. Preliminary

RWKV Attention:

对于位于时间步t t 的输入向量X X,首先进行线性变换:rt=XWr,kt=XWk,vt=XWvr_t = XW_r , k_t = XW_k , v_t = XW_v,其中 WrW_r , WkW_k , WvW_v 是线性变换矩阵。XX 是通过加权原始输入和前一个时间步的输入(称为 token-shift)得到的。rtr_t 被称为接收矩阵,其中的每个元素表示接收到的过去信息。KK V V 类似于自注意力中的键(Key)和值(Value)矩阵。

wkvt=i=1t1e(t1i)w+kivi+eu+ktvti=1t1e(t1i)w+kiw_k v_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} \odot v_i + e^{u + k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i}}

其中,kik_i kt k_t i i 行的向量,viv_i vt v_t i i 行的向量,\odot 是 Hadamard 积(逐元素乘积),ww 是一个可学习的位置衰减向量,uu 防止w w 退化。

然后进行Time Mixing:

ot=Wo(σ(rt)wkvt) o_t = W_o(\sigma(r_t) \odot w_k v_t)

和Channel Mixing:

ot=σ(rt)(Wvmax(kt,0)2) o_t = \sigma(r_t) \odot (W_v \max(k_t, 0)^2)

其中,WoW_o WvW_v 是线性变换,σ\sigma 是作用于rt r_t 的 Sigmoid 函数。通过对rt r_t 应用 Sigmoid,可以消除不必要的历史信息,充当一个“遗忘门”,从而获得输出。

时间混合可以简单地看作是标准自注意力的替代,而通道混合可以看作是前馈网络(Feed-Forward Network, FFN)层的替代。

3. Proposed Method

A. Overview of SDiT Architecture

image.png

如图 1 所示,网络架构概览如下:输入图像在进入模型前,首先经过 Patch Embedding,然后进行位置嵌入(Position Embedding)。去噪时间步通过时间嵌入(Time Embedding)输入模型。受 U-ViT [6] 启发,模型由三个阶段的 Spiking Transformer Blocks 组成。第一阶段的输入块(Input Blocks)将输出分为两条路径——一条直接连接到下一块,另一条作为跳跃连接到输出块(Output Blocks)的对应块。Spiking Transformer Blocks 的输出将被输入最终层(Final Layer),并映射回原始图像大小。经过 Patch 重构和卷积层后,最终生成的噪声预测图像被输出。

B. Embedding

划分Patch + Conv映射到高维 + Position Embedding。

C. Spiking Transformer Block

image.png

给定输入向量xRB×N×D x \in \mathbb{R}^{B \times N \times D},其中B B 是批量大小, NN 是 Patch 序列长度,DD 是特征维度。

定义用于跳跃连接的函数 F 如下:

x=F(x,xskip)x = F(x, x_{\text{skip}}) F(x,xskip)={([x;xskip])Wskip,若 xskipNonex,否则F(x, x_{\text{skip}}) = \begin{cases} ([x; x_{\text{skip}}])W_{\text{skip}}, & \text{若 } x_{\text{skip}} \neq \text{None} \\ x, & \text{否则} \end{cases}

其中,[][ \cdot ] 表示拼接操作,WskipR2D×DW_{\text{skip}} \in \mathbb{R}^{2D \times D} 是线性变换矩阵。对于输入块和中间块,xskipx_{\text{skip}} 为 None;对于输出块,xskipx_{\text{skip}} 是来自对应输入块的输出。

引入一个重构 Token,记为zR1×N×D z \in \mathbb{R}^{1 \times N \times D},并在拼接之前进行平铺操作:

x^=[x;z1B]RB×2N×D\hat{x} = [x; z \otimes 1_B] \in \mathbb{R}^{B \times 2N \times D}

其中,\otimes 表示克罗内克积,1B1_B 是全 1 向量。

拼接后的x^ \hat{x} 依次输入时间混合(Time-Mixing)和通道混合(Channel-Mixing)模块。经过 Leaky Integrate-and-Fire (LIF) 神经元后,添加残差连接:

xAttn=x^+LIF(TimeMixing(LN(x^)))x_{\text{Attn}} = \hat{x} + \text{LIF}(\text{TimeMixing}(\text{LN}(\hat{x}))) xFFN=xAttn+LIF(ChannelMixing(LN(xAttn)))x_{\text{FFN}} = x_{\text{Attn}} + \text{LIF}(\text{ChannelMixing}(\text{LN}(x_{\text{Attn}})))

其中,LN\text{LN} 表示 LayerNorm。

最终,通道混合的输出进行以下操作:

y=G(xFFN)y = G(x_{\text{FFN}})

其中,GG 表示重构模块中的操作。

D. Reconstruction Module

为弥补这种损失,我们设计了重构 Token 和重构模块。重构 Token 由可学习参数组成。通过嵌入后,重构 Token 可以表示脉冲神经元的内在动态信息。在重构模块中,原始输出被重新调整,以减轻通过脉冲神经元时的信息损失。在第 IV 节中展示了其有效性。

对拼接向量计算后,嵌入的重构 Token 从输出向量中分离:

y,z=split(xFFN)y{\prime}, z{\prime} = \text{split}(x_{\text{FFN}})

其中 yRB×N×D,zRB×N×D,splity{\prime} \in \mathbb{R}^{B \times N \times D} , z{\prime} \in \mathbb{R}^{B \times N \times D} ,\text{split}表示向量分离操作。

由于重构 Token 的维度与输入相同,其特征维度首先线性变换以匹配 Patch 数量:

zD=zWDRB×N×Nz{\prime}_D = z{\prime}W_D \in \mathbb{R}^{B \times N \times N}

其中 WDRD×NW_D \in \mathbb{R}^{D \times N}

然后交换 Patch 和特征维度,并再次线性变换以匹配特征维度:

zN=zDWNRB×N×Dz{\prime}_N = z{\prime}_D W_N \in \mathbb{R}^{B \times N \times D}

其中WNRN×D W_N \in \mathbb{R}^{N \times D}

最终,将获得的向量与分离向量逐元素相乘并相加,以作为补充信息:

y=y+zNyy = y{\prime} + z{\prime}_N \odot y{\prime}

其中\odot表示 Hadamard 积。

E. Final Layer

沿特征维度做线性变换,匹配原大小,然后将Patch重建,并且过一个3*3 conv提升图形质量。

4. Experiment

A. Experiment Settings

评估指标 :我们使用 Fréchet Inception Distance (FID)评估生成图像的质量,并使用 Inception Score (IS)评估样本的多样性。在计算 FID 时,从数据集中采样 50,000 张图像,并生成 50,000 张图像用于计算两者之间的 FID。在计算 MNIST 和 Fashion-MNIST 数据集的 IS 时,由于这些数据分布与最初用于提出和验证 IS 的 ImageNet 存在显著差异,因此 IS 分数的对比意义有限。 实现细节 :我们将 MNIST、Fashion-MNIST 和 CIFAR-10 的输入图像统一标准化为 28×28 的尺寸。对于 MNIST 和 Fashion-MNIST,模型架构包括 2 个输入块(Input Blocks)和输出块(Output Blocks),1 个中间块(Mid Block),隐藏维度为 384。对于 CIFAR-10,则扩展为 4 个输入块和输出块、1 个中间块,隐藏维度为 512。优化器使用 AdamW,学习率为1×104 1 \times 10^{-4},MNIST 和 Fashion-MNIST 的模型训练 1600 轮,CIFAR-10 的模型训练 2000 轮。我们的 SNN 模型基于 SpikingJelly 框架实现,实验在 4×NVIDIA 4090 GPUs 上进行。

B. Comparisons

image.png

image.png

CIFAR10上比SDDPM差一些。

C. Ablation Study

image.png

有无Reconstruction Module。

image.png

RWKV的AC数比纯Transformer Backbone的更少,参数量上有优势。

5. Discussion

尽管 SDiT 表现相对强大,但仍然存在一些局限性。SDiT 在 MNIST 和 Fashion-MNIST 数据集上的性能显著超越其他方法,但在 CIFAR-10 数据集上的结果未能达到最新技术水平。我们将这种不足部分归因于 Vision Transformer 在有限数据上训练的效率较低。MNIST 和 Fashion-MNIST 的图像编码相对简单的视觉信息,而 CIFAR-10 图像的内在复杂性和较低分辨率掩盖了 SDiT 学习表示中的细粒度边缘细节。 关于自注意力机制与 SNNs 的集成,我们计划在未来工作中引入更细粒度的方案,以进一步减少自注意力在 SNN 框架中的信息损失。 另一个局限性是 SDiT 的嵌入组件未完全基于 SNN 实现。目前的 SDiT 由 ANN 和 SNN 的混合架构组成。当图像尺寸增加时,ANN 模块的参数和计算量也会增加,从而削弱 SNNs 的低功耗特性。未来的研究将探索完全基于 SNN 框架的设计,尝试构建纯 SNN 架构的模型。

6. Conclusion

在本研究中,我们提出了一种新型生成脉冲神经网络架构 SDiT,它结合了 SNNs 的低功耗特性与卓越的生成能力。作为在 SNN 领域中引入 Transformer 作为扩散模型骨干的全新尝试,SDiT 在生成图像质量方面提升了当前技术水平,并在多个数据集上表现出色。我们希望该研究能够为 SNN 生成模型领域提供新视角,并推动未来相关领域的发展。

超级短的文章,感觉就只是发在Arxiv上占个坑。Reconstruction Module感觉就是插入了一个捕捉高维信息的旁路,一路跟着推完一边之后把这个旁路的信息fuse到原有的信息中,相当于对原信息做一个补正。


Previous Post
Segment Anything
Next Post
2024