You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import os
from PIL import Image
import flwr as fl
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from ultralytics import YOLO
class CustomDataset(Dataset):
def __init__(self, data_folder, transform=None):
self.data_folder = data_folder
self.transform = transform
self.classes = sorted(os.listdir(data_folder))
self.data = []
self.labels = []
for i, class_name in enumerate(self.classes):
class_path = os.path.join(data_folder, class_name)
for file_name in os.listdir(class_path):
file_path = os.path.join(class_path, file_name)
self.data.append(file_path)
self.labels.append(i)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = self.data[idx]
label = self.labels[idx]
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
return {'image': img, 'label': label}
class PyTorchClient(fl.client.NumPyClient):
def __init__(self, model, train_loader):
self.model = model
self.train_loader = train_loader
def get_parameters(self):
return [param.cpu().numpy() for param in self.model.parameters()]
def set_parameters(self, parameters):
for param, param_data in zip(self.model.parameters(), parameters):
param.data = torch.from_numpy(param_data)
def fit(self, parameters, config):
self.set_parameters(parameters)
optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(1): # Adjust the number of epochs as needed
self.model.train()
for batch_idx, batch in enumerate(self.train_loader):
data, target = batch['image'], batch['label']
optimizer.zero_grad()
output = self.model(data)
loss = torch.nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
return self.get_parameters(), len(self.train_loader.dataset), {}
custom_data = '/content/drive/MyDrive/local_1_final_all/train'
custom_dataset = CustomDataset(data_folder=custom_data, transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(custom_dataset, batch_size=64, shuffle=True)
model_files = ["/content/drive/MyDrive/runs/detect/local_1_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_2_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_3_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_4_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_5_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_6_train_result/weights/best.pt",
"/content/drive/MyDrive/runs/detect/local_7_train_result/weights/best.pt"]
clients = []
for model_file in model_files:
model = YOLO(model_file) # Load YOLO model
client = PyTorchClient(model, train_loader)
clients.append(client)
fl.server.start_server(config=fl.server.app.ServerConfig(num_rounds=3), server_address="[::]:8888")
fl.server.add_clients(clients)
server_model = fl.server.get_model()
torch.save(server_model.state_dict(), "federated_model.pt")
My development environment is Mac, and I'm running code on Google colab.
fl.server.start_server(config=fl.server.app.ServerConfig(num_rounds=3), server_address="[::]:8888")
Infinite loading is occurring in this area, and the
start_server() > run_fl() > fit() > _get_initial_parameters() > wait_for() > wait() > wait()
It's marked like this.
QuestionProgramming HelpProgramming languages, open source, and software development.Universe 2023All things related to our global developer conference, Universe 2023
1 participant
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
My development environment is Mac, and I'm running code on Google colab.
fl.server.start_server(config=fl.server.app.ServerConfig(num_rounds=3), server_address="[::]:8888")
Infinite loading is occurring in this area, and the
start_server() > run_fl() > fit() > _get_initial_parameters() > wait_for() > wait() > wait()
It's marked like this.
Beta Was this translation helpful? Give feedback.
All reactions