Skip to content

Commit 6be84c8

Browse files
authored
Merge pull request #1 from edchapman88/policyserver
Policyserver implementation for RLPolicyType
2 parents 9668f2d + 573ff2f commit 6be84c8

4 files changed

Lines changed: 136 additions & 21 deletions

File tree

bin/main.ml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,18 @@
88
*)
99

1010
open Blue
11-
module Agent = Markov.Agent.Make (MarkovCompressor) (Reward) (CountBasedPolicy)
11+
12+
module type P =
13+
Markov.Agent.RLPolicyType
14+
with type state = MarkovCompressor.state
15+
with type reward = Reward.t
16+
17+
let policy_module is_server =
18+
if is_server then (module ServerPolicy : P) else (module CountBasedPolicy)
1219

1320
let () =
1421
Cli.arg_parse ();
22+
let is_server = Cli.server_policy () |> Option.is_some in
23+
let module Policy = (val policy_module is_server) in
24+
let module Agent = Markov.Agent.Make (MarkovCompressor) (Reward) (Policy) in
1525
Agent.act (Agent.init_policy ())

blue/cli.ml

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,26 @@ let usage_msg =
99
/dev/ttyUSB0 -s 3.0 \n\n\
1010
\ Options:"
1111

12+
let parse_ip raw_addr =
13+
try
14+
match String.split_on_char ':' raw_addr with
15+
| [ ip; port ] ->
16+
Unix.ADDR_INET (Unix.inet_addr_of_string ip, int_of_string port)
17+
| _ -> raise @@ Failure "Network addresses must have exactly one ':'."
18+
with Failure msg ->
19+
failwith
20+
@@ Printf.sprintf
21+
"Failed parsing address: '%s'.\n\
22+
\ Error: %s\n\
23+
\ Accepted format is e.g. '172.0.1.1:8081'" raw_addr msg
24+
1225
let parse_addr raw_addr =
1326
if String.contains raw_addr ':' then
14-
try
15-
match String.split_on_char ':' raw_addr with
16-
| [ ip; port ] ->
17-
Unix.ADDR_INET (Unix.inet_addr_of_string ip, int_of_string port)
18-
| _ -> raise @@ Failure "Network addresses must have exactly one ':'."
19-
with Failure msg ->
20-
failwith
21-
@@ Printf.sprintf
22-
"Failed parsing response signal address: '%s'. Accepted formats are \
23-
e.g. '/dev/ttyUSB0' or '172.0.1.1:8081'. %s"
24-
raw_addr msg
27+
try parse_ip raw_addr
28+
with Failure msg -> failwith @@ msg ^ " or '/dev/USB0'"
2529
else Unix.ADDR_UNIX raw_addr
2630

2731
let _log_path = ref ""
28-
let _n_exploration_steps = ref 300
2932
let _obs_time_delay = ref 5.0
3033
let _acceptable_fraction = ref 0.8
3134
let _request_interval = ref 1.
@@ -34,6 +37,9 @@ let _red_ip = ref "172.0.0.3"
3437
let _response_signal_addr = ref "/dev/ttyUSB0"
3538
let _parsed_response_signal_addr = ref (parse_addr !_response_signal_addr)
3639
let _rolling_window_secs = ref 3.0
40+
let _n_exploration_steps = ref 300
41+
let _server_policy_addr = ref ""
42+
let _parsed_server_policy_addr = ref None
3743

3844
let speclist =
3945
[
@@ -42,10 +48,6 @@ let speclist =
4248
": Optionally write a log file in the specified directory with \
4349
information about the sequence of states observed and actions taken. If \
4450
no log file is specified, the information is written to stdout.\n" );
45-
( "-e",
46-
Arg.Set_int _n_exploration_steps,
47-
": Set the number of observations used by the policy for exploration, \
48-
after which actions are selected for exploitation. Defaults to 300.\n" );
4951
( "-t",
5052
Arg.Set_float _obs_time_delay,
5153
": Set the constant time delay between observations in seconds. Defaults \
@@ -82,22 +84,36 @@ let speclist =
8284
": Set the length of the rolling window used to evaluate the average OK \
8385
response rate indicated by the data received over the out-of-band \
8486
channel with the client. Defaults to 3.0.\n" );
87+
( "-e",
88+
Arg.Set_int _n_exploration_steps,
89+
": Set the number of observations used by the policy for exploration, \
90+
after which actions are selected for exploitation. Defaults to 300.\n" );
91+
( "--server-policy",
92+
Arg.Set_string _server_policy_addr,
93+
": Set the IP address of a policy server, overriding the \
94+
CountBasedPolicy. Serialised observations and rewards are sent as POST \
95+
requests to the configured address and actions are parsed from the \
96+
responses. The '-e' flag, used to configure the CountBasedPolicy, is \
97+
ignored when '--server-policy' is used.\n" );
8598
]
8699

87100
let log_path () =
88101
if String.length !_log_path == 0 then None else Some !_log_path
89102

90-
let n_exploration_steps () = !_n_exploration_steps
91103
let obs_time_delay () = !_obs_time_delay
92104
let acceptable_fraction () = !_acceptable_fraction
93105
let request_interval () = !_request_interval
94106
let green_ip () = !_green_ip
95107
let red_ip () = !_red_ip
96108
let response_signal_addr () = !_parsed_response_signal_addr
97109
let rolling_window_secs () = !_rolling_window_secs
110+
let n_exploration_steps () = !_n_exploration_steps
111+
let server_policy () = !_parsed_server_policy_addr
98112

99113
let arg_parse () =
100114
let parse_positional_args _ = () in
101115
Arg.parse speclist parse_positional_args usage_msg;
102116
let raw_addr = !_response_signal_addr in
103-
_parsed_response_signal_addr := parse_addr raw_addr
117+
_parsed_response_signal_addr := parse_addr raw_addr;
118+
if String.length !_server_policy_addr > 0 then
119+
_parsed_server_policy_addr := Some (parse_ip !_server_policy_addr)

blue/serverPolicy.ml

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
type reward = Reward.t
2+
type state = MarkovCompressor.state
3+
type observer = unit -> state
4+
type action = unit -> unit
5+
6+
type t = int
7+
(** Number of policy steps. *)
8+
9+
type inference = {
10+
action : action;
11+
observer : observer;
12+
policy : t;
13+
}
14+
15+
let init () = 0
16+
let init_observer = MarkovCompressor.observe
17+
let time_before_retry = 3
18+
19+
let request (state, reward) =
20+
let addr = Cli.server_policy () |> Option.get in
21+
let sock = Unix.(socket PF_INET SOCK_STREAM 0) in
22+
let _ = Unix.connect sock addr in
23+
24+
let in_ch = Unix.in_channel_of_descr sock in
25+
let out_ch = Unix.out_channel_of_descr sock in
26+
let body = (state, reward) |> Reward.string_of_state_reward in
27+
output_string out_ch
28+
@@ Printf.sprintf
29+
"POST / HTTP/1.1\r\n\
30+
Accept: application/json\r\n\
31+
Content-Type: application/json\r\n\
32+
Content-Length: %d\r\n\
33+
\r\n\
34+
%s"
35+
(String.length body) body;
36+
flush out_ch;
37+
38+
let rec readall ic acc =
39+
try readall ic (acc ^ input_line ic)
40+
with End_of_file ->
41+
Unix.close sock;
42+
acc
43+
in
44+
45+
let res = readall in_ch "" in
46+
let start_bod = String.rindex_from res (String.length res - 1) '\r' + 1 in
47+
let len_bod = String.length res - start_bod in
48+
String.sub res start_bod len_bod
49+
50+
let post (state, reward) = request (state, reward) |> System.eff_of_string
51+
52+
let infer policy (state, reward) =
53+
(* Log state and reward. *)
54+
Log.write_msg @@ Reward.string_of_state_reward (state, reward);
55+
56+
(* Send [state] and [reward] to policy server. *)
57+
let chosen_eff =
58+
match post (state, reward) with
59+
| Error msg -> (
60+
(* Write policy server failure to log and do one retry. *)
61+
Log.write_msg msg;
62+
Unix.sleep time_before_retry;
63+
(* 'Fail-fast' if the policy server fails after one retry. *)
64+
match post (state, reward) with
65+
| Error msg ->
66+
let msg = "On first (and only) retry: " ^ msg in
67+
Log.write_msg msg;
68+
failwith msg
69+
| Ok eff -> eff)
70+
| Ok eff -> eff
71+
in
72+
73+
(* Log chosen effect. *)
74+
Log.write_msg @@ System.string_of_eff chosen_eff;
75+
76+
let action () = System.exec_eff chosen_eff in
77+
{
78+
action;
79+
observer =
80+
(fun () ->
81+
Unix.sleepf @@ Cli.obs_time_delay ();
82+
MarkovCompressor.observe ());
83+
policy = policy + 1;
84+
}

blue/system.ml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,17 @@ type eff =
3838
| ToggleGreen
3939
| ToggleRed
4040

41-
let string_of_eff e =
42-
match e with
41+
let string_of_eff = function
4342
| Wait -> "Wait"
4443
| ToggleGreen -> "ToggleGreen"
4544
| ToggleRed -> "ToggleRed"
4645

46+
let eff_of_string = function
47+
| "Wait" -> Ok Wait
48+
| "ToggleGreen" -> Ok ToggleGreen
49+
| "ToggleRed" -> Ok ToggleRed
50+
| s -> Error (Printf.sprintf "Unrecognised effect: '%s'" s)
51+
4752
let random_eff () =
4853
match Random.int 3 with
4954
| 0 -> Wait

0 commit comments

Comments
 (0)