Federated Learning

From BloomWiki
Revision as of 06:12, 23 April 2026 by Wordpad (talk | contribs) (New article: Federated Learning structured through Bloom's Taxonomy)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

How to read this page: This article maps the topic from beginner to expert across six levels � Remembering, Understanding, Applying, Analyzing, Evaluating, and Creating. Scan the headings to see the full scope, then read from wherever your knowledge starts to feel uncertain. Learn more about how BloomWiki works ?

Federated Learning (FL) is a machine learning paradigm in which a model is trained across multiple decentralized devices or servers holding local data, without the raw data ever leaving those devices. Instead of shipping data to a central server for training, the training comes to the data. Federated learning enables AI development in contexts where data cannot be centralized due to privacy requirements, regulatory constraints, or practical limitations — making it essential for healthcare, finance, mobile applications, and national security.

Remembering

  • Federated learning — A distributed ML approach where model training happens locally on client devices, and only model updates (gradients or weights) are shared, never raw data.
  • Client — A participating device or institution that holds local data and trains a local model. Examples: smartphones, hospitals, banks.
  • Server — The central coordinator that aggregates client updates, computes a new global model, and distributes it back to clients.
  • Round — One communication cycle: server sends model → clients train locally → clients send updates → server aggregates.
  • FedAvg (Federated Averaging) — The foundational FL algorithm by McMahan et al. (2017); aggregates client models by weighted averaging of parameters.
  • Local epochs — The number of training epochs each client performs on local data before sending updates to the server.
  • Non-IID data (Non-Independent and Identically Distributed) — A key challenge in FL: each client's data reflects its own distribution, which may differ substantially from other clients.
  • Communication efficiency — Minimizing the amount of data transferred between clients and server, a key FL challenge.
  • Gradient compression — Techniques to reduce the size of gradients transmitted (sparsification, quantization).
  • Differential privacy (DP) — A mathematical privacy guarantee that limits what can be learned about any individual's data from shared model updates.
  • Secure aggregation — Cryptographic protocols ensuring the server can compute the sum of client updates without seeing any individual update.
  • Model poisoning — An attack where malicious clients submit corrupted updates to degrade or manipulate the global model.
  • Byzantine fault tolerance — The ability of an FL system to produce correct results even when some participants are malicious or faulty.
  • Cross-device FL — Federated learning across many mobile devices (millions of clients, heterogeneous, unreliable).
  • Cross-silo FL — Federated learning across a small number of organizations (hospitals, banks), each with large datasets.

Understanding

The intuition behind federated learning: instead of asking a thousand hospitals to share patient records (legally and ethically fraught), you send a copy of the model to each hospital, each hospital trains it on their local patients, and each sends back only the model's updated weights. The server averages all these updates to produce a better global model. No patient record ever left any hospital.

The FedAvg algorithm: 1. Server initializes global model weights w_0 2. Server sends w_t to a subset of K clients 3. Each client k trains on local data for E epochs → produces w_{t+1}^k 4. Server aggregates: w_{t+1} = Σ (n_k/N) · w_{t+1}^k (weighted average by dataset size) 5. Repeat for T rounds

The non-IID problem is FL's central challenge. If a smartphone's photos are taken mostly at night, its local model update will be biased toward night photography. Hospital A serves elderly patients; Hospital B serves pediatric patients. Their local models will diverge, and naive averaging may harm performance on each individual distribution.

Communication bottleneck: Each round requires clients to transmit model weights — potentially hundreds of MB for large models. With limited bandwidth (mobile devices, rural hospitals), this is a critical constraint. Solutions include gradient sparsification (transmit only the largest gradients), quantization (reduce precision), and local distillation (compress knowledge before transmission).

Privacy guarantees: Even without sharing raw data, shared gradients can reveal information. Gradient inversion attacks can reconstruct training images from gradients. Differential privacy adds calibrated noise to updates: the server never learns more than (ε, δ)-DP bounds about any individual's data.

Applying

Federated learning simulation with Flower (flwr):

<syntaxhighlight lang="python"> import flwr as fl import torch import torch.nn as nn from torch.utils.data import DataLoader

class SimpleModel(nn.Module):

   def __init__(self):
       super().__init__()
       self.net = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))
   def forward(self, x):
       return self.net(x)

class FlowerClient(fl.client.NumPyClient):

   def __init__(self, model, trainloader, valloader):
       self.model = model
       self.trainloader = trainloader
       self.valloader = valloader
       self.optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
       self.criterion = nn.CrossEntropyLoss()
   def get_parameters(self, config):
       return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
   def set_parameters(self, parameters):
       params_dict = zip(self.model.state_dict().keys(), parameters)
       state_dict = {k: torch.tensor(v) for k, v in params_dict}
       self.model.load_state_dict(state_dict, strict=True)
   def fit(self, parameters, config):
       self.set_parameters(parameters)
       # Local training for E=5 epochs
       for _ in range(5):
           for X, y in self.trainloader:
               self.optimizer.zero_grad()
               self.criterion(self.model(X), y).backward()
               self.optimizer.step()
       return self.get_parameters(config={}), len(self.trainloader.dataset), {}
   def evaluate(self, parameters, config):
       self.set_parameters(parameters)
       # ... evaluate on local val set
       return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
  1. Start simulation

fl.simulation.start_simulation(

   client_fn=lambda cid: FlowerClient(...),
   num_clients=10,
   config=fl.server.ServerConfig(num_rounds=20),
   strategy=fl.server.strategy.FedAvg(
       fraction_fit=0.5,        # 50% of clients per round
       min_fit_clients=3,
       min_available_clients=5,
   ),

) </syntaxhighlight>

Federated learning framework comparison
Flower (flwr) → Most flexible; supports any ML framework; great for research and production
PySyft (OpenMined) → Focus on privacy; strong differential privacy and secure aggregation
TensorFlow Federated (TFF) → Google's FL framework; tight TF integration; cross-device simulation
FATE → Enterprise cross-silo FL; strong support for healthcare and finance use cases

Analyzing

Cross-device vs Cross-silo FL
Aspect Cross-device (Smartphones) Cross-silo (Hospitals/Banks)
Number of clients Millions Tens to hundreds
Data per client Small, highly non-IID Large, moderately non-IID
Participation Intermittent, unreliable Reliable, scheduled
Communication Limited bandwidth (mobile) High bandwidth (datacenter)
Trust model Untrusted clients Semi-trusted institutions
Deployment Google Keyboard, Apple Siri Medical research consortia, banking

Key challenges and failure modes:

  • Non-IID degradation — With highly heterogeneous client data, FedAvg can converge to a poor global model or oscillate. FedProx (adds proximal term to keep local models close to global) and SCAFFOLD (variance reduction) address this.
  • Client stragglers — If the server waits for slow clients, training is bottlenecked. Asynchronous FL or ignoring slow clients introduces gradient staleness.
  • Model poisoning and backdoor attacks — Malicious clients inject backdoors (e.g., "if input contains a specific trigger, classify as target class") that survive averaging. Defense: Krum, median/trimmed-mean aggregation, anomaly detection on updates.
  • Gradient inversion — Large batch gradients can be used to reconstruct training data. Mitigation: secure aggregation (cryptographic) + differential privacy (statistical).
  • Concept drift — Client data distributions change over time; the global model becomes stale for some clients. Personalization and continual learning help.

Evaluating

Expert FL evaluation extends beyond model accuracy:

Privacy-utility trade-off: Differential privacy adds noise that degrades model accuracy. Practitioners track the Pareto frontier between ε (privacy budget) and model utility. Lower ε = stronger privacy = more noise = worse model. The right operating point depends on the use case's sensitivity.

Communication efficiency metrics: Total bytes transmitted per round × number of rounds. Compare baseline FedAvg vs. compressed (quantized, sparse) communication. Target: <1% accuracy drop for >10× communication reduction.

Fairness across clients: Global accuracy can mask poor performance on underrepresented clients. Evaluate accuracy per client or per client cluster and use min-max or worst-case client fairness objectives.

Convergence analysis: Track global validation loss over rounds. FL convergence is slower and noisier than centralized training — more rounds are needed to account for communication overhead and non-IID effects.

Expert practitioners in cross-silo FL conduct institutional audits — verifying that each participating institution's data governance policies are respected and that the FL protocol doesn't violate regulatory requirements (HIPAA, GDPR).

Creating

Designing a production federated learning system:

1. Threat model and privacy specification <syntaxhighlight lang="text"> Who might be adversarial? (honest-but-curious server? malicious clients?)

What level of privacy is required? (DP with ε ≤ 8? Secure aggregation?)

What is the regulatory framework? (HIPAA, GDPR, CCPA?)

Define privacy budget and acceptable accuracy-privacy trade-off </syntaxhighlight>

2. System architecture <syntaxhighlight lang="text"> [Orchestration server]

   ↓ (global model broadcast)

[Client selection algorithm]

   ↓ (based on data quality, staleness, fairness)

[Secure aggregation protocol] ← Clients send encrypted shares

[Differential privacy: add calibrated Gaussian noise to aggregate]

[Global model update]

[Evaluation on held-out global validation set]

[Anomaly detection: flag outlier client updates] </syntaxhighlight>

3. Personalization strategies for non-IID data

  • FedProx: penalize deviation from global model during local training
  • Per-FedAvg: meta-learning approach; global model as a good initialization for local fine-tuning
  • Clustered FL: group clients by data similarity, train separate global models per cluster
  • Local fine-tuning: ship global model → each client fine-tunes for a few steps on local data

4. Infrastructure for cross-silo FL

  • Dedicated FL node at each institution behind their firewall
  • Encrypted communication channels (mTLS)
  • Audit logs for every gradient exchange
  • Governance contract specifying data usage, participation requirements, IP ownership