DAY9-深度学习100例-循环神经网络(RNN)实现股票预测

2024-03-20 03:50

本文主要是介绍DAY9-深度学习100例-循环神经网络(RNN)实现股票预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


活动地址:CSDN21天学习挑战赛

  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 参考文章地址: 🔗深度学习100例-循环神经网络(RNN)实现股票预测 | 第9天

文章目录

  • 前言
  • 一、RNN
  • 二、准备工作
    • 1.设置GPU
    • 2.加载数据
  • 三、数据预处理
    • 1.归一化
    • 2.设置训练集、测试集
  • 四、构建模型
  • 五、激活模型
  • 六、训练模型
  • 七、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估


前言

今天会通过RNN实现股票开盘价格的预测。

一、RNN

传统神经网络的结构:输入层—隐藏层—输出层,结构简单。
在这里插入图片描述
RNN与传统神经网络的区别在于:每次会将前一次的输出结果带入隐含层中,一起参与训练。
在这里插入图片描述
举例:用户说了一句"What time is it?",神经网络会先将这句话分为五个基本单元(四个单词+一个问号)。
在这里插入图片描述
然后按照顺序输入到RNN中,先输入"What"得到输出O1,随后按照顺序将"time"输入,得到输出O2,可以看到"What"的输出对于O2来说也有影响,以此类推,前面所有输入的输出结果都对后续的输出产生了影响。当神经网络判断意图时,只需要最后一层的输出O5即可。

在这里插入图片描述

二、准备工作

1.设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2.加载数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('SH600519.csv')  # 读取股票文件data

在这里插入图片描述

"""
前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
后300天的开盘价作为测试集
"""
training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

三、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置训练集、测试集

x_train = []
y_train = []x_test = []
y_test = []"""
使用前60天的开盘价作为输入特征x_train第61天的开盘价作为输入标签y_trainfor循环共构建2426-300-60=2066组训练数据。共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):x_train.append(training_set[i - 60:i, 0])y_train.append(training_set[i, 0])for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

四、构建模型

model = tf.keras.Sequential([SimpleRNN(100, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。Dropout(0.1),                         #防止过拟合SimpleRNN(100),Dropout(0.1),Dense(1)
])

五、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 损失函数用均方误差

六、训练模型

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1)                  #测试的epoch间隔数model.summary()

在这里插入图片描述
在这里插入图片描述

七、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(01)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(01)反归一化到原始范围# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

"""
MSE  :均方误差    ----->  预测值减真实值求平方后求均值
RMSE :均方根误差  ----->  对均方误差开方
MAE  :平均绝对误差----->  预测值减真实值求绝对值后求均值
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量详细介绍可以参考文章:https://blog.csdn.net/qq_38251616/article/details/107997435
"""
MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 2832.56937
均方根误差: 53.22189
平均绝对误差: 46.33860
R2: 0.59097

这篇关于DAY9-深度学习100例-循环神经网络(RNN)实现股票预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/828195

相关文章

使用animation.css库快速实现CSS3旋转动画效果

《使用animation.css库快速实现CSS3旋转动画效果》随着Web技术的不断发展,动画效果已经成为了网页设计中不可或缺的一部分,本文将深入探讨animation.css的工作原理,如何使用以及... 目录1. css3动画技术简介2. animation.css库介绍2.1 animation.cs

Java进行日期解析与格式化的实现代码

《Java进行日期解析与格式化的实现代码》使用Java搭配ApacheCommonsLang3和Natty库,可以实现灵活高效的日期解析与格式化,本文将通过相关示例为大家讲讲具体的实践操作,需要的可以... 目录一、背景二、依赖介绍1. Apache Commons Lang32. Natty三、核心实现代

SpringBoot实现接口数据加解密的三种实战方案

《SpringBoot实现接口数据加解密的三种实战方案》在金融支付、用户隐私信息传输等场景中,接口数据若以明文传输,极易被中间人攻击窃取,SpringBoot提供了多种优雅的加解密实现方案,本文将从原... 目录一、为什么需要接口数据加解密?二、核心加解密算法选择1. 对称加密(AES)2. 非对称加密(R

基于Go语言实现Base62编码的三种方式以及对比分析

《基于Go语言实现Base62编码的三种方式以及对比分析》Base62编码是一种在字符编码中使用62个字符的编码方式,在计算机科学中,,Go语言是一种静态类型、编译型语言,它由Google开发并开源,... 目录一、标准库现状与解决方案1. 标准库对比表2. 解决方案完整实现代码(含边界处理)二、关键实现细

python通过curl实现访问deepseek的API

《python通过curl实现访问deepseek的API》这篇文章主要为大家详细介绍了python如何通过curl实现访问deepseek的API,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编... API申请和充值下面是deepeek的API网站https://platform.deepsee

Maven 插件配置分层架构深度解析

《Maven插件配置分层架构深度解析》:本文主要介绍Maven插件配置分层架构深度解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Maven 插件配置分层架构深度解析引言:当构建逻辑遇上复杂配置第一章 Maven插件配置的三重境界1.1 插件配置的拓扑

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

SpringBoot实现二维码生成的详细步骤与完整代码

《SpringBoot实现二维码生成的详细步骤与完整代码》如今,二维码的应用场景非常广泛,从支付到信息分享,二维码都扮演着重要角色,SpringBoot是一个非常流行的Java基于Spring框架的微... 目录一、环境搭建二、创建 Spring Boot 项目三、引入二维码生成依赖四、编写二维码生成代码五

MyBatisX逆向工程的实现示例

《MyBatisX逆向工程的实现示例》本文主要介绍了MyBatisX逆向工程的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录逆向工程准备好数据库、表安装MyBATisX插件项目连接数据库引入依赖pom.XML生成实体类、

C#实现查找并删除PDF中的空白页面

《C#实现查找并删除PDF中的空白页面》PDF文件中的空白页并不少见,因为它们有可能是作者有意留下的,也有可能是在处理文档时不小心添加的,下面我们来看看如何使用Spire.PDFfor.NET通过C#... 目录安装 Spire.PDF for .NETC# 查找并删除 PDF 文档中的空白页C# 添加与删