PaddleOCR识别框架解读[04] 文本检测det模型构建

2024-03-07 11:12

本文主要是介绍PaddleOCR识别框架解读[04] 文本检测det模型构建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • det_mv3_db.yml
    • build_model函数
      • base_model类
    • build_backbone函数
      • MobileNetV3
    • build_neck函数
    • build_head函数

det_mv3_db.yml

Global:use_gpu: trueuse_xpu: falseepoch_num: 1200log_smooth_window: 20print_batch_step: 10save_model_dir: ./output/db_mv3/save_epoch_step: 1200# evaluation is run every 2000 iterationseval_batch_step: [0, 2000]cal_metric_during_train: Falsepretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrainedcheckpoints:save_inference_dir:use_visualdl: Falseinfer_img: doc/imgs_en/img_10.jpgsave_res_path: ./output/det_db/predicts_db.txtArchitecture:model_type: detalgorithm: DBTransform:Backbone:name: MobileNetV3scale: 0.5model_name: largeNeck:name: DBFPNout_channels: 256Head:name: DBHeadk: 50Loss:name: DBLossbalance_loss: truemain_loss_type: DiceLossalpha: 5beta: 10ohem_ratio: 3Optimizer:name: Adambeta1: 0.9beta2: 0.999lr:learning_rate: 0.001regularizer:name: 'L2'factor: 0PostProcess:name: DBPostProcessthresh: 0.3box_thresh: 0.6max_candidates: 1000unclip_ratio: 1.5Metric:name: DetMetricmain_indicator: hmeanTrain:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/train_icdar2015_label.txtratio_list: [1.0]transforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- IaaAugment:augmenter_args:- { 'type': Fliplr, 'args': { 'p': 0.5 } }- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }- { 'type': Resize, 'args': { 'size': [0.5, 3] } }- EastRandomCropData:size: [640, 640]max_tries: 50keep_ratio: true- MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7- MakeShrinkMap:shrink_ratio: 0.4min_text_size: 8- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader listloader:shuffle: Truedrop_last: Falsebatch_size_per_card: 16num_workers: 8use_shared_memory: TrueEval:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/test_icdar2015_label.txttransforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- DetResizeForTest:image_shape: [736, 1280]- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'shape', 'polys', 'ignore_tags']loader:shuffle: Falsedrop_last: Falsebatch_size_per_card: 1 # must be 1num_workers: 8use_shared_memory: True

build_model函数

def build_model(config):config = copy.deepcopy(config)if not "name" in config:arch = BaseModel(config)else:name = config.pop("name")mod = importlib.import_module(__name__)arch = getattr(mod, name)(config)return arch

base_model类

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionfrom paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head__all__ = ['BaseModel']class BaseModel(nn.Layer):def __init__(self, config):super(BaseModel, self).__init__()# 输入通道in_channels = config.get('in_channels', 3)# 网络类型, 目前支持det, rec, cls.model_type = config['model_type']# ==========构建transfrom==========# 识别rec任务, transfrom可以设置为TPS、None;# 检测det和分类cls任务, transform可以设置为None;# if you make model differently, you can use transfrom in det and cls.if 'Transform' not in config or config['Transform'] is None:self.use_transform = Falseelse:self.use_transform = Trueconfig['Transform']['in_channels'] = in_channelsself.transform = build_transform(config['Transform'])in_channels = self.transform.out_channels# ==========构建backbone==========if 'Backbone' not in config or config['Backbone'] is None:self.use_backbone = Falseelse:self.use_backbone = Trueconfig["Backbone"]['in_channels'] = in_channelsself.backbone = build_backbone(config["Backbone"], model_type)in_channels = self.backbone.out_channels# ==========构建neck==========# 识别rec任务, neck可以是cnn、rnn或者reshape(None);# 检测det任务, neck可以是FPN、BIFPN等;# 分类cls任务, neck是none.if 'Neck' not in config or config['Neck'] is None:self.use_neck = Falseelse:self.use_neck = Trueconfig['Neck']['in_channels'] = in_channelsself.neck = build_neck(config['Neck'])in_channels = self.neck.out_channels# ==========构建head==========if 'Head' not in config or config['Head'] is None:self.use_head = Falseelse:self.use_head = Trueconfig["Head"]['in_channels'] = in_channelsself.head = build_head(config["Head"])self.return_all_feats = config.get("return_all_feats", False)def forward(self, x, data=None):# 以rec任务为例,输入x, 即data['image']的shape为[bs,3,48,320], # data['label_ctc']的shape为[bs,30]# data['label_sar']的shape为[bs,30]# data['length']的shape为[bs]# data['valid_ratio']的shape为[bs]y = dict()if self.use_transform:x = self.transform(x)# 以det任务为例,骨干网络MobileNetv3_large输出为列表# 特征图大小分别为原图的1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160], [bs, 24, 80, 80], [bs, 56, 40, 40], [bs, 480, 20, 20]# 以rec任务为例,骨干网络MobileNetV1Enhance输出为[bs, 512, 1, 40]if self.use_backbone:x = self.backbone(x)if isinstance(x, dict):y.update(x)else:y["backbone_out"] = xfinal_name = "backbone_out"# 以det任务为例,Neck网络DBFPN输出为特征图为原图的1/4大小 # [bs, 256, 160, 160]if self.use_neck:x = self.neck(x)if isinstance(x, dict):y.update(x)else:y["neck_out"] = xfinal_name = "neck_out"# 以det任务为例,Head网络DBHead输出为字典# 特征图大小为原图的大小,{'maps': y}  [bs, 3, 160, 160]# 以rec任务为例,Head网络MultiHead(CTCHead + SARHead)输出为字典# 'ctc_neck': [bs, 40, 64]# 'ctc_head': [bs, 40, 35], 35个字符是因为character_dict_path + blank + " "# 'sar_head': [bs, 30, 36], 36个字符是因为character_dict_path + " " + "<UKN>" + "<BOS/EOS>" + "<PAD>"if self.use_head:x = self.head(x, targets=data)if isinstance(x, dict) and 'ctc_neck' in x.keys():y["neck_out"] = x["ctc_neck"]y["head_out"] = xelif isinstance(x, dict):y.update(x)else:y["head_out"] = xfinal_name = "head_out"if self.return_all_feats:if self.training:return yelif isinstance(x, dict):return xelse:return {final_name: x}else:return x

build_backbone函数

__all__ = ["build_backbone"]def build_backbone(config, model_type):if model_type == "det" or model_type == "table":from .det_mobilenet_v3 import MobileNetV3from .det_resnet import ResNetfrom .det_resnet_vd import ResNet_vdfrom .det_resnet_vd_sast import ResNet_SASTfrom .det_pp_lcnet import PPLCNetsupport_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"]if model_type == "table":from .table_master_resnet import TableResNetExtrasupport_dict.append('TableResNetExtra')elif model_type == "rec" or model_type == "cls":from .rec_mobilenet_v3 import MobileNetV3from .rec_resnet_vd import ResNetfrom .rec_resnet_fpn import ResNetFPNfrom .rec_mv1_enhance import MobileNetV1Enhancefrom .rec_nrtr_mtb import MTBfrom .rec_resnet_31 import ResNet31from .rec_resnet_32 import ResNet32from .rec_resnet_45 import ResNet45from .rec_resnet_aster import ResNet_ASTERfrom .rec_micronet import MicroNetfrom .rec_efficientb3_pren import EfficientNetb3_PRENfrom .rec_svtrnet import SVTRNetfrom .rec_vitstr import ViTSTRfrom .rec_resnet_rfl import ResNetRFLfrom .rec_densenet import DenseNetsupport_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB','ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet','EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL','DenseNet']elif model_type == 'e2e':from .e2e_resnet_vd_pg import ResNetsupport_dict = ['ResNet']elif model_type == 'kie':from .kie_unet_sdmgr import Kie_backbonefrom .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForResupport_dict = ['Kie_backbone', 'LayoutLMForSer', 'LayoutLMv2ForSer','LayoutLMv2ForRe', 'LayoutXLMForSer', 'LayoutXLMForRe']elif model_type == 'table':from .table_resnet_vd import ResNetfrom .table_mobilenet_v3 import MobileNetV3support_dict = ['ResNet', 'MobileNetV3']else:raise NotImplementedErrormodule_name = config.pop('name')assert module_name in support_dict, Exception("when model typs is {}, backbone only support {}".format(model_type, support_dict))module_class = eval(module_name)(**config)return module_class

MobileNetV3

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr__all__ = ['MobileNetV3']def make_divisible(v, divisor=8, min_value=None):if min_value is None:min_value = divisornew_v = max(min_value, int(v+divisor/2)//divisor*divisor)if new_v < 0.9*v:new_v += divisorreturn new_vclass ConvBNLayer(nn.Layer):def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, if_act=True, act=None):super(ConvBNLayer, self).__init__()self.if_act = if_actself.act = actself.conv = nn.Conv2D(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias_attr=False)self.bn = nn.BatchNorm(num_channels=out_channels, act=None)def forward(self, x):x = self.conv(x)x = self.bn(x)if self.if_act:if self.act == "relu":x = F.relu(x)elif self.act == "hardswish":x = F.hardswish(x)else:print("The activation function({}) is selected incorrectly.".format(self.act))exit()return xclass ResidualUnit(nn.Layer):def __init__(self, in_channels, mid_channels, out_channels, kernel_size, stride, use_se, act=None):super(ResidualUnit, self).__init__()self.if_shortcut = stride == 1 and in_channels == out_channelsself.if_se = use_se# 1x1卷积self.expand_conv = ConvBNLayer(in_channels=in_channels,out_channels=mid_channels,kernel_size=1,stride=1,padding=0,if_act=True,act=act)# 膨胀卷积self.bottleneck_conv = ConvBNLayer(in_channels=mid_channels,out_channels=mid_channels,kernel_size=kernel_size,stride=stride,padding=int((kernel_size - 1) // 2),groups=mid_channels,if_act=True,act=act)# SE注意力机制if self.if_se:self.mid_se = SEModule(mid_channels)# 1x1卷积self.linear_conv = ConvBNLayer(in_channels=mid_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,if_act=False,act=None)def forward(self, inputs):x = self.expand_conv(inputs)x = self.bottleneck_conv(x)if self.if_se:x = self.mid_se(x)x = self.linear_conv(x)if self.if_shortcut:x = paddle.add(inputs, x)return xclass SEModule(nn.Layer):def __init__(self, in_channels, reduction=4):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2D(1)self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=in_channels // reduction,kernel_size=1,stride=1,padding=0)self.conv2 = nn.Conv2D(in_channels=in_channels // reduction,out_channels=in_channels,kernel_size=1,stride=1,padding=0)def forward(self, inputs):outputs = self.avg_pool(inputs)outputs = self.conv1(outputs)outputs = F.relu(outputs)outputs = self.conv2(outputs)outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)return inputs * outputsclass MobileNetV3(nn.Layer):def __init__(self, in_channels=3, model_name='large', scale=0.5, disable_se=False, **kwargs):super(MobileNetV3, self).__init__()# 不启用注意力机制SEself.disable_se = disable_seif model_name == "large":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, False, 'relu', 1],[3, 64, 24, False, 'relu', 2],[3, 72, 24, False, 'relu', 1],[5, 72, 40, True, 'relu', 2],[5, 120, 40, True, 'relu', 1],[5, 120, 40, True, 'relu', 1],[3, 240, 80, False, 'hardswish', 2],[3, 200, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 480, 112, True, 'hardswish', 1],[3, 672, 112, True, 'hardswish', 1],[5, 672, 160, True, 'hardswish', 2],[5, 960, 160, True, 'hardswish', 1],[5, 960, 160, True, 'hardswish', 1],]cls_ch_squeeze = 960elif model_name == "small":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, True, 'relu', 2],[3, 72, 24, False, 'relu', 2],[3, 88, 24, False, 'relu', 1],[5, 96, 40, True, 'hardswish', 2],[5, 240, 40, True, 'hardswish', 1],[5, 240, 40, True, 'hardswish', 1],[5, 120, 48, True, 'hardswish', 1],[5, 144, 48, True, 'hardswish', 1],[5, 288, 96, True, 'hardswish', 2],[5, 576, 96, True, 'hardswish', 1],[5, 576, 96, True, 'hardswish', 1],]cls_ch_squeeze = 576else:raise NotImplementedError("mode[" + model_name + "_model] is not implemented!")supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]assert scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale, scale)inplanes = 16# conv1self.conv = ConvBNLayer(in_channels=in_channels,out_channels=make_divisible(inplanes * scale),kernel_size=3,stride=2,padding=1,groups=1,if_act=True,act='hardswish')self.stages = []self.out_channels = []block_list = []i = 0inplanes = make_divisible(inplanes * scale)# k表示卷积核大小,kernal_size;# exp表示隐藏层通道数;# c表示输出通道数;# se表示是否使用SENet;# nl表示激活函数;# s表示stride;for (k, exp, c, se, nl, s) in cfg:se = se and not self.disable_sestart_idx = 2 if model_name == 'large' else 0if s == 2 and i > start_idx:self.out_channels.append(inplanes)self.stages.append(nn.Sequential(*block_list))block_list = []block_list.append(ResidualUnit(in_channels=inplanes,mid_channels=make_divisible(scale * exp),out_channels=make_divisible(scale * c),kernel_size=k,stride=s,use_se=se,act=nl))inplanes = make_divisible(scale * c)i += 1block_list.append(ConvBNLayer(in_channels=inplanes,out_channels=make_divisible(scale * cls_ch_squeeze),kernel_size=1,stride=1,padding=0,groups=1,if_act=True,act='hardswish'))self.stages.append(nn.Sequential(*block_list))self.out_channels.append(make_divisible(scale * cls_ch_squeeze))for i, stage in enumerate(self.stages):self.add_sublayer(sublayer=stage, name="stage{}".format(i))def forward(self, x):# 输入shape [16, 3, 640, 640]x = self.conv(x)    out_list = []# 有四个stage, 1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160]# [bs, 24, 80, 80]# [bs, 56, 40, 40]# [bs, 480, 20, 20]for stage in self.stages:x = stage(x)out_list.append(x)return out_list

build_neck函数

__all__ = ['build_neck']def build_neck(config):from .db_fpn import DBFPN, RSEFPN, LKPANfrom .east_fpn import EASTFPNfrom .sast_fpn import SASTFPNfrom .rnn import SequenceEncoderfrom .pg_fpn import PGFPNfrom .table_fpn import TableFPNfrom .fpn import FPNfrom .fce_fpn import FCEFPNfrom .pren_fpn import PRENFPNfrom .csp_pan import CSPPANfrom .ct_fpn import CTFPNfrom .fpn_unet import FPN_UNetfrom .rf_adaptor import RFAdaptorsupport_dict = ['FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN','SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN','RFAdaptor', 'FPN_UNet']module_name = config.pop('name')assert module_name in support_dict, Exception('neck only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class

build_head函数

__all__ = ['build_head']def build_head(config):# det headfrom .det_db_head import DBHeadfrom .det_east_head import EASTHeadfrom .det_sast_head import SASTHeadfrom .det_pse_head import PSEHeadfrom .det_fce_head import FCEHeadfrom .e2e_pg_head import PGHeadfrom .det_ct_head import CT_Head# rec headfrom .rec_ctc_head import CTCHeadfrom .rec_att_head import AttentionHeadfrom .rec_srn_head import SRNHeadfrom .rec_nrtr_head import Transformerfrom .rec_sar_head import SARHeadfrom .rec_aster_head import AsterHeadfrom .rec_pren_head import PRENHeadfrom .rec_multi_head import MultiHeadfrom .rec_spin_att_head import SPINAttentionHeadfrom .rec_abinet_head import ABINetHeadfrom .rec_robustscanner_head import RobustScannerHeadfrom .rec_visionlan_head import VLHeadfrom .rec_rfl_head import RFLHeadfrom .rec_can_head import CANHead# cls headfrom .cls_head import ClsHead# kie headfrom .kie_sdmgr_head import SDMGRHead# table headfrom .table_att_head import TableAttentionHead, SLAHeadfrom .table_master_head import TableMasterHeadsupport_dict = ['DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead','ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer','TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead','MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead','VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead','DRRGHead', 'CANHead']if config['name'] == 'DRRGHead':from .det_drrg_head import DRRGHeadsupport_dict.append('DRRGHead')module_name = config.pop('name')assert module_name in support_dict, Exception('head only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class

这篇关于PaddleOCR识别框架解读[04] 文本检测det模型构建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/783310

相关文章

Java实现在Word文档中添加文本水印和图片水印的操作指南

《Java实现在Word文档中添加文本水印和图片水印的操作指南》在当今数字时代,文档的自动化处理与安全防护变得尤为重要,无论是为了保护版权、推广品牌,还是为了在文档中加入特定的标识,为Word文档添加... 目录引言Spire.Doc for Java:高效Word文档处理的利器代码实战:使用Java为Wo

Three.js构建一个 3D 商品展示空间完整实战项目

《Three.js构建一个3D商品展示空间完整实战项目》Three.js是一个强大的JavaScript库,专用于在Web浏览器中创建3D图形,:本文主要介绍Three.js构建一个3D商品展... 目录引言项目核心技术1. 项目架构与资源组织2. 多模型切换、交互热点绑定3. 移动端适配与帧率优化4. 可

GSON框架下将百度天气JSON数据转JavaBean

《GSON框架下将百度天气JSON数据转JavaBean》这篇文章主要为大家详细介绍了如何在GSON框架下实现将百度天气JSON数据转JavaBean,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言一、百度天气jsON1、请求参数2、返回参数3、属性映射二、GSON属性映射实战1、类对象映

Python文本相似度计算的方法大全

《Python文本相似度计算的方法大全》文本相似度是指两个文本在内容、结构或语义上的相近程度,通常用0到1之间的数值表示,0表示完全不同,1表示完全相同,本文将深入解析多种文本相似度计算方法,帮助您选... 目录前言什么是文本相似度?1. Levenshtein 距离(编辑距离)核心公式实现示例2. Jac

Python利用PySpark和Kafka实现流处理引擎构建指南

《Python利用PySpark和Kafka实现流处理引擎构建指南》本文将深入解剖基于Python的实时处理黄金组合:Kafka(分布式消息队列)与PySpark(分布式计算引擎)的化学反应,并构建一... 目录引言:数据洪流时代的生存法则第一章 Kafka:数据世界的中央神经系统消息引擎核心设计哲学高吞吐

Springboot项目构建时各种依赖详细介绍与依赖关系说明详解

《Springboot项目构建时各种依赖详细介绍与依赖关系说明详解》SpringBoot通过spring-boot-dependencies统一依赖版本管理,spring-boot-starter-w... 目录一、spring-boot-dependencies1.简介2. 内容概览3.核心内容结构4.

Python脚本轻松实现检测麦克风功能

《Python脚本轻松实现检测麦克风功能》在进行音频处理或开发需要使用麦克风的应用程序时,确保麦克风功能正常是非常重要的,本文将介绍一个简单的Python脚本,能够帮助我们检测本地麦克风的功能,需要的... 目录轻松检测麦克风功能脚本介绍一、python环境准备二、代码解析三、使用方法四、知识扩展轻松检测麦

Python中高级文本模式匹配与查找技术指南

《Python中高级文本模式匹配与查找技术指南》文本处理是编程世界的永恒主题,而模式匹配则是文本处理的基石,本文将深度剖析PythonCookbook中的核心匹配技术,并结合实际工程案例展示其应用,希... 目录引言一、基础工具:字符串方法与序列匹配二、正则表达式:模式匹配的瑞士军刀2.1 re模块核心AP

Go语言使用net/http构建一个RESTful API的示例代码

《Go语言使用net/http构建一个RESTfulAPI的示例代码》Go的标准库net/http提供了构建Web服务所需的强大功能,虽然众多第三方框架(如Gin、Echo)已经封装了很多功能,但... 目录引言一、什么是 RESTful API?二、实战目标:用户信息管理 API三、代码实现1. 用户数据

解决若依微服务框架启动报错的问题

《解决若依微服务框架启动报错的问题》Invalidboundstatement错误通常由MyBatis映射文件未正确加载或Nacos配置未读取导致,需检查XML的namespace与方法ID是否匹配,... 目录ruoyi-system模块报错报错详情nacos文件目录总结ruoyi-systnGLNYpe