FATE —— 二.2.3 Homo-NN自定义损失函数

2023-12-19 11:10

本文主要是介绍FATE —— 二.2.3 Homo-NN自定义损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

当Pytorch的内置损失功能不能满足您的使用需求时,您可以使用自定义损失来训练您的模型

MNIST示例的一个小问题

您可能会注意到,在上一个教程“自定义数据集”中的MNIST示例中,分类器输出分数是Softmax函数的结果,我们使用torch内置的CrossEntropyLoss来计算损失。然而,它在文档(CrossEntropyLoss Doc)中显示,输入预期包含每个类的未规范化逻辑,也就是说,在该示例中,我们计算Softmax两次。为了解决这个问题,我们可以使用定制的CrossEntropyLoss。

开发自定义丢失

Customized Loss是torch.nn.Module的子类并实现forward函数的类。在FATE训练器中,损失函数将传递两个参数:预测分数和标签(loss_fn(pred,loss)),因此当您使用FATE的训练器时,损失函数需要将两个参数作为输入(预测分数&标签)。然而,如果您使用的是自己的培训师,并且定义了自己的培训流程,那么您不受如何使用损失函数的限制。

一种新的交叉熵损失

在这里,我们实现了一个新的CrossEntropyLoss,它跳过了softmax计算。我们可以使用jupyter接口save_to_rate将代码更新为federatedml.nn.loss(名为ce.py),当然,您可以手动将代码文件复制到目录中。

import torch as t
from federatedml.util import consts
from torch.nn.functional import one_hotdef cross_entropy(p2, p1, reduction='mean'):p2 = p2 + consts.FLOAT_ZERO  # to avoid nanassert p2.shape == p1.shapeif reduction == 'sum':return -t.sum(p1 * t.log(p2))elif reduction == 'mean':return -t.mean(t.sum(p1 * t.log(p2), dim=1))elif reduction == 'none':return -t.sum(p1 * t.log(p2), dim=1)else:raise ValueError('unknown reduction')class CrossEntropyLoss(t.nn.Module):"""A CrossEntropy Loss that will not compute Softmax"""def __init__(self, reduction='mean'):super(CrossEntropyLoss, self).__init__()self.reduction = reductiondef forward(self, pred, label):one_hot_label = one_hot(label.flatten())loss_ = cross_entropy(pred, one_hot_label, self.reduction)return loss_

训练新的损失

导入组件
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Modelt = fate_torch_hook(t)
将数据路径绑定到名称和命名空间
import os
# bind data path to name & namespace
# fate_project_path = os.path.abspath('../')
arbiter = 10000
host = 10000
guest = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,arbiter=arbiter)data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}
# 路径根据自己得文件位置及名称进行调整,这里以FATE 1.10.0 版本为例
data_path_0 = '/mnt/hgfs/mnist/'
data_path_1 = '/mnt/hgfs/mnist/'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)

{'namespace': 'experiment', 'table_name': 'mnist_host'}

reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)
使用CustLoss

在fate_torch_hook之后,我们可以使用t.nn.CustLoss指定您自己的损失。我们将在参数中指定模块名和类名,后面是损失类的初始化参数。初始化参数必须是JSON可序列化的,否则无法提交此PipeLine。

from pipeline.component.homo_nn import TrainerParam, DatasetParam  # Interface# your loss class
loss = t.nn.CustLoss(loss_module_name='cross_entropy', class_name='CrossEntropyLoss', reduction='mean')# our simple classification model:
model = t.nn.Sequential(t.nn.Linear(784, 32),t.nn.ReLU(),t.nn.Linear(32, 10),t.nn.Softmax(dim=1)
)nn_component = HomoNN(name='nn_0',model=model, # modelloss=loss,  # lossoptimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizerdataset=DatasetParam(dataset_name='mnist_dataset', flatten_feature=True),  # datasettrainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),torch_seed=100 # random seed)
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))pipeline.compile()
pipeline.fit()
pipeline.get_component('nn_0').get_output_data()
pipeline.get_component('nn_0').get_summary()

{'best_epoch': 1,

'loss_history': [3.472281552891043, 2.6957144274613256],

'metrics_summary': {'train': {'accuracy': [0.41711229946524064,

0.6348357524828113],

'precision': [0.5812903622442052, 0.7334376862468294],

'recall': [0.39894927536231883, 0.6243379446640317]}},

'need_stop': False}

这篇关于FATE —— 二.2.3 Homo-NN自定义损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511945

相关文章

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

如何自定义一个log适配器starter

《如何自定义一个log适配器starter》:本文主要介绍如何自定义一个log适配器starter的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求Starter 项目目录结构pom.XML 配置LogInitializer实现MDCInterceptor

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

C++/类与对象/默认成员函数@构造函数的用法

《C++/类与对象/默认成员函数@构造函数的用法》:本文主要介绍C++/类与对象/默认成员函数@构造函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录名词概念默认成员函数构造函数概念函数特征显示构造函数隐式构造函数总结名词概念默认构造函数:不用传参就可以

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

Druid连接池实现自定义数据库密码加解密功能

《Druid连接池实现自定义数据库密码加解密功能》在现代应用开发中,数据安全是至关重要的,本文将介绍如何在​​Druid​​连接池中实现自定义的数据库密码加解密功能,有需要的小伙伴可以参考一下... 目录1. 环境准备2. 密码加密算法的选择3. 自定义 ​​DruidDataSource​​ 的密码解密3

Python函数返回多个值的多种方法小结

《Python函数返回多个值的多种方法小结》在Python中,函数通常用于封装一段代码,使其可以重复调用,有时,我们希望一个函数能够返回多个值,Python提供了几种不同的方法来实现这一点,需要的朋友... 目录一、使用元组(Tuple):二、使用列表(list)三、使用字典(Dictionary)四、 使

spring-gateway filters添加自定义过滤器实现流程分析(可插拔)

《spring-gatewayfilters添加自定义过滤器实现流程分析(可插拔)》:本文主要介绍spring-gatewayfilters添加自定义过滤器实现流程分析(可插拔),本文通过实例图... 目录需求背景需求拆解设计流程及作用域逻辑处理代码逻辑需求背景公司要求,通过公司网络代理访问的请求需要做请