深度学习_数据读取到model模型存储

2024-09-01 01:20

本文主要是介绍深度学习_数据读取到model模型存储,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

概要

应用场景:用户流失
本文将介绍模型调用预测的步骤,这里深度学习模型使用的是自定义的deepfm,并用机器学习lgb做比较

代码

导包

import pandas as pd
import numpy as npimport matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict  
from scipy import stats
from scipy import signal
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score
from scipy.spatial.distance import cosineimport lightgbm as lgbfrom sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.models import Modelimport os,gc,re,warnings,sys,math
warnings.filterwarnings("ignore")pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)

读取数据

data = pd.read_csv('df_03m.csv')

区分稀疏及类别变量

sparse_cols = ['shop_id','sex']
dense_cols  = [c for c in data.columns if c not in sparse_cols + ['customer_id', 'flag', 'duartion_is_lm']]

dense特征处理

def process_dense_feats(data, cols):d = data.copy()for f in cols:d[f] = d[f].fillna(0)ss=StandardScaler()d[f] = ss.fit_transform(d[[f]])return ddata = process_dense_feats(data, dense_cols)

sparse稀疏特征处理

def process_sparse_feats(data, cols):d = data.copy()for f in cols:d[f] = d[f].fillna('-1').astype(str)label_encoder = LabelEncoder()d[f] = label_encoder.fit_transform(d[f])return ddata = process_sparse_feats(data, sparse_cols)

切分训练及测试集

X_train, X_test, _, _ = train_test_split(data, data, test_size=0.3, random_state=2024)y_train = X_train['flag']
y_test = X_test['flag']X_train1 = X_train.drop(['customer_id', 'flag', 'duartion_is_lm'], axis = 1)
X_test1 = X_test.drop(['customer_id', 'flag', 'duartion_is_lm'], axis = 1)

模型定义

def deepfm_model(sparse_columns, dense_columns, train, test):####### sparse features ##########sparse_input = []lr_embedding = []fm_embedding = []for col in sparse_columns:## lr_embedding_input = Input(shape=(1,))sparse_input.append(_input)nums = pd.concat((train[col], test[col])).nunique() + 1embed = Flatten()(Embedding(nums, 1, embeddings_regularizer=tf.keras.regularizers.l2(0.5))(_input))lr_embedding.append(embed)## fm_embeddingembed = Embedding(nums, 10, input_length=1, embeddings_regularizer=tf.keras.regularizers.l2(0.5))(_input)reshape = Reshape((10,))(embed)fm_embedding.append(reshape)####### fm layer ##########fm_square = Lambda(lambda x: K.square(x))(Add()(fm_embedding)) # square_fm = Add()([Lambda(lambda x:K.square(x))(embed)for embed in fm_embedding])snd_order_sparse_layer = subtract([fm_square, square_fm])snd_order_sparse_layer  = Lambda(lambda x: x * 0.5)(snd_order_sparse_layer)####### dense features ##########dense_input = []for col in dense_columns:_input = Input(shape=(1,))dense_input.append(_input)concat_dense_input = concatenate(dense_input)fst_order_dense_layer = Dense(4, activation='relu')(concat_dense_input)#     #######  NFM  ##########
#     inner_product = []
#     for i in range(field_cnt):
#         for j in range(i + 1, field_cnt):
#             tmp = dot([fm_embedding[i], fm_embedding[j]], axes=1)
#             # tmp = multiply([fm_embedding[i], fm_embedding[j]])
#             inner_product.append(tmp)
#     add_inner_product = add(inner_product)#     #######  PNN  ##########
#     for i in range(field_cnt):
#         for j in range(i+1,field_cnt):
#             tmp = dot([lr_embedding[i],lr_embedding[j]],axes=1)
#             product_list.append(temp)
#     inp = concatenate(lr_embedding+product_list)####### linear concat ##########fst_order_sparse_layer = concatenate(lr_embedding)linear_part = concatenate([fst_order_dense_layer, fst_order_sparse_layer])#     #######  DCN  ##########
#     linear_part = concatenate([fst_order_dense_layer, fst_order_sparse_layer])
#     x0 = linear_part
#     xl = x0
#     for i in range(3):
#         embed_dim = xl.shape[-1]
#         w = tf.Variable(tf.random.truncated_normal(shape=(embed_dim,), stddev=0.01))
#         b = tf.Variable(tf.zeros(shape=(embed_dim,)))
#         x_lw = tf.tensordot(tf.reshape(xl, [-1, 1, embed_dim]), w, axes=1)
#         cross = x0 * x_lw 
#         xl = cross + b + xl#######dnn layer##########concat_fm_embedding = concatenate(fm_embedding, axis=-1) # (None, 10*26)fc_layer = Dropout(0.2)(Activation(activation="relu")(BatchNormalization()(Dense(128)(concat_fm_embedding))))fc_layer = Dropout(0.2)(Activation(activation="relu")(BatchNormalization()(Dense(64)(fc_layer))))fc_layer = Dropout(0.2)(Activation(activation="relu")(BatchNormalization()(Dense(32)(fc_layer))))######## output layer ##########output_layer = concatenate([linear_part, snd_order_sparse_layer, fc_layer]) # (None, )output_layer = Dense(1, activation='sigmoid')(output_layer)model = Model(inputs=sparse_input+dense_input, outputs=output_layer)return model
model = deepfm_model(sparse_cols, dense_cols, X_train1, X_test1)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["binary_crossentropy", tf.keras.metrics.AUC(name='auc')])
train_sparse_x = [X_train1[f].values for f in sparse_cols]
train_dense_x = [X_train1[f].values for f in dense_cols]
train_label = [y_train.values]test_sparse_x = [X_test1[f].values for f in sparse_cols]
test_dense_x = [X_test1[f].values for f in dense_cols]
test_label = [y_test.values]
test_sparse_x

训练模型

from keras.callbacks import *
# 回调函数
file_path = "deepfm_model_data.h5"
earlystopping = EarlyStopping(monitor="val_loss", patience=3)
checkpoint = ModelCheckpoint(file_path, save_weights_only=True, verbose=1, save_best_only=True)
callbacks_list = [earlystopping, checkpoint]hist = model.fit(train_sparse_x+train_dense_x, train_label,batch_size=4096,epochs=20,validation_data=(test_sparse_x+test_dense_x, test_label),callbacks=callbacks_list,shuffle=False)

模型存储

model.save('deepfm_model.h5')
loaded_model = tf.keras.models.load_model('deepfm_model.h5')
print("np.min(hist.history['val_loss']):", np.min(hist.history['val_loss']))
#np.min(hist.history['val_loss']):0.19
print("np.max(hist.history['val_auc']):", np.max(hist.history['val_auc']))
#np.max(hist.history['val_auc']):0.95

模型预测

deepfm_prob = model.predict(test_sparse_x+test_dense_x, batch_size=4096*4, verbose=1)
deepfm_prob.shape
deepfm_prob
df_submit          = pd.DataFrame()
df_submit          = X_test
df_submit['prob']  = deepfm_prob 
df_submit.head(3)
df_submit.shape
df_submit['y_pre'] = ''
df_submit['y_pre'].loc[(df_submit['prob']>=0.5)] = 1
df_submit['y_pre'].loc[(df_submit['prob']<0.5)]  = 0
df_submit.head(3)
df_submit = df_submit.reset_index()
df_submit.head(1)
df_submit = df_submit.drop('index', axis = 1)
df_submit.head(1)
df_submit.groupby(['flag', 'y_pre'])['customer_id'].count()

根据上述结果打印召回及精准

precision = 
recall  = 

查看lgb结果做比较

from lightgbm import LGBMClassifier
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score, confusion_matrix, recall_score, precision_scoreparams = {'n_estimators': 1500, 'learning_rate': 0.1,'max_depth': 15,'metric': 'auc','verbose': -1, 'seed: 2023,'n_jobs':-1model=LGBMClarsifier(**params) 
model.fit(X_train, y_train,eval_set=[(X_train1, y_train), (X_test1, y_test)], eval_metric = 'auc', verbose=50,early_stopping_rounds = 100)
y_pred = model.predict(X_test1, num_iteration = model.best_iteration_)y_pred = model.predict(X_test1)
y_pred_proba = model.predict_proba(X_test1)
lgb_acc = model.score(X_test1, y_test) * 100
lgb_recall = recall_score(y_test, y_pred) * 100
lgb_precision = precision_score(y_test, y_pred) * 100 I 
lgb_f1 = f1_score(y_test, y_pred, pos_label=1) * 100
print("1gb 准确率:{:.2f}%".format(lgb_acc))
print("lgb 召回率:{:.2f}%".fornat(lgb_recall))
print("lgb 精准率:{:.2f}%".format(lgb_precision))
print("lgb F1分数:{:.2f}%".format(lgb_f1))#from sklearn.metrics import classification_report
#printf(classification_report(y_test, y_pred))# 混淆矩阵
plt.title("混淆矩阵", fontsize=21)
data_confusion_matrix = confusion_matrix(y_test, y_pred)
sns.heatmap(data_confusion_matrix, annot=True, cmap='Blues', fmt='d', cbar='False', annot_kws={'size': 28})
plt.xlabel('Predicted label') 
plt.ylabel('True label')from sklearn.metrics import roc_curve, auc
probs = model.predict_proba(X_test1)
preds = probs[:, 1]
fpr, tpr, threshold = roc_curve(y_test, preds)
# 绘制ROC曲线
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive(TPR)')
plt.xlabel('False Positive(FPR)')
plt.title('ROC')
plt.legend(loc='lower right')
plt.show()

参考资料:自己琢磨将资料整合

这篇关于深度学习_数据读取到model模型存储的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1125544

相关文章

一文教你Python如何快速精准抓取网页数据

《一文教你Python如何快速精准抓取网页数据》这篇文章主要为大家详细介绍了如何利用Python实现快速精准抓取网页数据,文中的示例代码简洁易懂,具有一定的借鉴价值,有需要的小伙伴可以了解下... 目录1. 准备工作2. 基础爬虫实现3. 高级功能扩展3.1 抓取文章详情3.2 保存数据到文件4. 完整示例

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

python处理带有时区的日期和时间数据

《python处理带有时区的日期和时间数据》这篇文章主要为大家详细介绍了如何在Python中使用pytz库处理时区信息,包括获取当前UTC时间,转换为特定时区等,有需要的小伙伴可以参考一下... 目录时区基本信息python datetime使用timezonepandas处理时区数据知识延展时区基本信息

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Pandas统计每行数据中的空值的方法示例

《Pandas统计每行数据中的空值的方法示例》处理缺失数据(NaN值)是一个非常常见的问题,本文主要介绍了Pandas统计每行数据中的空值的方法示例,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是空值?为什么要统计空值?准备工作创建示例数据统计每行空值数量进一步分析www.chinasem.cn处

如何使用 Python 读取 Excel 数据

《如何使用Python读取Excel数据》:本文主要介绍使用Python读取Excel数据的详细教程,通过pandas和openpyxl,你可以轻松读取Excel文件,并进行各种数据处理操... 目录使用 python 读取 Excel 数据的详细教程1. 安装必要的依赖2. 读取 Excel 文件3. 读

关于MongoDB图片URL存储异常问题以及解决

《关于MongoDB图片URL存储异常问题以及解决》:本文主要介绍关于MongoDB图片URL存储异常问题以及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录MongoDB图片URL存储异常问题项目场景问题描述原因分析解决方案预防措施js总结MongoDB图

Spring Boot读取配置文件的五种方式小结

《SpringBoot读取配置文件的五种方式小结》SpringBoot提供了灵活多样的方式来读取配置文件,这篇文章为大家介绍了5种常见的读取方式,文中的示例代码简洁易懂,大家可以根据自己的需要进... 目录1. 配置文件位置与加载顺序2. 读取配置文件的方式汇总方式一:使用 @Value 注解读取配置方式二