Transfer Learning
How to read this page: This article maps the topic from beginner to expert across six levels � Remembering, Understanding, Applying, Analyzing, Evaluating, and Creating. Scan the headings to see the full scope, then read from wherever your knowledge starts to feel uncertain. Learn more about how BloomWiki works ?
Transfer learning is a machine learning technique in which knowledge gained while training a model on one task or dataset is reused as the starting point for a model on a different, related task. Rather than training from scratch — which requires vast data and compute — transfer learning allows practitioners to leverage models pre-trained on large datasets, then adapt them to specific domains or tasks with far less data and time. Transfer learning is one of the most practically impactful ideas in modern AI, enabling high-performance models in domains where labeled data is scarce.
Remembering[edit]
- Pre-trained model — A model trained on a large, general dataset (e.g., ImageNet for vision, Wikipedia for NLP) whose weights serve as a starting point for a new task.
- Fine-tuning — The process of continuing training of a pre-trained model on task-specific data, adjusting weights to specialize behavior.
- Feature extraction — Using a pre-trained model as a fixed feature extractor: freeze all its weights and add a new trainable head for the new task.
- Domain adaptation — Adapting a model trained on one data distribution (source domain) to perform well on a different but related distribution (target domain).
- Source domain — The original domain on which the model was pre-trained.
- Target domain — The new domain or task to which you are transferring knowledge.
- Domain shift — The difference in statistical distribution between source and target domains.
- Frozen layers — Layers of a pre-trained model whose weights are not updated during fine-tuning.
- Trainable layers — Layers whose weights are updated during fine-tuning (typically the head and possibly last few blocks).
- Head — The task-specific output layer(s) added on top of a pre-trained backbone for a new task.
- Backbone — The main body of a pre-trained model (the feature extractor), as opposed to the task-specific head.
- ImageNet — A 1.2-million-image classification dataset; models pre-trained on ImageNet are the standard starting point for most computer vision tasks.
- BERT — A pre-trained transformer encoder; the standard starting point for many NLP fine-tuning tasks.
- Domain-adaptive pre-training — Additional pre-training on in-domain unlabeled data before task-specific fine-tuning.
- Zero-shot transfer — Applying a model trained on one task directly to a new task without any task-specific training (e.g., CLIP for zero-shot image classification).
Understanding[edit]
Transfer learning rests on a fundamental insight: lower layers of neural networks learn general features that are useful across many tasks, while higher layers learn task-specific features.
In a CNN trained on ImageNet:
- Early layers detect edges, colors, and simple textures — useful for any image task
- Middle layers detect shapes, patterns, and object parts
- Later layers detect high-level semantic features specific to ImageNet classes
When you transfer this model to a medical imaging task, the early and middle layer features are still useful (edges, textures, shapes are relevant in X-rays too), and only the final layers need to be adapted to the new task.
Why not always train from scratch? Three reasons: 1. Data efficiency: You may only have 500 labeled medical images, not enough to train a good model from scratch. Starting from a pre-trained model gives you millions of examples worth of feature learning for free. 2. Compute efficiency: Pre-training ImageNet takes weeks on many GPUs. Fine-tuning takes minutes to hours. 3. Better generalization: Pre-trained features are often more robust and generalizable than features learned from a small dataset.
When does transfer learning work best? When source and target domains share underlying structure. A model pre-trained on natural photos transfers well to satellite imagery (both are images), but transfers poorly to audio spectrograms (very different structure). The more similar the domains, the more layers you can freeze and the less fine-tuning data you need.
Zero-shot transfer is the most powerful form: a model like CLIP or GPT-4 trained on massive diverse data can perform tasks at inference time that it was never explicitly trained on — by virtue of having learned general-purpose representations.
Applying[edit]
Fine-tuning a pre-trained ResNet for a custom image classification task:
<syntaxhighlight lang="python"> import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder
- Load pre-trained ResNet-50
model = models.resnet50(weights='IMAGENET1K_V2')
- === Strategy 1: Feature Extraction (small dataset <1k images) ===
- Freeze ALL pre-trained weights
for param in model.parameters():
param.requires_grad = False
- Replace final layer with new head for 5 classes
model.fc = nn.Linear(model.fc.in_features, 5)
- Only model.fc parameters are trainable
- === Strategy 2: Fine-tuning (larger dataset 1k+ images) ===
- Unfreeze last 2 residual blocks
for name, param in model.named_parameters():
if 'layer4' in name or 'layer3' in name or 'fc' in name:
param.requires_grad = True
else:
param.requires_grad = False
- Use different learning rates: lower for frozen/pre-trained, higher for new head
optimizer = torch.optim.Adam([
{'params': model.layer3.parameters(), 'lr': 1e-5}, # small LR
{'params': model.layer4.parameters(), 'lr': 1e-4}, # medium LR
{'params': model.fc.parameters(), 'lr': 1e-3}, # large LR
])
- === Preprocessing must match pre-training ===
transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet stats
])
dataset = ImageFolder("data/my_dataset/", transform=transform) loader = DataLoader(dataset, batch_size=32, shuffle=True) </syntaxhighlight>
- Strategy selection guide
- <100 examples per class → Feature extraction only; freeze all backbone layers; only train head
- 100–1000 examples → Unfreeze last 1–2 blocks + head; use low LR for backbone
- 1000–10k examples → Fine-tune from last half of backbone; discriminative LRs
- >10k examples → Full fine-tuning or even train from scratch if very different domain
- No labeled data → Zero-shot (CLIP, GPT) or self-supervised domain adaptation
Analyzing[edit]
| Strategy | Labeled Data Needed | Compute | Risk of Overfitting | Flexibility |
|---|---|---|---|---|
| Feature extraction | Very small (<500) | Very low | Very low | Low (head only) |
| Partial fine-tuning | Small (500–5k) | Low | Low | Medium |
| Full fine-tuning | Medium (5k+) | Medium | Medium | High |
| Train from scratch | Large (100k+) | Very high | Low (with enough data) | Maximum |
| Zero-shot transfer | None | None (inference only) | N/A | Moderate |
Failure modes and pitfalls:
- Negative transfer — When source and target domains are too dissimilar, pre-trained features actually hurt performance. Example: NLP models transfer poorly to genomic sequences; better to use domain-specific models.
- Catastrophic forgetting — Full fine-tuning on a small dataset can cause the model to "forget" general pre-trained knowledge. Mitigated by lower learning rates, fewer epochs, and progressive unfreezing.
- Data preprocessing mismatch — Pre-trained models expect a specific normalization. Using wrong mean/std values causes significant performance degradation even with correct weights.
- Label distribution shift — If the pre-training task had very different class balance than the target task, the model's feature priorities may be poorly aligned.
- Overconfident transfer — Assuming a pre-trained model from a similar domain will work without validation. Always run a baseline evaluation on target domain before assuming transferability.
Evaluating[edit]
Expert transfer learning practitioners evaluate along a specific set of dimensions:
Transfer gain: Compare fine-tuned model vs. training-from-scratch baseline on the same target data. The transfer gain is the performance improvement attributed to the pre-trained initialization. If fine-tuning doesn't beat scratch after enough epochs, transfer learning may not be helping.
Few-shot evaluation curves: Plot performance as a function of available labeled data. Transfer learning should show much better performance at low data regimes, converging toward the from-scratch baseline as data increases.
Feature analysis: Using linear probing — train only a linear classifier on frozen pre-trained features — to measure how much task-relevant information is already encoded in the pre-trained representations before any fine-tuning.
Domain proximity metrics: Measure the distribution distance (FID, A-distance, or H-divergence) between source and target domains. Higher distance predicts lower transfer benefit and may suggest negative transfer risk.
Expert practitioners also perform layer-wise learning rate sweeps to identify which layers benefit most from fine-tuning, rather than applying a blanket strategy.
Creating[edit]
Designing a transfer learning pipeline for a new domain:
1. Source model selection <syntaxhighlight lang="text"> Task modality: ├── Images → ImageNet pre-trained (ResNet, EfficientNet, ViT) │ ├── Medical images → MedCLIP, RadDINO, BioViL │ ├── Satellite images → SatMAE, Prithvi │ └── Microscopy → CellViT, BioCLIP ├── Text → BERT, RoBERTa, DeBERTa (encoder); LLaMA (decoder) │ ├── Scientific text → SciBERT, BioMedBERT │ ├── Legal text → LegalBERT │ └── Code → CodeBERT, StarCoder └── Multimodal → CLIP, Flamingo, LLaVA </syntaxhighlight>
2. Progressive fine-tuning schedule <syntaxhighlight lang="text"> [Optional] Domain-adaptive pre-training (DAP):
- Pre-train on in-domain unlabeled text/images (MLM or MAE) - No labels required; leverages domain corpus ↓
Phase 1: Feature extraction (head only, 5–10 epochs)
- Verify features are useful for target task - Establishes good head initialization ↓
Phase 2: Gradual unfreezing
- Unfreeze last block → train 5 epochs - Unfreeze next block → train 5 epochs - Continue until performance plateaus ↓
Phase 3: Full fine-tuning (optional, if data permits)
- Very low learning rate (1e-5) - Early stopping on validation metric
</syntaxhighlight>
3. Preventing catastrophic forgetting
- Elastic Weight Consolidation (EWC): penalize updates to weights important for source task
- Learning Without Forgetting (LwF): use original model as soft-label teacher
- For LLMs: KL penalty on fine-tuned vs. original model's output distribution (the RLHF approach)
- Replay: mix small fraction of source-domain data into fine-tuning batches