小波卷积:为计算机视觉任务开辟新的参数效率之路

2024-08-24 16:36

本文主要是介绍小波卷积:为计算机视觉任务开辟新的参数效率之路,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文复述

这篇论文介绍了一种创新的卷积神经网络层——WTConv,它通过小波变换技术显著扩展了CNN的感受野,同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应,增强了模型对形状而非纹理的特征识别能力,提高了在图像分类、语义分割和目标检测等视觉任务中的性能和鲁棒性。论文通过广泛的实验验证了WTConv的有效性,并展示了其在不同视觉任务中的应用潜力。

论文地址: https://arxiv.org/abs/2407.05848

摘要

论文指出,近年来尝试通过增加卷积核的大小来模仿视觉变换器(Vision Transformers, ViTs)自注意力模块的全局感受野,但这种方法很快遇到了上限,并且在达到全局感受野之前就饱和了。作者展示了通过利用小波变换(WT),实际上可以不遭受过度参数化的问题,获得非常大的感受野。例如,对于一个k×k的感受野,所提出方法中可训练参数的数量仅以k的对数级增长。提出的层名为WTConv,可以作为现有架构中的替代品,有效响应多频率,并随着感受野大小的增加而优雅地扩展。通过在ConvNeXt和MobileNetV2架构中展示WTConv层的有效性,以及作为下游任务的骨干网络,并展示了它带来的额外属性,如对图像损坏的鲁棒性增加以及对形状而非纹理的响应增加。

引言

引言指出了卷积神经网络(CNN)在计算机视觉领域的统治地位正受到视觉变换器(ViTs)的挑战,特别是由于ViTs的多头自注意力层能够实现全局特征混合。为了缩小CNN和ViTs之间的性能差距,研究人员尝试通过增大卷积核来增加感受野,但这种方法遇到了饱和问题。论文提出了一个问题:是否有可能在不增加过多参数的情况下,利用信号处理工具有效增加卷积的感受野,从而提高性能。

总结

论文成功地利用小波变换(WT)提出了WTConv层,这是一种新的CNN层,能够在不大幅增加参数的情况下显著增加感受野。WTConv层通过在小波域中进行卷积操作,实现了对输入数据的多频率响应,这使得网络能够更好地捕捉低频信息,从而提高了对形状的敏感性,并增强了网络的鲁棒性。实验结果表明,WTConv层在多个视觉任务中都取得了性能提升,证明了其有效性。

全文要点

WTConv

WTConv(Wavelet Transform Convolution)是一种基于小波变换的卷积层,它旨在为卷积神经网络(CNN)提供更大的感受野,同时避免因使用大卷积核而带来的参数数量急剧增加的问题。WTConv是一种创新的卷积神经网络层,它通过小波变换技术实现了对输入数据的深层次和多尺度分析。以下是WTConv的几个关键特点和工作原理的详细概括:

  1. 小波变换的应用:WTConv使用小波变换对输入信号进行分解,这允许网络在不同的频率和空间尺度上捕捉信息。小波变换提供了一种将信号分解为可提供时间和频率信息的组成部分的方法。

  2. 感受野的显著扩展:通过小波变换的多级分解,WTConv能够在保持参数数量相对较低的同时,实现对输入数据更大范围的覆盖。这意味着即使是小的卷积核也能够通过小波变换捕捉到更广泛的上下文信息。

  3. 参数效率与性能提升:WTConv的设计减少了模型参数的数量,与传统的大卷积核相比,它以参数数量的对数级增长实现了感受野的扩展。这种效率的提升使得WTConv在保持计算成本较低的同时,能够提高模型在图像分类、语义分割等任务上的性能。

  4. 多频率特征的独立处理:WTConv允许网络对分解出的不同频率特征进行独立的卷积处理,这增强了模型对信号中不同特征的响应能力,特别是对低频特征的捕捉,这对于理解图像中的形状和结构非常重要。

  5. 小波反变换的集成:在小波域中处理完信号后,WTConv利用小波反变换将处理后的信号重新组合,以生成最终的输出。这一步骤确保了信号的完整性,并允许网络在原始域中进行最终的特征整合。

WTConv通过这些设计,有效地结合了小波变换的多尺度分析能力和卷积神经网络的深度学习能力,为解决计算机视觉中的复杂问题提供了一种新的工具。

wt(Wavelet Transform)

小波变换(Wavelet Transform, WT)是一种数学变换,用于将信号分解成不同时间尺度上的成分,这些成分能够提供信号的时频信息。它广泛应用于信号处理、图像分析、数据压缩和其他许多领域。以下是小波变换的几个关键特点:

  1. 时频联合表示:与仅提供频率信息的傅里叶变换相比,小波变换能够同时提供信号的时间(或空间)和频率信息,使得它在分析非平稳信号时特别有用。

  2. 多尺度分析能力:小波变换通过在不同的尺度上分析信号,能够揭示信号在不同分辨率下的特性。这种多尺度分解使得小波变换能够适应信号的局部变化,捕捉到重要细节。

  3. 正交小波基:在某些小波变换中,如Haar小波变换,变换基是正交的,这允许无失真地从变换后的系数重构原始信号,保证了变换的逆过程的准确性。

  4. 稀疏性优势:小波变换通常能够产生稀疏的系数矩阵,其中许多系数为零或很小,这不仅有助于数据压缩,还可以在信号去噪和特征提取中发挥作用。

  5. 计算效率:小波变换可以通过快速算法实现,如快速小波变换(FWT),它减少了计算量,提高了处理速度。

小波变换的这些特性使其成为分析和处理信号的理想选择,特别是在需要同时考虑时间和频率信息的复杂场景中。

iwt

小波反变换(Inverse Wavelet Transform, IWT)是小波变换的逆过程,它用于从小波变换的系数中重构原始信号。以下是IWT的关键特点和工作原理:

  1. 信号重构:IWT的主要目的是将小波变换产生的系数转换回原始的信号或数据。这是通过使用小波变换时定义的相同小波函数来实现的,但是以相反的顺序。

  2. 逆过程:IWT是小波变换的逆操作,它利用了小波变换的正交性质,特别是当使用正交小波基时,可以确保信号的精确重构。

  3. 多尺度合成:在多级小波分解的情况下,IWT通过逐步合成不同尺度(或分辨率)上的细节信息来重构信号。这包括将低频和高频成分重新组合。

  4. 系数的整合:IWT通过整合小波变换产生的所有系数,包括近似系数(Approximation coefficients)和细节系数(Detail coefficients),来恢复原始数据。

  5. 计算流程:IWT的计算通常涉及从最粗糙的尺度开始,逐步向上细化至更高尺度的过程。每一步都涉及到将当前尺度的系数与小波函数相结合,以及将从更粗糙尺度上恢复的信息逐步添加进来。

  6. 稀疏性利用:如果小波变换产生了稀疏系数,IWT可以利用这一特性来减少计算量,因为许多接近零的系数可以被忽略或近似处理。

  7. 与WT的兼容性:IWT与小波变换紧密兼容,确保了变换和反变换过程的一致性,这对于保持信号的完整性至关重要。

小波反变换是小波分析中不可或缺的一部分,它确保了小波变换的实用性和有效性,特别是在需要从变换后的系数中恢复原始信号的场景中。

pytorch代码实现

源自:https://github.com/BGU-CS-VIL/WTConv

import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import pywt.datafrom functools import partialdef create_wavelet_filter(wave, in_size, out_size, type=torch.float):w = pywt.Wavelet(wave)dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)return dec_filters, rec_filtersdef wavelet_transform(x, filters):b, c, h, w = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)x = x.reshape(b, c, 4, h // 2, w // 2)return xdef inverse_wavelet_transform(x, filters):b, c, _, h_half, w_half = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = x.reshape(b, c * 4, h_half, w_half)x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)return xclass _ScaleModule(nn.Module):def __init__(self, dims, init_scale=1.0, init_bias=0):super(_ScaleModule, self).__init__()self.dims = dimsself.weight = nn.Parameter(torch.ones(*dims) * init_scale)self.bias = Nonedef forward(self, x):return torch.mul(self.weight, x)class WTConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):super(WTConv2d, self).__init__()assert in_channels == out_channelsself.in_channels = in_channelsself.wt_levels = wt_levelsself.stride = strideself.dilation = 1self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)self.wt_function = partial(wavelet_transform, filters=self.wt_filter)self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels, bias=bias)self.base_scale = _ScaleModule([1, in_channels, 1, 1])self.wavelet_convs = nn.ModuleList([nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)])self.wavelet_scale = nn.ModuleList([_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)])if self.stride > 1:self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,groups=in_channels)else:self.do_stride = Nonedef forward(self, x):x_ll_in_levels = []x_h_in_levels = []shapes_in_levels = []curr_x_ll = xfor i in range(self.wt_levels):curr_shape = curr_x_ll.shapeshapes_in_levels.append(curr_shape)if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)curr_x_ll = F.pad(curr_x_ll, curr_pads)curr_x = self.wt_function(curr_x_ll)curr_x_ll = curr_x[:, :, 0, :, :]shape_x = curr_x.shapecurr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))curr_x_tag = curr_x_tag.reshape(shape_x)x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])next_x_ll = 0for i in range(self.wt_levels - 1, -1, -1):curr_x_ll = x_ll_in_levels.pop()curr_x_h = x_h_in_levels.pop()curr_shape = shapes_in_levels.pop()curr_x_ll = curr_x_ll + next_x_llcurr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)next_x_ll = self.iwt_function(curr_x)next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]x_tag = next_x_llassert len(x_ll_in_levels) == 0x = self.base_scale(self.base_conv(x))x = x + x_tagif self.do_stride is not None:x = self.do_stride(x)return xx = torch.randn((4, 64, 128, 128))
model = WTConv2d(in_channels=64, out_channels=64)
out = model(x)
print(out.shape)

这篇关于小波卷积:为计算机视觉任务开辟新的参数效率之路的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103016

相关文章

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

SpringBoot集成XXL-JOB实现任务管理全流程

《SpringBoot集成XXL-JOB实现任务管理全流程》XXL-JOB是一款轻量级分布式任务调度平台,功能丰富、界面简洁、易于扩展,本文介绍如何通过SpringBoot项目,使用RestTempl... 目录一、前言二、项目结构简述三、Maven 依赖四、Controller 代码详解五、Service

Java利用@SneakyThrows注解提升异常处理效率详解

《Java利用@SneakyThrows注解提升异常处理效率详解》这篇文章将深度剖析@SneakyThrows的原理,用法,适用场景以及隐藏的陷阱,看看它如何让Java异常处理效率飙升50%,感兴趣的... 目录前言一、检查型异常的“诅咒”:为什么Java开发者讨厌它1.1 检查型异常的痛点1.2 为什么说

python中的显式声明类型参数使用方式

《python中的显式声明类型参数使用方式》文章探讨了Python3.10+版本中类型注解的使用,指出FastAPI官方示例强调显式声明参数类型,通过|操作符替代Union/Optional,可提升代... 目录背景python函数显式声明的类型汇总基本类型集合类型Optional and Union(py

Linux系统管理与进程任务管理方式

《Linux系统管理与进程任务管理方式》本文系统讲解Linux管理核心技能,涵盖引导流程、服务控制(Systemd与GRUB2)、进程管理(前台/后台运行、工具使用)、计划任务(at/cron)及常用... 目录引言一、linux系统引导过程与服务控制1.1 系统引导的五个关键阶段1.2 GRUB2的进化优

Go语言使用Gin处理路由参数和查询参数

《Go语言使用Gin处理路由参数和查询参数》在WebAPI开发中,处理路由参数(PathParameter)和查询参数(QueryParameter)是非常常见的需求,下面我们就来看看Go语言... 目录一、路由参数 vs 查询参数二、Gin 获取路由参数和查询参数三、示例代码四、运行与测试1. 测试编程路

Python Flask实现定时任务的不同方法详解

《PythonFlask实现定时任务的不同方法详解》在Flask中实现定时任务,最常用的方法是使用APScheduler库,本文将提供一个完整的解决方案,有需要的小伙伴可以跟随小编一起学习一下... 目录完js整实现方案代码解释1. 依赖安装2. 核心组件3. 任务类型4. 任务管理5. 持久化存储生产环境

Python lambda函数(匿名函数)、参数类型与递归全解析

《Pythonlambda函数(匿名函数)、参数类型与递归全解析》本文详解Python中lambda匿名函数、灵活参数类型和递归函数三大进阶特性,分别介绍其定义、应用场景及注意事项,助力编写简洁高效... 目录一、lambda 匿名函数:简洁的单行函数1. lambda 的定义与基本用法2. lambda

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种