Другое

Как остановить коллапс эмбеддингов SimSiam ViT на CUB-200

Узнайте причины коллапса эмбеддингов SimSiam при ViT на CUB-200 и получите проверенные решения: аугментации, модификации архитектуры и альтернативные подходы.

Почему эмбеддинги SimSiam с ViT на датасете CUB‑200‑2011 коллапсируют?

Я реализую SimSiam с Vision Transformer (ViT) в качестве backbone на датасете CUB‑200‑2011, но во время обучения эмбеддинги коллапсируют в одну сторону, несмотря на использование stop‑gradient. Вот что я наблюдаю в первых нескольких эпохах:

Epoch 0:

  • Loss = -0.12 | Collapse Level: 0.46 / 1.00
  • Cosine similarity (off‑diagonal): mean=0.035835, std=0.318266, min=-0.780536, max=0.997578
  • Top 10 eigenvalues: [51.52014, 10.083374, 7.2546287, 5.572749, 4.3434677, 3.533019, 3.0718656, 2.5875258, 2.0254238, 1.9101429]
  • Embedding metrics: N=5794, D=128, norm_mean=11.402124404907227, norm_std=2.796746253967285, norm_min=6.187736511230469, norm_max=22.01985740661621
  • Recall: 0.3715912997722626 | Recall_b: 0.6962375044822693

Epoch 2:

  • Loss = -0.91 | Collapse Level: 0.84 / 1.00
  • Cosine similarity (off‑diagonal): mean=1.000000, std=0.000006, min=0.999890, max=1.000000
  • Top 10 eigenvalues: [1.6440651e+02, 1.3151270e-01, 8.6707681e-02, 6.4878970e-02, 5.0928112e-02, 3.0504635e-02, 1.9978724e-02, 1.4542857e-02, 7.8499522e-03, 6.8454165e-03]
  • Embedding metrics: N=5794, D=128, norm_mean=515.996826171875, norm_std=12.814229965209961, norm_min=489.2132263183594, norm_max=591.9946899414062
  • Recall: 0.005177770275622606 | Recall_b: 0.012599240988492966

Implementation Details

Data Augmentations

python
transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.2, 1.)),
    T.RandomApply([
        T.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
    ], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.5),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_transform = lambda x: (transform(x), transform(x))

SimSiam Implementation

python
class SimSiam(nn.Module):
    def __init__(self, encoder, head_dim=128, predictor_hidden=64):
        super().__init__()
        self.encoder = encoder
        self.head_dim = head_dim
        
        # Collapse avoidance requires a non-trivial projector
        prev_dim = self.encoder.model.backbone.num_features
        self.encoder.model.head = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # first layer
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # second layer
            nn.Linear(prev_dim, head_dim, bias=True),
            nn.BatchNorm1d(head_dim, affine=False)  # output layer
        )
        
        self.encoder.model.head[6].bias.requires_grad = False  # hack: not use bias as it is followed by BN

        self.predictor = EuclideanPredictor(dim=head_dim, hidden_dim=predictor_hidden)
        
        # ensure predictor on same device as encoder
        try:
            self.device = next(self.encoder.parameters()).device
        except StopIteration:
            self.device = torch.device('cpu')
        self.predictor.to(self.device)

    def forward(self, x1, x2):
        # Projector outputs in hyperbolic/spherical space
        z1 = self.encoder(x1) 
        z2 = self.encoder(x2)
        
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        
        # Predictor (symmetric)
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        return p1, p2, z1.detach(), z2.detach()

    def loss(self, p1, p2, z1, z2):
        # safety: ensure targets are stop-grad
        z1 = z1.detach()
        z2 = z2.detach()
        
        assert (not z1.requires_grad) and (not z2.requires_grad), "Stop gradient is not implemented!"

        sign = -1.0
        # Explicitly normalization
        p1n = F.normalize(p1, dim=1)
        p2n = F.normalize(p2, dim=1)
        z1n = F.normalize(z1, dim=1)
        z2n = F.normalize(z2, dim=1)
        criterion = nn.CosineSimilarity(dim=1)
        
        simsiam_loss = sign * 0.5 * (criterion(p1n, z2n).mean() + criterion(p2n, z1n).mean())
        return simsiam_loss

Encoder Setup

python
backbone = timm.create_model(encoder_version, pretrained=True, cache_dir="./cached_models/") + NormLayer

Troubleshooting Steps Already Taken

  1. Verified that stop-gradient is working (p1.requires_grad=True, z1.requires_grad=False)
  2. Confirmed predictor output matches paper recommendations (no BN/ReLU on output)
  3. Checked that gradient norms are reasonable (except occasionally spiking)

Вопрос

Почему эмбеддинги коллапсируют в моей реализации SimSiam? Может ли это быть связано с датасетом CUB‑200‑2011 или стратегией аугментации? Я проверил аугментации, и они действительно создают разные представления одного изображения.

Понимание феномена коллапса

Согласно оригинальному исследованию SimSiam, коллапсирующие решения действительно существуют для данной функции потерь и структуры, однако операция stop‑gradient играет ключевую роль в их предотвращении. Тем не менее, ваш случай показывает, что stop‑gradient сам по себе недостаточен на датасете CUB‑200‑2011.

Коллапс, который вы наблюдаете, характеризуется:

  • Сходимость косинусных сходств: сходства по диагонали приближаются к 1.0 (идеальная похожесть)
  • Доминирование собственных значений: первое собственное значение захватывает большую часть дисперсии (164.4 против 0.131 для второго)
  • Улучшение потерь без обучения: потери улучшаются (‑0.12 → ‑0.91), но recall резко падает

Это указывает на то, что модель находит тривиальное решение, при котором все эмбеддинги указывают в одном направлении, не предоставляя различительной силы.

Проблемы датасета CUB‑200‑2011

Датасет CUB‑200‑2011 предъявляет уникальные требования, которые усиливают коллапс:

  1. Финогранулярная природа: 200 видов птиц, которые визуально очень похожи, требуют обучения тонким различиям
  2. Малая внутриклассовая вариация: изображения одного вида имеют очень похожие признаки
  3. Ограниченный размер датасета: 6 000 обучающих изображений (в отличие от 1.2 М ImageNet) затрудняет обучение разнообразных представлений

Как отмечено в исследованиях по финогранулярной классификации, этот датасет требует более сложного обучения признаков, чем стандартные наборы данных.

Критические проблемы в вашей реализации

1. Недостаточная разнообразность аугментаций

Текущая стратегия аугментации не обеспечивает достаточного разнообразия для финогранулярного обучения:

python
transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.2, 1.)),
    T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.5),
    T.RandomHorizontalFlip(),
    # Отсутствует: более сильные аугментации, необходимые для финогранулярного обучения
])

2. Проблемы с архитектурой предиктора

Ваш предиктор может быть слишком простым или неправильно инициализированным:

python
class EuclideanPredictor(nn.Module):
    # Возможно, потребуется более сложная архитектура
    # или другая стратегия инициализации

3. Нестабильность скорости обучения

Быстрый коллапс указывает на проблемы с learning‑rate, которые не фиксируются вашими проверками нормы градиентов.

Проверенные решения для предотвращения коллапса

1. Усиление стратегии аугментаций

Согласно best practices SimSiam, необходимо более агрессивные аугментации:

python
transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.08, 1.)),  # Меньший масштаб для большего разнообразия
    T.RandomApply([
        T.ColorJitter(0.8, 0.8, 0.8, 0.2)  # Усиленный цветовой джиттер
    ], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.8),  # Увеличена вероятность
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # Добавлены геометрические трансформации
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

2. Реализация правильного отрицательного сэмплинга

В отличие от оригинального SimSiam, исследования показывают, что на сложных датасетах, таких как CUB‑200‑2011, некоторая форма отрицательного сэмплинга может предотвратить коллапс:

python
class ImprovedSimSiam(nn.Module):
    def __init__(self, encoder, head_dim=128, predictor_hidden=64, num_negatives=32):
        super().__init__()
        # ... существующий код ...
        self.num_negatives = num_negatives
        
    def loss(self, p1, p2, z1, z2):
        # Оригинальная функция потерь SimSiam
        original_loss = self.original_loss_function(p1, p2, z1, z2)
        
        # Добавляем отрицательный сэмплинг для CUB‑200‑2011
        batch_size = z1.shape[0]
        # Сэмплируем отрицательные примеры из других батчей
        negatives = self.sample_negatives(z1, z2)
        negative_loss = self.compute_negative_loss(p1, p2, negatives)
        
        return original_loss + 0.1 * negative_loss  # Небольшой вес для отрицательных

3. Корректировка расписания learning‑rate

Используйте warmup и cosine annealing:

python
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

Рекомендованные модификации архитектуры

1. Улучшение архитектуры проекторa

Согласно исследованиям по transfer learning, вашему проектору нужна большая емкость:

python
self.encoder.model.head = nn.Sequential(
    nn.Linear(prev_dim, prev_dim * 2, bias=False),
    nn.BatchNorm1d(prev_dim * 2),
    nn.ReLU(inplace=True),
    nn.Linear(prev_dim * 2, prev_dim * 2, bias=False),
    nn.BatchNorm1d(prev_dim * 2),
    nn.ReLU(inplace=True),
    nn.Linear(prev_dim * 2, head_dim, bias=True),
    nn.BatchNorm1d(head_dim, affine=False)
)

2. Изменение структуры предиктора

Сделайте предиктор асимметричным и более сложным:

python
class ImprovedPredictor(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, dim, bias=False)  # Нет BN/ReLU на выходе
        )
        
    def forward(self, x):
        return self.net(x)

Корректировки стратегии обучения

1. Используйте больший размер батча

Маленькие батчи могут вызывать нестабильность. Попробуйте размеры 256‑512:

python
# В вашем загрузчике данных
batch_size = 256  # Вместо текущего меньшего размера

2. Реализуйте обрезку градиентов

python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

3. Мониторинг разнообразия эмбеддингов

Добавьте эти метрики в мониторинг:

python
def compute_embedding_diversity(embeddings):
    # Вычисляем попарные сходства
    similarities = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
    
    # Убираем диагональ (само‑сходства)
    off_diag_similarities = similarities[~torch.eye(similarities.shape[0], dtype=torch.bool)]
    
    # Вычисляем статистику
    mean_sim = off_diag_similarities.mean()
    std_sim = off_diag_similarities.std()
    
    # Вычисляем отношение собственных значений (индикатор коллапса)
    cov_matrix = torch.cov(embeddings.T)
    eigenvalues = torch.linalg.eigvals(cov_matrix).real
    eigenvalues = eigenvalues[eigenvalues > 1e-6]  # Убираем почти нулевые значения
    if len(eigenvalues) > 1:
        eigenvalue_ratio = eigenvalues[0] / eigenvalues[1]
    else:
        eigenvalue_ratio = float('inf')
    
    return {
        'mean_similarity': mean_sim.item(),
        'std_similarity': std_sim.item(),
        'eigenvalue_ratio': eigenvalue_ratio.item(),
        'collapse_level': min(1.0, eigenvalue_ratio / 100.0)
    }

Альтернативные подходы

Если SimSiam продолжает коллапсировать на CUB‑200‑2011, рассмотрите следующие варианты:

1. Используйте MoCo или BYOL

Эти методы доказали свою эффективность на финогранулярных датасетах:

python
# Пример реализации MoCo
class MoCo(nn.Module):
    def __init__(self, dim=128, K=4096, m=0.999, T=0.07):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T
        
        # Создаём энкодеры
        self.encoder_q = self.create_encoder()
        self.encoder_k = self.create_encoder()
        
        # Инициализируем ключевой энкодер
        for param in self.encoder_k.parameters():
            param.requires_grad = False
    
    def forward(self, q, k):
        q = self.encoder_q(q)
        k = self.encoder_k(k)
        k = k.detach()
        
        # Вычисляем логиты
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T
        
        return logits

2. Супервизированная дообучаемость

Поскольку CUB‑200‑2011 имеет метки, рассмотрите полу‑супервизированные подходы:

python
class SemiSupervisedSimSiam(nn.Module):
    def __init__(self, encoder, num_classes=200):
        super().__init__()
        self.encoder = encoder
        self.sim_siam = SimSiam(encoder)
        self.classifier = nn.Linear(encoder.model.backbone.num_features, num_classes)
        
    def forward(self, x, labels=None):
        # Потеря само‑обучения
        p1, p2, z1, z2 = self.sim_siam(x, x)
        ssl_loss = self.sim_siam.loss(p1, p2, z1, z2)
        
        if labels is not None:
            # Супервизированная потеря
            features = self.encoder(x)
            cls_loss = F.cross_entropy(self.classifier(features), labels)
            total_loss = ssl_loss + 0.1 * cls_loss
            return total_loss, features
        else:
            return ssl_loss, z1

Итоги рекомендаций

  1. Аугментации: Усильте данные более разнообразными трансформациями
  2. Архитектура: Увеличьте емкость проекторa и улучшите дизайн предиктора
  3. Обучение: Используйте большие батчи, корректное расписание learning‑rate и обрезку градиентов
  4. Мониторинг: Отслеживайте метрики разнообразия эмбеддингов более подробно
  5. Альтернативы: Рассмотрите MoCo/BYOL или полу‑супервизированные подходы, если SimSiam продолжает коллапсировать

Коллапс, который вы наблюдаете, особенно труден на CUB‑200‑2011 из‑за его финогранулярной природы. Внедряйте эти изменения последовательно и внимательно отслеживайте влияние каждого шага на метрики коллапса.

Авторы
Проверено модерацией
Модерация