引言

大型语言模型(LLM)的强化学习训练流程往往涉及复杂的分布式采样、大规模策略网络和高维奖励建模,初学者很容易迷失在工程细节中。为了剥离这些干扰,本文借助 OpenAI Gym 中经典的 Frozen Lake(冰湖) 环境,结合一份精简但功能完整的代码实现,深入剖析两种主流策略优化算法:PPO(Proximal Policy Optimization) 和 GRPO(Group Relative Policy Optimization)。我们将聚焦于它们的核心——优势函数(Advantage Estimation) 的构建逻辑,并解释其背后的动机与数学形式。

Frozen Lake 环境简介

Frozen Lake 是一个网格世界(Grid World)任务:

  • 状态空间:智能体位于 N times N 网格中的某个格子,用整数索引表示(如 0 到 N^2 - 1)。
  • 动作空间:上下左右四个方向移动(离散动作,共 4 个)。
  • 地图元素:
    • S:起点(Start)
    • F:安全冰面(Frozen)
    • H:冰窟(Hole),掉入即失败,奖励为 0
    • G:目标(Goal),到达即成功,奖励为 1
  • 关键特性:
    • 稀疏奖励:只有到达终点才获得 +1 奖励,其余均为 0。
    • 随机性(可选):若启用 is_slippery=True,执行动作后可能滑向相邻方向,增加探索难度。

这一设定与 LLM 的强化学习训练高度相似:在语言模型中,生成一个完整句子(trajectory)需要逐 token 决策(action),而人类偏好或自动评估器通常只在句子结束后给出一个标量反馈(如奖励模型打分)。类似地,在 Frozen Lake 中,智能体每一步的移动(对应 LLM 生成一个 token)本身不产生奖励,只有最终是否成功抵达目标(对应句子质量)才决定整个轨迹的回报。因此,这个强化学习环境是比较好的用来研究 LLM 强化学习的样例。

PPO:基于时序差分的优势估计

核心思想

PPO 属于 on-policy actor-critic 方法,其核心是通过 裁剪(clipping) 限制策略更新步长,避免因大步更新导致性能崩溃。而这一切的前提是准确估计每个状态-动作对的 优势函数 A(s_t, a_t)。

优势函数的构造:GAE

PPO 使用 广义优势估计(Generalized Advantage Estimation, GAE) 来平衡偏差与方差:

δt=rt+γV(st+1)V(st)AtGAE(γ,λ)=l=0(γλ)lδt+l\begin{aligned} \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t) \\ A_t^{\text{GAE}(\gamma, \lambda)} &= \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} \end{aligned}

其中:

  • δt\delta_t 是 TD 误差(Temporal Difference Error),衡量当前价值函数 V(st)V(s_t) 对未来回报的预测偏差。
  • γ\gamma 是折扣因子,λ[0,1]\lambda \in [0,1] 控制“多步”程度:
    • λ=0\lambda = 0:仅用单步 TD 误差(低方差、高偏差)
    • λ=1\lambda = 1:等价于蒙特卡洛回报减去基线(高方差、低偏差)

在 Frozen Lake 这类稀疏奖励环境中,GAE 能将终点的 +1 奖励 反向传播 到路径上的所有状态,赋予它们正的优势值,从而指导策略学习“哪些中间步骤是有益的”。

PPO 损失函数

给定旧策略下采样的动作对数概率 logπθold(atst)\log \pi_{\theta_{\text{old}}}(a_t|s_t),新策略的损失为:

LPPO=Et[min(rt(θ)At,clip(rt(θ),1ϵ,1+ϵ)At)]L^{\text{PPO}} = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip} \left( r_t(\theta), 1 - \epsilon, 1 + \epsilon \right) A_t \right) \right]

其中 rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} 是概率比。裁剪机制确保:即使新策略大幅偏离旧策略,梯度也不会过大,从而保证训练稳定性。

GRPO:基于群体相对排序的优势估计

动机与背景

GRPO 最初由 DeepSeek 提出,用于 LLM 的 结果监督(Outcome Supervision) 场景:仅在完整输出序列结束后获得一个标量奖励(如人类评分),无中间信号。这与 Frozen Lake 的稀疏奖励设定高度一致。

GRPO 放弃了对每一步进行精细优势建模,转而采用 群体内相对排序 的方式分配优势值。

优势函数的构造:群体标准化

对于一个包含 KK 个轨迹的 群体(group),每个轨迹 ii 的总奖励为 RiR_i。GRPO 将该轨迹中 所有时间步 的优势值设为:

Ai,t=RiμRσR+ϵ,tA_{i,t} = \frac{R_i - \mu_R}{\sigma_R + \epsilon}, \quad \forall t

其中 μR\mu_RσR\sigma_R 是该群体奖励的均值与标准差。这意味着:

  • 高奖励轨迹的所有动作都被赋予 正优势
  • 低奖励轨迹的所有动作都被赋予 负优势
  • 优势值大小反映该轨迹在群体中的 相对优劣

这种方法完全规避了价值函数建模和时序信用分配问题,特别适合 轨迹级奖励 且 难以建模中间状态价值 的场景。

GRPO 损失函数

损失形式与 PPO 相同,但优势值 AtA_t 被替换为上述群体标准化后的常数:

LGRPO=Ei,t[min(ri,t(θ)A^i,clip(ri,t(θ),1ϵ,1+ϵ)A^i)]L^{\text{GRPO}} = \mathbb{E}_{i,t} \left[ \min \left( r_{i,t}(\theta) \hat{A}_i, \text{clip} \left( r_{i,t}(\theta), 1 - \epsilon, 1 + \epsilon \right) \hat{A}_i \right) \right]

几个需要关注的问题

非常好的问题!以下是对您提出的五个关键问题的逐条详细解答,结合 Frozen Lake 环境与代码实现,帮助深入理解 PPO 与 GRPO 的核心机制:

优势值 AA 是如何计算的?

优势函数(Advantage Function) 定义为:

A(st,at)=Q(st,at)V(st)A(s_t, a_t) = Q(s_t, a_t) - V(s_t)

即:采取动作 ata_t 相比于“平均策略”能多获得多少回报。在实践中,由于 Q(st,at)Q(s_t, a_t) 难以直接估计,通常用 回报(Return)减去价值函数估计 来近似:

A(st,at)RtV(st)A(s_t, a_t) \approx R_t - V(s_t)

其中 Rt=k=tTgammaktrkR_t = \sum_{k=t}^{T} gamma^{k-t} r_k 是从时间步 tt 开始的实际折扣回报。但在 PPO 中,使用 GAE(Generalized Advantage Estimation) 更稳定地估计优势:

δt=rt+γV(st+1)V(st)AtGAE(γ,λ)=l=0(γλ)lδt+l\begin{aligned} \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t) \\ A_t^{\text{GAE}(\gamma, \lambda)} &= \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} \end{aligned}

这是一个 递归反向计算 的过程。利用了 critic 模型对 V(st)V(s_t) 的估计。平衡了偏差(bias)与方差(variance)。

而在 GRPO 中,不计算每一步的优势,而是对一个 group(如8个轨迹)的 最终奖励 R_i 做标准化:

A^i=RiμRσR+ϵ\hat{A}_i = \frac{R_i - \mu_R}{\sigma_R + \epsilon}

将该标量 A^i\hat{A}_i 赋给该轨迹中所有时间步 的优势值。

对比 PPO 和 GRPO 优势函数 AA 的含义

维度 PPO GRPO
优势来源 通过 critic 模型估计状态价值,结合实际回报计算 TD 误差 仅依赖完整轨迹的最终奖励(outcome supervision)
粒度 每个时间步独立计算(fine-grained) 整个轨迹共享同一个优势值(coarse-grained)
是否需要 critic 是(必须训练价值网络) 否(无需 critic)
信用分配能力 强:能区分路径中“好”和“坏”的中间步骤 弱:假设整条轨迹要么全好,要么全坏
适用场景 有中间奖励或可建模价值函数的环境 仅有最终标量奖励(如 LLM 的人类评分、自动评估分数)

💡 在 Frozen Lake 中,若启用 is_slippery=True,最优路径可能有多条,PPO 能学习“哪些转弯更可靠”,而 GRPO 只知道“成功/失败”,无法区分路径质量细节。

哪些计算步骤是在权重固定状态下进行的?哪些是在权重更新时进行的?

✅ 权重固定(@torch.no_grad())阶段:采样 + 准备输入
目的:用当前策略(旧策略)收集数据,确保策略梯度无偏。
包括:

  • sample_round() / sample_batch():用 actor 采样轨迹
  • prepare_inputs()(在 @torch.no_grad() 下):
    • PPO:计算 old_action_log_probs、values(critic 推理)、step_level_rewards、advantages(GAE)
    • GRPO:计算 old_action_log_probs、grouped_advantages(基于 score 标准化)

⚠️ 注意:虽然 critic 在 prepare_inputs 中被调用,但其参数在此阶段 不更新,仅用于推理。

🔁 权重更新阶段:update_model()
目的:基于旧策略采集的数据,更新 actor(和 critic)参数。
包括:

  • 重新前向传播 actor_model(states) 得到新 action_log_probs
  • 计算概率比 ratio = exp(new_logp - old_logp)
  • 构建 PPO/GRPO 损失函数
  • 反向传播 + 优化器更新

📌 关键点:old_action_log_probs 必须在采样时记录并冻结,否则会引入偏差(因为策略变了,log prob 也变)。

如何把 sentence-wise 的奖励,换算成 token-wise 的奖励?

在 LLM 或 Frozen Lake 这类 稀疏奖励 场景中:
原始奖励:只在序列结束时给出(如 Frozen Lake 的 score ∈ {0, 1})
目标:为每个 token(每一步动作)分配一个“责任”信号,指导策略更新

方法一:PPO 的 GAE 方式
将 sentence-level 奖励放在最后一步:rewards = [0, 0, …, R]
通过 GAE 将终点奖励 反向传播 到前面各步
结果:靠近成功路径的 token 获得正优势,远离的获得负优势

方法二:GRPO 的广播方式
不拆分奖励,直接将整个句子的标准化奖励 hat{A} 赋给 所有 token
即:假设“如果句子成功,那么每一步都值得鼓励”

✅ 两种都是将 sentence-wise reward → token-wise advantage 的映射,但粒度不同。

Critic Model 和 Reward Model 的区别?优化目标是什么?

项目 Critic Model(PPO 中) Reward Model(RM,常用于 RLHF)
输入 当前状态 sts_t(或 token 序列前缀) 完整输出序列(如整个句子)
输出 估计每步的未来总回报 V(st)=E[Rt]V(s_t) = \mathbb{E}[R_t] 对完整序列的打分 rRr \in \mathbb{R}
训练方式 与 actor 联合训练,最小化 (V(st)Rt)2(V(s_t) - R_t)^2 独立训练(如偏好排序、回归人类评分)
作用 辅助计算优势函数 At=RtV(st)A_t = R_t - V(s_t) 提供最终奖励信号(替代人类)
是否参与策略梯度 否(仅用于构造目标) 否(仅提供 reward)

附录

具体实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
import os
import json
import copy
import time
import random
from typing import *
from tqdm import trange
from dataclasses import dataclass
from argparse import ArgumentParser

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import gymnasium as gym
from gymnasium.envs.toy_text.frozen_lake import generate_random_map


class ActorNet(nn.Module):

def __init__(self, input_size: int, num_actions: int, feature_size: int = 128) -> None:
super(ActorNet, self).__init__()

self.feature_extractor = nn.Sequential(
nn.Conv2d(in_channels=2, out_channels=feature_size, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),

nn.Conv2d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)

h_out = w_out = input_size // 4
conv_output_size = feature_size * h_out * w_out

self.fc_layers = nn.Sequential(
nn.Linear(conv_output_size, 128),
nn.ReLU(inplace=True),
nn.Linear(128, num_actions),
)

self.num_actions = num_actions

def forward(self, state: torch.Tensor, action: torch.Tensor = None):
x = self.feature_extractor(state)
x = x.view(x.size(0), -1)
logits = self.fc_layers(x) # (batch_size, num_actions)
proba = F.softmax(logits, dim=-1) # (batch_size, num_actions)

if action is None:
return proba, None

# 在这里计算logproba
log_proba = F.log_softmax(logits, dim=-1) # (batch_size, num_actions)
log_proba_selected = log_proba.gather(1, action.long().unsqueeze(1)).squeeze(1) # (batch_size,)

return proba, log_proba_selected


class CriticNet(nn.Module):

def __init__(self, input_size: int, feature_size: int = 128) -> None:
super(CriticNet, self).__init__()

self.feature_extractor = nn.Sequential(
nn.Conv2d(in_channels=2, out_channels=feature_size, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),

nn.Conv2d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)

h_out = w_out = input_size // 4
conv_output_size = feature_size * h_out * w_out

self.fc_layers = nn.Sequential(
nn.Linear(conv_output_size, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1),
)

def forward(self, state: torch.Tensor):
# 输入张量形状应该是 1 x input_size x input_size (C x H x W)
x = self.feature_extractor(state)
x = x.view(x.size(0), -1) # 将多维特征图展平为一维向量
return self.fc_layers(x)


class Utils():

@staticmethod
def set_seed(seed: int) -> None:
"""设置 Python 环境的所有常用随机数生成器的种子。"""
if seed is None:
return
random.seed(seed) # Python's built-in random module
np.random.seed(seed) # Numpy library
os.environ['PYTHONHASHSEED'] = str(seed) # Environment variable

# TensorFlow 2.x
# import tensorflow as tf
# tf.random.set_seed(seed)

# PyTorch - If you are using PyTorch, you would also need to set its seed
import torch
torch.manual_seed(seed)
# if you are using CUDA:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

# Other libraries might also have their own random number generators.

@staticmethod
def whiten_sequence(sequence: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
# 如果总元素数 <= 1,std 必然为 0,直接处理
if sequence.numel() <= 1:
return sequence - sequence.mean() if shift_mean else sequence.clone()
mean, std = sequence.mean(), sequence.std()
# 避免全零方差导致爆炸
if std.item() < 1e-8:
return sequence - mean if shift_mean else sequence.clone()
whiten = (sequence - mean) / (std + 1e-8)
if not shift_mean:
whiten += mean
return whiten


class DataUtils():

@staticmethod
def get_env(size: int = 8, is_slippery: bool = True, render_mode: str = None) -> gym.Env:
return gym.make(
'FrozenLake-v1',
desc=generate_random_map(size=size),
is_slippery=is_slippery,
render_mode=render_mode,
)

@staticmethod
def build_static_grid(env: gym.Env) -> torch.Tensor:
""" 从 env.unwrapped.desc 构造静态网格通道(S/G/F/H -> 0/1/2/3) """
mapping = {b'S': 0.0, b'G': 1.0, b'F': 2.0, b'H': 3.0}
desc = env.unwrapped.desc # np.ndarray of bytes, shape (H, W)
H, W = desc.shape
grid = torch.empty((H, W), dtype=torch.float32)
for i in range(H):
for j in range(W):
grid[i, j] = mapping[desc[i, j]]
return grid

@staticmethod
def make_state_tensor(static_grid: torch.Tensor, obs: int) -> torch.Tensor:
""" 根据 obs(离散索引)构造位置 one-hot 通道,并与静态网格通道堆叠 """
H, W = static_grid.shape
pos = torch.zeros((H, W), dtype=torch.float32)
pos[obs // W, obs % W] = 1.0
return torch.stack([static_grid, pos], dim=0) # (2, H, W)

@staticmethod
@torch.no_grad()
def sample_action(actor_model: nn.Module, state: torch.Tensor) -> Tuple[int, float]:
device = next(actor_model.parameters()).device
state = state.unsqueeze(0).float().to(device) # (1, 2, H, W)
probas, _ = actor_model(state)
dist = torch.distributions.Categorical(probas)
action = dist.sample()
action_log_proba = dist.log_prob(action)
return int(action.item()), float(action_log_proba.item())

@staticmethod
@torch.no_grad()
def sample_round(env: gym.Env, actor_model: nn.Module, render_mode: str = None) -> List[Dict[str, Any]]:
sequence = []
score = None
obs, info = env.reset()

static_grid = DataUtils.build_static_grid(env)
state = DataUtils.make_state_tensor(static_grid, obs)

while True:
if render_mode in ("rgb_array", "human"):
env.render()
time.sleep(0.3)
action, _ = DataUtils.sample_action(actor_model, state)
obs, reward, terminated, truncated, info = env.step(action)
next_state = DataUtils.make_state_tensor(static_grid, obs)

sequence.append((state, action))
state = next_state

if terminated or truncated:
sequence.append((state, None))
score = float(reward)
break

states, actions = list(zip(*sequence))

return dict(states=list(states), actions=list(actions), score=score)

@staticmethod
@torch.no_grad()
def sample_batch(actor_model: nn.Module, batch_size: int, group_size: int, **env_args) -> List[Dict[str, Any]]:
actor_model.eval()
examples = []
for i in range(batch_size):
env = DataUtils.get_env(**env_args)
try:
for i in range(group_size):
examples.append(DataUtils.sample_round(env, actor_model))
finally:
env.close()
return examples


@dataclass
class Config():

version: str = "v0"
seed: int = 42
frozen_lake_size: int = 4
frozen_lake_slippery: bool = False
num_actions: int = 4

whiten_rewards: bool = False

max_steps: int = 1000
save_steps: int = 100
batch_size: int = 32
group_size: int = 8
num_updates_per_batch: int = 1
max_grad_norm: float = 0.5

clip_epsilon: float = 0.2
entropy_coef: float = 0.01

device: str = "cuda" if torch.cuda.is_available() else "cpu"
output_dir: str = None

def __post_init__(self):
self.output_dir = os.path.join("./", self.version)
os.makedirs(self.output_dir, exist_ok=True)
print(f"Saving to {self.output_dir}")


class Inferer():

def __init__(self, config: Config, step_no: int, render_mode: str = "human") -> None:
self.config = config
self.step_no = step_no
self.render_mode = render_mode

# 读取模型
save_dir = os.path.join(self.config.output_dir, f"checkpoint-{step_no:06d}")
print(f"Loading model states from {save_dir}")
self.actor_model = ActorNet(self.config.frozen_lake_size, self.config.num_actions).to(self.config.device)
self.actor_model.load_state_dict(torch.load(os.path.join(save_dir, "actor.pt")))
self.actor_model.eval()

@torch.no_grad()
def infer(self, ) -> None:
# 初始化环境
env = DataUtils.get_env(
self.config.frozen_lake_size,
self.config.frozen_lake_slippery,
render_mode=self.render_mode,
)
return DataUtils.sample_round(env, self.actor_model, render_mode=self.render_mode)


class Trainer():

def __init__(self, config: Config) -> None:
self.config = config
self.writer = SummaryWriter(
os.path.join(config.output_dir, "logs/")
)

@torch.no_grad()
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
raise NotImplementedError

def update_model(self) -> None:
raise NotImplementedError

def save_model(self, step_no: int) -> str:
raise NotImplementedError

def train(self):
for step_no in trange(self.config.max_steps, desc="Training...", total=self.config.max_steps):
# 采样一批数据
batch = DataUtils.sample_batch(
self.actor_model,
self.config.batch_size,
self.config.group_size,
size=self.config.frozen_lake_size,
is_slippery=self.config.frozen_lake_slippery,
)
# 准备输入
batch = self.prepare_inputs(batch)
# 更新模型参数
metrics = self.update_model(batch)
# 打印参数
print(json.dumps(metrics, ensure_ascii=False))
for score_name, score_value in metrics.items():
self.writer.add_scalar(score_name, score_value, step_no)
# 保存模型
if step_no > 0 and step_no % self.config.save_steps == 0:
model_path = self.save_model(step_no)
print(f"Step [{step_no+1}/{self.config.max_steps}] model saved at {model_path}")


@dataclass
class PPOConfig(Config):

actor_learning_rate: float = 1e-4
critic_learning_rate: float = 3e-4

gamma: float = 0.9
lam: float = 0.95

critic_loss_coef: float = 0.5


class PPOTrainer(Trainer):

def __init__(self, config: PPOConfig) -> None:
super().__init__(config)

self.actor_model = ActorNet(config.frozen_lake_size, config.num_actions).to(config.device)
self.critic_model = CriticNet(config.frozen_lake_size).to(config.device)
self.reference_model = None # 预训练模型作为reference模型,但该实验无预训练模型

self.actor_optimizer = optim.Adam(self.actor_model.parameters(), lr=config.actor_learning_rate)
self.critic_optimizer = optim.Adam(self.critic_model.parameters(), lr=config.critic_learning_rate)

def compute_gae(self, step_level_rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
sequence_length = step_level_rewards.size(0)
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(sequence_length)): # 优势函数依赖于未来的值,所以从终点往回推
next_value = values[t + 1] if t + 1 < sequence_length else 0.0 # 最后一个时间步,没有后续状态了,相当于假设 episode 结束,价值为 0
delta = step_level_rewards[t] + self.config.gamma * next_value - values[t] # 计算TD误差(Temporal Difference Error):
# \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam # 递归计算 GAE 优势值:
# \A^{GAE}(s_t, a_t) = \delta_t + \gamma \lambda \delta_{t+1} + (\gamma \lambda) ** 2 \delta_{t+2} + ...
# 当 λ = 1,接近蒙特卡洛优势(即多步,高方差低偏差);
# 当 λ = 0,退化为单步TD误差(即单步 \delta_t,低方差高偏差);
# 取中间值,平衡偏差与方差。
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=-1) # (sequence_length,)
return advantages

@torch.no_grad()
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
batch = copy.deepcopy(batch)
for example_no in range(self.config.batch_size * self.config.group_size):
example = batch[example_no]
states: List[str] = example["states"] # (sequence_length,)
actions: List[int] = example["actions"] # (sequence_length,)
score: float = example["score"]
sequence_length: int = len(states)

# step 1. 计算每一步时,采取动作的对数概率(action_log_probs)和状态价值(values)
encode_states = torch.stack(states, dim=0).float() # (sequence_length, channel, height, width)
encode_actions = torch.tensor(actions[:-1], dtype=torch.int64) # (sequence_length - 1,)
_, action_log_probs = self.actor_model(encode_states[:-1], encode_actions) # (sequence_length - 1,)
values = self.critic_model(encode_states).squeeze(-1) # (sequence_length,)

# step 2. 计算步级奖励(step_level_rewards),如果有参考模型(reference_model),这里应该:
# 1)计算step-level的KL散度作为步级奖励;
# 2)把序列级奖励加到最后一步
step_level_rewards = [0.0] * (sequence_length - 1) + [score]
step_level_rewards = torch.tensor(step_level_rewards, dtype=torch.float32) # (sequence_length,) [0.0, 0.0, 0.0, ..., 1.0]
if self.config.whiten_rewards:
step_level_rewards = Utils.whiten_sequence(step_level_rewards, shift_mean=False)

# step 3. GAE(Generalized Advantage Estimation),计算每一步的优势值(advantages)和回报(returns)
advantages = self.compute_gae(step_level_rewards, values) # (sequence_length,)
returns = advantages + values # 计算回报值,作为critic model的groundtruth
# 但注意:V(s_t)是critic当前输出的,A(s_t, a_t)是基于这个value计算出的,所以这个returns并不是完全独立的真实标签,而是自举式构造的目标值(boostratp target)
advantages = Utils.whiten_sequence(advantages)

example["action_log_probs"] = action_log_probs # (sequence_length - 1,)
example["values"] = values # (sequence_length,)
example["advantages"] = advantages # (sequence_length,)
example["returns"] = returns # (sequence_length,)

return batch

def update_model(self, batch: List[Dict[str, Any]]) -> None:
self.actor_model.train()
log_actor_loss = 0.0
log_critic_loss = 0.0
# 更新模型参数
for epoch_no in range(self.config.num_updates_per_batch):
# 使用“步数加权”的累计器
device = next(self.actor_model.parameters()).device
total_actor_loss = torch.tensor(0.0, device=device)
total_actor_steps = 0 # 记录步数,防止序列长度影响样本权重
total_critic_loss = torch.tensor(0.0, device=device)
total_critic_steps = 0 # 记录步数,防止序列长度影响样本权重

for example_no in range(self.config.batch_size * self.config.group_size):
example = batch[example_no]
states: List[str] = example["states"] # (sequence_length,)
actions: List[int] = example["actions"] # (sequence_length,)
old_action_log_probs: torch.Tensor = example["action_log_probs"] # (sequence_length - 1,)
advantages: torch.Tensor = example["advantages"] # (sequence_length,)
returns: torch.Tensor = example["returns"] # (sequence_length,)

# 重新前向
encode_states = torch.stack(states, dim=0).float() # (sequence_length, channel, height, width)
encode_actions = torch.tensor(actions[:-1], dtype=torch.int64) # (sequence_length - 1,)
probas, action_log_probs = self.actor_model(encode_states[:-1], encode_actions) # (sequence_length - 1,)
values = self.critic_model(encode_states).squeeze(-1) # (sequence_length,)

# actor:逐步损失,不做 mean
ratio = torch.exp(action_log_probs - old_action_log_probs) # (sequence_length - 1,)
step_actor_loss = - torch.min(
ratio * advantages[:-1],
torch.clamp(
ratio,
1 - self.config.clip_epsilon,
1 + self.config.clip_epsilon,
) * advantages[:-1]
) # (sequence_length - 1,)

# 熵奖励,最大化行动熵以鼓励探索
entropy = - (probas * torch.log(torch.clamp(probas, min=1e-8))).sum(dim=1) # (sequence_length,)
step_actor_loss = step_actor_loss - self.config.entropy_coef * entropy

# critic:逐步 MSE,不做 mean
step_critic_loss = 0.5 * torch.square(values - returns) # (sequence_length,)

# 累加总和与有效步数
total_actor_loss += step_actor_loss.sum()
total_actor_steps += step_actor_loss.numel()

total_critic_loss += (self.config.critic_loss_coef * step_critic_loss).sum()
total_critic_steps += step_critic_loss.numel()

# 如需记录每个样本的指标(仅用于日志,不用于梯度)
example["actor_loss"] = step_actor_loss.mean().item()
example["critic_loss"] = step_critic_loss.mean().item()

# 用“总和 / 总步数”得到 batch 级损失,确保每个时间步权重一致
actor_loss = total_actor_loss / max(1, total_actor_steps)
critic_loss = total_critic_loss / max(1, total_critic_steps)
log_actor_loss += (actor_loss.item() / self.config.num_updates_per_batch)
log_critic_loss += (critic_loss.item() / self.config.num_updates_per_batch)

# 更新actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.config.max_grad_norm)
self.actor_optimizer.step()

# 更新critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic_model.parameters(), self.config.max_grad_norm)
self.critic_optimizer.step()

# 打印指标(保持不变)
metrics = {
"score/mean": torch.tensor([e["score"] for e in batch]).mean().item(),
"score/max": torch.tensor([e["score"] for e in batch]).max().item(),
"score/min": torch.tensor([e["score"] for e in batch]).min().item(),
"actor_loss": log_actor_loss,
"critic_loss": log_critic_loss,
}
return metrics

def save_model(self, step_no: int) -> str:
save_dir = os.path.join(self.config.output_dir, f"checkpoint-{step_no:06d}")
os.makedirs(save_dir, exist_ok=True)
torch.save(self.actor_model.state_dict(), os.path.join(save_dir, f"actor.pt"))
torch.save(self.critic_model.state_dict(), os.path.join(save_dir, f"critic.pt"))
return save_dir

@dataclass
class GRPOConfig(Config):

actor_learning_rate: float = 1e-4


class GRPOTrainer(Trainer):

def __init__(self, config: PPOConfig) -> None:
super().__init__(config)

self.actor_model = ActorNet(config.frozen_lake_size, config.num_actions).to(config.device)
self.reference_model = None # 预训练模型作为reference模型,但该实验无预训练模型

self.actor_optimizer = optim.Adam(self.actor_model.parameters(), lr=config.actor_learning_rate)

def compute_grpo(self, rewards: torch.Tensor) -> torch.Tensor:
# 如果总元素数 <= 1,std 必然为 0,直接处理
if rewards.numel() <= 1:
return rewards - rewards.mean()
mean, std = rewards.mean(), rewards.std()
# 避免全零方差导致爆炸
if std.item() < 1e-8:
return rewards - mean
return (rewards - mean) / (std + 1e-8)

@torch.no_grad()
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
batch = copy.deepcopy(batch)
device = self.config.device

for group_no in range(self.config.batch_size):
group_start = group_no * self.config.group_size
group_end = group_start + self.config.group_size
group = batch[group_start: group_end]

# GRPO(Group Relative Policy Optimization): group relative advantage estimation
grouped_rewards = torch.tensor([example["score"] for example in group]).float().to(device)
grouped_advantages = self.compute_grpo(grouped_rewards) # len = group_size

for example_no in range(self.config.group_size):
example = group[example_no]
states: List[str] = example["states"] # (sequence_length,)
actions: List[int] = example["actions"] # (sequence_length,)
score: float = example["score"]
sequence_length: int = len(states)

# step 1. 计算每一步时,采取动作的对数概率(action_log_probs)和状态价值(values)
encode_states = torch.stack(states, dim=0).float() # (sequence_length, channel, height, width)
encode_actions = torch.tensor(actions[:-1], dtype=torch.int64) # (sequence_length - 1,)
_, action_log_probs = self.actor_model(encode_states[:-1], encode_actions) # (sequence_length - 1,)

# step 2. GRPO(Group Relative Policy Optimization): group relative advantage estimation
# DeepSeek原文:Outcome supervision provides the normalized reward at the end of each output 𝑜𝑖 and
# sets the advantages 𝐴ˆ𝑖,𝑡 of all tokens in the output as the normalized reward
advantages = grouped_advantages[example_no] # (1,)

example["action_log_probs"] = action_log_probs.detach() # (sequence_length,)
example["advantages"] = advantages.detach() # (sequence_length,)

return batch

def update_model(self, batch: List[Dict[str, Any]]) -> None:
self.actor_model.train()
log_actor_loss = 0.0
# 更新模型参数
for epoch_no in range(self.config.num_updates_per_batch):
# 使用“步数加权”的累计器
device = next(self.actor_model.parameters()).device
total_actor_loss = torch.tensor(0.0, device=device)
total_actor_steps = 0 # 记录步数,防止序列长度影响样本权重

for example_no in range(self.config.batch_size * self.config.group_size):
example = batch[example_no]
states: List[str] = example["states"] # (sequence_length,)
actions: List[int] = example["actions"] # (sequence_length,)
old_action_log_probs: torch.Tensor = example["action_log_probs"] # (sequence_length - 1,)
advantages: torch.Tensor = example["advantages"] # (1,)

# 重新前向
encode_states = torch.stack(states, dim=0).float() # (sequence_length, channel, height, width)
encode_actions = torch.tensor(actions[:-1], dtype=torch.int64) # (sequence_length - 1,)
probas, action_log_probs = self.actor_model(encode_states[:-1], encode_actions) # (sequence_length - 1,)

# actor:逐步损失,不做 mean
ratio = torch.exp(action_log_probs - old_action_log_probs) # (sequence_length - 1,)
step_actor_loss = - torch.min(
ratio * advantages,
torch.clamp(
ratio,
1 - self.config.clip_epsilon,
1 + self.config.clip_epsilon,
) * advantages
) # (sequence_length - 1,)

# 熵奖励,最大化行动熵以鼓励探索
entropy = - (probas * torch.log(torch.clamp(probas, min=1e-8))).sum(dim=1) # (sequence_length,)
step_actor_loss = step_actor_loss - self.config.entropy_coef * entropy

# 累加总和与有效步数
total_actor_loss += step_actor_loss.sum()
total_actor_steps += step_actor_loss.numel()

# 如需记录每个样本的指标(仅用于日志,不用于梯度)
example["actor_loss"] = step_actor_loss.mean().item()

# 用“总和 / 总步数”得到 batch 级损失,确保每个时间步权重一致
actor_loss = total_actor_loss / max(1, total_actor_steps)
log_actor_loss += (actor_loss.item() / self.config.num_updates_per_batch)

# 更新actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.config.max_grad_norm)
self.actor_optimizer.step()

# 打印指标(保持不变)
metrics = {
"score/mean": torch.tensor([e["score"] for e in batch]).mean().item(),
"score/max": torch.tensor([e["score"] for e in batch]).max().item(),
"score/min": torch.tensor([e["score"] for e in batch]).min().item(),
"actor_loss": log_actor_loss,
}
return metrics

def save_model(self, step_no: int) -> str:
save_dir = os.path.join(self.config.output_dir, f"checkpoint-{step_no:06d}")
os.makedirs(save_dir, exist_ok=True)
torch.save(self.actor_model.state_dict(), os.path.join(save_dir, f"actor.pt"))
return save_dir


if __name__ == "__main__":
parser = ArgumentParser(description="""
# 最简单的实现,没有进行异步采样、训练
# 为方便理解,没有采取全向量化的计算方式,比如回报(returns)的计算,也没有用到GPU加速

# # 环境说明:https://gymnasium.farama.org/environments/toy_text/frozen_lake/
# desc = generate_random_map(size=8)
# env = gym.make('FrozenLake-v1', desc=desc, is_slippery=True)

# RL运算参考:
# PPO:https://github.com/huggingface/trl/blob/20cc58d7772ae660792c7b5249d8b817986a547d/trl/trainer/ppo_trainer.py#L448
# GRPO:https://github.com/huggingface/trl/blob/9e5e60c9334d0d6d52498da4de68632148fceafb/trl/trainer/grpo_trainer.py#L1362
""")
parser.add_argument("--version", type=str, default="v0")
parser.add_argument("--seed", type=int, default=42)

parser.add_argument("--observation_size", type=int, default=4)
parser.add_argument("--num_actions", type=int, default=4)
parser.add_argument("--frozen_lake_size", type=int, default=4)

parser.add_argument("--adv_estimator", type=str, choices=["ppo", "grpo"], default="ppo")
parser.add_argument("--max_steps", type=int, default=1000, help="总的训练步数")
parser.add_argument("--save_steps", type=int, default=100, help="每隔若干步数保存一次模型")
parser.add_argument("--batch_size", type=int, default=32, help="每个step中的样本数量")
parser.add_argument("--group_size", type=int, default=8, help="每个样本采样的个数,每个step中的总样本数是(batch_size * group_size)")
parser.add_argument("--num_updates_per_batch", type=int, default=1, help="每个采样的批次用于迭代模型的轮数")
parser.add_argument("--actor_learning_rate", type=float, default=1e-4, help="actor模型学习率")
parser.add_argument("--critic_learning_rate", type=float, default=3e-4, help="critic模型学习率")
parser.add_argument("--max_grad_norm", type=float, default=0.5)

parser.add_argument("--whiten_rewards", action="store_true")
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--lam", type=float, default=0.95)
parser.add_argument("--clip_epsilon", type=float, default=0.2)

parser.add_argument("--entropy_coef", type=float, default=0.0, help="熵奖励系数,用于最大化行动熵以鼓励探索")
parser.add_argument("--critic_loss_coef", type=float, default=1.0, help="critic模型的权重系数")

args = parser.parse_args()

Utils.set_seed(args.seed)

if args.adv_estimator == "ppo":
ppo_config = PPOConfig(
version=args.version,
seed=args.seed,
frozen_lake_size=args.frozen_lake_size,
num_actions=args.num_actions,
max_steps=args.max_steps,
batch_size=args.batch_size,
group_size=args.group_size,
num_updates_per_batch=args.num_updates_per_batch,
actor_learning_rate=args.actor_learning_rate,
critic_learning_rate=args.critic_learning_rate,
max_grad_norm=args.max_grad_norm,
whiten_rewards=args.whiten_rewards,
gamma=args.gamma,
lam=args.lam,
clip_epsilon=args.clip_epsilon,
entropy_coef=args.entropy_coef,
critic_loss_coef=args.critic_loss_coef,
)
# inferer = Inferer(ppo_config, step_no=900)
# import pdb; pdb.set_trace()
# for i in range(100):
# inferer.infer()
# exit(0)

trainer = PPOTrainer(ppo_config)
trainer.train()

elif args.adv_estimator == "grpo":
grpo_config = GRPOConfig(
version=args.version,
seed=args.seed,
frozen_lake_size=args.frozen_lake_size,
num_actions=args.num_actions,
max_steps=args.max_steps,
save_steps=args.save_steps,
batch_size=args.batch_size,
group_size=args.group_size,
num_updates_per_batch=args.num_updates_per_batch,
actor_learning_rate=args.actor_learning_rate,
max_grad_norm=args.max_grad_norm,
whiten_rewards=args.whiten_rewards,
clip_epsilon=args.clip_epsilon,
entropy_coef=args.entropy_coef,
)

trainer = GRPOTrainer(grpo_config)
trainer.train()