theano中训练方法和模型的一些写法

2023-12-15 13:58

本文主要是介绍theano中训练方法和模型的一些写法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

按照这theano的tutorial开始跟着写了,因为去年年底之前学习过一段时间,但是当时时间少,并且很多地方也没搞懂,大多都是看着书来模仿,结果出错不好找地方之外,自己如果根据自己的想法随便写下结果就出错了。这几天好好的学习了下。我来总结下

在softmax或者逻辑回归的代码中:

1:在教程中写法是,先写一个类,在类的init方法中初始化w和b,以及计算概率。

2:然后分别再另外的函数中计算cost和error。

3:在类外来进行训练。

以上方法由于在init方法中,要进行概率计算,所以在初始化类的时候至少要传递进去x,在计算cost和error的时候要传递进去y。所以如果此时训练方法写在类内的话就不行。因为数据也是在训练方法内生成的,在类别没办法传递进去(此时,你也许会说在类外面单独加载数据怎么样?因为训练方法内其实传递进去的是数据的索引,所以这个在外面单独加载数据的话也不太现实。所以加载数据还是和训练方法写在一起是最方便的)


为了解决以上问题,很容易,就是把init方法中,把在类外面不能传递的参数,全部去掉,然后在别的计算的地方用到的话,再传递进去。比如把概率预测拿出来,写成一个函数,把x传递进去。然后其余的保持不变。这样在训练方法内,因为有对x和y的声明,以及对应数据的传递进去,所以不会出问题


我这里写好了一个,发上来作为备忘吧,为下一步的更高层次的封装做准备:

import numpy, theano, theano.tensor as T, gzip, cPickleclass NN():def __init__(self, n_in, n_out):self.w = theano.shared(numpy.asarray(numpy.zeros([n_in, n_out]), theano.config.floatX))self.b = theano.shared(numpy.asarray(numpy.zeros(n_out), theano.config.floatX))def get_probalblity(self, x):return  T.nnet.softmax(T.dot(x, self.w) + self.b)  def get_prediction(self, x, y):return T.argmax(self.get_probalblity(x), 1)def cost(self, x, y):p_y_given_x = self.get_probalblity(x)return  -T.mean(T.log(p_y_given_x[T.arange(y.shape[0]), y]))def error(self, x, y):prediction = self.get_prediction(x, y)return T.mean(T.neq(prediction, y))def load_data(self):f = gzip.open('mnist.pkl.gz')trainxy, validatexy, testxy = cPickle.load(f)def share_data(xy):x,y = xyx = theano.shared(numpy.asarray(x, theano.config.floatX))y = theano.shared(numpy.asarray(y, theano.config.floatX))return [x, T.cast(y, 'int32')]trainx, trainy = share_data(trainxy)validatex,validatey = share_data(validatexy)testx, testy = share_data(testxy)return [(trainx,trainy),(validatex,validatey),(testx,testy)]def train(self):x = T.matrix('x', theano.config.floatX)y = T.ivector('y')[(trainx,trainy),(validatex,validatey),(testx,testy)] = self.load_data()gw,gb = T.grad(self.cost(x,y), [self.w, self.b])index = T.lscalar()batch_size = 600trainModel = theano.function([index], self.cost(x,y), updates=[(self.w, self.w-0.13*gw), (self.b, self.b-0.13*gb)], givens={x:trainx[index*batch_size:(index+1)*batch_size], y:trainy[index*batch_size:(index+1)*batch_size]})validateModel = theano.function([index], self.error(x,y), givens={x:validatex[index*batch_size:(index+1)*batch_size], y:validatey[index*batch_size:(index+1)*batch_size]})testModel = theano.function([index], self.error(x,y), givens={x:testx[index*batch_size:(index+1)*batch_size], y:testy[index*batch_size:(index+1)*batch_size]})best_validate_error = numpy.Infbest_test_error = 0patience = 5000increasement = 2train_batchs = trainx.get_value().shape[0]/batch_sizevalidate_batchs = validatex.get_value().shape[0]/batch_sizetest_batchs = testx.get_value().shape[0]/batch_sizevalidate_frequency = min(patience/2, train_batchs)epochs = 1000epoch = 1ite = 0stopping = Falsewhile (epoch < epochs) and (not stopping):for i in xrange(train_batchs):ite += 1this_cost = trainModel(i)if ite%validate_frequency == 0:this_validate_error = numpy.mean([validateModel(j) for j in xrange(validate_batchs)])print ('ite:%d/%d, cost:%f, validate:%f'%(ite, epoch, this_cost, this_validate_error))    if this_validate_error < best_validate_error:if this_validate_error < 0.995*best_validate_error:patience = max(patience, ite*increasement)this_test_error = numpy.mean([testModel(j) for j in xrange(test_batchs)])best_validate_error = this_validate_errorbest_test_error = this_test_errorprint ('ite:%d/%d, cost:%f, validate:%f, test:%f'%(ite, epoch, this_cost, this_validate_error, this_test_error))    if patience <= ite:stopping = Truebreakepoch +=1print ('best validate error:%f, best test error:%f'%(best_validate_error, best_test_error))if __name__ == '__main__':nn = NN(784, 10)nn.train()

以上训练结果,和tutorial给出的基本一致。

。。。。。

ite:5810/70, cost:0.329380, validate:0.075104
ite:5810/70, cost:0.329380, validate:0.075104, test:0.075000
ite:5893/71, cost:0.329054, validate:0.075208
ite:5976/72, cost:0.328735, validate:0.075104
ite:5976/72, cost:0.328735, validate:0.075104, test:0.075104
ite:6059/73, cost:0.328422, validate:0.075000
ite:6059/73, cost:0.328422, validate:0.075000, test:0.074896
ite:6142/74, cost:0.328116, validate:0.074792
ite:6142/74, cost:0.328116, validate:0.074792, test:0.074896
best validate error:0.074792, best test error:0.074896

这篇关于theano中训练方法和模型的一些写法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/496676

相关文章

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

mybatis的mapper对应的xml写法及配置详解

《mybatis的mapper对应的xml写法及配置详解》这篇文章给大家介绍mybatis的mapper对应的xml写法及配置详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录前置mapper 对应 XML 基础配置mapper 对应 xml 复杂配置Mapper 中的相

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll