世界模型(World Model) 是具身智能领域的重要范式,它通过学习环境的内部表示和动态规律,使智能体能够在想象空间中进行规划和学习。世界模型的核心思想源于认知科学:人类不需要每次决策都与真实环境交互,而是在大脑中构建了世界的"心智模型",可以在脑海中模拟未来场景。
在机器人控制和强化学习中,世界模型扮演着"环境模拟器"的角色。通过学习状态转移函数
本节将系统介绍世界模型的核心原理,包括模型架构、学习范式、以及与传统强化学习的对比。
世界模型是对环境动态的可学习表示,它包含三个核心组件:
graph TB
subgraph WorldModel["世界模型组件"]
V[视觉编码器 V]
M[记忆模型 M]
C[控制器 C]
end
Obs[观测 o_t] --> V
V --> |紧凑表示 z_t| M
M --> |隐状态 h_t| C
C --> |动作 a_t| Env[环境]
Env --> |奖励 r_t| M
Env --> Obs
V -.-> |VAE编码| Latent[潜在空间]
M -.-> |MDN-RNN| Dynamics[动态预测]
C -.-> |进化策略| Policy[策略网络]
style V fill:#e1f5ff
style M fill:#fff4e1
style C fill:#f0fff0
组件详解:
-
视觉编码器(V):
- 将高维观测(如图像)压缩为低维潜在表示
$z_t$ - 通常使用VAE(变分自编码器)实现
- 学习目标:$\mathcal{L}V = \mathbb{E}{q(z|o)}[\log p(o|z)] - \text{KL}(q(z|o) | p(z))$
- 将高维观测(如图像)压缩为低维潜在表示
-
记忆模型(M):
- 学习潜在空间的动态转移:$p(z_{t+1}, r_t | z_t, a_t, h_t)$
- 通常使用RNN/LSTM加MDN(混合密度网络)实现
- 捕获时序依赖和环境随机性
-
控制器(C):
- 根据当前状态
$(z_t, h_t)$ 选择动作$a_t$ - 可使用进化策略(ES)或强化学习训练
- 完全在世界模型内部训练,无需真实环境交互
- 根据当前状态
世界模型采用"先学习环境,再学习策略"的两阶段范式:
graph LR
subgraph Phase1["阶段1: 模型学习"]
D1[收集经验数据] --> T1[训练V编码器]
T1 --> T2[训练M动态模型]
T2 --> Model[环境模型]
end
subgraph Phase2["阶段2: 策略学习"]
Model --> Sim[在模型中想象]
Sim --> T3[训练C控制器]
T3 --> Policy[策略网络]
end
Policy -.-> |新数据| D1
style Phase1 fill:#ffe1e1
style Phase2 fill:#e1ffe1
关键特性:
- 样本效率:策略训练完全在模型内部进行,不消耗真实交互
- 可解释性:可以可视化模型的"想象"过程
- 泛化能力:学到的动态模型可迁移到新任务
- 安全探索:危险策略可在模型中提前检测
| 维度 | 传统强化学习(Model-Free) | 世界模型(Model-Based) |
|---|---|---|
| 环境交互 | 大量真实交互 | 少量真实交互+大量模拟 |
| 样本效率 | 低(百万级样本) | 高(千级样本) |
| 训练速度 | 慢(受环境限制) | 快(并行想象) |
| 计算成本 | 低(简单策略网络) | 高(复杂动态模型) |
| 适用场景 | 简单确定性环境 | 复杂随机性环境 |
| 可解释性 | 差(黑盒策略) | 好(可视化想象) |
为什么需要潜在空间?
直接在原始观测空间(如
VAE架构:
graph LR
subgraph Encoder["编码器 q(z|o)"]
O[观测 o] --> Conv[卷积层]
Conv --> Mu[均值 μ]
Conv --> Sigma[标准差 σ]
Mu --> Sample
Sigma --> Sample[重参数化采样]
end
Sample --> Z[潜在变量 z]
subgraph Decoder["解码器 p(o|z)"]
Z --> Deconv[反卷积层]
Deconv --> Recon[重建 ô]
end
O -.-> |L2损失| Recon
style Encoder fill:#e1f5ff
style Decoder fill:#fff4e1
VAE损失函数:
- 重建损失:确保潜在表示保留重要信息
- KL散度:使潜在分布接近标准正态分布,避免过拟合
-
$\beta$ 超参数:平衡重建质量和表示压缩度
现实世界环境通常具有随机性(如物理引擎的近似误差、传感器噪声)。确定性模型
混合密度网络(MDN):
MDN将下一状态建模为混合高斯分布:
其中:
-
$\pi_i$ :第$i$ 个高斯分量的权重 -
$\mu_i, \sigma_i$ :RNN输出的均值和标准差
graph TB
Input[输入: z_t, a_t, h_{t-1}] --> RNN[LSTM单元]
RNN --> H[隐状态 h_t]
H --> Pi[权重头 π_1...π_K]
H --> Mu[均值头 μ_1...μ_K]
H --> Sigma[方差头 σ_1...σ_K]
Pi --> Mix[混合高斯分布]
Mu --> Mix
Sigma --> Mix
Mix --> Sample[采样 z_{t+1}]
style RNN fill:#e1f5ff
style Mix fill:#fff4e1
MDN损失函数:
传统强化学习需要在真实环境中采样轨迹
graph TB
subgraph Real["真实环境(少量交互)"]
S0[初始状态 s_0] --> |a_0| R1[r_0, s_1]
R1 --> |a_1| R2[r_2, s_2]
end
subgraph Imagination["想象空间(大量rollout)"]
Z0[编码 z_0] --> |ã_0| I1[r̃_0, z̃_1]
I1 --> |ã_1| I2[r̃_1, z̃_2]
I2 --> |ã_2| I3[r̃_2, z̃_3]
I3 --> Dots[...]
Dots --> IN[r̃_N, z̃_N]
end
S0 -.-> |V编码| Z0
IN --> Loss[策略梯度损失]
style Real fill:#ffe1e1
style Imagination fill:#e1ffe1
想象轨迹生成伪代码:
/**
* 在世界模型中生成想象轨迹
*/
public class ImaginationRollout {
private VAE vae;
private MDNRNN mdnRnn;
private Controller controller;
public Trajectory rollout(Observation initialObs, int horizon) {
Trajectory trajectory = new Trajectory();
// 1. 编码初始观测
Tensor z = vae.encode(initialObs).sample();
Tensor h = mdnRnn.getInitialState();
// 2. 在想象空间中展开
for (int t = 0; t < horizon; t++) {
// 控制器决策
Action a = controller.forward(z, h);
// 世界模型预测
MDNOutput output = mdnRnn.forward(z, a, h);
Tensor zNext = output.sampleNextState();
float reward = output.predictReward();
h = output.getHiddenState();
// 记录轨迹
trajectory.add(z, a, reward);
z = zNext;
}
return trajectory;
}
}策略优化:
使用进化策略(CMA-ES)或策略梯度优化控制器参数:
其中轨迹
/**
* 变分自编码器(用于观测压缩)
*/
public class VAE {
private ConvEncoder encoder;
private ConvDecoder decoder;
private int latentDim = 32;
private float beta = 1.0f; // KL散度权重
/**
* 编码:观测 -> 潜在分布
*/
public Distribution encode(Tensor observation) {
// observation: [batch, 3, 64, 64]
Tensor features = encoder.forward(observation); // [batch, 256]
// 分离均值和对数方差
Tensor mu = features.slice(1, 0, latentDim); // [batch, 32]
Tensor logVar = features.slice(1, latentDim, 64); // [batch, 32]
return new GaussianDistribution(mu, logVar);
}
/**
* 解码:潜在变量 -> 重建观测
*/
public Tensor decode(Tensor z) {
// z: [batch, 32]
return decoder.forward(z); // [batch, 3, 64, 64]
}
/**
* 重参数化技巧
*/
public Tensor reparameterize(Tensor mu, Tensor logVar) {
Tensor std = logVar.mul(0.5f).exp();
Tensor eps = Tensor.randn(mu.shape());
return mu.add(std.mul(eps));
}
/**
* VAE损失函数
*/
public VAELoss computeLoss(Tensor observation, Distribution dist, Tensor reconstruction) {
// 1. 重建损失(MSE)
float reconLoss = observation.sub(reconstruction).pow(2).mean();
// 2. KL散度
Tensor mu = dist.getMean();
Tensor logVar = dist.getLogVar();
Tensor kl = mu.pow(2).add(logVar.exp()).sub(logVar).sub(1.0f);
float klLoss = kl.sum(1).mean() * 0.5f;
// 3. 总损失
float totalLoss = reconLoss + beta * klLoss;
return new VAELoss(totalLoss, reconLoss, klLoss);
}
}
/**
* 卷积编码器
*/
class ConvEncoder extends Module {
public ConvEncoder(int latentDim) {
// Conv1: [3, 64, 64] -> [32, 32, 32]
add(new Conv2d(3, 32, 4, 2, 1));
add(new ReLU());
// Conv2: [32, 32, 32] -> [64, 16, 16]
add(new Conv2d(32, 64, 4, 2, 1));
add(new ReLU());
// Conv3: [64, 16, 16] -> [128, 8, 8]
add(new Conv2d(64, 128, 4, 2, 1));
add(new ReLU());
// Conv4: [128, 8, 8] -> [256, 4, 4]
add(new Conv2d(128, 256, 4, 2, 1));
add(new ReLU());
// Flatten: [256, 4, 4] -> [4096]
// FC: [4096] -> [latentDim*2] (mu + logvar)
add(new Flatten());
add(new Linear(256 * 4 * 4, latentDim * 2));
}
}/**
* 混合密度网络RNN(用于动态预测)
*/
public class MDNRNN {
private LSTM lstm;
private Linear piHead; // 混合权重
private Linear muHead; // 均值
private Linear sigmaHead; // 标准差
private Linear rewardHead; // 奖励预测
private int hiddenSize = 256;
private int latentDim = 32;
private int actionDim = 3;
private int numMixtures = 5; // 混合高斯分量数
public MDNRNN() {
// LSTM: 输入(z_t + a_t) -> 隐状态
lstm = new LSTM(latentDim + actionDim, hiddenSize);
// MDN输出头
piHead = new Linear(hiddenSize, numMixtures);
muHead = new Linear(hiddenSize, numMixtures * latentDim);
sigmaHead = new Linear(hiddenSize, numMixtures * latentDim);
rewardHead = new Linear(hiddenSize, 1);
}
/**
* 前向传播
*/
public MDNOutput forward(Tensor z, Tensor action, Tensor hidden) {
// 1. 拼接输入
Tensor input = Tensor.cat(z, action, dim=1); // [batch, 35]
// 2. LSTM更新隐状态
LSTMOutput lstmOut = lstm.forward(input, hidden);
Tensor h = lstmOut.getHiddenState(); // [batch, 256]
// 3. MDN参数预测
Tensor pi = softmax(piHead.forward(h), dim=1); // [batch, 5]
Tensor mu = muHead.forward(h).reshape(-1, numMixtures, latentDim); // [batch, 5, 32]
Tensor sigma = exp(sigmaHead.forward(h)).reshape(-1, numMixtures, latentDim); // [batch, 5, 32]
// 4. 奖励预测
Tensor reward = rewardHead.forward(h); // [batch, 1]
return new MDNOutput(pi, mu, sigma, reward, h);
}
/**
* MDN负对数似然损失
*/
public float mdnLoss(MDNOutput output, Tensor targetZ) {
Tensor pi = output.getPi(); // [batch, 5]
Tensor mu = output.getMu(); // [batch, 5, 32]
Tensor sigma = output.getSigma(); // [batch, 5, 32]
// 扩展target维度: [batch, 32] -> [batch, 1, 32]
targetZ = targetZ.unsqueeze(1);
// 计算每个分量的概率密度
Tensor diff = targetZ.sub(mu); // [batch, 5, 32]
Tensor exponent = diff.pow(2).div(sigma.pow(2).mul(2.0f));
Tensor coeff = (float)(1.0 / Math.sqrt(2 * Math.PI));
Tensor prob = coeff * sigma.reciprocal().mul(exponent.neg().exp()); // [batch, 5, 32]
// 混合概率: Σ π_i * N(z | μ_i, σ_i)
prob = prob.prod(dim=2); // [batch, 5]
Tensor mixProb = pi.mul(prob).sum(dim=1); // [batch]
// 负对数似然
return mixProb.log().neg().mean();
}
/**
* 从MDN采样下一状态
*/
public Tensor sampleNextState(MDNOutput output) {
Tensor pi = output.getPi();
Tensor mu = output.getMu();
Tensor sigma = output.getSigma();
int batchSize = pi.shape(0);
Tensor samples = new Tensor(batchSize, latentDim);
for (int b = 0; b < batchSize; b++) {
// 根据混合权重选择分量
int k = categoricalSample(pi.getRow(b));
// 从选中的高斯分布采样
Tensor mean = mu.get(b, k);
Tensor std = sigma.get(b, k);
Tensor sample = mean.add(std.mul(Tensor.randn(latentDim)));
samples.setRow(b, sample);
}
return samples;
}
}
/**
* MDN输出
*/
class MDNOutput {
private Tensor pi; // 混合权重
private Tensor mu; // 均值
private Tensor sigma; // 标准差
private Tensor reward; // 奖励预测
private Tensor hidden; // LSTM隐状态
public Tensor sampleNextState() {
// 实现同上sampleNextState方法
}
public float predictReward() {
return reward.mean();
}
}/**
* 世界模型控制器(策略网络)
*/
public class Controller {
private Linear fc1;
private Linear fc2;
private int latentDim = 32;
private int hiddenDim = 256;
private int actionDim = 3;
public Controller() {
fc1 = new Linear(latentDim + hiddenDim, 64);
fc2 = new Linear(64, actionDim);
}
/**
* 策略前向传播
*/
public Action forward(Tensor z, Tensor h) {
// 拼接潜在状态和隐状态
Tensor state = Tensor.cat(z, h, dim=1); // [batch, 288]
// 两层全连接
Tensor x = tanh(fc1.forward(state));
Tensor action = tanh(fc2.forward(x)); // 输出范围[-1, 1]
return new Action(action);
}
/**
* 使用进化策略(CMA-ES)优化控制器
*/
public void trainWithEvolution(WorldModel worldModel, int generations) {
CMAES optimizer = new CMAES(this.getNumParameters());
for (int gen = 0; gen < generations; gen++) {
// 1. 采样参数种群
List<float[]> population = optimizer.samplePopulation();
// 2. 评估每个个体
float[] fitness = new float[population.size()];
for (int i = 0; i < population.size(); i++) {
this.setParameters(population.get(i));
fitness[i] = evaluateInWorldModel(worldModel);
}
// 3. 更新分布
optimizer.update(fitness);
System.out.printf("Generation %d: Best Fitness = %.2f\n",
gen, Arrays.stream(fitness).max().getAsDouble());
}
// 使用最优参数
this.setParameters(optimizer.getBestSolution());
}
/**
* 在世界模型中评估策略
*/
private float evaluateInWorldModel(WorldModel worldModel) {
float totalReward = 0;
int numRollouts = 16;
for (int i = 0; i < numRollouts; i++) {
Trajectory traj = worldModel.rollout(this, horizon=1000);
totalReward += traj.getTotalReward();
}
return totalReward / numRollouts;
}
}1. 样本效率极高
graph LR
subgraph Traditional["传统RL: 百万级样本"]
T1[环境步骤1] --> T2[环境步骤2]
T2 --> T3[...]
T3 --> T4[环境步骤1M]
end
subgraph WorldModel["世界模型: 千级样本"]
W1[真实交互1k步] --> W2[学习环境模型]
W2 --> W3[想象100万步]
W3 --> W4[策略优化]
end
style Traditional fill:#ffe1e1
style WorldModel fill:#e1ffe1
实验对比(CarRacing任务):
- PPO(Model-Free):需要 1000万 环境步骤
- World Models:仅需 1万 环境步骤
2. 并行加速
世界模型可在GPU上并行生成数千条想象轨迹,训练速度比真实环境快 100-1000倍。
3. 安全探索
危险动作(如机器人碰撞)可在模型中提前检测,避免真实损坏。
1. 模型误差累积
世界模型的预测误差会随时间步累积:
长期预测(如T>100步)容易偏离真实轨迹。
解决方案:
- 使用集成模型(Ensemble)估计不确定性
- 限制想象长度(如H=15步)
- 引入模型再训练(online adaptation)
2. 模型容量限制
复杂环境(如真实世界物理)难以用有限容量模型完全捕获。
解决方案:
- 使用分层世界模型(抽象层+细节层)
- 引入物理先验(如Newton定律)
- 仅建模关键因素(不建模无关背景)
3. 计算成本高
训练VAE + MDN-RNN比简单策略网络计算量大10倍以上。
权衡:
- 训练慢,但推理快(适合需要长期规划的任务)
- 初期投入大,后期收益高(模型可复用)
场景:机械臂抓取任务
- 传统方法:需要数千次真实抓取尝试
- 世界模型:100次真实尝试+百万次想象,学习成功抓取策略
场景:危险场景应对(如突然闯入行人)
- 在模拟器中无法覆盖所有长尾场景
- 世界模型学习真实交通规律,生成罕见但重要的测试用例
场景:Atari游戏、赛车游戏
- 学习游戏物理规律(如车辆动力学)
- 在想象中尝试策略,避免频繁重启游戏
| 模块 | 时间复杂度 | 空间复杂度 | 瓶颈 |
|---|---|---|---|
| VAE编码 | 卷积计算 | ||
| MDN-RNN | 序列长度 | ||
| 想象Rollout | 并行度 |
其中:
-
$C, H, W$ :图像通道、高、宽 -
$d$ :潜在维度 -
$h$ :RNN隐状态维度 -
$T$ :训练序列长度 -
$N$ :并行rollout数 -
$H$ :想象horizon
/**
* 世界模型推理优化
*/
public class OptimizedWorldModel {
/**
* 批量想象(并行加速)
*/
public List<Trajectory> batchRollout(List<Observation> initialStates, int horizon) {
int batchSize = initialStates.size();
// 1. 批量编码
Tensor zBatch = vae.encodeBatch(initialStates); // [batch, 32]
Tensor hBatch = mdnRnn.getInitialStateBatch(batchSize); // [batch, 256]
List<Trajectory> trajectories = new ArrayList<>();
// 2. 批量展开
for (int t = 0; t < horizon; t++) {
// 批量决策
Tensor actions = controller.forwardBatch(zBatch, hBatch);
// 批量预测
MDNOutput output = mdnRnn.forwardBatch(zBatch, actions, hBatch);
zBatch = output.sampleNextStateBatch();
hBatch = output.getHiddenStateBatch();
// 记录轨迹(省略)
}
return trajectories;
}
/**
* 模型集成(降低不确定性)
*/
public class EnsembleWorldModel {
private List<MDNRNN> models; // 5个独立训练的模型
public Tensor predictWithUncertainty(Tensor z, Tensor a, Tensor h) {
List<Tensor> predictions = new ArrayList<>();
for (MDNRNN model : models) {
predictions.add(model.forward(z, a, h).sampleNextState());
}
// 均值作为预测,方差作为不确定性
Tensor mean = Tensor.stack(predictions).mean(dim=0);
Tensor variance = Tensor.stack(predictions).var(dim=0);
return new UncertaintyPrediction(mean, variance);
}
}
}本节系统介绍了世界模型的核心原理:
-
三组件架构:
- 视觉编码器(V):观测压缩
- 记忆模型(M):动态预测
- 控制器(C):策略决策
-
核心技术:
- VAE潜在表示学习
- MDN-RNN随机动态建模
- 想象中的策略训练
-
主要优势:
- 样本效率高(千级vs百万级)
- 并行加速(GPU想象)
- 安全探索(无真实风险)
-
实际挑战:
- 模型误差累积
- 容量限制
- 计算成本
// 世界模型完整流程
VAE vae = new VAE(latentDim=32);
MDNRNN mdnRnn = new MDNRNN(hiddenSize=256, numMixtures=5);
Controller controller = new Controller();
// 阶段1: 学习环境模型
trainVAE(vae, observations);
trainMDNRNN(mdnRnn, latentTrajectories);
// 阶段2: 想象中训练策略
controller.trainWithEvolution(new WorldModel(vae, mdnRnn), generations=300);世界模型为具身智能提供了"大脑中模拟"的能力,是实现样本高效学习的关键技术。下一节将深入探讨VAE编码器的实现细节。
-
为什么世界模型使用VAE而不是普通自编码器?KL散度项的作用是什么?
-
MDN-RNN如何处理环境随机性?如果去掉混合密度网络,只用单高斯输出会有什么问题?
-
世界模型的"想象轨迹"与真实轨迹的差异如何影响策略学习?如何量化这种差异?
-
在什么场景下世界模型不适用?与Model-Free方法相比的劣势是什么?
-
如何设计实验验证世界模型学到了正确的环境动态?
-
World Models (Ha & Schmidhuber, 2018)
经典论文,提出V-M-C三组件架构 -
Dreamer系列 (Hafner et al., 2019-2023)
现代世界模型方法,引入Transformer和潜在动态 -
MBPO (Janner et al., 2019)
模型误差分析与短期想象策略 -
PlaNet & Dreamer (Hafner et al., 2019)
纯视觉输入的世界模型控制