Другое

Как применить приоритетные веса в PyTorch PER

Руководство по применению приоритетных весов из TensorDictPrioritizedReplayBuffer в PyTorch. Узнайте, автоматическое применение не происходит и как исправить смещение PER.

При использовании TensorDictPrioritizedReplayBuffer из PyTorch, нужно ли вручную применять веса приоритетов к функции потерь, или они автоматически учитываются? В PER (Prioritized Experience Replay) мы используем параметр Beta для вычисления весов, которые компенсируют смещение, вызванное PER. В TensorDictPrioritizedReplayBuffer эти веса, как кажется, хранятся в поле _weight выходного TensorDict. При использовании функций потерь PyTorch, таких как DDPGLoss, они автоматически учитывают поле _weight, если в TensorDict его присутствует? Если нет, как следует вручную применять веса приоритетов к расчёту потерь?

С TensorDictPrioritizedReplayBuffer приоритетные веса не применяются автоматически к функциям потерь, таким как DDPGLoss. Поле _weight, содержащее приоритетные веса, автоматически включается в сэмпл TensorDict, но их необходимо вручную умножить на значение потерь, чтобы корректно компенсировать смещение, вызванное Prioritized Experience Replay (PER).

Содержание

Как работают приоритетные веса в TensorDictPrioritizedReplayBuffer

TensorDictPrioritizedReplayBuffer автоматически вычисляет и сохраняет приоритетные веса в поле _weight объектов TensorDict, которые возвращаются при сэмплинге. Согласно документации PyTorch RL, при сэмплинге из буфера возвращаемый TensorDict содержит:

python
TensorDict(
    fields={
        _weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)
            },
            batch_size=torch.Size([5]), device=cpu, is_shared=False
        ),
        index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)
    },
    batch_size=torch.Size([5]), device=cpu, is_shared=False
)

Поле _weight содержит приоритетные веса, вычисленные с использованием параметра beta, который вы указываете при создании буфера. Эти веса предназначены для компенсации смещения, вызванного приоритетным сэмплингом.

Ручное применение приоритетных весов

Поскольку функции потерь, такие как DDPGLoss, не используют поле _weight автоматически, вам необходимо вручную применить эти веса к вычислению потерь. Вот как это сделать:

Шаг 1: Сэмплирование из буфера

python
from torchrl.data import TensorDictPrioritizedReplayBuffer

# Создание буфера с параметрами alpha и beta
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(1000), batch_size=32)

# Сэмплирование из буфера
sample = rb.sample(32)

Шаг 2: Вычисление потерь с приоритетными весами

python
from torchrl.objectives import DDPGLoss

# Инициализация функции потерь DDPG
loss_fn = DDPGLoss()

# Вычисление потерь
losses = loss_fn(sample)

# Извлечение приоритетных весов
priority_weights = sample.get("_weight")

# Применение весов к потере
weighted_loss = losses["loss_value"] * priority_weights

Шаг 3: Обратное распространение взвешенной потери

python
optimizer.zero_grad()
weighted_loss.mean().backward()
optimizer.step()

Примеры реализации

Ниже приведён полный пример, демонстрирующий правильную интеграцию:

python
import torch
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from torchrl.objectives import DDPGLoss

# Инициализация буфера воспроизведения
rb = TensorDictPrioritizedReplayBuffer(
    alpha=0.6,  # Сила приоритезации
    beta=0.4,   # Коррекция смещения при важностном сэмплинге
    storage=LazyTensorStorage(10000),
    batch_size=64,
    priority_key="td_error"  # Ключ для TD‑ошибок
)

# Инициализация функции потерь DDPG
loss_fn = DDPGLoss()

# Цикл обучения
for epoch in range(num_epochs):
    # Сэмплирование из буфера воспроизведения
    batch = rb.sample(64)
    
    # Вычисление потерь
    losses = loss_fn(batch)
    
    # Извлечение приоритетных весов
    weights = batch.get("_weight")
    
    # Применение весов важностного сэмплинга
    weighted_loss = losses["loss_value"] * weights
    
    # Обратное распространение
    optimizer.zero_grad()
    weighted_loss.mean().backward()
    optimizer.step()
    
    # Обновление приоритетов (см. следующий раздел)
    update_priorities(rb, batch)

Автоматическое обновление приоритетов

TensorDictPrioritizedReplayBuffer упрощает процесс обновления приоритетов. Согласно документации PyTorch, функция DDPGLoss автоматически записывает ключ "td_error", который можно использовать буферами с приоритетом:

python
# Это обновит приоритеты на основе TD‑ошибок
rb.update_tensordict_priority(sample)

Метод update_tensordict_priority использует поле "td_error" из TensorDict для обновления приоритетов выбранных опытов. Это создаёт бесшовный рабочий процесс:

  1. Вы сэмплируете опыты с приоритетными весами
  2. Вы вычисляете потери и распространяете взвешенную потерю
  3. Вы обновляете приоритеты на основе новых TD‑ошибок

Лучшие практики

Анимация параметра beta

Постепенно увеличивайте beta от 0.4 до 1.0 в течение обучения, чтобы перейти от важностного сэмплинга к равномерному сэмплингу:

python
def get_beta(epoch, total_epochs, start_beta=0.4, end_beta=1.0):
    return start_beta + (end_beta - start_beta) * (epoch / total_epochs)

Обрезка весов

Обрежьте экстремальные веса, чтобы избежать численной нестабильности:

python
weights = torch.clamp(weights, 0.1, 10.0)

Многошаговые TD‑ошибки

Для более стабильного обновления приоритетов рассмотрите использование многошаговых TD‑ошибок:

python
# Сохраняйте n‑шаговые возвраты в буфере воспроизведения
# Используйте их для обновления приоритетов вместо 1‑шаговых TD‑ошибок

Отладка приоритетных весов

Отслеживайте приоритетные веса, чтобы убедиться, что они работают корректно:

python
# Печать статистики приоритетных весов
print(f"Priority weights - mean: {weights.mean():.3f}, std: {weights.std():.3f}, min: {weights.min():.3f}, max: {weights.max():.3f}")

Показатели производительности

  • Поле _weight автоматически включается во все сэмплы из TensorDictPrioritizedReplayBuffer
  • Вы не обязаны вручную вычислять веса важностного сэмплинга — они обрабатываются буфером
  • Метод update_tensordict_priority эффективно обновляет приоритеты, используя сохранённые индексы

Заключение

Приоритетные веса в TensorDictPrioritizedReplayBuffer не применяются автоматически к функциям потерь. Необходимо вручную умножить значения потерь на поле _weight из сэмплённого TensorDict, чтобы корректно компенсировать смещение PER. Буфер автоматически обрабатывает вычисление и хранение приоритетных весов, что делает реализацию простым, если вы знаете, что ручное применение необходимо.

Ключевые выводы:

  1. Всегда извлекайте _weight из сэмплённого TensorDict
  2. Умножайте ваши потери на эти веса перед обратным распространением
  3. Используйте update_tensordict_priority() для обновления приоритетов на основе TD‑ошибок
  4. Реализуйте анимацию beta для лучшей стабильности обучения
  5. Мониторьте статистику приоритетных весов, чтобы убедиться в корректной работе

Эта ручная стратегия даёт гибкость в применении важностного сэмплинга при сохранении удобства автоматического управления приоритетами, предоставляемого библиотекой PyTorch RL.

Источники

  1. PyTorch RL – TensorDictPrioritizedReplayBuffer Documentation
  2. Stack Overflow – Priority weights in TensorDictPrioritizedReplayBuffer
  3. PyTorch RL – DDPGLoss Documentation
  4. PyTorch RL – PrioritizedReplayBuffer Documentation
  5. GitHub – PyTorch RL Implementation
Авторы
Проверено модерацией
Модерация