RobotMem + Stable-Baselines3
一个 BaseCallback,为你的 RL 智能体提供持久化记忆 —— 在训练中保存感知数据,并跨训练轮次检索历史经验。
快速开始
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from robotmem import RobotMemory
class RobotMemCallback(BaseCallback):
def __init__(self, db="sb3_memory.db"):
super().__init__()
self.mem = RobotMemory(db=db)
def _on_step(self) -> bool:
obs = self.locals["new_obs"][0]
action = self.locals["actions"][0]
reward = self.locals["rewards"][0]
self.mem.save_perception(
observation=obs,
action=action,
reward=reward,
metadata={"step": self.num_timesteps}
)
return True
model = SAC("MlpPolicy", "FetchReach-v3")
model.learn(50_000, callback=RobotMemCallback())
集成功能
Stable-Baselines3 是 Python 生态系统中使用最广泛的强化学习库。它提供了 PPO、SAC、TD3、A2C、DQN 等标准算法的可靠、经过充分测试的实现,以及一套简洁的回调系统,让你可以在训练循环的每一步进行挂钩。RobotMem 通过这一回调系统进行集成,这意味着它可以与 SB3 支持的每种算法配合使用,无需修改核心训练代码的任何一行。
RobotMemCallback 在训练过程中监听每个环境步骤。它捕获观测、策略选择的动作以及获得的奖励,然后通过 save_perception 将它们存储到本地 SQLite 数据库中。这会创建智能体所有经历的持久化记录 —— 不仅是当前训练轮次,而是共享同一数据库文件的所有训练轮次。当你明天重启训练、切换到不同的算法或在新任务上微调时,过去经验的完整历史仍然存在。
在检索方面,你可以在任何时候使用 recall 来查询记忆。一种常见的模式是在每个回合开始时检索相似的历史观测,以热启动智能体的价值估计,或者从高奖励经验中构建示范缓冲区用于离线 RL 预训练。由于 RobotMem 使用向量相似度搜索,检索速度足够快,可以在训练循环内运行而不会产生明显的开销。
- 即插即用的 BaseCallback —— 通过单个回调参数为任何 SB3 算法(SAC、PPO、TD3、A2C、DQN)添加持久化记忆,无需更改训练代码。
- 跨轮次经验持久化 —— 所有观测、动作和奖励存储在 SQLite 中,可在训练重启、机器重启和算法更换后保持不变。
- 高奖励经验挖掘 —— 查询记忆中表现最佳的回合,构建示范数据集用于模仿学习或离线 RL 预训练。
- 步级粒度 —— 每个环境步骤都以其时间步索引记录,可以精确分析智能体在何时何处学到了特定行为。
- 多环境支持 —— 通过环境名称标记记忆来维护独立的经验池,或将它们合并用于跨任务迁移学习。
进阶:训练中检索
基础回调被动地存储经验。如需更高级的用法,你可以扩展回调以主动检索历史经验,并将其注入智能体的观测或回放缓冲区。例如,当智能体遇到与之前获得高奖励的状态相似的状态时,你可以提升该转换在回放缓冲区中的优先级,或将检索到的示范追加到训练批次中。
这种模式对稀疏奖励环境特别有效。在 FetchPickAndPlace 或复杂操作场景等任务中,智能体可能经历数千个回合都没有收到任何正向奖励。使用 RobotMem,你可以用少量成功的示范(来自人类遥操作、脚本策略或之前的训练轮次)预填充记忆,回调会在智能体到达类似状态时检索这些示范。这为探索过程提供了一个具体的目标,而不是依赖纯随机探索。
SB3 的回调架构使这一集成点非常自然。_on_step 方法可以访问训练循环的完整局部变量作用域,包括回放缓冲区、当前策略和环境状态。这意味着 RobotMem 可以在每一步同时读取和写入训练过程,使其成为 SB3 训练中集成最深的记忆层。