使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练

本文主要是介绍使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

WebsocketServerWorker端代码:start_worker.py

import argparseimport torch as th
from syft.workers.websocket_server import WebsocketServerWorkerimport syft as sy# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument("--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument("--verbose","-v",action="store_true",help="if set, websocket server worker will be started in verbose mode",
)def main(**kwargs):  # pragma: no cover"""Helper function for spinning up a websocket participant."""# Create websocket workerworker = WebsocketServerWorker(**kwargs)# Setup toy data (xor example)data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)# Create a dataset using the toy datadataset = sy.BaseDataset(data, target)# Tell the worker about the datasetworker.add_dataset(dataset, key="xor")# Start workerworker.start()return workerif __name__ == "__main__":hook = sy.TorchHook(th)args = parser.parse_args()kwargs = {"id": args.id,"host": args.host,"port": args.port,"hook": hook,"verbose": args.verbose,}main(**kwargs)

启动worker

  python start_worker.py --host 172.16.5.45 --port 8777 --id alice

客户端代码:

import inspect
import start_workerprint(inspect.getsource(start_worker.main))# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nnuse_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")import syft as sy
from syft import workershook = sy.TorchHook(th)  # hook torch as always :)class Net(th.nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(2, 20)self.fc2 = nn.Linear(20, 10)self.fc3 = nn.Linear(10, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Instantiate the model
model = Net()# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)type(traced_model)# Loss function
@th.jit.script
def loss_fn(target, pred):return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()type(loss_fn)optimizer = "SGD"batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = Truetrain_config = sy.TrainConfig(model=traced_model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size,optimizer_args=optimizer_args,epochs=epochs,shuffle=shuffle)kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)# Send train config
train_config.send(alice)# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))for epoch in range(10):loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" datasetprint("-" * 50)print("Iteration %s: alice's loss: %s" % (epoch, loss))new_model = train_config.model_ptr.get()print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

运行:

python worker-client.py 

输出结果:

Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[ 0.1258],[-0.0994],[ 0.0033],[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[0.6553],[0.3781],[0.4834],[0.4477]], grad_fn=<DifferentiableGraphBackward>)

这篇关于使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/240449

相关文章

使用Redis快速实现共享Session登录的详细步骤

《使用Redis快速实现共享Session登录的详细步骤》在Web开发中,Session通常用于存储用户的会话信息,允许用户在多个页面之间保持登录状态,Redis是一个开源的高性能键值数据库,广泛用于... 目录前言实现原理:步骤:使用Redis实现共享Session登录1. 引入Redis依赖2. 配置R

使用Python的requests库调用API接口的详细步骤

《使用Python的requests库调用API接口的详细步骤》使用Python的requests库调用API接口是开发中最常用的方式之一,它简化了HTTP请求的处理流程,以下是详细步骤和实战示例,涵... 目录一、准备工作:安装 requests 库二、基本调用流程(以 RESTful API 为例)1.

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

Python yield与yield from的简单使用方式

《Pythonyield与yieldfrom的简单使用方式》生成器通过yield定义,可在处理I/O时暂停执行并返回部分结果,待其他任务完成后继续,yieldfrom用于将一个生成器的值传递给另一... 目录python yield与yield from的使用代码结构总结Python yield与yield

Go语言使用select监听多个channel的示例详解

《Go语言使用select监听多个channel的示例详解》本文将聚焦Go并发中的一个强力工具,select,这篇文章将通过实际案例学习如何优雅地监听多个Channel,实现多任务处理、超时控制和非阻... 目录一、前言:为什么要使用select二、实战目标三、案例代码:监听两个任务结果和超时四、运行示例五

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

pandas数据的合并concat()和merge()方式

《pandas数据的合并concat()和merge()方式》Pandas中concat沿轴合并数据框(行或列),merge基于键连接(内/外/左/右),concat用于纵向或横向拼接,merge用于... 目录concat() 轴向连接合并(1) join='outer',axis=0(2)join='o

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的