MNIST3_numpy手写全连接神经网络

2023-11-02 14:32

本文主要是介绍MNIST3_numpy手写全连接神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、链式求导
  • 二、numpy layer和反向传播
    • 反向传播
  • 三、MNIST训练及测试

一、链式求导

在这里插入图片描述

二、numpy layer和反向传播

全部脚本见笔者github: numpynn.py


import numpy as npclass npLayer():def __init__(self, n_input, n_out, activation=None, weights=None,bias=None):self.weights = weights if weights is not None else np.random.randn(n_input, n_out) * np.sqrt(1 / n_out)self.bias = bias if bias is not None else np.random.randn(n_out) * 0.1self.activation = activation self.last_activation = None self.error = None self.delta = None def activate(self, x):# 前向传播r = np.dot(x, self.weights) + self.biasself.last_activation = self.apply_activation(r)return self.last_activation def apply_activation(self, r):# 计算激活函数的输出if self.activation is None:return relif self.activation == 'relu':return np.maximum(r, 0)elif self.activation == 'tanh':return np.tanh(r)elif self.activation == 'sigmoid':return 1/(1 + np.exp(-r))return rdef apply_activation_derivative(self, act_r):# 计算激活函数的导数if self.activation is None:return np.ones_like(act_r)elif self.activation == 'relu':return (act_r > 0) * 1elif self.activation == 'tanh':return 1 - act_r ** 2elif self.activation == 'sigmoid':return act_r * (1 - act_r)return act_rdef __call__(self, x):return self.activate(x)

反向传播

    def backpropagation(self, x, y, learning_rate):# 反向传播算法实现## 从后向前计算梯度 output = self.feed_forward(x) # 最后层输出layer_len = len(self._layers)for i in reversed(range(layer_len)):layer = self._layers[i] # 如果是输出层if layer  == self._layers[-1]:delta_i = layer.apply_activation_derivative(output)layer.error = output - ylayer.delta = layer.error * delta_ielse:next_layer = self._layers[i + 1]delta_i = layer.apply_activation_derivative(layer.last_activation)layer.error = np.dot(next_layer.weights, next_layer.delta)layer.delta = layer.error * delta_i# 梯度下降for i in range(layer_len):layer = self._layers[i]o_i = np.atleast_2d(x if i == 0 else self._layers[i - 1].last_activation)layer.weights -= layer.delta * o_i.T * learning_rate

三、MNIST训练及测试


if __name__ == '__main__':mnistdf = get_ministdata()te_index = mnistdf.sample(frac=0.8).index.tolist()mnist_te = mnistdf.loc[te_index, :]mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]x_tr, y_tr = mnist_tr.iloc[:, :-1].values, mnist_tr.iloc[:, -1].valuesx_te, y_te = mnist_te.iloc[:, :-1].values, mnist_te.iloc[:, -1].valuesprint(x_te.shape)nn = NeuralNetwork()nn.add_layer(npLayer(784, 128, 'relu')) nn.add_layer(npLayer(128, 10, 'sigmoid'))st = time.perf_counter()mses, accs = nn.train(x_tr, x_te, y_tr, y_te, 0.01, 150)cost_ = time.perf_counter() - stprint(f'cost: {cost_:.2f}s',accs)
 ================================================================================
Epoch: # 85, MSE: 0.00713
Accuracy: 93.93 % ================================================================================
Epoch: # 90, MSE: 0.00654
Accuracy: 94.09 % ================================================================================
Epoch: # 95, MSE: 0.00600
Accuracy: 94.27 % ================================================================================
Epoch: # 100, MSE: 0.00558
Accuracy: 94.41 % ================================================================================
Epoch: # 105, MSE: 0.00514
Accuracy: 94.53 % ================================================================================
Epoch: # 110, MSE: 0.00479
Accuracy: 94.65 % ================================================================================
Epoch: # 115, MSE: 0.00447
Accuracy: 94.75 % ================================================================================
Epoch: # 120, MSE: 0.00417
Accuracy: 94.84 % ================================================================================
Epoch: # 125, MSE: 0.00393
Accuracy: 94.93 % ================================================================================
Epoch: # 130, MSE: 0.00370
Accuracy: 94.98 % ================================================================================
Epoch: # 135, MSE: 0.00350
Accuracy: 95.03 %================================================================================
Epoch: # 140, MSE: 0.00332
Accuracy: 95.08 %================================================================================
Epoch: # 145, MSE: 0.00316
Accuracy: 95.12 %================================================================================
Epoch: # 150, MSE: 0.00303
Accuracy: 95.14 %
cost: 1104.11s [0.2034285714285714, 0.5135714285714286, 0.5907142857142857, 0.6798928571428572, 0.74375, 0.7954285714285715
, 0.8364821428571428, 0.863125, 0.8833571428571428, 0.8975178571428571, 0.9077857142857142, 0.9149285714285714, 0.9213214285714286
, 0.9264821428571427, 0.9302142857142858, 0.9336071428571429, 0.9372678571428571, 0.9392857142857143, 0.9408928571428572, 0.9427321428571429
, 0.9440535714285714, 0.94525, 0.9465178571428572, 0.9475178571428572, 0.9483571428571429, 0.9493035714285715, 0.9498214285714286
, 0.9502857142857143, 0.95075, 0.9511607142857144, 0.9513571428571429]

这篇关于MNIST3_numpy手写全连接神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/331270

相关文章

java连接opcua的常见问题及解决方法

《java连接opcua的常见问题及解决方法》本文将使用EclipseMilo作为示例库,演示如何在Java中使用匿名、用户名密码以及证书加密三种方式连接到OPCUA服务器,若需要使用其他SDK,原理... 目录一、前言二、准备工作三、匿名方式连接3.1 匿名方式简介3.2 示例代码四、用户名密码方式连接4

MySQL 表的内外连接案例详解

《MySQL表的内外连接案例详解》本文给大家介绍MySQL表的内外连接,结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录表的内外连接(重点)内连接外连接表的内外连接(重点)内连接内连接实际上就是利用where子句对两种表形成的笛卡儿积进行筛选,我

Apache 高级配置实战之从连接保持到日志分析的完整指南

《Apache高级配置实战之从连接保持到日志分析的完整指南》本文带你从连接保持优化开始,一路走到访问控制和日志管理,最后用AWStats来分析网站数据,对Apache配置日志分析相关知识感兴趣的朋友... 目录Apache 高级配置实战:从连接保持到日志分析的完整指南前言 一、Apache 连接保持 - 性

电脑蓝牙连不上怎么办? 5 招教你轻松修复Mac蓝牙连接问题的技巧

《电脑蓝牙连不上怎么办?5招教你轻松修复Mac蓝牙连接问题的技巧》蓝牙连接问题是一些Mac用户经常遇到的常见问题之一,在本文章中,我们将提供一些有用的提示和技巧,帮助您解决可能出现的蓝牙连接问... 蓝牙作为一种流行的无线技术,已经成为我们连接各种设备的重要工具。在 MAC 上,你可以根据自己的需求,轻松地

宝塔安装的MySQL无法连接的情况及解决方案

《宝塔安装的MySQL无法连接的情况及解决方案》宝塔面板是一款流行的服务器管理工具,其中集成的MySQL数据库有时会出现连接问题,本文详细介绍两种最常见的MySQL连接错误:“1130-Hostisn... 目录一、错误 1130:Host ‘xxx.xxx.xxx.xxx’ is not allowed

MySQL 多表连接操作方法(INNER JOIN、LEFT JOIN、RIGHT JOIN、FULL OUTER JOIN)

《MySQL多表连接操作方法(INNERJOIN、LEFTJOIN、RIGHTJOIN、FULLOUTERJOIN)》多表连接是一种将两个或多个表中的数据组合在一起的SQL操作,通过连接,... 目录一、 什么是多表连接?二、 mysql 支持的连接类型三、 多表连接的语法四、实战示例 数据准备五、连接的性

MySQL中的分组和多表连接详解

《MySQL中的分组和多表连接详解》:本文主要介绍MySQL中的分组和多表连接的相关操作,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录mysql中的分组和多表连接一、MySQL的分组(group javascriptby )二、多表连接(表连接会产生大量的数据垃圾)MySQL中的

MySQL中的交叉连接、自然连接和内连接查询详解

《MySQL中的交叉连接、自然连接和内连接查询详解》:本文主要介绍MySQL中的交叉连接、自然连接和内连接查询,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、引入二、交php叉连接(cross join)三、自然连接(naturalandroid join)四

python连接本地SQL server详细图文教程

《python连接本地SQLserver详细图文教程》在数据分析领域,经常需要从数据库中获取数据进行分析和处理,下面:本文主要介绍python连接本地SQLserver的相关资料,文中通过代码... 目录一.设置本地账号1.新建用户2.开启双重验证3,开启TCP/IP本地服务二js.python连接实例1.

Ubuntu中远程连接Mysql数据库的详细图文教程

《Ubuntu中远程连接Mysql数据库的详细图文教程》Ubuntu是一个以桌面应用为主的Linux发行版操作系统,这篇文章主要为大家详细介绍了Ubuntu中远程连接Mysql数据库的详细图文教程,有... 目录1、版本2、检查有没有mysql2.1 查询是否安装了Mysql包2.2 查看Mysql版本2.