nn.Sequential与tensorflow的Sequential对比

2024-04-21 08:04

本文主要是介绍nn.Sequential与tensorflow的Sequential对比,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

nn.Sequential() 是 PyTorch 深度学习框架中的一个类,用于按顺序容器化模块。nn.Sequential 是一个有序的容器,它包含多个网络层,数据会按照在构造函数中传入顺序依次通过每个层。在 nn.Sequential 中,不需要定义 forward 方法,因为当你调用它时,会按照顺序调用每个子模块。

import torch  
import torch.nn as nn  # 创建一个简单的神经网络  
model = nn.Sequential(  nn.Linear(784, 128),  # 输入层到隐藏层,784个输入节点,128个输出节点  nn.ReLU(),           # 激活函数  nn.Linear(128, 10)   # 隐藏层到输出层,128个输入节点,10个输出节点  
)  # 创建一个随机的输入张量  
input_tensor = torch.randn(1, 784)  # 将输入张量传递给模型  
output_tensor = model(input_tensor)  print(output_tensor.size())  # 输出应为 torch.Size([1, 10])
import tensorflow as tf  
from tensorflow.keras.datasets import mnist  
from tensorflow.keras.models import Sequential  
from tensorflow.keras.layers import Dense, Flatten  
from tensorflow.keras.losses import SparseCategoricalCrossentropy  
from tensorflow.keras.optimizers import Adam  # 加载MNIST数据集  
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()  # 归一化像素值到0-1之间  
train_images = train_images / 255.0  
test_images = test_images / 255.0  # 将图像的形状展平,以便可以输入到神经网络中  
train_images = train_images[..., tf.newaxis]  
test_images = test_images[..., tf.newaxis]  # 创建Sequential模型  
model = Sequential([  Flatten(input_shape=(28, 28, 1)),  # 将28x28像素的图像展平为一维张量  Dense(128, activation='relu'),      # 全连接层,128个神经元,使用ReLU激活函数  Dense(10)                           # 输出层,10个神经元(对应10个数字类别)  
])  # 编译模型  
model.compile(optimizer=Adam(),  loss=SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])  # 训练模型  
model.fit(train_images, train_labels, epochs=5)  # 评估模型  
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)  print('\nTest accuracy:', test_acc)

self.conv 通常指的是类的一个属性,这个属性通常是一个卷积层(Convolutional Layer)

import torch  
import torch.nn as nn  
import torch.nn.functional as F  class SimpleCNN(nn.Module):  def __init__(self):  super(SimpleCNN, self).__init__()  # 定义一个卷积层,输入通道数为3(例如RGB图像),输出通道数为16,卷积核大小为3x3  self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 可以添加其他层...  def forward(self, x):  # 在前向传播中使用卷积层  x = self.conv(x)  # 可以添加其他操作,例如激活函数、池化等...  x = F.relu(x)  # 返回处理后的输出  return x  # 创建模型实例  
model = SimpleCNN()  # 假设我们有一个输入张量 input_tensor,其形状为 [batch_size, 3, height, width]  
# 我们可以将其传递给模型进行前向传播  
output_tensor = model(input_tensor)

这篇关于nn.Sequential与tensorflow的Sequential对比的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/922564

相关文章

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

详解MySQL中JSON数据类型用法及与传统JSON字符串对比

《详解MySQL中JSON数据类型用法及与传统JSON字符串对比》MySQL从5.7版本开始引入了JSON数据类型,专门用于存储JSON格式的数据,本文将为大家简单介绍一下MySQL中JSON数据类型... 目录前言基本用法jsON数据类型 vs 传统JSON字符串1. 存储方式2. 查询方式对比3. 索引

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

关于MyISAM和InnoDB对比分析

《关于MyISAM和InnoDB对比分析》:本文主要介绍关于MyISAM和InnoDB对比分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录开篇:从交通规则看存储引擎选择理解存储引擎的基本概念技术原理对比1. 事务支持:ACID的守护者2. 锁机制:并发控制的艺

CSS中的Static、Relative、Absolute、Fixed、Sticky的应用与详细对比

《CSS中的Static、Relative、Absolute、Fixed、Sticky的应用与详细对比》CSS中的position属性用于控制元素的定位方式,不同的定位方式会影响元素在页面中的布... css 中的 position 属性用于控制元素的定位方式,不同的定位方式会影响元素在页面中的布局和层叠关

Linux中的more 和 less区别对比分析

《Linux中的more和less区别对比分析》在Linux/Unix系统中,more和less都是用于分页查看文本文件的命令,但less是more的增强版,功能更强大,:本文主要介绍Linu... 目录1. 基础功能对比2. 常用操作对比less 的操作3. 实际使用示例4. 为什么推荐 less?5.

基于Go语言实现Base62编码的三种方式以及对比分析

《基于Go语言实现Base62编码的三种方式以及对比分析》Base62编码是一种在字符编码中使用62个字符的编码方式,在计算机科学中,,Go语言是一种静态类型、编译型语言,它由Google开发并开源,... 目录一、标准库现状与解决方案1. 标准库对比表2. 解决方案完整实现代码(含边界处理)二、关键实现细

PostgreSQL 序列(Sequence) 与 Oracle 序列对比差异分析

《PostgreSQL序列(Sequence)与Oracle序列对比差异分析》PostgreSQL和Oracle都提供了序列(Sequence)功能,但在实现细节和使用方式上存在一些重要差异,... 目录PostgreSQL 序列(Sequence) 与 oracle 序列对比一 基本语法对比1.1 创建序