# MoCo Implementation in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
"""
MoCo: Momentum Contrast for Unsupervised Visual Representation Learning
Args:
base_encoder: backbone CNN architecture (ResNet)
dim: feature dimension for contrastive learning
K: queue size (number of negative samples)
m: momentum coefficient for key encoder update
T: temperature parameter for InfoNCE loss
"""
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
# Create query and key encoders
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# Initialize key encoder parameters with query encoder
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)= False # Key encoder is not updated by gradient
param_k.requires_grad
# Create the queue for storing keys
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
# Queue pointer for circular buffer
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
= param_k.data * self.m + param_q.data * (1. - self.m)
param_k.data
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""
Update the queue by dequeuing old keys and enqueuing new keys
"""
= keys.shape[0]
batch_size = int(self.queue_ptr)
ptr assert self.K % batch_size == 0
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
= (ptr + batch_size) % self.K # Move pointer
ptr self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_ddp(self, x):
"""
Batch shuffle for distributed training
This prevents information leakage between query and key encoders
"""
= x.shape[0]
batch_size_this = concat_all_gather(x)
x_gather = x_gather.shape[0]
batch_size_all = batch_size_all // batch_size_this
num_gpus = torch.randperm(batch_size_all).cuda()
idx_shuffle =0)
torch.distributed.broadcast(idx_shuffle, src= torch.argsort(idx_shuffle)
idx_unshuffle = torch.distributed.get_rank()
gpu_idx = idx_shuffle.view(num_gpus, -1)[gpu_idx]
idx_this return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
"""
Undo batch shuffle for distributed training
"""
= x.shape[0]
batch_size_this = concat_all_gather(x)
x_gather = x_gather.shape[0]
batch_size_all = batch_size_all // batch_size_this
num_gpus = torch.distributed.get_rank()
gpu_idx = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
idx_this return x_gather[idx_this]
def forward(self, im_q, im_k):
"""
Forward pass for MoCo
Args:
im_q: query images
im_k: key images
Returns:
logits: logits for InfoNCE loss
labels: labels for InfoNCE loss
"""
# Compute query features
= self.encoder_q(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1)
q
# Compute key features
with torch.no_grad(): # No gradient for key encoder
self._momentum_update_key_encoder() # Update key encoder
= self._batch_shuffle_ddp(im_k)
im_k, idx_unshuffle = self.encoder_k(im_k) # keys: NxC
k = nn.functional.normalize(k, dim=1)
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
k
# Compute logits
= torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
l_pos = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
l_neg = torch.cat([l_pos, l_neg], dim=1)
logits /= self.T
logits = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
labels self._dequeue_and_enqueue(k)
return logits, labels
# Utility function for distributed training
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors
"""
= [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
tensors_gather =False)
torch.distributed.all_gather(tensors_gather, tensor, async_op= torch.cat(tensors_gather, dim=0)
output return output
Momentum Contrast (MoCo)
Introduction
MoCo (Momentum Contrast) is a self-supervised learning framework introduced by Facebook AI Research in 2019. It addresses the fundamental challenge of learning visual representations without labeled data by treating contrastive learning as a dictionary look-up problem.
The key insight behind MoCo is that contrastive learning can be viewed as training an encoder to perform a dictionary look-up task, where we want to match a query representation with its corresponding key.
Read the original paper: He, Kaiming, et al.
Momentum Contrast for Unsupervised Visual Representation Learning (MoCo) (2019)
arXiv:1911.05722
Core Concepts of MoCo
1. Dictionary Look-up Perspective
MoCo (Momentum Contrast) reframes contrastive learning as a dictionary look-up task:
- Query (q): An encoded representation of an augmented image.
- Positive Key (k⁺): The encoded representation of a different augmentation of the same image as the query.
- Negative Keys (k⁻): Encoded representations of different images.
- Dictionary: A dynamic set of keys (representations) used to compare against the query. The goal is to bring the query closer to its positive key while pushing it away from negative keys.
2. Queue Mechanism
MoCo introduces a queue-based dictionary to efficiently manage a large set of negative samples:
- A FIFO queue stores encoded representations (keys) from previous mini-batches.
- As new keys are added to the queue, the oldest keys are removed.
- This design ensures:
- A large and consistent dictionary size independent of the mini-batch size.
- Better utilization of past samples, improving contrastive learning.
3. Momentum Update
Instead of training both encoders via backpropagation, MoCo stabilizes learning with a momentum update for the key encoder:
Query Encoder (
f_q
): Updated normally using gradient descent.Key Encoder (
f_k
): Updated as an exponential moving average of the query encoder:\[ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q \]
Momentum Coefficient (
m
): Typically set to0.999
, ensuring slow, stable updates.
This strategy helps maintain consistent representations for keys, reducing noise in the contrastive learning process.
MoCo Architecture and Algorithm
Architecture Components
- Query Encoder (
f_q
): A CNN (typically ResNet) that encodes query images. - Key Encoder (
f_k
): A CNN with identical architecture tof_q
, updated via momentum. - Queue: A memory bank storing encoded keys from previous batches.
- Projection Head: An MLP that projects features into a lower-dimensional embedding space.
Training Process
Sample a mini-batch of
N
images.Apply data augmentation to each image to create query and key views.
Encode queries using
f_q
and keys usingf_k
.Compute the contrastive loss between each query and all keys in the queue.
Update the query encoder (
f_q
) via gradient descent.Update the key encoder (
f_k
) via momentum update:\[ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q \]
Update the queue by enqueuing the new keys and dequeuing the oldest keys.
InfoNCE Loss
MoCo uses the InfoNCE (Noise Contrastive Estimation) loss:
\[ \mathcal{L}_q = -\log \left( \frac{\exp(q \cdot k^+ / \tau)}{\sum_i \exp(q \cdot k_i / \tau)} \right) \]
Where:
- \(q\): Query representation
- \(k^+\): Positive key representation
- \(k_i\): All keys in the dictionary (including positive and negatives)
- \(\tau\): Temperature parameter
- \(\cdot\): Dot product (cosine similarity after L2 normalization)
Training Loop Implementation
# Example usage
def main():
import torchvision.models as models
# Create ResNet-50 base encoder
def resnet50_encoder(num_classes=128):
= models.resnet50(pretrained=False)
model = nn.Sequential(
model.fc
nn.Linear(model.fc.in_features, model.fc.in_features),
nn.ReLU(),
nn.Linear(model.fc.in_features, num_classes)
)return model
# Initialize MoCo model
= MoCo(resnet50_encoder, dim=128, K=65536, m=0.999, T=0.07)
model
# Setup optimizer
= torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4)
optimizer
# Training loop
for epoch in range(200):
train_moco(model, train_loader, optimizer, epoch, device)# Add validation and checkpointing as needed
# Data augmentation typically used with MoCo
def get_moco_augmentation():
from torchvision import transforms
# MoCo v1 augmentation
= transforms.Compose([
augmentation 224, scale=(0.2, 1.0)),
transforms.RandomResizedCrop(=0.2),
transforms.RandomGrayscale(p0.4, 0.4, 0.4, 0.4),
transforms.ColorJitter(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Normalize(mean
])return augmentation
MoCo vs SimCLR
Aspect | MoCo | SimCLR |
---|---|---|
1. Dictionary Management | Queue-based dictionary with large, consistent size Independent of batch size Memory efficient |
Uses within-batch negatives only Requires large batch sizes Memory intensive |
2. Encoder Architecture | Two encoders: query (f_q ) and key (f_k )Momentum update: \[\theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q\] Asymmetric design |
Single encoder for all samples Symmetric design No momentum update |
3. Training Dynamics | Stable training with momentum Diverse negatives via queue Robust to batch size |
Requires large batches All samples updated together More sensitive to batch size |
4. Computational Requirements | Lower memory footprint Efficient for small batches Works on modest hardware |
High memory requirements Needs multiple GPUs Heavy batch computations |
5. Augmentation Strategy | Initially simple augmentations MoCo v2 adopts stronger ones Less dependent on augmentation |
Strong augmentations essential Uses heavy transforms (blur, color distortions) Performance depends on augmentation strength |
Summary and Practical Recommendations
When to Choose MoCo:
- Limited computational resources
- Suitable for academic research or prototyping environments
- Preferred when stable training is important
- Flexible with varying batch sizes
- Enables faster experimentation cycles
When to Choose SimCLR:
- Abundant computational resources
- Ideal for production environments with large-scale data
- Needed when maximum performance is a priority
- Well-suited for large-scale industrial applications
- Works best when strong augmentation pipelines are already established
Key Takeaways:
- MoCo democratizes contrastive learning by making it accessible with limited resources
- SimCLR achieves strong performance but requires significant computational investment
- The two methods are complementary, serving different use cases
- MoCo’s queue mechanism is an efficient solution to the negative sampling problem
- SimCLR’s simplicity makes it easier to understand and adapt to specific applications
The choice between MoCo and SimCLR depends on your available resources and performance needs. MoCo strikes a practical balance between efficiency and effectiveness, while SimCLR excels when compute and scale are not limiting factors.