Skip to content

Commit 088bdb6

Browse files
added_mcts_and_metrics
1 parent 6c79259 commit 088bdb6

File tree

3 files changed

+568
-0
lines changed

3 files changed

+568
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "18676180",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import numpy as np\n",
11+
"import ipywidgets as widgets\n",
12+
"from tqdm import tqdm\n",
13+
"import random\n",
14+
"import matplotlib.pyplot as plt"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 254,
20+
"id": "d69c70f2",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"class MCTSNode:\n",
25+
" def __init__(self, state, parent_node):\n",
26+
" self.state = state\n",
27+
" self.parent_node = parent_node\n",
28+
" self.total_visits = 0\n",
29+
" self.total_score = 0\n",
30+
" self.children_nodes = []\n",
31+
" self.player = self.check_player(state)\n",
32+
" self.terminate_state = False\n",
33+
" self.all_children_nodes = False\n",
34+
"\n",
35+
" def check_player(self, state):\n",
36+
" if np.sum(state==1) > np.sum(state==2):\n",
37+
" return 2\n",
38+
" else:\n",
39+
" return 1\n",
40+
"\n",
41+
"class MCTS:\n",
42+
" def __init__(self, exploration_constant = 2):\n",
43+
" self.exploration_constant = exploration_constant\n",
44+
"\n",
45+
" def is_terminal(self, board):\n",
46+
" return not np.any(board == 0)\n",
47+
"\n",
48+
" def is_win(self, state, player):\n",
49+
" col_win = (np.sum(state == player, axis=0) == 3).any()\n",
50+
" row_win = (np.sum(state == player, axis=1) == 3).any()\n",
51+
" diagonal_win = np.trace(state == player) == 3\n",
52+
" opposite_diagonal = np.trace(np.fliplr(state) == player) == 3\n",
53+
" return col_win or row_win or diagonal_win or opposite_diagonal\n",
54+
"\n",
55+
" def select(self, curr_node, should_explore=True):\n",
56+
" while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n",
57+
" if curr_node.all_children_nodes:\n",
58+
" highest_value = -float(\"inf\")\n",
59+
" chosen_child = None\n",
60+
"\n",
61+
" # loop all children nodes and take the best one according to heuristic\n",
62+
" for child in curr_node.children_nodes:\n",
63+
" # compute UCB1 score\n",
64+
" child_val = (child.total_score/child.total_visits) + should_explore*self.exploration_constant*np.sqrt(np.log(curr_node.total_visits)/child.total_visits)\n",
65+
"\n",
66+
" # if it has highest value then store it as the chosen child from this step\n",
67+
" if child_val > highest_value:\n",
68+
" highest_value = child_val\n",
69+
" chosen_child = child\n",
70+
"\n",
71+
" # choose highest value move\n",
72+
" return chosen_child\n",
73+
"\n",
74+
" else:\n",
75+
" # if not all children nodes accessible then expand the node first\n",
76+
" return self.expand(curr_node)\n",
77+
"\n",
78+
" print(\"should never come here\")\n",
79+
"\n",
80+
" def expand(self, curr_node):\n",
81+
" states = self.generate_next_states(curr_node)\n",
82+
"\n",
83+
" for state in states:\n",
84+
" # unroll children states, and ensure we do not expand to a state we have \n",
85+
" # already expanded to in a previous iteration\n",
86+
" if str(state) not in [str(b.state) for b in curr_node.children_nodes]:\n",
87+
" child_node = MCTSNode(state, curr_node)\n",
88+
" curr_node.children_nodes.append(child_node)\n",
89+
" \n",
90+
" # if the num children nodes equal the amount of possible next states\n",
91+
" # we have explored all child nodes for this state\n",
92+
" if len(states) == len(curr_node.children_nodes):\n",
93+
" curr_node.all_children_nodes = True\n",
94+
"\n",
95+
" return child_node\n",
96+
"\n",
97+
"\n",
98+
" def simulate(self, curr_node, computer_playing):\n",
99+
" opponent = 1 if computer_playing == 2 else 1\n",
100+
" \n",
101+
" while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n",
102+
" next_states = self.generate_next_states(curr_node)\n",
103+
" curr_node = MCTSNode(next_states[random.randint(0, len(next_states) - 1)], curr_node)\n",
104+
" \n",
105+
" if self.is_win(curr_node.state, player=computer_playing):\n",
106+
" return 1\n",
107+
" elif self.is_win(curr_node.state, player=opponent):\n",
108+
" return -1\n",
109+
" else:\n",
110+
" return 0\n",
111+
"\n",
112+
" \n",
113+
" def backpropagate(self, node, score):\n",
114+
" while node:\n",
115+
" node.total_visits += 1\n",
116+
" node.total_score += score\n",
117+
" node = node.parent_node\n",
118+
" \n",
119+
" def generate_next_states(self, curr_node):\n",
120+
" player = curr_node.player\n",
121+
" curr_state = curr_node.state\n",
122+
" next_states = []\n",
123+
" for i in range(3):\n",
124+
" for j in range(3):\n",
125+
" if curr_state[i,j] == 0:\n",
126+
" to_append = np.copy(curr_state)\n",
127+
" to_append[i,j] = player\n",
128+
" next_states.append(to_append)\n",
129+
" return next_states\n",
130+
"\n",
131+
"\n",
132+
" def get_move(self, root, num_iterations=1000):\n",
133+
" for it in range(num_iterations):\n",
134+
" curr_node = self.select(root)\n",
135+
" obtained_value = self.simulate(curr_node, root.player)\n",
136+
" self.backpropagate(curr_node, obtained_value)\n",
137+
" \n",
138+
" chosen_move = self.select(root, should_explore=False)\n",
139+
" return chosen_move"
140+
]
141+
},
142+
{
143+
"cell_type": "code",
144+
"execution_count": 263,
145+
"id": "36e39228",
146+
"metadata": {
147+
"scrolled": true
148+
},
149+
"outputs": [
150+
{
151+
"name": "stdout",
152+
"output_type": "stream",
153+
"text": [
154+
"Row and column to place with ,1,1\n",
155+
"[[0. 0. 0.]\n",
156+
" [0. 1. 0.]\n",
157+
" [0. 0. 2.]]\n",
158+
"Row and column to place with ,0,0\n",
159+
"[[1. 0. 0.]\n",
160+
" [0. 1. 0.]\n",
161+
" [2. 0. 2.]]\n",
162+
"Row and column to place with ,2,1\n",
163+
"[[1. 2. 0.]\n",
164+
" [0. 1. 0.]\n",
165+
" [2. 1. 2.]]\n",
166+
"Row and column to place with ,1,2\n",
167+
"[[1. 2. 0.]\n",
168+
" [2. 1. 1.]\n",
169+
" [2. 1. 2.]]\n",
170+
"Row and column to place with ,0,2\n",
171+
"should never come here\n"
172+
]
173+
},
174+
{
175+
"ename": "AttributeError",
176+
"evalue": "'NoneType' object has no attribute 'state'",
177+
"output_type": "error",
178+
"traceback": [
179+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
180+
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
181+
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/2518229713.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mnext_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mroot\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_move\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
182+
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36mget_move\u001b[1;34m(self, root, num_iterations)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_iterations\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mobtained_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msimulate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplayer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackpropagate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mobtained_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
183+
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36msimulate\u001b[1;34m(self, curr_node, computer_playing)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[0mopponent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcomputer_playing\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m2\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[1;32mwhile\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_terminal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[0mnext_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_next_states\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
184+
"\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'state'"
185+
]
186+
}
187+
],
188+
"source": [
189+
"a = np.zeros((3,3))\n",
190+
"root = MCTSNode(a, None)\n",
191+
"mc = MCTS()\n",
192+
"\n",
193+
"for i in range(9):\n",
194+
" row_col = input(\"Row and column to place with ,\").split(\",\")\n",
195+
" state = np.copy(root.state)\n",
196+
" state[int(row_col[0]), int(row_col[1])] = 1\n",
197+
" next_node = MCTSNode(state, root)\n",
198+
" \n",
199+
" root = mc.get_move(next_node)\n",
200+
" print(root.state)\n",
201+
"\n",
202+
"print(\"Final: {root.state}\")\n",
203+
" "
204+
]
205+
}
206+
],
207+
"metadata": {
208+
"kernelspec": {
209+
"display_name": "Python 3 (ipykernel)",
210+
"language": "python",
211+
"name": "python3"
212+
},
213+
"language_info": {
214+
"codemirror_mode": {
215+
"name": "ipython",
216+
"version": 3
217+
},
218+
"file_extension": ".py",
219+
"mimetype": "text/x-python",
220+
"name": "python",
221+
"nbconvert_exporter": "python",
222+
"pygments_lexer": "ipython3",
223+
"version": "3.9.5"
224+
}
225+
},
226+
"nbformat": 4,
227+
"nbformat_minor": 5
228+
}

Diff for: ML/ml_metrics/data.txt

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
0 0.827142151760153
2+
0 0.6044595910412887
3+
0 0.7916340858282026
4+
0 0.16080518180592987
5+
0 0.611222921705038
6+
0 0.2555087295500818
7+
0 0.5681507664364468
8+
0 0.05990570219972058
9+
0 0.6644434078306367
10+
0 0.11293577405861703
11+
0 0.06152372321587048
12+
0 0.35250697207600584
13+
0 0.3226701829081975
14+
0 0.43339115381458776
15+
0 0.2280744262436838
16+
0 0.7219848389339433
17+
0 0.23527698971402375
18+
0 0.2850245335200196
19+
0 0.4107047877448165
20+
0 0.2008356196164621
21+
0 0.3711921802697385
22+
0 0.4234822657253734
23+
0 0.4876482027124213
24+
0 0.4234822657253734
25+
0 0.5750985220664769
26+
0 0.6734047730095499
27+
0 0.7355892648444824
28+
0 0.7137899092959652
29+
0 0.3873972469024071
30+
0 0.24042033264833723
31+
0 0.1663411647259707
32+
0 0.1663411647259707
33+
0 0.2850245335200196
34+
0 0.3683741846950643
35+
0 0.17375784896208155
36+
0 0.43636290738886574
37+
0 0.7219848389339433
38+
0 0.46745878087292836
39+
0 0.23527698971402375
40+
0 0.17202866439941822
41+
0 0.17786913865061538
42+
0 0.44335359557308707
43+
0 0.2768503833164947
44+
0 0.06891755391553003
45+
0 0.21414010746535972
46+
0 0.27120595352357546
47+
0 0.26328216986315905
48+
0 0.48056205121673834
49+
0 0.08848560476699129
50+
0 0.2555087295500818
51+
1 0.5681507664364468
52+
1 0.2850245335200196
53+
1 0.842216416418616
54+
1 0.5280820469827786
55+
1 0.6302728469340095
56+
1 0.9325162813331325
57+
1 0.062225621463076315
58+
1 0.8823445035377085
59+
1 0.670739773835188
60+
1 0.891663414209465
61+
1 0.6489254823470298
62+
1 0.5552119758821265
63+
1 0.7510275470993321
64+
1 0.23310831157247616
65+
1 0.2933421288888426
66+
1 0.6044595910412887
67+
1 0.6302728469340095
68+
1 0.9585115007613662
69+
1 0.9342800686704079
70+
1 0.3226701829081975
71+
1 0.7982301827889998
72+
1 0.22102862644325694
73+
1 0.9390780973389883
74+
1 0.5078780077620866
75+
1 0.7379344573081708
76+
1 0.8750078631067137
77+
1 0.4704701704107932
78+
1 0.44335359557308707
79+
1 0.5651814720676593
80+
1 0.8658845001112441
81+
1 0.897024614730928
82+
1 0.9712637967845552
83+
1 0.5651814720676593
84+
1 0.517987379389242
85+
1 0.40385540386469254
86+
1 0.9435470013187671
87+
1 0.5780506539476005
88+
1 0.594744923406366
89+
1 0.3970432858350056
90+
1 0.7916340858282026
91+
1 0.7219848389339433
92+
1 0.7916340858282026
93+
1 0.2850245335200196
94+
1 0.7658513560779588
95+
1 0.7379344573081708
96+
1 0.7137899092959652
97+
1 0.4876482027124213
98+
1 0.6302728469340095
99+
1 0.5310944974701136
100+
1 0.35250697207600584

0 commit comments

Comments
 (0)