-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
35 lines (26 loc) · 1.12 KB
/
client.py
File metadata and controls
35 lines (26 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from collections import OrderedDict
from centralized import load_data, load_model, train, test
import flwr as fl
import torch
def set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
net = load_model()
trainloader, testloader = load_data()
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epoch=1)
return self.get_parameters({}), len(trainloader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(net, parameters)
loss, accuracy = test(net, testloader)
print(f"Evaluation: Loss = {loss}, Accuracy = {accuracy}")
return float(loss), len(testloader.dataset), {"accuracy": float(accuracy)}
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)