diff --git a/.gitignore b/.gitignore index cddc84d..a7dd2c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ .ipynb_checkpoints .*.swp *.pyc +__pycache__ +.vscode/ +test.py +venv/ \ No newline at end of file diff --git a/muzero/model_torch.py b/muzero/model_torch.py new file mode 100644 index 0000000..713b458 --- /dev/null +++ b/muzero/model_torch.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +from collections import namedtuple +from typing import Tuple, List, Union + + +def set_seed(seed: int = 42) -> None: + torch.manual_seed(seed) + np.random.seed(seed) + + +def bstack(bb: List[Union[float, np.ndarray]]) -> List[np.ndarray]: + # reduced loop version of bstak + l, ll = len(bb), len(bb[0]) + return [np.array([i[j] for i in bb]).reshape(l, -1) for j in range(ll)] + + +def to_one_hot(a: np.ndarray, K: int, a_dim: int) -> np.ndarray: + # vectorized version of to_one_hot + one_hot_action = np.zeros((K * a_dim)) + index = np.arange(0, K * a_dim, a_dim) + index = index[a >= 0] + one_hot_action[a[a >= 0] + index] = 1 + return np.split(one_hot_action, K) + + +def reformat_batch(batch: np.ndarray, K: int, a_dim: int, remove_policy=False) -> Tuple[List[np.ndarray]]: + X, Y = [], [] + for o, a, outs in batch: + a = np.array(a) + x = [o] + to_one_hot(a, K, a_dim) + + # flatten outs + y = [item for sublist in outs for item in sublist] + + X.append(x) + Y.append(y) + + X, Y = bstack(X), bstack(Y) + + if remove_policy: + nY = [Y[0]] + for i in range(3, len(Y), 3): + nY.append(Y[i]) + nY.append(Y[i + 1]) + Y = nY + else: + Y.pop(1) + + return X, Y + + +class DenseRepresentation(nn.Module): + # h network + def __init__(self, o_dim: int, s_dim: int, hidden_layer_dim: int, hidden_layer_count: int) -> None: + super().__init__() + + sequential = [nn.Linear(o_dim, hidden_layer_dim), nn.ELU()] + sequential += [nn.Linear(hidden_layer_dim, + hidden_layer_dim), nn.ELU()] * hidden_layer_count + self.out = nn.Linear(hidden_layer_dim, s_dim) + + self.linearReluStack = nn.Sequential(*tuple(sequential)) + + def forward(self, state: torch.Tensor) -> torch.Tensor: + x = self.linearReluStack(state) + return self.out(x) + + +class DenseDynamics(nn.Module): + # g network + def __init__(self, s_dim: int, a_dim: int, hidden_layer_dim: int, hidden_layer_count: int) -> None: + super().__init__() + + sequential = [nn.Linear(s_dim + a_dim, hidden_layer_dim), nn.ELU()] + sequential += [nn.Linear(hidden_layer_dim, + hidden_layer_dim), nn.ELU()] * hidden_layer_count + + self.linearReluStack = nn.Sequential(*tuple(sequential)) + self.out1 = nn.Linear(hidden_layer_dim, 1) + self.out2 = nn.Linear(hidden_layer_dim, s_dim) + + def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.cat([state.T, action.T]).T + x = self.linearReluStack(x) + + return self.out1(x), self.out2(x) + + +class DensePrediction(nn.Module): + # f network + def __init__(self, s_dim: int, a_dim: int, hidden_layer_dim: int, hidden_layer_count: int, with_policy: bool = True) -> None: + super().__init__() + self.with_policy = with_policy + + sequential = [nn.Linear(s_dim, hidden_layer_dim), nn.ELU()] + sequential += [nn.Linear(hidden_layer_dim, + hidden_layer_dim), nn.ELU()] * hidden_layer_count + + self.linearReluStack = nn.Sequential(*tuple(sequential)) + self.out1 = nn.Linear(hidden_layer_dim, a_dim) + self.out2 = nn.Linear(hidden_layer_dim, 1) + + def forward(self, state: torch.Tensor) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + x = self.linearReluStack(state) + + if self.with_policy: + return self.out1(x), self.out2(x) + return self.out2(x) + + +class MuModel: + LAYER_COUNT = 4 + LAYER_DIM = 128 + BN = False + + def __init__(self, observation_dim: int, action_dim: int, s_dim: int = 8, K: int = 5, lr: float = 1e-3, with_policy: bool = True, device='cpu') -> None: + self.observation_dim, self.action_dim, self.s_dim = observation_dim, action_dim, s_dim + self.K, self.lr, self.with_policy = K, lr, with_policy + self.device = device + + self.h = DenseRepresentation( + o_dim=observation_dim[0], s_dim=s_dim, hidden_layer_dim=self.LAYER_DIM, hidden_layer_count=self.LAYER_COUNT).to(device) + self.g = DenseDynamics(s_dim=s_dim, a_dim=action_dim, hidden_layer_dim=self.LAYER_DIM, + hidden_layer_count=self.LAYER_COUNT).to(device) + self.f = DensePrediction(s_dim=s_dim, a_dim=action_dim, hidden_layer_dim=self.LAYER_DIM, + hidden_layer_count=self.LAYER_COUNT, with_policy=with_policy).to(device) + + params = list(self.h.parameters()) + \ + list(self.g.parameters()) + list(self.f.parameters()) + self.optimizer = optim.Adam(params, lr=self.lr) + self.losses = [] + + # make class compatible with Geohot's other code + self.o_dim = self.observation_dim + self.a_dim = self.action_dim + Mu = namedtuple('mu', 'predict') + self.mu = Mu(self.predict) + + def forward(self, X: List[torch.Tensor], train: bool = True) -> List[torch.Tensor]: + self.h.eval(), self.g.eval(), self.f.eval() + if train: + self.h.train(), self.g.train(), self.f.train() + + X = [torch.from_numpy(x.astype(np.float32)).to(self.device) for x in X] + Y_pred = [] + + state = self.h(X[0]) + if self.with_policy: + policy, value = self.f(state) + Y_pred += [value, policy] + else: + value = self.f(state) + Y_pred.append(value) + + for k in range(self.K): + reward, new_state = self.g(state, X[k + 1]) + if self.with_policy: + policy, value = self.f(state) + Y_pred += [value, reward, policy] + else: + value = self.f(state) + Y_pred += [value, reward] + + state = new_state + + return Y_pred + + def predict(self, X: List[torch.Tensor]) -> List[torch.Tensor]: + with torch.no_grad(): + Y_pred = self.forward(X, train=False) + return Y_pred + + def train(self, batch: List[np.ndarray]) -> None: + self.h.train(), self.g.train(), self.f.train() + losses = [] + mse, smcel = F.mse_loss, nn.BCEWithLogitsLoss() + + X, Y = reformat_batch( + batch, self.K, self.action_dim, not self.with_policy) + Y = [torch.from_numpy(y.astype(np.float32)).to(self.device) for y in Y] + Y_pred = self.forward(X, train=True) + + losses.append(mse(Y_pred[0], Y[0])) + if self.with_policy: + losses.append(smcel(Y_pred[1], Y[1])) + + for k in range(self.K): + losses.append(mse(Y_pred[3 * k + 2], Y[3 * k + 2])) + losses.append(mse(Y_pred[3 * k + 3], Y[3 * k + 3])) + if self.with_policy: + losses.append(smcel(Y_pred[3 * k + 4], Y[3 * k + 4])) + + loss = sum(losses) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.losses.append([loss.item()] + [l.item() for l in losses]) + + def ht(self, state: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + with torch.no_grad(): + if not torch.is_tensor(state): + state = torch.from_numpy( + state.astype(np.float32)).to(self.device) + return self.h(state) + + def ft(self, state: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + with torch.no_grad(): + if not torch.is_tensor(state): + state = torch.from_numpy( + state.astype(np.float32)).to(self.device) + if self.with_policy: + policy, value = self.f(state) + return policy.exp(), value + else: + value = self.f(state) + return value diff --git a/muzero_torch_cartpole_v3.ipynb b/muzero_torch_cartpole_v3.ipynb new file mode 100644 index 0000000..f0e9821 --- /dev/null +++ b/muzero_torch_cartpole_v3.ipynb @@ -0,0 +1,442 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "869c49bb-9b92-404e-9a93-2b5f6211a5f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "%pylab inline\n", + "import gym\n", + "import collections\n", + "import random\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from muzero import model_torch\n", + "from muzero import game as Game\n", + "from muzero.mcts import naive_search" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "aee9b919-164a-4597-9d8f-93909ec87830", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.9.0+cu102\n" + ] + } + ], + "source": [ + "print(model_torch.torch.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b0e95d29-c29c-4a69-ac7f-f62c3e585d5a", + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.make('CartPole-v0')" + ] + }, + { + "cell_type": "markdown", + "id": "53f1dd1a-bdee-45e0-b47a-290ae1edb033", + "metadata": {}, + "source": [ + "Set Seed for Reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a3326cf7-6501-44f8-8b9b-5c801da203e0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[42]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "seed = 42\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "model_torch.set_seed(seed)\n", + "Game.random.seed(seed)\n", + "Game.np.random.seed(seed)\n", + "env.seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "be635d82-c58a-4e02-9cd3-95c110b54b84", + "metadata": {}, + "outputs": [], + "source": [ + "def play_game(env, m):\n", + " game = Game.Game(env, discount=0.997)\n", + " while not game.terminal():\n", + " cc = random.random()\n", + " if (cc < 0.05):\n", + " policy = [1 / m.action_dim] * m.action_dim\n", + " else:\n", + " policy = naive_search(m, game.observation, T=1)\n", + " game.act_with_policy(policy)\n", + " return game" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "dcb7eb2b-c6e1-4972-a8da-308f22ab6ef3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4,) 2\n" + ] + } + ], + "source": [ + "device = 'cpu'\n", + "m = model_torch.MuModel(env.observation_space.shape, env.action_space.n, s_dim=128, K=3, lr=1e-3, device=device)\n", + "replay_buffer = Game.ReplayBuffer(50, 128, m.K)\n", + "print(env.observation_space.shape, env.action_space.n)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3b6106d5-095b-41e7-a5c5-40d1fecda75e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "return: 29, action_count: 0 ( 16) 1 ( 13), loss: 252.86\n", + "return: 14, action_count: 0 ( 11) 1 ( 3), loss: 124.44\n", + "return: 16, action_count: 0 ( 6) 1 ( 10), loss: 143.55\n", + "return: 25, action_count: 0 ( 12) 1 ( 13), loss: 173.56\n", + "return: 11, action_count: 0 ( 2) 1 ( 9), loss: 210.48\n", + "return: 45, action_count: 0 ( 22) 1 ( 23), loss: 277.84\n", + "return: 18, action_count: 0 ( 10) 1 ( 8), loss: 203.85\n", + "return: 30, action_count: 0 ( 15) 1 ( 15), loss: 238.98\n", + "return: 16, action_count: 0 ( 6) 1 ( 10), loss: 94.64\n", + "return: 126, action_count: 0 ( 63) 1 ( 63), loss: 1131.46\n", + "return: 40, action_count: 0 ( 21) 1 ( 19), loss: 926.59\n", + "return: 41, action_count: 0 ( 20) 1 ( 21), loss: 1273.09\n", + "return: 200, action_count: 0 ( 97) 1 (103), loss: 2598.50\n", + "return: 194, action_count: 0 ( 96) 1 ( 98), loss: 4195.48\n", + "return: 41, action_count: 0 ( 24) 1 ( 17), loss: 4343.98\n", + "return: 81, action_count: 0 ( 39) 1 ( 42), loss: 3017.06\n", + "return: 54, action_count: 0 ( 27) 1 ( 27), loss: 4378.75\n", + "return: 53, action_count: 0 ( 33) 1 ( 20), loss: 2689.64\n", + "return: 41, action_count: 0 ( 23) 1 ( 18), loss: 3317.25\n", + "return: 48, action_count: 0 ( 27) 1 ( 21), loss: 2766.92\n", + "return: 81, action_count: 0 ( 42) 1 ( 39), loss: 2947.62\n", + "return: 200, action_count: 0 ( 99) 1 (101), loss: 3802.24\n", + "return: 59, action_count: 0 ( 32) 1 ( 27), loss: 3446.98\n", + "return: 100, action_count: 0 ( 46) 1 ( 54), loss: 3019.74\n", + "return: 200, action_count: 0 (102) 1 ( 98), loss: 3470.03\n", + "return: 200, action_count: 0 (101) 1 ( 99), loss: 3593.76\n", + "return: 200, action_count: 0 (102) 1 ( 98), loss: 3797.45\n", + "return: 200, action_count: 0 (100) 1 (100), loss: 3885.07\n", + "return: 200, action_count: 0 (101) 1 ( 99), loss: 4570.01\n", + "return: 200, action_count: 0 (103) 1 ( 97), loss: 4062.75\n" + ] + } + ], + "source": [ + "rews = []\n", + "\n", + "fmt = 'return: %3u, action_count: 0 (%3u) 1 (%3u), loss: %8.2f'\n", + "for j in range(30):\n", + " game = play_game(env, m)\n", + " replay_buffer.save_game(game)\n", + " for i in range(20):\n", + " m.train(replay_buffer.sample_batch())\n", + " rew = sum(game.rewards)\n", + " rews.append(rew)\n", + " history = np.array(game.history)\n", + " actions = [(history == 0).sum(), (history == 1).sum()]\n", + " print(fmt % (rew, *actions, m.losses[-1][0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "29cab57a-ebf3-4fa8-8ed6-852df7455568", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot(rews)\n", + "figure()\n", + "plt.yscale('log')\n", + "plot([x[0] for x in m.losses])\n", + "plot([x[1] for x in m.losses])\n", + "plot([x[-3] for x in m.losses])" + ] + }, + { + "cell_type": "markdown", + "id": "80a5e8ee-f19f-4baa-a150-0de278cbae27", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "can act?" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "52850a94-9d85-47c9-bb24-15d9d029281a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "action: 1 value: 32.34, reward: 1, policy: [1.9e-15 1.0e+00]\n", + "action: 1 value: 45.02, reward: 1, policy: [9.0e-16 1.0e+00]\n", + "action: 0 value: 44.56, reward: 1, policy: [1.0e+00 2.6e-14]\n", + "action: 1 value: 48.30, reward: 1, policy: [2.7e-16 1.0e+00]\n", + "action: 0 value: 38.88, reward: 1, policy: [1.0e+00 3.8e-15]\n", + "action: 1 value: 50.63, reward: 1, policy: [2.8e-02 9.7e-01]\n", + "action: 0 value: 31.59, reward: 1, policy: [1.0e+00 4.9e-13]\n", + "action: 0 value: 46.69, reward: 1, policy: [1.0e+00 1.2e-14]\n", + "action: 1 value: 53.09, reward: 1, policy: [1.4e-13 1.0e+00]\n", + "action: 0 value: 38.67, reward: 1, policy: [1.0e+00 2.1e-15]\n", + "action: 0 value: 52.90, reward: 1, policy: [1.0e+00 7.7e-13]\n", + "action: 1 value: 57.10, reward: 1, policy: [9.0e-14 1.0e+00]\n", + "action: 0 value: 44.46, reward: 1, policy: [1.0e+00 1.2e-15]\n", + "action: 0 value: 57.76, reward: 1, policy: [1.0e+00 4.1e-08]\n", + "action: 1 value: 61.95, reward: 1, policy: [1.3e-11 1.0e+00]\n", + "action: 0 value: 49.79, reward: 1, policy: [1.0e+00 1.9e-15]\n", + "action: 0 value: 61.80, reward: 1, policy: [1.0e+00 1.9e-03]\n", + "action: 1 value: 66.00, reward: 1, policy: [1.3e-08 1.0e+00]\n", + "action: 0 value: 54.68, reward: 1, policy: [1.0e+00 3.8e-14]\n", + "action: 1 value: 65.05, reward: 1, policy: [2.7e-02 9.7e-01]\n", + "action: 0 value: 43.98, reward: 1, policy: [1.0e+00 6.2e-16]\n", + "action: 0 value: 57.79, reward: 1, policy: [1.0e+00 3.7e-12]\n", + "action: 1 value: 66.26, reward: 1, policy: [6.5e-04 1.0e+00]\n", + "action: 0 value: 48.09, reward: 1, policy: [1.0e+00 1.1e-15]\n", + "action: 0 value: 60.51, reward: 1, policy: [1.0e+00 1.1e-08]\n", + "action: 1 value: 65.65, reward: 1, policy: [4.9e-06 1.0e+00]\n", + "action: 0 value: 52.98, reward: 1, policy: [1.0e+00 3.7e-14]\n", + "action: 0 value: 62.41, reward: 1, policy: [8.1e-01 1.9e-01]\n", + "action: 1 value: 61.66, reward: 1, policy: [5.7e-10 1.0e+00]\n", + "action: 0 value: 58.33, reward: 1, policy: [1.0e+00 4.6e-09]\n", + "action: 1 value: 61.52, reward: 1, policy: [2.8e-07 1.0e+00]\n", + "action: 0 value: 53.30, reward: 1, policy: [1.0e+00 6.0e-13]\n", + "action: 1 value: 60.06, reward: 1, policy: [1.3e-03 1.0e+00]\n", + "action: 0 value: 47.97, reward: 1, policy: [1.0e+00 1.1e-14]\n", + "action: 0 value: 57.54, reward: 1, policy: [1.0e+00 2.0e-05]\n", + "action: 1 value: 55.00, reward: 1, policy: [8.3e-12 1.0e+00]\n", + "action: 0 value: 55.50, reward: 1, policy: [1.0e+00 9.3e-09]\n", + "action: 1 value: 55.33, reward: 1, policy: [4.6e-10 1.0e+00]\n", + "action: 0 value: 53.47, reward: 1, policy: [1.0e+00 1.3e-10]\n", + "action: 1 value: 55.26, reward: 1, policy: [3.6e-08 1.0e+00]\n", + "action: 0 value: 51.62, reward: 1, policy: [1.0e+00 1.5e-11]\n", + "action: 1 value: 54.96, reward: 1, policy: [2.9e-06 1.0e+00]\n", + "action: 0 value: 50.00, reward: 1, policy: [1.0e+00 4.5e-12]\n", + "action: 1 value: 54.53, reward: 1, policy: [2.0e-04 1.0e+00]\n", + "action: 0 value: 48.61, reward: 1, policy: [1.0e+00 1.8e-12]\n", + "action: 1 value: 54.02, reward: 1, policy: [1.9e-02 9.8e-01]\n", + "action: 0 value: 47.44, reward: 1, policy: [1.0e+00 8.4e-13]\n", + "action: 0 value: 53.47, reward: 1, policy: [3.4e-01 6.6e-01]\n", + "action: 1 value: 43.20, reward: 1, policy: [6.1e-13 1.0e+00]\n", + "action: 1 value: 53.70, reward: 1, policy: [3.2e-03 1.0e+00]\n", + "action: 0 value: 47.63, reward: 1, policy: [1.0e+00 4.2e-12]\n", + "action: 1 value: 53.27, reward: 1, policy: [9.4e-04 1.0e+00]\n", + "action: 0 value: 47.52, reward: 1, policy: [1.0e+00 8.2e-12]\n", + "action: 1 value: 52.80, reward: 1, policy: [9.9e-05 1.0e+00]\n", + "action: 0 value: 47.51, reward: 1, policy: [1.0e+00 2.1e-11]\n", + "action: 1 value: 52.29, reward: 1, policy: [8.4e-06 1.0e+00]\n", + "action: 0 value: 47.58, reward: 1, policy: [1.0e+00 6.7e-11]\n", + "action: 1 value: 51.72, reward: 1, policy: [4.6e-07 1.0e+00]\n", + "action: 0 value: 47.73, reward: 1, policy: [1.0e+00 2.8e-10]\n", + "action: 1 value: 51.06, reward: 1, policy: [2.3e-08 1.0e+00]\n", + "action: 0 value: 47.92, reward: 1, policy: [1.0e+00 1.5e-09]\n", + "action: 1 value: 50.28, reward: 1, policy: [1.6e-09 1.0e+00]\n", + "action: 0 value: 48.14, reward: 1, policy: [1.0e+00 1.0e-08]\n", + "action: 1 value: 49.33, reward: 1, policy: [1.6e-10 1.0e+00]\n", + "action: 0 value: 48.36, reward: 1, policy: [1.0e+00 1.9e-07]\n", + "action: 1 value: 48.17, reward: 1, policy: [2.7e-11 1.0e+00]\n", + "action: 0 value: 48.55, reward: 1, policy: [1.0e+00 6.5e-06]\n", + "action: 1 value: 46.79, reward: 1, policy: [6.1e-12 1.0e+00]\n", + "action: 0 value: 48.68, reward: 1, policy: [1.0e+00 2.0e-04]\n", + "action: 1 value: 45.17, reward: 1, policy: [1.7e-12 1.0e+00]\n", + "action: 0 value: 48.69, reward: 1, policy: [9.9e-01 8.3e-03]\n", + "action: 1 value: 43.32, reward: 1, policy: [9.8e-13 1.0e+00]\n", + "action: 1 value: 48.53, reward: 1, policy: [2.7e-01 7.3e-01]\n", + "action: 0 value: 39.95, reward: 1, policy: [1.0e+00 2.8e-12]\n", + "action: 1 value: 47.45, reward: 1, policy: [3.6e-03 1.0e+00]\n", + "action: 0 value: 39.62, reward: 1, policy: [1.0e+00 8.7e-12]\n", + "action: 1 value: 46.10, reward: 1, policy: [3.0e-05 1.0e+00]\n", + "action: 0 value: 39.01, reward: 1, policy: [1.0e+00 1.7e-11]\n", + "action: 1 value: 44.50, reward: 1, policy: [1.9e-07 1.0e+00]\n", + "action: 0 value: 38.12, reward: 1, policy: [1.0e+00 9.5e-11]\n", + "action: 1 value: 42.66, reward: 1, policy: [7.9e-10 1.0e+00]\n", + "action: 0 value: 36.95, reward: 1, policy: [1.0e+00 2.5e-09]\n", + "action: 1 value: 40.63, reward: 1, policy: [2.6e-13 1.0e+00]\n", + "action: 0 value: 35.56, reward: 1, policy: [1.0e+00 1.6e-08]\n", + "action: 1 value: 38.46, reward: 1, policy: [1.2e-17 1.0e+00]\n", + "action: 0 value: 33.97, reward: 1, policy: [1.0e+00 3.7e-08]\n", + "action: 1 value: 36.28, reward: 1, policy: [6.8e-15 1.0e+00]\n", + "action: 0 value: 32.18, reward: 1, policy: [1.0e+00 2.1e-08]\n", + "action: 1 value: 34.17, reward: 1, policy: [2.3e-13 1.0e+00]\n", + "action: 0 value: 30.28, reward: 1, policy: [1.0e+00 2.4e-09]\n", + "action: 1 value: 32.20, reward: 1, policy: [4.9e-12 1.0e+00]\n", + "action: 0 value: 28.18, reward: 1, policy: [1.0e+00 1.7e-10]\n", + "action: 1 value: 30.41, reward: 1, policy: [1.5e-10 1.0e+00]\n", + "action: 0 value: 25.75, reward: 1, policy: [1.0e+00 1.9e-11]\n", + "action: 1 value: 28.79, reward: 1, policy: [1.0e-07 1.0e+00]\n", + "action: 0 value: 23.17, reward: 1, policy: [1.0e+00 5.4e-14]\n", + "action: 1 value: 27.05, reward: 1, policy: [1.8e-02 9.8e-01]\n", + "action: 0 value: 20.34, reward: 1, policy: [1.0e+00 1.9e-15]\n", + "action: 0 value: 24.93, reward: 1, policy: [1.0e+00 5.5e-08]\n", + "action: 1 value: 23.81, reward: 1, policy: [1.5e-13 1.0e+00]\n", + "action: 0 value: 22.97, reward: 1, policy: [1.0e+00 9.3e-12]\n", + "action: 1 value: 22.90, reward: 1, policy: [1.0e-12 1.0e+00]\n", + "action: 0 value: 20.86, reward: 1, policy: [1.0e+00 2.4e-14]\n", + "action: 1 value: 21.72, reward: 1, policy: [5.6e-11 1.0e+00]\n", + "action: 0 value: 18.62, reward: 1, policy: [1.0e+00 3.7e-16]\n", + "action: 0 value: 20.24, reward: 1, policy: [3.9e-01 6.1e-01]\n", + "action: 1 value: 16.81, reward: 1, policy: [4.2e-17 1.0e+00]\n", + "action: 0 value: 19.00, reward: 1, policy: [1.0e+00 1.1e-11]\n", + "action: 1 value: 16.07, reward: 1, policy: [3.8e-17 1.0e+00]\n", + "action: 0 value: 17.66, reward: 1, policy: [1.0e+00 7.2e-16]\n", + "action: 1 value: 15.19, reward: 1, policy: [4.5e-16 1.0e+00]\n", + "action: 0 value: 16.26, reward: 1, policy: [1.0e+00 1.7e-17]\n", + "action: 1 value: 14.24, reward: 1, policy: [2.2e-13 1.0e+00]\n", + "action: 0 value: 14.82, reward: 1, policy: [1.0e+00 4.0e-16]\n", + "action: 1 value: 13.26, reward: 1, policy: [5.3e-09 1.0e+00]\n", + "action: 0 value: 13.47, reward: 1, policy: [1.0e+00 4.9e-15]\n", + "action: 1 value: 12.29, reward: 1, policy: [8.6e-02 9.1e-01]\n", + "action: 0 value: 12.24, reward: 1, policy: [1.0e+00 3.6e-14]\n", + "action: 0 value: 11.38, reward: 1, policy: [1.0e+00 1.7e-05]\n", + "action: 1 value: 8.88, reward: 1, policy: [2.8e-08 1.0e+00]\n", + "action: 0 value: 10.72, reward: 1, policy: [1.0e+00 5.1e-06]\n", + "action: 1 value: 8.48, reward: 1, policy: [2.0e-07 1.0e+00]\n", + "action: 0 value: 10.09, reward: 1, policy: [1.0e+00 7.5e-06]\n", + "action: 1 value: 8.06, reward: 1, policy: [8.9e-07 1.0e+00]\n", + "action: 0 value: 9.50, reward: 1, policy: [1.0e+00 2.7e-05]\n", + "action: 1 value: 7.61, reward: 1, policy: [2.5e-06 1.0e+00]\n", + "action: 0 value: 8.96, reward: 1, policy: [1.0e+00 1.8e-04]\n", + "action: 1 value: 7.17, reward: 1, policy: [4.5e-06 1.0e+00]\n", + "DONE 127\n" + ] + } + ], + "source": [ + "fmt = 'action: %1u value: %6.2f, reward: %2u, policy: [%2.1e %2.1e]'\n", + "state = env.reset()\n", + "for sn in range(2000):\n", + " p_0 = naive_search(m, state, debug=False, T=0.1)\n", + " #p_0, _ = mcts_search(m, state, 50)\n", + " #print(p_0)\n", + " \n", + " a_1 = np.random.choice(list(range(len(p_0))), p=p_0)\n", + " _, v_0 = m.ft(m.ht(state))\n", + " state,r,done,_ = env.step(a_1)\n", + " print(fmt % (a_1, v_0[0], r, p_0[0], p_0[1]))\n", + " if done:\n", + " print(\"DONE\", sn)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b3d6446-8808-4e17-8da3-4016b315401a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}