-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathpolicy_config.py
More file actions
executable file
·99 lines (88 loc) · 4.33 KB
/
policy_config.py
File metadata and controls
executable file
·99 lines (88 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import os
import pathlib
from typing import Any
import jax.numpy as jnp
import openpi.models.model as _model
import openpi.policies.policy as _policy
import openpi.shared.download as download
from openpi.training import checkpoints as _checkpoints
from openpi.training import config as _config
import openpi.transforms as transforms
def create_trained_policy(
train_config: _config.TrainConfig,
checkpoint_dir: pathlib.Path | str,
*,
repack_transforms: transforms.Group | None = None,
sample_kwargs: dict[str, Any] | None = None,
default_prompt: str | None = None,
norm_stats: dict[str, transforms.NormStats] | None = None,
pytorch_device: str | None = None,
) -> _policy.Policy:
"""Create a policy from a trained checkpoint.
Args:
train_config: The training config to use to create the model.
checkpoint_dir: The directory to load the model from.
repack_transforms: Optional transforms that will be applied before any other transforms.
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
kwargs will be used.
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
data if it doesn't already exist.
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
from the checkpoint directory.
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
Note:
The function automatically detects whether the model is PyTorch-based by checking for the
presence of "model.safensors" in the checkpoint directory.
"""
repack_transforms = repack_transforms or transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
# Check if this is a PyTorch model by looking for model.safetensors
weight_path = os.path.join(checkpoint_dir, "model.safetensors")
is_pytorch = os.path.exists(weight_path)
logging.info("Loading model...")
if is_pytorch:
model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
else:
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
if norm_stats is None:
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
# that the policy is using the same normalization stats as the original training process.
if data_config.asset_id is None:
raise ValueError("Asset id is required to load norm stats.")
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
# Determine the device to use for PyTorch models
if is_pytorch and pytorch_device is None:
try:
import torch
pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
pytorch_device = "cpu"
return _policy.Policy(
model,
transforms=[
*repack_transforms.inputs,
transforms.InjectDefaultPrompt(default_prompt),
*data_config.data_transforms.inputs,
transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
output_transforms=[
*data_config.model_transforms.outputs,
transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.data_transforms.outputs,
*repack_transforms.outputs,
],
sample_kwargs=sample_kwargs,
metadata=train_config.policy_metadata,
is_pytorch=is_pytorch,
pytorch_device=pytorch_device if is_pytorch else None,
# Wire RTC normalization params from the loaded checkpoint so served prev_action_chunk
# is normalized into model space before reaching Pi0RTC.sample_actions().
norm_stats=norm_stats,
use_quantile_norm=data_config.use_quantile_norm,
action_horizon=train_config.model.action_horizon,
)