Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇

本文主要是介绍Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 1.数据处理部分
      • 1.1 下载数据集
      • 1.2 数据集预处理
      • 1.3 划分train-val并创建Dataset对象
      • 1.4 掩码mask操作

数据、评估标准见NLB2021
https://neurallatents.github.io/

以下代码依据
https://github.com/trungle93/STNDT

原代码使用了 Ray+Config文件进行了参数搜索,库依赖较多,数据流过程不明显,代码冗杂,这里进行了抽丝剥茧,将其中最核心的部分提取出来。

1.数据处理部分

1.1 下载数据集

需要依赖 pip install dandi
downald.py

root = "D:/NeuralLatent/"
def downald_data():from dandi.download import downloaddownload("https://dandiarchive.org/dandiset/000128", root)download("https://dandiarchive.org/dandiset/000138", root)download("https://dandiarchive.org/dandiset/000139", root)download("https://dandiarchive.org/dandiset/000140", root)download("https://dandiarchive.org/dandiset/000129", root)download("https://dandiarchive.org/dandiset/000127", root)download("https://dandiarchive.org/dandiset/000130", root)

1.2 数据集预处理

需要依赖官方工具包pip install nlb_tools
主要是加载锋值序列数据,将其采样为5ms的时间槽
preprocess.py

## 以下为参数示例
# data_path = root + "/000129/sub-Indy/"
# dataset_name = "mc_rtt"
## 注意 "./data" 必须提前创建好from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, combine_h5def preprocess(data_path, dataset_name=None):dataset = NWBDataset(datapath)bin_width = 5dataset.resample(bin_width)make_train_input_tensors(dataset, dataset_name=dataset_name, trial_split="train", include_behavior=True, include_forward_pred=True, save_file=True,save_path=f"./data/{dataset_name}_train.h5")make_eval_input_tensors(dataset, dataset_name=dataset_name, trial_split="val", save_file=True, save_path=f"./data/{dataset_name}_val.h5")combine_h5([f"./data/{dataset_name}_train.h5", f"./data/{dataset_name}_val.h5"], save_path=f"./data/{dataset_name}_full.h5")## './data/mc_rtt_full.h5' 将成为后续的主要分析数据

1.3 划分train-val并创建Dataset对象

读取'./data/mc_rtt_full.h5'中的数据并创建dataset
dataset.py

import h5py
import numpy as np
import torch
from torch.utils import data
# data_path = "./data/mc_rtt_full.h5"class SpikesDataset(data.Dataset):def __init__(self, spikes, heldout_spikes, forward_spikes) -> None:self.spikes = spikesself.heldout_spikes = heldout_spikesself.forward_spikes = forward_spikesdef __len__(self):return self.spikes.size(0)def __getitem__(self, index):r"""Return spikes and rates, shaped T x N (num_neurons)"""return self.spikes[index], self.heldout_spikes[index], self.forward_spikes[index]def make_datasets(data_path):with h5py.File(data_path, 'r') as h5file:h5dict = {key: h5file[key][()] for key in h5file.keys()}if 'eval_spikes_heldin' in h5dict: # NLB dataget_key = lambda key: h5dict[key].astype(np.float32)train_data = get_key('train_spikes_heldin')train_data_fp = get_key('train_spikes_heldin_forward')train_data_heldout_fp = get_key('train_spikes_heldout_forward')train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)valid_data = get_key('eval_spikes_heldin')train_data_heldout = get_key('train_spikes_heldout')if 'eval_spikes_heldout' in h5dict:valid_data_heldout = get_key('eval_spikes_heldout')else:valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)if 'eval_spikes_heldin_forward' in h5dict:valid_data_fp = get_key('eval_spikes_heldin_forward')valid_data_heldout_fp = get_key('eval_spikes_heldout_forward')valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)else:valid_data_all_fp = np.zeros((valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32)train_dataset = SpikesDataset(torch.tensor(train_data).long(),            # [810, 120, 98]torch.tensor(train_data_heldout).long(),    # [810, 120, 32]torch.tensor(train_data_all_fp).long(),     # [810, 40, 130])val_dataset = SpikesDataset(torch.tensor(valid_data).long(),            # [810, 120, 98]torch.tensor(valid_data_heldout).long(),    # [810, 120, 32]torch.tensor(valid_data_all_fp).long(),     # [810, 40, 130])return train_dataset, val_dataset

1.4 掩码mask操作

dataset.py

# Some infeasibly high spike count
UNMASKED_LABEL = -100def mask_batch(batch, heldout_spikes, forward_spikes):batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)mask_ratio = 0.31254mask_random_ratio = 0.876mask_token_ratio = 0.527labels = batch.clone()mask_probs = torch.full(labels.shape, mask_ratio)# If we want any tokens to not get masked, do it here (but we don't currently have any)mask = torch.bernoulli(mask_probs)mask = mask.bool()labels[~mask] = UNMASKED_LABEL  # No ground truth for unmasked - use this to mask loss# We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context# Most times, we replace tokens with MASK tokenindices_replaced = torch.bernoulli(torch.full(labels.shape, mask_token_ratio)).bool() & maskbatch[indices_replaced] = 0# Random % of the time, we replace masked input tokens with random value (the rest are left intact)indices_random = torch.bernoulli(torch.full(labels.shape, mask_random_ratio)).bool() & mask & ~indices_replacedrandom_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long)batch[indices_random] = random_spikes[indices_random]# heldout spikes are all maskedbatch = torch.cat([batch, torch.zeros_like(heldout_spikes)], -1)labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)batch = torch.cat([batch, torch.zeros_like(forward_spikes)], 1)labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)# Leave the other 10% alonereturn batch, labels

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906187

这篇关于Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088377

相关文章

Java中流式并行操作parallelStream的原理和使用方法

《Java中流式并行操作parallelStream的原理和使用方法》本文详细介绍了Java中的并行流(parallelStream)的原理、正确使用方法以及在实际业务中的应用案例,并指出在使用并行流... 目录Java中流式并行操作parallelStream0. 问题的产生1. 什么是parallelS

Linux join命令的使用及说明

《Linuxjoin命令的使用及说明》`join`命令用于在Linux中按字段将两个文件进行连接,类似于SQL的JOIN,它需要两个文件按用于匹配的字段排序,并且第一个文件的换行符必须是LF,`jo... 目录一. 基本语法二. 数据准备三. 指定文件的连接key四.-a输出指定文件的所有行五.-o指定输出

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Redis 基本数据类型和使用详解

《Redis基本数据类型和使用详解》String是Redis最基本的数据类型,一个键对应一个值,它的功能十分强大,可以存储字符串、整数、浮点数等多种数据格式,本文给大家介绍Redis基本数据类型和... 目录一、Redis 入门介绍二、Redis 的五大基本数据类型2.1 String 类型2.2 Hash

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Linux创建服务使用systemctl管理详解

《Linux创建服务使用systemctl管理详解》文章指导在Linux中创建systemd服务,设置文件权限为所有者读写、其他只读,重新加载配置,启动服务并检查状态,确保服务正常运行,关键步骤包括权... 目录创建服务 /usr/lib/systemd/system/设置服务文件权限:所有者读写js,其他