spikingjelly学习-训练网络

2024-04-10 04:04

本文主要是介绍spikingjelly学习-训练网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【MNIST数据集包含若干尺寸为28*28的8位灰度图像,总共有0~9共10个类别。以MNIST的分类为例,一个简单的单层ANN网络如下

我们也可以用完全类似结构的SNN来进行分类任务。就这个网络而言,只需要先去掉所有的激活函数,再将尖峰神经元添加到原来激活函数的位置,这里我们选择的是LIF神经元。神经元之间的连接层需要用
spikingjelly.activation_based.layer包装:

在 spikingjelly 中,我们约定,只能输出脉冲,即0或1的神经元,都可以称之为“脉冲神经元”。使用脉冲神经元的网络,进而也可以称之为脉冲神经元网络(Spiking Neural Networks, SNNs)。这里使用了 neuron.IFNode() 来构建 IF 神经元层,该神经元层有如下构造函数:
  1. v_threshold – 神经元的阈值电压
  2. v_reset – 神经元的重置电压。
  3. surrogate_function – 反向传播时用来计算脉冲函数梯度的替代函数
    神经元的数量是在初始化或调用 reset() 函数重新初始化后,根据第一次接收的输入的 shape 自动决定的。此处则是10个神经元。其中膜电位衰减常数 需要通过参数tau设置,替代函数这里选择surrogate.ATan。
    然后是训练SNN网络,指定好训练参数如学习率等以及若干其他配置优化器默认使用Adam,以及使用泊松编码器,在每次输入图片时进行脉冲编码。

【训练代码的编写需要遵循以下三个要点:
 脉冲神经元的输出是二值的,而直接将单次运行的结果用于分类极易受到编码带来的噪声干扰。因此一般认为脉冲网络的输出是输出层一段时间内的发放频率(或称发放率),发放率的高低表示该类别的响应大小。因此网络需要运行一段时间,即使用T个时刻后的平均发放率作为分类依据。
 我们希望的理想结果是除了正确的神经元以最高频率发放,其他神经元保持静默。常常采用交叉熵损失或者MSE损失,这里我们使用实际效果更好的MSE损失。
 每次网络仿真结束后,需要重置网络状态

 # 保存绘图用数据net.eval()# 注册钩子output_layer = net.layer[-1] # 输出层output_layer.v_seq = []output_layer.s_seq = []def save_hook(m, x, y):m.v_seq.append(m.v.unsqueeze(0))m.s_seq.append(y.unsqueeze(0))output_layer.register_forward_hook(save_hook)with torch.no_grad():img, label = test_dataset[0]img = img.to(args.device)out_fr = 0.for t in range(args.T):encoded_img = encoder(img)out_fr += net(encoded_img)out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()print(f'Firing rate: {out_spikes_counter_frequency}')output_layer.v_seq = torch.cat(output_layer.v_seq)output_layer.s_seq = torch.cat(output_layer.s_seq)v_t_array = output_layer.v_seq.cpu().numpy().squeeze()  # v_t_array[i][j]表示神经元i在j时刻的电压值np.save("v_t_array.npy",v_t_array)s_t_array = output_layer.s_seq.cpu().numpy().squeeze()  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1np.save("s_t_array.npy",s_t_array)

在这里插入图片描述
【在PyTorch中,钩子(hooks)是一种强大的工具,允许你在模型的前向传播(forward pass)或反向传播(backward pass)过程中插入自定义操作。这些操作可以用于调试、可视化、保存中间状态等目的,而不需要修改模型的定义。
钩子的类型
前向钩子(Forward Hooks):在层的前向传播执行完毕后立即执行。它们通常用于检查、修改或记录从层输出的数据。
反向钩子(Backward Hooks):在层的梯度计算过程中执行。它们用于检查或修改梯度值。
这段代码中的钩子使用
在提供的代码段中,使用了一个前向钩子(save_hook)来保存神经网络某层在前向传播过程中的电压值(v)和脉冲值(s)。
这个钩子函数save_hook接收三个参数:
m:注册钩子的模块(在这个例子中是输出层)。
x:输入到该模块的数据。
y:从该模块输出的数据。
在钩子函数内部,它将模块m的电压值v和输出脉冲y保存到列表中。这里使用unsqueeze(0)是为了增加一个批次维度,使得每次迭代的数据可以被堆叠起来。
钩子的注册
这行代码将save_hook函数注册为output_layer(网络的最后一层)的前向钩子。这意味着每当output_layer完成前向传播时,save_hook函数都会被调用。
数据的保存
在所有测试图像通过网络并且钩子函数被调用之后,v_seq和s_seq列表中的数据被合并(使用torch.cat)并转换为NumPy数组,然后通过np.save保存到文件中。这些文件包含了在整个测试集上,输出层神经元的电压值和脉冲发放情况,可以用于进一步的分析和可视化。】
这段代码通过注册一个前向钩子来捕获并保存神经网络最后一层在前向传播过程中的电压和脉冲数据。这种方法非常有用,因为它允许在不修改网络结构的情况下收集内部状态信息,对于理解和分析网络的行为非常有帮助。

这篇关于spikingjelly学习-训练网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/890024

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Linux网络配置之网桥和虚拟网络的配置指南

《Linux网络配置之网桥和虚拟网络的配置指南》这篇文章主要为大家详细介绍了Linux中配置网桥和虚拟网络的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、网桥的配置在linux系统中配置一个新的网桥主要涉及以下几个步骤:1.为yum仓库做准备,安装组件epel-re

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Linux高并发场景下的网络参数调优实战指南

《Linux高并发场景下的网络参数调优实战指南》在高并发网络服务场景中,Linux内核的默认网络参数往往无法满足需求,导致性能瓶颈、连接超时甚至服务崩溃,本文基于真实案例分析,从参数解读、问题诊断到优... 目录一、问题背景:当并发连接遇上性能瓶颈1.1 案例环境1.2 初始参数分析二、深度诊断:连接状态与

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.