[deeplearning-015]一文理解rnn

2024-06-11 09:18

本文主要是介绍[deeplearning-015]一文理解rnn,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.参考文档 https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/

  前提,熟悉bp算法推导,熟悉bptt算法推导

2.py源码

#!/usr/bin/env python3#参考文献 https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/import copy, numpy as npnp.random.seed(0)# compute sigmoid nonlinearity
def sigmoid(x):output = 1 / (1 + np.exp(-x))return output# convert output of sigmoid function to its derivative
def sigmoid_output_to_derivative(output):return output * (1 - output)# training dataset generation
int2binary = {}#空字典
binary_dim = 8#=256
largest_number = pow(2, binary_dim)binary = np.unpackbits(np.array([range(largest_number)], dtype=np.uint8).T, axis = 1)
for i in range(largest_number):int2binary[i] = binary[i]# input variables
alpha = 0.1
input_dim = 2
hidden_dim = 16
output_dim = 1# initialize neural network weights
synapse_0 = 2 * np.random.random((input_dim, hidden_dim)) - 1 #把值调整到(-1,1)
synapse_1 = 2 * np.random.random((hidden_dim, output_dim)) - 1
synapse_h = 2 * np.random.random((hidden_dim, hidden_dim)) - 1
synapse_0_update = np.zeros_like(synapse_0)
synapse_1_update = np.zeros_like(synapse_1)
synapse_h_update = np.zeros_like(synapse_h)# 随机生成ab两个整数,c=a+b。将abc转成二进制,a和b的二进制是rnn的输入,c的二进制是目标值,训练rnn。# training logic
for j in range(50000):# generate a simple addition problem (a + b = c)a_int = np.random.randint(largest_number / 2)  # int versiona = int2binary[a_int]  # binary encodingb_int = np.random.randint(largest_number / 2)  # int versionb = int2binary[b_int]  # binary encoding# true answerc_int = a_int + b_intc = int2binary[c_int]# print(a_int, a, b_int, b,c_int,c)# exit(1)# where we'll store our best guess (binary encoded)d = np.zeros_like(c)overallError = 0#记录第二层的deta值layer_2_deltas = list()#记录第一层的输出值layer_1_values = list()layer_1_values.append(np.zeros(hidden_dim))# moving along the positions in the binary encodingfor position in range(binary_dim):# generate input and output# 从abc三个数的二进制的最低位开始,每次取一位,生成训练样本# print("\nposition=", position)# print("a=",a)# print("b=",b)# print("c=", c)#比如,取第一位的时候,X=[[1,0]],y=[[1]],是训练样本和目标值X = np.array([[a[binary_dim - position - 1], b[binary_dim - position - 1]]])y = np.array([[c[binary_dim - position - 1]]]).T# print('X=',X)# print('y=',y)# hidden layer (input ~+ prev_hidden)# 计算第一层的输出值:sigmoid里面有两个部分,分别是input*synapse_0和前次第一层输出值*synapse_h隐层layer_1 = sigmoid(np.dot(X, synapse_0) + np.dot(layer_1_values[-1], synapse_h))# output layer (new binary representation)#第二层输出就很直接了layer_2 = sigmoid(np.dot(layer_1, synapse_1))# did we miss?... if so, by how much?# 计算输出误差layer_2_error = y - layer_2#根据误差,计算第二层权重系数的deta,这个公式是对第二层微分求导得到的,熟悉BP算法的都清楚,不在叙述layer_2_deltas.append((layer_2_error) * sigmoid_output_to_derivative(layer_2))#计算绝对误差之和,以衡量8个位的总误差大小overallError += np.abs(layer_2_error[0])# decode estimate so we can print it out# 记录下 rnn对当前位的预测值d[binary_dim - position - 1] = np.round(layer_2[0][0])# store hidden layer so we can use it in the next timestep# 把本轮第一层的输出记录到layer_1_values,以备bptt使用。layer_1_values.append(copy.deepcopy(layer_1))#计算第一个隐层的deltafuture_layer_1_delta = np.zeros(hidden_dim)for position in range(binary_dim):#注意喽,X这一轮循环是从ab的二进制的高位开始往低位走的X = np.array([[a[position], b[position]]])#高位二进制是后计算的,因此其值在layer_1的尾部layer_1 = layer_1_values[-position - 1]prev_layer_1 = layer_1_values[-position - 2]# error at output layerlayer_2_delta = layer_2_deltas[-position - 1]# error at hidden layerlayer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(synapse_1.T)) * sigmoid_output_to_derivative(layer_1)# let's update all our weights so we can try againsynapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)synapse_h_update += np.atleast_2d(prev_layer_1).T.dot(layer_1_delta)synapse_0_update += X.T.dot(layer_1_delta)future_layer_1_delta = layer_1_deltasynapse_0 += synapse_0_update * alphasynapse_1 += synapse_1_update * alphasynapse_h += synapse_h_update * alphasynapse_0_update *= 0synapse_1_update *= 0synapse_h_update *= 0# print out progressif (j % 1000 == 0):print("Error:" + str(overallError))print("Pred:" + str(d))print("True:" + str(c))out = 0for index, x in enumerate(reversed(d)):out += x * pow(2, index)print(str(a_int) + " + " + str(b_int) + " = " + str(out))print("------------")

 

这篇关于[deeplearning-015]一文理解rnn的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1050761

相关文章

一文解析C#中的StringSplitOptions枚举

《一文解析C#中的StringSplitOptions枚举》StringSplitOptions是C#中的一个枚举类型,用于控制string.Split()方法分割字符串时的行为,核心作用是处理分割后... 目录C#的StringSplitOptions枚举1.StringSplitOptions枚举的常用

一文详解Python如何开发游戏

《一文详解Python如何开发游戏》Python是一种非常流行的编程语言,也可以用来开发游戏模组,:本文主要介绍Python如何开发游戏的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、python简介二、Python 开发 2D 游戏的优劣势优势缺点三、Python 开发 3D

深入理解Mysql OnlineDDL的算法

《深入理解MysqlOnlineDDL的算法》本文主要介绍了讲解MysqlOnlineDDL的算法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小... 目录一、Online DDL 是什么?二、Online DDL 的三种主要算法2.1COPY(复制法)

一文详解MySQL索引(六张图彻底搞懂)

《一文详解MySQL索引(六张图彻底搞懂)》MySQL索引的建立对于MySQL的高效运行是很重要的,索引可以大大提高MySQL的检索速度,:本文主要介绍MySQL索引的相关资料,文中通过代码介绍的... 目录一、什么是索引?为什么需要索引?二、索引该用哪种数据结构?1. 哈希表2. 跳表3. 二叉排序树4.

一文带你迅速搞懂路由器/交换机/光猫三者概念区别

《一文带你迅速搞懂路由器/交换机/光猫三者概念区别》讨论网络设备时,常提及路由器、交换机及光猫等词汇,日常生活、工作中,这些设备至关重要,居家上网、企业内部沟通乃至互联网冲浪皆无法脱离其影响力,本文将... 当谈论网络设备时,我们常常会听到路由器、交换机和光猫这几个名词。它们是构建现代网络基础设施的关键组成

深入理解go中interface机制

《深入理解go中interface机制》本文主要介绍了深入理解go中interface机制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录前言interface使用类型判断总结前言go的interface是一组method的集合,不

Java Spring的依赖注入理解及@Autowired用法示例详解

《JavaSpring的依赖注入理解及@Autowired用法示例详解》文章介绍了Spring依赖注入(DI)的概念、三种实现方式(构造器、Setter、字段注入),区分了@Autowired(注入... 目录一、什么是依赖注入(DI)?1. 定义2. 举个例子二、依赖注入的几种方式1. 构造器注入(Con

一文解密Python进行监控进程的黑科技

《一文解密Python进行监控进程的黑科技》在计算机系统管理和应用性能优化中,监控进程的CPU、内存和IO使用率是非常重要的任务,下面我们就来讲讲如何Python写一个简单使用的监控进程的工具吧... 目录准备工作监控CPU使用率监控内存使用率监控IO使用率小工具代码整合在计算机系统管理和应用性能优化中,监

一文详解如何使用Java获取PDF页面信息

《一文详解如何使用Java获取PDF页面信息》了解PDF页面属性是我们在处理文档、内容提取、打印设置或页面重组等任务时不可或缺的一环,下面我们就来看看如何使用Java语言获取这些信息吧... 目录引言一、安装和引入PDF处理库引入依赖二、获取 PDF 页数三、获取页面尺寸(宽高)四、获取页面旋转角度五、判断

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态