DARTS论文和算法解析

2023-10-31 20:38
文章标签 算法 解析 论文 darts

本文主要是介绍DARTS论文和算法解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DARTS,论文的全名是Differentiable Architecture Search,即可微分的架构搜索。

综合之前的一些NAS论文方法可以看出,不管是强化学习、进化算法还是SMBO,这些都无法通过像传统深度学习那样由Loss的梯度来更新网络架构,只能间接去优化生成子网络模型的控制器(Controller RNN,Predictor)或方法(进化算法)。

DARTS论文第一次把网络模型以可微分参数化的形式实现,网络模型和网络架构整合在一起,通过数据的训练集和验证集交替优化。在训练结束后,再从网络架构参数中解析出搜索出来的子网络。

DARTS论文的基本设计思想:

  1. 采用NASNet里的Cell和Block的设计方法;
  2. 对Cell里的所有Block的可能性架构参数化;
  3. DARTS搜索阶段训练的Cell架构是所有可能性的集合;
  4. 在验证集上对Cell的架构参数求导优化。

前面两点比较好理解,第3点怎么理解所有可能性的集合呢?

从之前的研究方法我们可以看出,每次都是先挑选出子网络后,再进行训练,要么是从头训练(大部分方法),要么在之前训练的基础上(ENAS)。DARTS避免了挑选子网络的过程,它将Cell里面所有的可能性以参数化的形式表示,在训练时,Cell里面所有的可能性连接和操作都会进行前向计算和反向推理,所有操作的模型参数均会进行更新,只是可能性更大的参数有更大的梯度更新。

DARTS的Cell结构图如下图所示。

图1. DARTS的Cell架构图
在这里插入图片描述

图1中的每一个小矩形(称为Node)表示的是特征图,第1个Node是Cell的输入,最后一个Node是Cell的输出。特征图之间的颜色线表示的是operation(操作),图中假设特征图之间只有三种operation可选空间(分别用红绿蓝表示)。图 ( a ) (a) (a)表示的是搜索问题,即两两特征图之间要用哪一种operation;图 ( b ) (b) (b)表示的是Cell中所有operation的集合;图 ( c ) (c) (c)表示经过训练后,各个operation的权重(表示选择可能性)变化值,越粗的表示参数权重越大;图 ( d ) (d) (d)表示最终选出的Cell架构,可以看出挑选的是当前Node跟前继Node中可能性最大的一条线(除了最后一个Node,与前继Node都挑选出最大可能性的线)。

图1中颜色线在DARTS中叫做架构参数 α \alpha α,Node表示的特征图为 x x x o o o表示操作。那么中间任意一个Node可以用公式表示为:

x ( j ) = ∑ i < j o ( i , j ) ( x ( i ) ) x^{(j)}=\sum_{i<j}o^{(i,j)}(x^{(i)}) x(j)=i<jo(i,j)(x(i))

其中, i i i j j j表示Node序号,公式的意思是中间Node是所有前继Node经过操作后之和。两个Node之间的操作可以表示为:

o ‾ ( i , j ) ( x ) = ∑ o ∈ O e x p ( α o ( i , j ) ) ∑ o ′ ∈ O e x p ( α o ′ ( i , j ) ) o ( x ) \overline{o}^{(i,j)}(x)=\sum_{o\in O}\frac{exp(\alpha_{o}^{(i,j)})}{\sum_{o^{'}\in O}exp(\alpha_{o^{'}}^{(i,j)})}o(x) o(i,j)(x)=oOoOexp(αo(i,j))exp(αo(i,j))o(x)

这个公式表示两个Node之间的操作是它们之间所有操作的softmax之和。

在训练的时候,需要交替对网络的模型参数 w w w和架构参数 α \alpha α进行优化,优化的目标函数是:

m i n α L v a l ( w ∗ ( α ) , α ) s . t . w ∗ ( α ) = a r g m i n w L t r a i n ( w , α ) \begin{aligned} & \underset{\alpha}{min}~~~\mathcal {L}_{val}(w^{*}(\alpha), \alpha) \\ & s.t. ~~~ w^{*}(\alpha)=argmin_{w}~\mathcal {L}_{train}(w,\alpha) \end{aligned} αmin   Lval(w(α),α)s.t.   w(α)=argminw Ltrain(w,α)

训练的方法过程如下图所示:

图2. DARTS的搜索训练方法
在这里插入图片描述

大致的步骤只有两个,而且是交替进行:

  1. 固定架构参数,用训练数据集训练模型参数;
  2. 固定模型参数,用验证数据集训练架构参数。

训练结束后,选择子网络的方式:对于中间Node,每个Node会挑选出前继Node中可能性最大的两个作为连接对象,两个Node之间最多只有一条线(operation)可以连接,所以中间Node只有两个输入来源和对应的operation;最后的Node是所有前继Node(除了输入)按照channel维度concat起来的结果。

由于要同时训练所有的架构,所以Cell叠加的个数不能太大,也不能在大的数据集上进行搜索。作者在Cifar-10小数据集上进行搜索,叠加8个Cell,第一个Cell的输出通道为16,使用1个GPU(GTX 1080Ti)训练50个epoch,耗时1天。

搜索完成后,解析架构参数确定最佳子网络,将Cell个数扩充,并进行正式训练。DARTS搜索出来的网络在Cifar-10和ImageNet上的实验结果如下两图所示。

图3. DARTS在Cifar-10上的实验性能和对比
在这里插入图片描述

图4. DARTS在ImageNet上的实验性能和对比
在这里插入图片描述

从Cifar-10的实验可以看出,二阶梯度方法的DARTS精度只比NASNet和AmoebaNet-B的方法差,但是在训练的计算资源和耗时上要远远小于它们。DARTS的搜索时间比ENAS长,但是精度比它高。

在ImageNet的实验上我们可以看到,DARTS的精度也能接近之前的NAS方法,同等参数量条件下与NASNet相当,比AmoebaNet和PNASNet差一些,但是在搜索消耗的GPU时长上,DARTS方法具有明显的优势。

这篇关于DARTS论文和算法解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/317977

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

一文解析C#中的StringSplitOptions枚举

《一文解析C#中的StringSplitOptions枚举》StringSplitOptions是C#中的一个枚举类型,用于控制string.Split()方法分割字符串时的行为,核心作用是处理分割后... 目录C#的StringSplitOptions枚举1.StringSplitOptions枚举的常用

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

MyBatis延迟加载与多级缓存全解析

《MyBatis延迟加载与多级缓存全解析》文章介绍MyBatis的延迟加载与多级缓存机制,延迟加载按需加载关联数据提升性能,一级缓存会话级默认开启,二级缓存工厂级支持跨会话共享,增删改操作会清空对应缓... 目录MyBATis延迟加载策略一对多示例一对多示例MyBatis框架的缓存一级缓存二级缓存MyBat

深入理解Mysql OnlineDDL的算法

《深入理解MysqlOnlineDDL的算法》本文主要介绍了讲解MysqlOnlineDDL的算法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小... 目录一、Online DDL 是什么?二、Online DDL 的三种主要算法2.1COPY(复制法)

前端缓存策略的自解方案全解析

《前端缓存策略的自解方案全解析》缓存从来都是前端的一个痛点,很多前端搞不清楚缓存到底是何物,:本文主要介绍前端缓存的自解方案,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、为什么“清缓存”成了技术圈的梗二、先给缓存“把个脉”:浏览器到底缓存了谁?三、设计思路:把“发版”做成“自愈”四、代码

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工

Java JDK Validation 注解解析与使用方法验证

《JavaJDKValidation注解解析与使用方法验证》JakartaValidation提供了一种声明式、标准化的方式来验证Java对象,与框架无关,可以方便地集成到各种Java应用中,... 目录核心概念1. 主要注解基本约束注解其他常用注解2. 核心接口使用方法1. 基本使用添加依赖 (Maven