【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回

2024-06-12 22:38

本文主要是介绍【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回归任务

浏览更多内容,可访问:http://www.growai.cn

1.简介

该部分是比较基础的深度网络部分,是基于keras实现的多层感知机网络(mlp),使用nn个人感觉最大的一个好处就是目标函数自定义很方便,下面将从数据处理、网络搭建和模型训练三个部分介绍。如果只是想要阅读代码,可直接移步到尾部链接。

2. 数据处理

神经网络对数据的要求比较多,不能处理缺失值,并且数据分布对其影响也很大,输入模型前需要对数据做预处理。具体需要做如下处理

  • onehot:参考上一节

  • 填充:常用的有均值填充,常数值填充,中位数填充等,根据数据场景做选择,这里直接填充的常数值-1

    for i in train_x.columns:if train_x[i].isnull().sum() != 0:train_x[i] = train_x[i].fillna(-1)test[i] = test[i].fillna(-1)
    
  • 归一化:如果各个特征值差距很大,会严重影响模型参数分布,需要对整体数据进行归一化处理

    scaler = StandardScaler()
    train_X = scaler.fit_transform(train_x)
    test_X = scaler.transform(test)
    

3.模型部分

def MLP(dropout_rate=0.25, activation='relu'):start_neurons = 512model = Sequential()model.add(Dense(start_neurons, input_dim=train_X.shape[1], activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 2, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 4, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 8, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate / 2))model.add(Dense(classes, activation='sigmoid'))return model

这里定义的是四层感知网络,为了提高网络的性能,添加的dropout层和BN层。Dropout的具体工作原理是随机的使一些神经元失活,从而达到防止过拟合的作用。直观的理解的话,dropout有点像集成学习中的bagging的思路,每次训练的时候只训练一部分神经元,相当于训练了多个弱分类器,预测的时候则是全部分类器同时作用。而bagging的作用也是为了减少方差(防止过拟合)。BN,Batch Normalization,就是在深度神经网络训练过程中使得每一层神经网络的输入保持相近的分布,可以加速训练。

针对不同的网络,输出层的激活函数不同

  • 二分类:sigmoid
  • 多分类:softmax
  • 回归:linear

4. 模型训练

首先需要定义网络模型,然后定义loss优化和目标函数,keras训练函数和sklearn很相似,直接调用fit函数即可。

model = MLP(dropout_rate=0.5, activation='relu')
model.compile(optimizer='adam', loss='binary_crossentropy',  metrics=['accuracy'])
history = model.fit(x_train, y_train,validation_data=[x_valid, y_valid],epochs=epochs,batch_size=batch_size,callbacks=[call_ES, ],shuffle=True,verbose=1)
  • optimizer:loss优化函数,常用的有sgd, rmsprop, adam等

  • loss:常用的loss损失函数

    • 二分类:binary_crossentropy等
    • 多分类:categorical_crossentropy等
    • 回归:mse,mae等
  • metrics:评价函数:

    • 分类:accuracy等
    • 回归:mse, mae等
  • callbacks:这个是回调函数,该函数是在加载完一次数据后调用,可以用他来加载loss,打印tensorboard,提前停止等,这里给出了提前停止的代码

    call_ES = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=patience, verbose=1, mode='auto', baseline=None)
    

模型预测部分

##分类
predictions = model.predict_proba(test_X, batch_size=batch_size)##回归&分类
oof_preds[val_] = model.predict(x_valid, batch_size=batch_size)

分类任务可以通过第一个式子预测每个类别的概率。对于二分类任务可以自定义阈值,得到最终的分类结果

threshold = 0.5
result = []
for pred in predictions:result.append(1 if pred > threshold else 0)

对于多分类:

result = np.argmax(predictions, axis=1)

代码地址:data_mining_models

写在后面

欢迎您关注作者知乎:ML与DL成长之路

推荐关注公众号:AI成长社,ML与DL的成长圣地。

这篇关于【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1055538

相关文章

IIS 7.0 及更高版本中的 FTP 状态代码

《IIS7.0及更高版本中的FTP状态代码》本文介绍IIS7.0中的FTP状态代码,方便大家在使用iis中发现ftp的问题... 简介尝试使用 FTP 访问运行 Internet Information Services (IIS) 7.0 或更高版本的服务器上的内容时,IIS 将返回指示响应状态的数字代

MySQL 添加索引5种方式示例详解(实用sql代码)

《MySQL添加索引5种方式示例详解(实用sql代码)》在MySQL数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中,下面给大家分享MySQL添加索引5种方式示例详解(实用sql代码),... 在mysql数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中。索引可以在创建表时定义,也可

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元

Python实现一键PDF转Word(附完整代码及详细步骤)

《Python实现一键PDF转Word(附完整代码及详细步骤)》pdf2docx是一个基于Python的第三方库,专门用于将PDF文件转换为可编辑的Word文档,下面我们就来看看如何通过pdf2doc... 目录引言:为什么需要PDF转Word一、pdf2docx介绍1. pdf2docx 是什么2. by

Spring Security介绍及配置实现代码

《SpringSecurity介绍及配置实现代码》SpringSecurity是一个功能强大的Java安全框架,它提供了全面的安全认证(Authentication)和授权(Authorizatio... 目录简介Spring Security配置配置实现代码简介Spring Security是一个功能强

通过cmd获取网卡速率的代码

《通过cmd获取网卡速率的代码》今天从群里看到通过bat获取网卡速率两段代码,感觉还不错,学习bat的朋友可以参考一下... 1、本机有线网卡支持的最高速度:%v%@echo off & setlocal enabledelayedexpansionecho 代码开始echo 65001编码获取: >

Java集成Onlyoffice的示例代码及场景分析

《Java集成Onlyoffice的示例代码及场景分析》:本文主要介绍Java集成Onlyoffice的示例代码及场景分析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 需求场景:实现文档的在线编辑,团队协作总结:两个接口 + 前端页面 + 配置项接口1:一个接口,将o

SpringBoot实现Kafka动态反序列化的完整代码

《SpringBoot实现Kafka动态反序列化的完整代码》在分布式系统中,Kafka作为高吞吐量的消息队列,常常需要处理来自不同主题(Topic)的异构数据,不同的业务场景可能要求对同一消费者组内的... 目录引言一、问题背景1.1 动态反序列化的需求1.2 常见问题二、动态反序列化的核心方案2.1 ht

IDEA实现回退提交的git代码(四种常见场景)

《IDEA实现回退提交的git代码(四种常见场景)》:本文主要介绍IDEA实现回退提交的git代码(四种常见场景),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1.已提交commit,还未push到远端(Undo Commit)2.已提交commit并push到

Kotlin Compose Button 实现长按监听并实现动画效果(完整代码)

《KotlinComposeButton实现长按监听并实现动画效果(完整代码)》想要实现长按按钮开始录音,松开发送的功能,因此为了实现这些功能就需要自己写一个Button来解决问题,下面小编给大... 目录Button 实现原理1. Surface 的作用(关键)2. InteractionSource3.