SKIL/工作流程/在实验中训练模型

2023-10-21 16:50

本文主要是介绍SKIL/工作流程/在实验中训练模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在实验中训练模型

如果你想跟踪结果并进行可重复的评估,实验对于训练模型很有用。一旦你学习了工作间,笔记本和进行实验的基本知识你就准备好用SKIL练一个模型了。


先决条件
这个文档假设你已经设置了一个工作间并在SKIL中创建了一个新的实验。创建实验后,打开“笔记本”选项卡,该选项卡将显示scala的模板笔记本,其中已设置导入和结构化训练代码。

如果你不打算动态加载任何其他依赖项,可以单击工具栏左上角的“play”按钮(形状像一个侧面三角形),以评估模板笔记本中的所有单元,并将SkilContext和deeplearning4j库放到作用域中。
如果你喜欢使用其他库,SKIL已将TensorFlow和Keras预先打包。更多信息请参见实验中的TensorFlow。

 

典型工作流程

为训练而设置的笔记本通常遵循此工作流程:

  1. 将第一个和顶部单元用于动态依赖项(可选)。
  2. 把所有常见的导入放在最上面。
  3. 实例化SkilContext并引用SkilContext.client。
  4. 添加用于加载、拆分和转换数据集的代码。
  5. 编写深度学习模型配置和超参数。
  6. 把数据传入Model.fit() 或者,如果使用多GPU,传入 ParallelWrapper.fit.
  7. 使用测试/验证/维持数据集评估模型。
  8. 将经过训练的模型和评估结果传递给SkilContext进行存储。

 

样例代码
TensorFlow、多个Keras后端和Deeplarning4J是默认情况下可用的深度学习框架。下面的示例代码使用scala语言和deeplearning4j。如果要完全下载示例笔记本,建议使用uci_quickstart_notebook.json。
如果要使用外部库,请使用笔记本第一个单元格中的%spark.dep解释器预加载要在笔记本中使用的任何依赖项。

 

%spark.dep//清除以前添加的项目和仓库
z.reset() // 添加maven仓库
z.addRepo("RepoName").url("RepoURL")// 添加Maven快照仓库
z.addRepo("RepoName").url("RepoURL").snapshot()// 添加私有Maven仓库的凭据
z.addRepo("RepoName").url("RepoURL").username("username").password("password")// 从文件系统添加项目
z.load("/path/to.jar")

在配置模型或运行代码之前,需要将必要的类导入作用域。通常,这涉及到deeplarning4j及其一些实用程序库(如ND4J和DataVec)的导入。还要记住导入SKIL实用程序,以便将模型和评估保存到SKIL存储。下面的代码拥有训练LSTM序列分类器所需要的一切。

import scala.collection.JavaConversions._import io.skymind.zeppelin.utils._
import io.skymind.modelproviders.history.client.ModelHistoryClient
import io.skymind.modelproviders.history.model._import org.deeplearning4j.datasets.iterator._
import org.deeplearning4j.datasets.iterator.impl._
import org.deeplearning4j.nn.api._
import org.deeplearning4j.nn.multilayer._
import org.deeplearning4j.nn.graph._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs._
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex
import org.deeplearning4j.nn.weights._
import org.deeplearning4j.optimize.listeners._
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter
import org.deeplearning4j.ui.stats.StatsListener
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator
import org.deeplearning4j.eval.Evaluationimport org.datavec.api.transform._
import org.datavec.api.records.reader.RecordReader
import org.datavec.api.records.reader.SequenceRecordReader
import org.datavec.api.records.reader.impl.csv.CSVRecordReader
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader
import org.datavec.api.split.NumberedFileInputSplitimport org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.learning.config._
import org.nd4j.linalg.lossfunctions.LossFunctions._
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.primitives.Pair
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize
import org.nd4j.linalg.util.ArrayUtilimport java.io.File
import java.net.URL
import java.util.ArrayList
import java.util.Collections
import java.util.List
import java.util.Random

假设你已将数据集序列保存到单独的特征和标签文件中,则可以定义一个CSVSequenceRecordReader。它使用RecordReader基类从csv文件中提取单个序列。最后,在使用神经网络中的数据之前,必须将RecordReader传递给一个扩展DataSetIterator的类。这允许预取和批处理你的训练。

 

val trainFeatures: SequenceRecordReader = new CSVSequenceRecordReader()
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath + "/%d.csv",0,449))val trainLabels: RecordReader = new CSVRecordReader()
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath + "/%d.csv",0,449))val minibatch: Int = 10
val numLabelClasses: Int = 6val trainData: MultiDataSetIterator = new RecordReaderMultiDataSetIterator.Builder(minibatch).addSequenceReader("features", trainFeatures).addReader("labels", trainLabels).addInput("features").addOutputOneHot("labels", 0, numLabelClasses).build()

 

最后,初始化网络配置。Deeplarning4J公开了一个称为MultiLayerNetwork的简单接口,并且一个更复杂的配置ComputationGraph可用于多个输入和输出。它们类似于Keras中的两个API,ComputationGraph的工作原理与TensorFlow自己的配置非常相似。
配置网络时,必须首先使用NeuralNetConfiguration Builder定义层、输入、输出和其他超参数。然后传递到ComputationGraphMultiLayerNetwork类,不要忘记调用init()

val conf: ComputationGraphConfiguration = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.005, 0.9)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(0.5).graphBuilder().addInputs("input").setInputTypes(InputType.recurrent(1)).addLayer("lstm", new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build(), "input").addVertex("pool", new LastTimeStepVertex("input"), "lstm").addLayer("output", new OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build(), "pool").setOutputs("output").pretrain(false).backprop(true).build()val network_model: ComputationGraph = new ComputationGraph(conf)
network_model.init()

Training the network is fairly simple. You can either use a MultipleEpochsIteratorincluded with Deeplearning4j or manually iterate through each epoch if you prefer to perform other operations such as evaluation.

训练网络相当简单。如果你愿意执行其他操作(如评估),可以使用MultipleEpochsIterator(包括deeplarming4j),也可以手动迭代每个epoch。

for (i <- 0 until nEpochs) {network_model.fit(trainData)// 在测试集上评估:val evaluation = eval(testData)var accuracy = evaluation.accuracy()var f1 = evaluation.f1()println(s"Test set evaluation at epoch $i: Accuracy = $accuracy, F1 = $f1")testData.reset()trainData.reset()
}

Certain datasets might require more complex evaluation. The code below shows you how to create an evaluation method that returns an Evaluation class which is compatible with SKIL's model storage system.

某些数据集可能需要更复杂的评估。下面的代码向你展示了如何创建一个返回与SKIL's的模型存储系统兼容的evaluation类的评估方法。

def eval(it:MultiDataSetIterator) : Evaluation = {val evaluation = new Evaluation(numLabelClasses)it.reset()while (it.hasNext()) {val ds = it.next()val prediction = network_model.outputSingle(ds.getFeatures(0))evaluation.eval(ds.getLabels(0), prediction)}return evaluation
}

最后,使用SkilContext类将模型上传到SKIL并附加评估结果。

var evaluation = eval(testData)
val modelId = skilContext.addModelToExperiment(z, network_model)
val evalId = skilContext.addEvaluationToModel(z, modelId, evaluation)

 

这篇关于SKIL/工作流程/在实验中训练模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/255671

相关文章

Nginx分布式部署流程分析

《Nginx分布式部署流程分析》文章介绍Nginx在分布式部署中的反向代理和负载均衡作用,用于分发请求、减轻服务器压力及解决session共享问题,涵盖配置方法、策略及Java项目应用,并提及分布式事... 目录分布式部署NginxJava中的代理代理分为正向代理和反向代理正向代理反向代理Nginx应用场景

Spring Boot分层架构详解之从Controller到Service再到Mapper的完整流程(用户管理系统为例)

《SpringBoot分层架构详解之从Controller到Service再到Mapper的完整流程(用户管理系统为例)》本文将以一个实际案例(用户管理系统)为例,详细解析SpringBoot中Co... 目录引言:为什么学习Spring Boot分层架构?第一部分:Spring Boot的整体架构1.1

nodejs打包作为公共包使用的完整流程

《nodejs打包作为公共包使用的完整流程》在Node.js项目中,打包和部署是发布应用的关键步骤,:本文主要介绍nodejs打包作为公共包使用的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言一、前置准备二、创建与编码三、一键构建四、本地“白嫖”测试(可选)五、发布公共包六、常见踩坑提醒

C#利用Free Spire.XLS for .NET复制Excel工作表

《C#利用FreeSpire.XLSfor.NET复制Excel工作表》在日常的.NET开发中,我们经常需要操作Excel文件,本文将详细介绍C#如何使用FreeSpire.XLSfor.NET... 目录1. 环境准备2. 核心功能3. android示例代码3.1 在同一工作簿内复制工作表3.2 在不同

Ubuntu向多台主机批量传输文件的流程步骤

《Ubuntu向多台主机批量传输文件的流程步骤》:本文主要介绍在Ubuntu中批量传输文件到多台主机的方法,需确保主机互通、用户名密码统一及端口开放,通过安装sshpass工具,准备包含目标主机信... 目录Ubuntu 向多台主机批量传输文件1.安装 sshpass2.准备主机列表文件3.创建一个批处理脚

一个Java的main方法在JVM中的执行流程示例详解

《一个Java的main方法在JVM中的执行流程示例详解》main方法是Java程序的入口点,程序从这里开始执行,:本文主要介绍一个Java的main方法在JVM中执行流程的相关资料,文中通过代码... 目录第一阶段:加载 (Loading)第二阶段:链接 (Linking)第三阶段:初始化 (Initia

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

Git打标签从本地创建到远端推送的详细流程

《Git打标签从本地创建到远端推送的详细流程》在软件开发中,Git标签(Tag)是为发布版本、标记里程碑量身定制的“快照锚点”,它能永久记录项目历史中的关键节点,然而,仅创建本地标签往往不够,如何将其... 目录一、标签的两种“形态”二、本地创建与查看1. 打附注标http://www.chinasem.cn

通过Docker容器部署Python环境的全流程

《通过Docker容器部署Python环境的全流程》在现代化开发流程中,Docker因其轻量化、环境隔离和跨平台一致性的特性,已成为部署Python应用的标准工具,本文将详细演示如何通过Docker容... 目录引言一、docker与python的协同优势二、核心步骤详解三、进阶配置技巧四、生产环境最佳实践

MyBatis分页查询实战案例完整流程

《MyBatis分页查询实战案例完整流程》MyBatis是一个强大的Java持久层框架,支持自定义SQL和高级映射,本案例以员工工资信息管理为例,详细讲解如何在IDEA中使用MyBatis结合Page... 目录1. MyBATis框架简介2. 分页查询原理与应用场景2.1 分页查询的基本原理2.1.1 分