Как применить приоритетные веса в PyTorch PER
Руководство по применению приоритетных весов из TensorDictPrioritizedReplayBuffer в PyTorch. Узнайте, автоматическое применение не происходит и как исправить смещение PER.
При использовании TensorDictPrioritizedReplayBuffer из PyTorch, нужно ли вручную применять веса приоритетов к функции потерь, или они автоматически учитываются? В PER (Prioritized Experience Replay) мы используем параметр Beta для вычисления весов, которые компенсируют смещение, вызванное PER. В TensorDictPrioritizedReplayBuffer эти веса, как кажется, хранятся в поле _weight выходного TensorDict. При использовании функций потерь PyTorch, таких как DDPGLoss, они автоматически учитывают поле _weight, если в TensorDict его присутствует? Если нет, как следует вручную применять веса приоритетов к расчёту потерь?
С TensorDictPrioritizedReplayBuffer приоритетные веса не применяются автоматически к функциям потерь, таким как DDPGLoss. Поле _weight, содержащее приоритетные веса, автоматически включается в сэмпл TensorDict, но их необходимо вручную умножить на значение потерь, чтобы корректно компенсировать смещение, вызванное Prioritized Experience Replay (PER).
Содержание
- Как работают приоритетные веса в
TensorDictPrioritizedReplayBuffer - Ручное применение приоритетных весов
- Примеры реализации
- Автоматическое обновление приоритетов
- Лучшие практики
Как работают приоритетные веса в TensorDictPrioritizedReplayBuffer
TensorDictPrioritizedReplayBuffer автоматически вычисляет и сохраняет приоритетные веса в поле _weight объектов TensorDict, которые возвращаются при сэмплинге. Согласно документации PyTorch RL, при сэмплинге из буфера возвращаемый TensorDict содержит:
TensorDict(
fields={
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)
},
batch_size=torch.Size([5]), device=cpu, is_shared=False
),
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)
},
batch_size=torch.Size([5]), device=cpu, is_shared=False
)
Поле _weight содержит приоритетные веса, вычисленные с использованием параметра beta, который вы указываете при создании буфера. Эти веса предназначены для компенсации смещения, вызванного приоритетным сэмплингом.
Ручное применение приоритетных весов
Поскольку функции потерь, такие как DDPGLoss, не используют поле _weight автоматически, вам необходимо вручную применить эти веса к вычислению потерь. Вот как это сделать:
Шаг 1: Сэмплирование из буфера
from torchrl.data import TensorDictPrioritizedReplayBuffer
# Создание буфера с параметрами alpha и beta
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(1000), batch_size=32)
# Сэмплирование из буфера
sample = rb.sample(32)
Шаг 2: Вычисление потерь с приоритетными весами
from torchrl.objectives import DDPGLoss
# Инициализация функции потерь DDPG
loss_fn = DDPGLoss()
# Вычисление потерь
losses = loss_fn(sample)
# Извлечение приоритетных весов
priority_weights = sample.get("_weight")
# Применение весов к потере
weighted_loss = losses["loss_value"] * priority_weights
Шаг 3: Обратное распространение взвешенной потери
optimizer.zero_grad() weighted_loss.mean().backward() optimizer.step()
Примеры реализации
Ниже приведён полный пример, демонстрирующий правильную интеграцию:
import torch
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from torchrl.objectives import DDPGLoss
# Инициализация буфера воспроизведения
rb = TensorDictPrioritizedReplayBuffer(
alpha=0.6, # Сила приоритезации
beta=0.4, # Коррекция смещения при важностном сэмплинге
storage=LazyTensorStorage(10000),
batch_size=64,
priority_key="td_error" # Ключ для TD‑ошибок
)
# Инициализация функции потерь DDPG
loss_fn = DDPGLoss()
# Цикл обучения
for epoch in range(num_epochs):
# Сэмплирование из буфера воспроизведения
batch = rb.sample(64)
# Вычисление потерь
losses = loss_fn(batch)
# Извлечение приоритетных весов
weights = batch.get("_weight")
# Применение весов важностного сэмплинга
weighted_loss = losses["loss_value"] * weights
# Обратное распространение
optimizer.zero_grad()
weighted_loss.mean().backward()
optimizer.step()
# Обновление приоритетов (см. следующий раздел)
update_priorities(rb, batch)
Автоматическое обновление приоритетов
TensorDictPrioritizedReplayBuffer упрощает процесс обновления приоритетов. Согласно документации PyTorch, функция DDPGLoss автоматически записывает ключ "td_error", который можно использовать буферами с приоритетом:
# Это обновит приоритеты на основе TD‑ошибок
rb.update_tensordict_priority(sample)
Метод update_tensordict_priority использует поле "td_error" из TensorDict для обновления приоритетов выбранных опытов. Это создаёт бесшовный рабочий процесс:
- Вы сэмплируете опыты с приоритетными весами
- Вы вычисляете потери и распространяете взвешенную потерю
- Вы обновляете приоритеты на основе новых TD‑ошибок
Лучшие практики
Анимация параметра beta
Постепенно увеличивайте beta от 0.4 до 1.0 в течение обучения, чтобы перейти от важностного сэмплинга к равномерному сэмплингу:
def get_beta(epoch, total_epochs, start_beta=0.4, end_beta=1.0):
return start_beta + (end_beta - start_beta) * (epoch / total_epochs)
Обрезка весов
Обрежьте экстремальные веса, чтобы избежать численной нестабильности:
weights = torch.clamp(weights, 0.1, 10.0)
Многошаговые TD‑ошибки
Для более стабильного обновления приоритетов рассмотрите использование многошаговых TD‑ошибок:
# Сохраняйте n‑шаговые возвраты в буфере воспроизведения
# Используйте их для обновления приоритетов вместо 1‑шаговых TD‑ошибок
Отладка приоритетных весов
Отслеживайте приоритетные веса, чтобы убедиться, что они работают корректно:
# Печать статистики приоритетных весов
print(f"Priority weights - mean: {weights.mean():.3f}, std: {weights.std():.3f}, min: {weights.min():.3f}, max: {weights.max():.3f}")
Показатели производительности
- Поле
_weightавтоматически включается во все сэмплы изTensorDictPrioritizedReplayBuffer - Вы не обязаны вручную вычислять веса важностного сэмплинга — они обрабатываются буфером
- Метод
update_tensordict_priorityэффективно обновляет приоритеты, используя сохранённые индексы
Заключение
Приоритетные веса в TensorDictPrioritizedReplayBuffer не применяются автоматически к функциям потерь. Необходимо вручную умножить значения потерь на поле _weight из сэмплённого TensorDict, чтобы корректно компенсировать смещение PER. Буфер автоматически обрабатывает вычисление и хранение приоритетных весов, что делает реализацию простым, если вы знаете, что ручное применение необходимо.
Ключевые выводы:
- Всегда извлекайте
_weightиз сэмплённогоTensorDict - Умножайте ваши потери на эти веса перед обратным распространением
- Используйте
update_tensordict_priority()для обновления приоритетов на основе TD‑ошибок - Реализуйте анимацию
betaдля лучшей стабильности обучения - Мониторьте статистику приоритетных весов, чтобы убедиться в корректной работе
Эта ручная стратегия даёт гибкость в применении важностного сэмплинга при сохранении удобства автоматического управления приоритетами, предоставляемого библиотекой PyTorch RL.