SKIL/工作流程/在实验中训练模型

2023-10-21 16:50

本文主要是介绍SKIL/工作流程/在实验中训练模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在实验中训练模型

如果你想跟踪结果并进行可重复的评估,实验对于训练模型很有用。一旦你学习了工作间,笔记本和进行实验的基本知识你就准备好用SKIL练一个模型了。


先决条件
这个文档假设你已经设置了一个工作间并在SKIL中创建了一个新的实验。创建实验后,打开“笔记本”选项卡,该选项卡将显示scala的模板笔记本,其中已设置导入和结构化训练代码。

如果你不打算动态加载任何其他依赖项,可以单击工具栏左上角的“play”按钮(形状像一个侧面三角形),以评估模板笔记本中的所有单元,并将SkilContext和deeplearning4j库放到作用域中。
如果你喜欢使用其他库,SKIL已将TensorFlow和Keras预先打包。更多信息请参见实验中的TensorFlow。

 

典型工作流程

为训练而设置的笔记本通常遵循此工作流程:

  1. 将第一个和顶部单元用于动态依赖项(可选)。
  2. 把所有常见的导入放在最上面。
  3. 实例化SkilContext并引用SkilContext.client。
  4. 添加用于加载、拆分和转换数据集的代码。
  5. 编写深度学习模型配置和超参数。
  6. 把数据传入Model.fit() 或者,如果使用多GPU,传入 ParallelWrapper.fit.
  7. 使用测试/验证/维持数据集评估模型。
  8. 将经过训练的模型和评估结果传递给SkilContext进行存储。

 

样例代码
TensorFlow、多个Keras后端和Deeplarning4J是默认情况下可用的深度学习框架。下面的示例代码使用scala语言和deeplearning4j。如果要完全下载示例笔记本,建议使用uci_quickstart_notebook.json。
如果要使用外部库,请使用笔记本第一个单元格中的%spark.dep解释器预加载要在笔记本中使用的任何依赖项。

 

%spark.dep//清除以前添加的项目和仓库
z.reset() // 添加maven仓库
z.addRepo("RepoName").url("RepoURL")// 添加Maven快照仓库
z.addRepo("RepoName").url("RepoURL").snapshot()// 添加私有Maven仓库的凭据
z.addRepo("RepoName").url("RepoURL").username("username").password("password")// 从文件系统添加项目
z.load("/path/to.jar")

在配置模型或运行代码之前,需要将必要的类导入作用域。通常,这涉及到deeplarning4j及其一些实用程序库(如ND4J和DataVec)的导入。还要记住导入SKIL实用程序,以便将模型和评估保存到SKIL存储。下面的代码拥有训练LSTM序列分类器所需要的一切。

import scala.collection.JavaConversions._import io.skymind.zeppelin.utils._
import io.skymind.modelproviders.history.client.ModelHistoryClient
import io.skymind.modelproviders.history.model._import org.deeplearning4j.datasets.iterator._
import org.deeplearning4j.datasets.iterator.impl._
import org.deeplearning4j.nn.api._
import org.deeplearning4j.nn.multilayer._
import org.deeplearning4j.nn.graph._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs._
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex
import org.deeplearning4j.nn.weights._
import org.deeplearning4j.optimize.listeners._
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter
import org.deeplearning4j.ui.stats.StatsListener
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator
import org.deeplearning4j.eval.Evaluationimport org.datavec.api.transform._
import org.datavec.api.records.reader.RecordReader
import org.datavec.api.records.reader.SequenceRecordReader
import org.datavec.api.records.reader.impl.csv.CSVRecordReader
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader
import org.datavec.api.split.NumberedFileInputSplitimport org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.learning.config._
import org.nd4j.linalg.lossfunctions.LossFunctions._
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.primitives.Pair
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize
import org.nd4j.linalg.util.ArrayUtilimport java.io.File
import java.net.URL
import java.util.ArrayList
import java.util.Collections
import java.util.List
import java.util.Random

假设你已将数据集序列保存到单独的特征和标签文件中,则可以定义一个CSVSequenceRecordReader。它使用RecordReader基类从csv文件中提取单个序列。最后,在使用神经网络中的数据之前,必须将RecordReader传递给一个扩展DataSetIterator的类。这允许预取和批处理你的训练。

 

val trainFeatures: SequenceRecordReader = new CSVSequenceRecordReader()
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath + "/%d.csv",0,449))val trainLabels: RecordReader = new CSVRecordReader()
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath + "/%d.csv",0,449))val minibatch: Int = 10
val numLabelClasses: Int = 6val trainData: MultiDataSetIterator = new RecordReaderMultiDataSetIterator.Builder(minibatch).addSequenceReader("features", trainFeatures).addReader("labels", trainLabels).addInput("features").addOutputOneHot("labels", 0, numLabelClasses).build()

 

最后,初始化网络配置。Deeplarning4J公开了一个称为MultiLayerNetwork的简单接口,并且一个更复杂的配置ComputationGraph可用于多个输入和输出。它们类似于Keras中的两个API,ComputationGraph的工作原理与TensorFlow自己的配置非常相似。
配置网络时,必须首先使用NeuralNetConfiguration Builder定义层、输入、输出和其他超参数。然后传递到ComputationGraphMultiLayerNetwork类,不要忘记调用init()

val conf: ComputationGraphConfiguration = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.005, 0.9)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(0.5).graphBuilder().addInputs("input").setInputTypes(InputType.recurrent(1)).addLayer("lstm", new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build(), "input").addVertex("pool", new LastTimeStepVertex("input"), "lstm").addLayer("output", new OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build(), "pool").setOutputs("output").pretrain(false).backprop(true).build()val network_model: ComputationGraph = new ComputationGraph(conf)
network_model.init()

Training the network is fairly simple. You can either use a MultipleEpochsIteratorincluded with Deeplearning4j or manually iterate through each epoch if you prefer to perform other operations such as evaluation.

训练网络相当简单。如果你愿意执行其他操作(如评估),可以使用MultipleEpochsIterator(包括deeplarming4j),也可以手动迭代每个epoch。

for (i <- 0 until nEpochs) {network_model.fit(trainData)// 在测试集上评估:val evaluation = eval(testData)var accuracy = evaluation.accuracy()var f1 = evaluation.f1()println(s"Test set evaluation at epoch $i: Accuracy = $accuracy, F1 = $f1")testData.reset()trainData.reset()
}

Certain datasets might require more complex evaluation. The code below shows you how to create an evaluation method that returns an Evaluation class which is compatible with SKIL's model storage system.

某些数据集可能需要更复杂的评估。下面的代码向你展示了如何创建一个返回与SKIL's的模型存储系统兼容的evaluation类的评估方法。

def eval(it:MultiDataSetIterator) : Evaluation = {val evaluation = new Evaluation(numLabelClasses)it.reset()while (it.hasNext()) {val ds = it.next()val prediction = network_model.outputSingle(ds.getFeatures(0))evaluation.eval(ds.getLabels(0), prediction)}return evaluation
}

最后,使用SkilContext类将模型上传到SKIL并附加评估结果。

var evaluation = eval(testData)
val modelId = skilContext.addModelToExperiment(z, network_model)
val evalId = skilContext.addEvaluationToModel(z, modelId, evaluation)

 

这篇关于SKIL/工作流程/在实验中训练模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/255671

相关文章

深入浅出Spring中的@Autowired自动注入的工作原理及实践应用

《深入浅出Spring中的@Autowired自动注入的工作原理及实践应用》在Spring框架的学习旅程中,@Autowired无疑是一个高频出现却又让初学者头疼的注解,它看似简单,却蕴含着Sprin... 目录深入浅出Spring中的@Autowired:自动注入的奥秘什么是依赖注入?@Autowired

redis-sentinel基础概念及部署流程

《redis-sentinel基础概念及部署流程》RedisSentinel是Redis的高可用解决方案,通过监控主从节点、自动故障转移、通知机制及配置提供,实现集群故障恢复与服务持续可用,核心组件包... 目录一. 引言二. 核心功能三. 核心组件四. 故障转移流程五. 服务部署六. sentinel部署

SpringBoot集成XXL-JOB实现任务管理全流程

《SpringBoot集成XXL-JOB实现任务管理全流程》XXL-JOB是一款轻量级分布式任务调度平台,功能丰富、界面简洁、易于扩展,本文介绍如何通过SpringBoot项目,使用RestTempl... 目录一、前言二、项目结构简述三、Maven 依赖四、Controller 代码详解五、Service

Python中的filter() 函数的工作原理及应用技巧

《Python中的filter()函数的工作原理及应用技巧》Python的filter()函数用于筛选序列元素,返回迭代器,适合函数式编程,相比列表推导式,内存更优,尤其适用于大数据集,结合lamb... 目录前言一、基本概念基本语法二、使用方式1. 使用 lambda 函数2. 使用普通函数3. 使用 N

MySQL 临时表与复制表操作全流程案例

《MySQL临时表与复制表操作全流程案例》本文介绍MySQL临时表与复制表的区别与使用,涵盖生命周期、存储机制、操作限制、创建方法及常见问题,本文结合实例代码给大家介绍的非常详细,感兴趣的朋友跟随小... 目录一、mysql 临时表(一)核心特性拓展(二)操作全流程案例1. 复杂查询中的临时表应用2. 临时

MySQL 升级到8.4版本的完整流程及操作方法

《MySQL升级到8.4版本的完整流程及操作方法》本文详细说明了MySQL升级至8.4的完整流程,涵盖升级前准备(备份、兼容性检查)、支持路径(原地、逻辑导出、复制)、关键变更(空间索引、保留关键字... 目录一、升级前准备 (3.1 Before You Begin)二、升级路径 (3.2 Upgrade

setsid 命令工作原理和使用案例介绍

《setsid命令工作原理和使用案例介绍》setsid命令在Linux中创建独立会话,使进程脱离终端运行,适用于守护进程和后台任务,通过重定向输出和确保权限,可有效管理长时间运行的进程,本文给大家介... 目录setsid 命令介绍和使用案例基本介绍基本语法主要特点命令参数使用案例1. 在后台运行命令2.

Spring Boot 中的默认异常处理机制及执行流程

《SpringBoot中的默认异常处理机制及执行流程》SpringBoot内置BasicErrorController,自动处理异常并生成HTML/JSON响应,支持自定义错误路径、配置及扩展,如... 目录Spring Boot 异常处理机制详解默认错误页面功能自动异常转换机制错误属性配置选项默认错误处理

Spring Boot从main方法到内嵌Tomcat的全过程(自动化流程)

《SpringBoot从main方法到内嵌Tomcat的全过程(自动化流程)》SpringBoot启动始于main方法,创建SpringApplication实例,初始化上下文,准备环境,刷新容器并... 目录1. 入口:main方法2. SpringApplication初始化2.1 构造阶段3. 运行阶

Java中的xxl-job调度器线程池工作机制

《Java中的xxl-job调度器线程池工作机制》xxl-job通过快慢线程池分离短时与长时任务,动态降级超时任务至慢池,结合异步触发和资源隔离机制,提升高频调度的性能与稳定性,支撑高并发场景下的可靠... 目录⚙️ 一、调度器线程池的核心设计 二、线程池的工作流程 三、线程池配置参数与优化 四、总结:线程