MNIST3_numpy手写全连接神经网络

2023-11-02 14:32

本文主要是介绍MNIST3_numpy手写全连接神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、链式求导
  • 二、numpy layer和反向传播
    • 反向传播
  • 三、MNIST训练及测试

一、链式求导

在这里插入图片描述

二、numpy layer和反向传播

全部脚本见笔者github: numpynn.py


import numpy as npclass npLayer():def __init__(self, n_input, n_out, activation=None, weights=None,bias=None):self.weights = weights if weights is not None else np.random.randn(n_input, n_out) * np.sqrt(1 / n_out)self.bias = bias if bias is not None else np.random.randn(n_out) * 0.1self.activation = activation self.last_activation = None self.error = None self.delta = None def activate(self, x):# 前向传播r = np.dot(x, self.weights) + self.biasself.last_activation = self.apply_activation(r)return self.last_activation def apply_activation(self, r):# 计算激活函数的输出if self.activation is None:return relif self.activation == 'relu':return np.maximum(r, 0)elif self.activation == 'tanh':return np.tanh(r)elif self.activation == 'sigmoid':return 1/(1 + np.exp(-r))return rdef apply_activation_derivative(self, act_r):# 计算激活函数的导数if self.activation is None:return np.ones_like(act_r)elif self.activation == 'relu':return (act_r > 0) * 1elif self.activation == 'tanh':return 1 - act_r ** 2elif self.activation == 'sigmoid':return act_r * (1 - act_r)return act_rdef __call__(self, x):return self.activate(x)

反向传播

    def backpropagation(self, x, y, learning_rate):# 反向传播算法实现## 从后向前计算梯度 output = self.feed_forward(x) # 最后层输出layer_len = len(self._layers)for i in reversed(range(layer_len)):layer = self._layers[i] # 如果是输出层if layer  == self._layers[-1]:delta_i = layer.apply_activation_derivative(output)layer.error = output - ylayer.delta = layer.error * delta_ielse:next_layer = self._layers[i + 1]delta_i = layer.apply_activation_derivative(layer.last_activation)layer.error = np.dot(next_layer.weights, next_layer.delta)layer.delta = layer.error * delta_i# 梯度下降for i in range(layer_len):layer = self._layers[i]o_i = np.atleast_2d(x if i == 0 else self._layers[i - 1].last_activation)layer.weights -= layer.delta * o_i.T * learning_rate

三、MNIST训练及测试


if __name__ == '__main__':mnistdf = get_ministdata()te_index = mnistdf.sample(frac=0.8).index.tolist()mnist_te = mnistdf.loc[te_index, :]mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]x_tr, y_tr = mnist_tr.iloc[:, :-1].values, mnist_tr.iloc[:, -1].valuesx_te, y_te = mnist_te.iloc[:, :-1].values, mnist_te.iloc[:, -1].valuesprint(x_te.shape)nn = NeuralNetwork()nn.add_layer(npLayer(784, 128, 'relu')) nn.add_layer(npLayer(128, 10, 'sigmoid'))st = time.perf_counter()mses, accs = nn.train(x_tr, x_te, y_tr, y_te, 0.01, 150)cost_ = time.perf_counter() - stprint(f'cost: {cost_:.2f}s',accs)
 ================================================================================
Epoch: # 85, MSE: 0.00713
Accuracy: 93.93 % ================================================================================
Epoch: # 90, MSE: 0.00654
Accuracy: 94.09 % ================================================================================
Epoch: # 95, MSE: 0.00600
Accuracy: 94.27 % ================================================================================
Epoch: # 100, MSE: 0.00558
Accuracy: 94.41 % ================================================================================
Epoch: # 105, MSE: 0.00514
Accuracy: 94.53 % ================================================================================
Epoch: # 110, MSE: 0.00479
Accuracy: 94.65 % ================================================================================
Epoch: # 115, MSE: 0.00447
Accuracy: 94.75 % ================================================================================
Epoch: # 120, MSE: 0.00417
Accuracy: 94.84 % ================================================================================
Epoch: # 125, MSE: 0.00393
Accuracy: 94.93 % ================================================================================
Epoch: # 130, MSE: 0.00370
Accuracy: 94.98 % ================================================================================
Epoch: # 135, MSE: 0.00350
Accuracy: 95.03 %================================================================================
Epoch: # 140, MSE: 0.00332
Accuracy: 95.08 %================================================================================
Epoch: # 145, MSE: 0.00316
Accuracy: 95.12 %================================================================================
Epoch: # 150, MSE: 0.00303
Accuracy: 95.14 %
cost: 1104.11s [0.2034285714285714, 0.5135714285714286, 0.5907142857142857, 0.6798928571428572, 0.74375, 0.7954285714285715
, 0.8364821428571428, 0.863125, 0.8833571428571428, 0.8975178571428571, 0.9077857142857142, 0.9149285714285714, 0.9213214285714286
, 0.9264821428571427, 0.9302142857142858, 0.9336071428571429, 0.9372678571428571, 0.9392857142857143, 0.9408928571428572, 0.9427321428571429
, 0.9440535714285714, 0.94525, 0.9465178571428572, 0.9475178571428572, 0.9483571428571429, 0.9493035714285715, 0.9498214285714286
, 0.9502857142857143, 0.95075, 0.9511607142857144, 0.9513571428571429]

这篇关于MNIST3_numpy手写全连接神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/331270

相关文章

java.sql.SQLTransientConnectionException连接超时异常原因及解决方案

《java.sql.SQLTransientConnectionException连接超时异常原因及解决方案》:本文主要介绍java.sql.SQLTransientConnectionExcep... 目录一、引言二、异常信息分析三、可能的原因3.1 连接池配置不合理3.2 数据库负载过高3.3 连接泄漏

Mac电脑如何通过 IntelliJ IDEA 远程连接 MySQL

《Mac电脑如何通过IntelliJIDEA远程连接MySQL》本文详解Mac通过IntelliJIDEA远程连接MySQL的步骤,本文通过图文并茂的形式给大家介绍的非常详细,感兴趣的朋友跟... 目录MAC电脑通过 IntelliJ IDEA 远程连接 mysql 的详细教程一、前缀条件确认二、打开 ID

Go语言连接MySQL数据库执行基本的增删改查

《Go语言连接MySQL数据库执行基本的增删改查》在后端开发中,MySQL是最常用的关系型数据库之一,本文主要为大家详细介绍了如何使用Go连接MySQL数据库并执行基本的增删改查吧... 目录Go语言连接mysql数据库准备工作安装 MySQL 驱动代码实现运行结果注意事项Go语言执行基本的增删改查准备工作

python连接sqlite3简单用法完整例子

《python连接sqlite3简单用法完整例子》SQLite3是一个内置的Python模块,可以通过Python的标准库轻松地使用,无需进行额外安装和配置,:本文主要介绍python连接sqli... 目录1. 连接到数据库2. 创建游标对象3. 创建表4. 插入数据5. 查询数据6. 更新数据7. 删除

在 Spring Boot 中连接 MySQL 数据库的详细步骤

《在SpringBoot中连接MySQL数据库的详细步骤》本文介绍了SpringBoot连接MySQL数据库的流程,添加依赖、配置连接信息、创建实体类与仓库接口,通过自动配置实现数据库操作,... 目录一、添加依赖二、配置数据库连接三、创建实体类四、创建仓库接口五、创建服务类六、创建控制器七、运行应用程序八

解决hive启动时java.net.ConnectException:拒绝连接的问题

《解决hive启动时java.net.ConnectException:拒绝连接的问题》Hadoop集群连接被拒,需检查集群是否启动、关闭防火墙/SELinux、确认安全模式退出,若问题仍存,查看日志... 目录错误发生原因解决方式1.关闭防火墙2.关闭selinux3.启动集群4.检查集群是否正常启动5.

在Linux系统上连接GitHub的方法步骤(适用2025年)

《在Linux系统上连接GitHub的方法步骤(适用2025年)》在2025年,使用Linux系统连接GitHub的推荐方式是通过SSH(SecureShell)协议进行身份验证,这种方式不仅安全,还... 目录步骤一:检查并安装 Git步骤二:生成 SSH 密钥步骤三:将 SSH 公钥添加到 github

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连

C#连接SQL server数据库命令的基本步骤

《C#连接SQLserver数据库命令的基本步骤》文章讲解了连接SQLServer数据库的步骤,包括引入命名空间、构建连接字符串、使用SqlConnection和SqlCommand执行SQL操作,... 目录建议配合使用:如何下载和安装SQL server数据库-CSDN博客1. 引入必要的命名空间2.

Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式

《Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式》本文详细介绍如何使用Java通过JDBC连接MySQL数据库,包括下载驱动、配置Eclipse环境、检测数据库连接等关键步骤,... 目录一、下载驱动包二、放jar包三、检测数据库连接JavaJava 如何使用 JDBC 连接 mys