diff --git a/opcua/client/ua_client.py b/opcua/client/ua_client.py index d7121501f..e64b748c1 100644 --- a/opcua/client/ua_client.py +++ b/opcua/client/ua_client.py @@ -3,182 +3,170 @@ """ import logging -import socket -from threading import Thread, Lock -from concurrent.futures import Future from functools import partial - +from opcua.common import uaasync +from opcua.common.uaasync import coroutine, From, Return, await_super_coro from opcua import ua from opcua.common import utils -class UASocketClient(object): - """ - handle socket connection and send ua messages - timeout is the timeout used while waiting for an ua answer from server - """ - def __init__(self, timeout=1, security_policy=ua.SecurityPolicy()): - self.logger = logging.getLogger(__name__ + "Socket") - self._thread = None - self._lock = Lock() - self.timeout = timeout - self._socket = None - self._do_stop = False +def _cancel_handle_cb(hdl, _): + hdl.cancel() + + +class UAClientProtocol(uaasync.asyncio.Protocol): + def __init__(self, security_connection, disconnected_cb): + uaasync.asyncio.Protocol.__init__(self) + self.logger = logging.getLogger(__name__ + "Protocol") self.authentication_token = ua.NodeId() + self._connection = security_connection + self.disconnected_cb = disconnected_cb + self.transport = None + self.futuremap = {} + self._buffer = utils.Buffer(b"") self._request_id = 0 self._request_handle = 0 - self._callbackmap = {} - self._connection = ua.SecureConnection(security_policy) - - def start(self): - """ - Start receiving thread. - this is called automatically in connect and - should not be necessary to call directly - """ - self._thread = Thread(target=self._run) - self._thread.start() - - def _send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage): - """ - send request to server, lower-level method - timeout is the timeout written in ua header - returns future - """ - with self._lock: - request.RequestHeader = self._create_request_header(timeout) - try: - binreq = request.to_binary() - except: - # reset reqeust handle if any error - # see self._create_request_header - self._request_handle -= 1 - raise - self._request_id += 1 - future = Future() - if callback: - future.add_done_callback(callback) - self._callbackmap[self._request_id] = future - msg = self._connection.message_to_binary(binreq, message_type, self._request_id) - self._socket.write(msg) - return future - - def send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage): - """ - send request to server. - timeout is the timeout written in ua header - returns response object if no callback is provided - """ - future = self._send_request(request, callback, timeout, message_type) - if not callback: - data = future.result(self.timeout) - self.check_answer(data, " in response to " + request.__class__.__name__) - return data - - def check_answer(self, data, context): - data = data.copy() - typeid = ua.NodeId.from_binary(data) - if typeid == ua.FourByteNodeId(ua.ObjectIds.ServiceFault_Encoding_DefaultBinary): - self.logger.warning("ServiceFault from server received %s", context) - hdr = ua.ResponseHeader.from_binary(data) - hdr.ServiceResult.check() - return False - return True - - def _run(self): - self.logger.info("Thread started") - while not self._do_stop: - try: - self._receive() - except ua.utils.SocketClosedException: - self.logger.info("Socket has closed connection") - break - self.logger.info("Thread ended") - - def _receive(self): - msg = self._connection.receive_from_socket(self._socket) - if msg is None: - return - elif isinstance(msg, ua.Message): - self._call_callback(msg.request_id(), msg.body()) - elif isinstance(msg, ua.Acknowledge): - self._call_callback(0, msg) - elif isinstance(msg, ua.ErrorMessage): - self.logger.warning("Received an error: {}".format(msg)) + self._cur_header = None + + def connection_made(self, transport): + self.logger.info("connect to server") + self.transport = transport + + def connection_lost(self, ex): + self.logger.info("connection lost") + self.close(ex) + + def close(self, ex=None): + for k in self.futuremap: + # cancel all waiting response callback + fut = self.futuremap[k] + if fut is not None and not fut.done(): + if ex is None: + fut.cancel() + else: + fut.set_exception(ex) + self.futuremap.clear() + if self.transport is not None: + self.disconnected_cb(ex) + self.transport.close() + self.transport = None + + def data_received(self, data): + self._buffer.write(data) + self.transport.pause_reading() + self._process_msg() + + def _process_msg(self): + try: + msgnum = 0 + while True: + hdr = self._get_header() + if hdr is None or len(self._buffer) < hdr.body_size: + self.transport.resume_reading() + return + # entire packet recieved, clear curent header and process it + self._cur_header = None + msgnum += self._process_one_packet(hdr) + if msgnum >= 1: + # yield cpu for packet processing + uaasync.call_soon(self._process_msg) + return + except Exception as e: + self.close(e) + + def _process_one_packet(self, hdr): + msg = self._connection.receive_from_header_and_body(hdr, self._buffer) + if msg is None: + # wait for more chunk + return 0 + elif isinstance(msg, ua.Message): + self._set_result(msg.request_id(), msg.body()) + return 1 + elif isinstance(msg, ua.Acknowledge): + self._set_result(0, msg) + return 1 + elif isinstance(msg, ua.ErrorMessage): + raise ua.UaError("Received an error: {}".format(msg)) + else: + raise ua.UaError("Unsupported message type: {}".format(msg)) + + def _get_header(self): + if self._cur_header is not None: + return self._cur_header + buf = self._buffer + if len(buf) < 8: + # a UA Header is at least 8 bytes, wait for more + return None + buf_copy = buf.copy() + try: + header = ua.Header.from_string(buf) + except ua.NotEnoughData: + # not enougth data for header, wait for more + self._buffer = buf_copy + return None + self._cur_header = header + return header + + def _set_result(self, reqid, body): + if reqid not in self.futuremap: + raise ua.UaError( + "No future found for request: {}, futures in list are {}".format( + reqid, self.futuremap.keys())) + fut = self.futuremap.pop(reqid) + if fut is not None and not fut.done(): + fut.set_result(body) + + def _timeout_request(self, name, reqid): + if reqid in self.futuremap: + fut = self.futuremap[reqid] + # keep the slot, in case a response is return in future + self.futuremap[reqid] = None + if fut is not None and not fut.done(): + fut.set_exception(uaasync.asyncio.TimeoutError("%s id %d timeout" % (name, reqid))) + + def set_authentication_token(self, token): + self.authentication_token = token + + def send_request(self, request, timeout): + if isinstance(request, ua.Hello): + msg = self._connection.tcp_to_binary(ua.MessageType.Hello, request) + reqid = 0 else: - raise ua.UaError("Unsupported message type: {}".format(msg)) - - def _call_callback(self, request_id, body): - with self._lock: - future = self._callbackmap.pop(request_id, None) - if future is None: - raise ua.UaError("No future object found for request: {}, callbacks in list are {}".format(request_id, self._callbackmap.keys())) - future.set_result(body) - - def _create_request_header(self, timeout=1000): + if isinstance(request, ua.OpenSecureChannelRequest): + msgtype = ua.MessageType.SecureOpen + elif isinstance(request, ua.CloseSecureChannelRequest): + msgtype = ua.MessageType.SecureClose + else: + msgtype = ua.MessageType.SecureMessage + timeout_hint = int(timeout * 1000) + binreq = self._to_binreq(request, timeout_hint) + self._request_id += 1 + reqid = self._request_id + msg = self._connection.message_to_binary(binreq, msgtype, reqid) + self.transport.write(msg) + fut = uaasync.new_future() + self.futuremap[reqid] = fut + handle = uaasync.call_later(timeout * 1.1, self._timeout_request, request.__class__.__name__, reqid) + fut.add_done_callback(partial(_cancel_handle_cb, handle)) + return fut + + def _to_binreq(self, request, timeout_hint): hdr = ua.RequestHeader() hdr.AuthenticationToken = self.authentication_token self._request_handle += 1 hdr.RequestHandle = self._request_handle - hdr.TimeoutHint = timeout - return hdr - - def connect_socket(self, host, port): - """ - connect to server socket and start receiving thread - """ - self.logger.info("opening connection") - sock = socket.create_connection((host, port)) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # nodelay ncessary to avoid packing in one frame, some servers do not like it - self._socket = utils.SocketWrapper(sock) - self.start() - - def disconnect_socket(self): - self.logger.info("stop request") - self._do_stop = True - self._socket.socket.shutdown(socket.SHUT_WR) - self._socket.socket.close() - - def send_hello(self, url): - hello = ua.Hello() - hello.EndpointUrl = url - future = Future() - with self._lock: - self._callbackmap[0] = future - binmsg = self._connection.tcp_to_binary(ua.MessageType.Hello, hello) - self._socket.write(binmsg) - ack = future.result(self.timeout) - return ack - - def open_secure_channel(self, params): - self.logger.info("open_secure_channel") - request = ua.OpenSecureChannelRequest() - request.Parameters = params - future = self._send_request(request, message_type=ua.MessageType.SecureOpen) - - response = ua.OpenSecureChannelResponse.from_binary(future.result(self.timeout)) - response.ResponseHeader.ServiceResult.check() - self._connection.set_security_token(response.Parameters.SecurityToken) - return response.Parameters - - def close_secure_channel(self): - """ - close secure channel. It seems to trigger a shutdown of socket - in most servers, so be prepare to reconnect. - OPC UA specs Part 6, 7.1.4 say that Server does not send a CloseSecureChannel response and should just close socket - """ - self.logger.info("close_secure_channel") - request = ua.CloseSecureChannelRequest() - future = self._send_request(request, message_type=ua.MessageType.SecureClose) - with self._lock: - # don't expect any more answers - future.cancel() - self._callbackmap.clear() - - # some servers send a response here, most do not ... so we ignore + hdr.TimeoutHint = timeout_hint + request.RequestHeader = hdr + try: + binreq = request.to_binary() + except: + # reset reqeust handle if any error + self._request_handle -= 1 + raise + return binreq -class UaClient(object): +class AsyncUaClient(object): """ low level OPC-UA client. @@ -192,79 +180,142 @@ class UaClient(object): def __init__(self, timeout=1): self.logger = logging.getLogger(__name__) - # _publishcallbacks should be accessed in recv thread only self._publishcallbacks = {} self._timeout = timeout - self._uasocket = None self._security_policy = ua.SecurityPolicy() + self._proto_connected = False def set_security(self, policy): self._security_policy = policy + @coroutine def connect_socket(self, host, port): """ connect to server socket and start receiving thread """ - self._uasocket = UASocketClient(self._timeout, security_policy=self._security_policy) - return self._uasocket.connect_socket(host, port) + self.logger.info("connect_socket %s %d", host, port) + self._connection = ua.SecureConnection(self._security_policy) + self._proto = UAClientProtocol(self._connection, self._disconn_cb) + loop = uaasync.get_loop() + coro = loop.create_connection(lambda: self._proto, host, port) + yield From(coro) + self._proto_connected = True + + def _disconn_cb(self, ex): + self._proto_connected = False def disconnect_socket(self): - return self._uasocket.disconnect_socket() + self._proto.close() + self._proto = None + self._connection = None + self._publishcallbacks.clear() + + def _check_answer(self, data, context=None): + if not isinstance(data, utils.Buffer): + # may be an ack etc... + return + data = data.copy() + typeid = ua.NodeId.from_binary(data) + if typeid == ua.FourByteNodeId(ua.ObjectIds.ServiceFault_Encoding_DefaultBinary): + if context is None: + self.logger.warning("ServiceFault from server received") + else: + self.logger.warning("ServiceFault from server received in response to %s", context) + hdr = ua.ResponseHeader.from_binary(data) + hdr.ServiceResult.check() + @coroutine + def _send_request(self, request, timeout=None): + if not self._proto_connected: + raise ua.UaError("client is disconnected") + reqname = request.__class__.__name__ + self.logger.info("sending %s", reqname) + if timeout is None: + timeout = self._timeout + fut = self._proto.send_request(request, timeout) + data = yield From(fut) + # data = yield From(uaasync.wait_for(fut, self._timeout * 1.1)) + self._check_answer(data, reqname) + raise Return(data) + + def send_request(self, request): + return self._send_request(request) + + @coroutine def send_hello(self, url): - return self._uasocket.send_hello(url) + hello = ua.Hello() + hello.EndpointUrl = url + ack = yield From(self._send_request(hello)) + raise Return(ack) + @coroutine def open_secure_channel(self, params): - return self._uasocket.open_secure_channel(params) + request = ua.OpenSecureChannelRequest() + request.Parameters = params + data = yield From(self._send_request(request)) + response = ua.OpenSecureChannelResponse.from_binary(data) + response.ResponseHeader.ServiceResult.check() + self._connection.set_security_token(response.Parameters.SecurityToken) + raise Return(response.Parameters) + @coroutine def close_secure_channel(self): """ close secure channel. It seems to trigger a shutdown of socket - in most servers, so be prepare to reconnect + in most servers, so be prepare to reconnect. + OPC UA specs Part 6, 7.1.4 say that Server does not send a CloseSecureChannel response + and should just close socket """ - return self._uasocket.close_secure_channel() + reqeust = ua.CloseSecureChannelRequest() + try: + yield From(self._send_request(reqeust)) + except uaasync.asyncio.CancelledError: + # some servers send a response here, most do not ... so we ignore + pass + @coroutine def create_session(self, parameters): - self.logger.info("create_session") request = ua.CreateSessionRequest() request.Parameters = parameters - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.CreateSessionResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - self._uasocket.authentication_token = response.Parameters.AuthenticationToken - return response.Parameters + self._proto.set_authentication_token(response.Parameters.AuthenticationToken) + raise Return(response.Parameters) + @coroutine def activate_session(self, parameters): - self.logger.info("activate_session") request = ua.ActivateSessionRequest() request.Parameters = parameters - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.ActivateSessionResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Parameters + raise Return(response.Parameters) + @coroutine def close_session(self, deletesubscriptions): - self.logger.info("close_session") request = ua.CloseSessionRequest() request.DeleteSubscriptions = deletesubscriptions - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) ua.CloseSessionResponse.from_binary(data) - # response.ResponseHeader.ServiceResult.check() #disabled, it seems we sent wrong session Id, but where is the sessionId supposed to be sent??? + # disabled, it seems we sent wrong session Id, + # but where is the sessionId supposed to be sent??? + # response.ResponseHeader.ServiceResult.check() + @coroutine def browse(self, parameters): - self.logger.info("browse") request = ua.BrowseRequest() request.Parameters = parameters - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.BrowseResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def read(self, parameters): - self.logger.info("read") request = ua.ReadRequest() request.Parameters = parameters - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.ReadResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() # cast to Enum attributes that need to @@ -277,182 +328,286 @@ def read(self, parameters): dv = response.Results[idx] if dv.StatusCode.is_good() and dv.Value.Value in (-3, -2, -1, 0, 1, 2, 3, 4): dv.Value.Value = ua.ValueRank(dv.Value.Value) - return response.Results + raise Return(response.Results) + @coroutine def write(self, params): - self.logger.info("read") request = ua.WriteRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.WriteResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def get_endpoints(self, params): - self.logger.info("get_endpoint") request = ua.GetEndpointsRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.GetEndpointsResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Endpoints + raise Return(response.Endpoints) + @coroutine def find_servers(self, params): - self.logger.info("find_servers") request = ua.FindServersRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.FindServersResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Servers + raise Return(response.Servers) + @coroutine def find_servers_on_network(self, params): - self.logger.info("find_servers_on_network") request = ua.FindServersOnNetworkRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.FindServersOnNetworkResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Parameters + raise Return(response.Parameters) + @coroutine def register_server(self, registered_server): - self.logger.info("register_server") request = ua.RegisterServerRequest() request.Server = registered_server - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.RegisterServerResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - # nothing to return for this service + @coroutine def register_server2(self, params): - self.logger.info("register_server2") request = ua.RegisterServer2Request() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.RegisterServer2Response.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.ConfigurationResults + raise Return(response.ConfigurationResults) + @coroutine def translate_browsepaths_to_nodeids(self, browsepaths): - self.logger.info("translate_browsepath_to_nodeid") request = ua.TranslateBrowsePathsToNodeIdsRequest() request.Parameters.BrowsePaths = browsepaths - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.TranslateBrowsePathsToNodeIdsResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def create_subscription(self, params, callback): - self.logger.info("create_subscription") request = ua.CreateSubscriptionRequest() request.Parameters = params - resp_fut = Future() - mycallbak = partial(self._create_subscription_callback, callback, resp_fut) - self._uasocket.send_request(request, mycallbak) - return resp_fut.result(self._timeout) - - def _create_subscription_callback(self, pub_callback, resp_fut, data_fut): - self.logger.info("_create_subscription_callback") - data = data_fut.result() + data = yield From(self._send_request(request)) response = ua.CreateSubscriptionResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - self._publishcallbacks[response.Parameters.SubscriptionId] = pub_callback - resp_fut.set_result(response.Parameters) + self._publishcallbacks[response.Parameters.SubscriptionId] = callback + raise Return(response.Parameters) + @coroutine def delete_subscriptions(self, subscriptionids): - self.logger.info("delete_subscription") request = ua.DeleteSubscriptionsRequest() request.Parameters.SubscriptionIds = subscriptionids - resp_fut = Future() - mycallbak = partial(self._delete_subscriptions_callback, subscriptionids, resp_fut) - self._uasocket.send_request(request, mycallbak) - return resp_fut.result(self._timeout) - - def _delete_subscriptions_callback(self, subscriptionids, resp_fut, data_fut): - self.logger.info("_delete_subscriptions_callback") - data = data_fut.result() + data = yield From(self._send_request(request)) response = ua.DeleteSubscriptionsResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() for sid in subscriptionids: self._publishcallbacks.pop(sid) - resp_fut.set_result(response.Results) + raise Return(response.Results) def publish(self, acks=None): - self.logger.info("publish") + uaasync.ensure_future(self._do_publish(acks)) + + @coroutine + def _do_publish(self, acks): if acks is None: acks = [] request = ua.PublishRequest() request.Parameters.SubscriptionAcknowledgements = acks - self._uasocket.send_request(request, self._call_publish_callback, timeout=int(9e8)) # timeout could be set to 0 but some servers to not support it + try: + data = yield From(self._send_request(request, int(9e5))) + except uaasync.asyncio.CancelledError: + raise Return() + except Exception: + self.logger.exception("Error receving publish response") + # send publish request to server so the server does not stop sending notifications + self.publish([]) + raise Return() - def _call_publish_callback(self, future): - self.logger.info("call_publish_callback") - data = future.result() - self._uasocket.check_answer(data, "ServiceFault received from server while waiting for publish response") try: response = ua.PublishResponse.from_binary(data) except Exception: self.logger.exception("Error parsing notificatipn from server") - self.publish([]) #send publish request ot server so he does stop sending notifications - return + # send publish request to server so the server does not stop sending notifications + self.publish([]) + raise Return() if response.Parameters.SubscriptionId not in self._publishcallbacks: self.logger.warning("Received data for unknown subscription: %s ", response.Parameters.SubscriptionId) - return + raise Return() callback = self._publishcallbacks[response.Parameters.SubscriptionId] try: callback(response.Parameters) except Exception: # we call client code, catch everything! - self.logger.exception("Exception while calling user callback: %s") + self.logger.exception("Exception while calling user callback") + @coroutine def create_monitored_items(self, params): - self.logger.info("create_monitored_items") request = ua.CreateMonitoredItemsRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.CreateMonitoredItemsResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def delete_monitored_items(self, params): - self.logger.info("delete_monitored_items") request = ua.DeleteMonitoredItemsRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.DeleteMonitoredItemsResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def add_nodes(self, nodestoadd): - self.logger.info("add_nodes") request = ua.AddNodesRequest() request.Parameters.NodesToAdd = nodestoadd - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.AddNodesResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def delete_nodes(self, nodestodelete): - self.logger.info("delete_nodes") request = ua.DeleteNodesRequest() request.Parameters.NodesToDelete = nodestodelete - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.DeleteNodesResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def call(self, methodstocall): request = ua.CallRequest() request.Parameters.MethodsToCall = methodstocall - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.CallResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + @coroutine def history_read(self, params): - self.logger.info("history_read") request = ua.HistoryReadRequest() request.Parameters = params - data = self._uasocket.send_request(request) + data = yield From(self._send_request(request)) response = ua.HistoryReadResponse.from_binary(data) response.ResponseHeader.ServiceResult.check() - return response.Results + raise Return(response.Results) + + +class UaClient(AsyncUaClient): + def connect_socket(self, host, port): + uaasync.start_loop() + uaasync.await_coro(AsyncUaClient.connect_socket, self, host, port) + + def disconnect_socket(self): + uaasync.await_call(AsyncUaClient.disconnect_socket, self) + uaasync.stop_loop() + + @await_super_coro + def send_request(self, request): + pass + + @await_super_coro + def send_hello(self, url): + pass + + @await_super_coro + def open_secure_channel(self, params): + pass + + @await_super_coro + def close_secure_channel(self): + pass + + @await_super_coro + def create_session(self, parameters): + pass + + @await_super_coro + def activate_session(self, parameters): + pass + + @await_super_coro + def close_session(self, deletesubscriptions): + pass + + @await_super_coro + def browse(self, parameters): + pass + + @await_super_coro + def read(self, parameters): + pass + + @await_super_coro + def write(self, params): + pass + + @await_super_coro + def get_endpoints(self, params): + pass + + @await_super_coro + def find_servers(self, params): + pass + + @await_super_coro + def find_servers_on_network(self, params): + pass + + @await_super_coro + def register_server(self, registered_server): + pass + + @await_super_coro + def register_server2(self, params): + pass + + @await_super_coro + def translate_browsepaths_to_nodeids(self, browsepaths): + pass + + @await_super_coro + def create_subscription(self, params, callback): + pass + + @await_super_coro + def delete_subscriptions(self, subscriptionids): + pass + + @uaasync.await_super_call + def publish(self, acks=None): + pass + + @await_super_coro + def create_monitored_items(self, params): + pass + + @await_super_coro + def delete_monitored_items(self, params): + pass + + @await_super_coro + def add_nodes(self, nodestoadd): + pass + + @await_super_coro + def delete_nodes(self, nodestodelete): + pass + + @await_super_coro + def call(self, methodstocall): + pass + + @await_super_coro + def history_read(self, params): + pass diff --git a/opcua/common/uaasync.py b/opcua/common/uaasync.py new file mode 100644 index 000000000..8ac60f7a7 --- /dev/null +++ b/opcua/common/uaasync.py @@ -0,0 +1,229 @@ +import sys + +if sys.version_info[0] < 3 or sys.version_info[1] <= 2: + # python 2.7 - python 3.2 + from trollius import coroutine, From, Return + import trollius as asyncio +else: + # python 3.3 or above + from asyncio import coroutine + import asyncio + + def From(*args, **kwargs): + # we don't want anyone catch it using try: ... except Exception: ... + raise BaseException("'yield from' shoud be used, not yield From()") + + class Return(BaseException): + pass + +import threading +try: + from threading import get_ident as _get_thread_ident +except ImportError: + # Python 2 + from threading import _get_ident as _get_thread_ident +import logging +from concurrent.futures import Future +from functools import wraps, partial +from .uaerrors import UaError + +logger = logging.getLogger(__name__) + + +class LoopController(object): + def __init__(self): + self.loop = None + self.external = False + self.lock = threading.Lock() + self.count = 0 + self.thread = None + self.thread_id = None + + def install_loop(self, loop): + with self.lock: + if self.count > 0 or self.thread is not None or self.external is True: + raise UaError("install_loop must be called before using the opcua library") + self.loop = loop + self.external = True + + def start_loop(self): + if self.external: + raise UaError("synchonized interface unavilable after install_loop()") + with self.lock: + if self.count > 0: + self.count += 1 + return + self._start_loop_thread() + self.count = 1 + + def stop_loop(self): + if self.external: + return + with self.lock: + if self.count < 1: + return + elif self.count > 1: + self.count -= 1 + return + self._stop_loop_thread() + self.count = 0 + + def _start_loop_thread(self): + loop = asyncio.new_event_loop() + cond = threading.Condition() + thread = threading.Thread(target=self._loop_run, args=(loop, cond)) + thread.daemon = True + with cond: + thread.start() + cond.wait() + self.thread = thread + self.loop = loop + + def _stop_loop_thread(self): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join(10) + if self.thread.is_alive(): + raise UaError("can not stop loop thread") + self.thread = None + self.loop.close() + self.loop = None + + def _loop_run(self, loop, cond): + logger.info("start loop thread") + self.thread_id = _get_thread_ident() + with cond: + cond.notify_all() + del cond + try: + loop.run_forever() + finally: + self.thread_id = None + logger.info("stop loop thread") + + def in_loop_thread(self): + if self.thread_id is None: + raise UaError("loop thread is not running") + return self.thread_id == _get_thread_ident() + + +_ctrl = LoopController() + + +def install_loop(loop): + _ctrl.install_loop(loop) + + +def get_loop(): + return _ctrl.loop + + +def start_loop(): + _ctrl.start_loop() + + +def stop_loop(): + _ctrl.stop_loop() + + +def new_future(): + return asyncio.Future(loop=_ctrl.loop) + + +def call_soon(*args): + return _ctrl.loop.call_soon(*args) + + +def call_later(*args): + return _ctrl.loop.call_later(*args) + + +def call_at(*args): + return _ctrl.loop.call_at(*args) + + +def call_soon_threadsafe(*args): + return _ctrl.loop.call_soon_threadsafe(*args) + + +def wait_for(*args): + return asyncio.wait_for(*args, loop=_ctrl.loop) + +if hasattr(asyncio, "ensure_future"): + def ensure_future(*args): + return asyncio.ensure_future(*args, loop=_ctrl.loop) +else: + def ensure_future(*args): + return asyncio.async(*args, loop=_ctrl.loop) + + +def _transfer_future(dstfut, srcfut): + if srcfut.cancelled(): + dstfut.cancel() + return + ex = srcfut.exception() + if ex is not None: + dstfut.set_exception(ex) + return + dstfut.set_result(srcfut.result()) + + +def _wait_coro_in_loop(fut, coro, args, kwargs): + task = ensure_future(coro(*args, **kwargs)) + task.add_done_callback(partial(_transfer_future, fut)) + + +def await_coro(coro, *args, **kwargs): + if _ctrl.in_loop_thread(): + raise UaError("coro should not be called in loop thread") + fut = Future() + call_soon_threadsafe(_wait_coro_in_loop, fut, coro, args, kwargs) + return fut.result() + + +def await_super_coro(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if _ctrl.in_loop_thread(): + raise UaError("coro %s should not be called in loop thread" % func.__name__) + coro = getattr(super(self.__class__, self), func.__name__) + fut = Future() + call_soon_threadsafe(_wait_coro_in_loop, fut, coro, args, kwargs) + return fut.result() + return wrapper + + +def _wait_call_in_loop(fut, func, args, kwargs): + try: + rs = func(*args, **kwargs) + except Exception as e: + fut.set_exception(e) + fut.set_result(rs) + + +def await_call(func, *args, **kwargs): + if _ctrl.in_loop_thread(): + return func(*args, **kwargs) + fut = Future() + call_soon_threadsafe(_wait_call_in_loop, fut, func, args, kwargs) + return fut.result() + + +def await_super_call(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + method = getattr(super(self.__class__, self), func.__name__) + if _ctrl.in_loop_thread(): + return method(*args, **kwargs) + fut = Future() + call_soon_threadsafe(_wait_call_in_loop, fut, method, args, kwargs) + return fut.result() + return wrapper + +__all__ = [ + "asyncio", "coroutine", "From", "Return", + "install_loop", "get_loop", "start_loop", "stop_loop", + "new_future", "call_soon", "call_soon_threadsafe", + "call_later", "call_at", "wait_for", + "await_coro", "await_super_coro", + "await_call", "await_super_call", +] diff --git a/opcua/common/utils.py b/opcua/common/utils.py index 05f9a9854..ae867b32c 100644 --- a/opcua/common/utils.py +++ b/opcua/common/utils.py @@ -83,6 +83,18 @@ def skip(self, size): self._size -= size self._cur_pos += size + def write(self, bstr): + """ + append bstr to current buffer + """ + if self._cur_pos > 0: + # dicard old buffer to free memory + self._data = self._data[self._cur_pos:] + bstr + self._cur_pos = 0 + else: + self._data += bstr + self._size = len(self._data) + class SocketWrapper(object): """ diff --git a/opcua/ua/uaprotocol_hand.py b/opcua/ua/uaprotocol_hand.py index 0d555a7bb..dde4eb260 100644 --- a/opcua/ua/uaprotocol_hand.py +++ b/opcua/ua/uaprotocol_hand.py @@ -632,20 +632,6 @@ def receive_from_header_and_body(self, header, body): else: raise UaError("Unsupported message type {}".format(header.MessageType)) - def receive_from_socket(self, socket): - """ - Convert binary stream to OPC UA TCP message (see OPC UA - specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message - object, or None (if intermediate chunk is received) - """ - logger.debug("Waiting for header") - header = Header.from_string(socket) - logger.info("received header: %s", header) - body = socket.read(header.body_size) - if len(body) != header.body_size: - raise UaError("{} bytes expected, {} available".format(header.body_size, len(body))) - return self.receive_from_header_and_body(header, utils.Buffer(body)) - def _receive(self, msg): self._check_incoming_chunk(msg) self._incoming_parts.append(msg) diff --git a/setup.py b/setup.py index 31f08e7c0..fbb3ebef6 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,45 @@ from setuptools import setup, find_packages import sys +import re +from os import path -if sys.version_info[0] < 3: - install_requires = ["enum34", "trollius", "futures"] + +def translate_from_trollius_to_asyncio(): + pypaths = [ + ["client", "ua_client.py"], + ] + + curdir = path.dirname(path.realpath(__file__)) + for pypath in pypaths: + fpath = path.join(curdir, "opcua", *pypath) + with open(fpath, 'r+b') as f: + src = f.read() + src = re.sub(br"raise Return\(\s*\)", b"return", src) + src = re.sub(br"raise Return\(", b"return (", src) + src = re.sub(br"yield From\(None\)", b"yield from []", src) + src = re.sub(br"yield From\(", b"yield from (", src) + f.seek(0) + f.truncate(0) + f.write(src) + +if len(sys.argv) >= 2 and sys.argv[1] == "asyncio_translate": + translate_from_trollius_to_asyncio() + sys.exit() + +install_requires = [] +if sys.version_info[0] < 3 or sys.version_info[1] < 2: + # python 2.7 - python 3.1 + install_requires.append("futures") +if sys.version_info[0] < 3 or sys.version_info[1] < 3: + # python 2.7 - python 3.2 + install_requires.append("trollius") else: - install_requires = [] + translate_from_trollius_to_asyncio() +if sys.version_info[0] < 3 or sys.version_info[1] < 4: + # python 2.7 - python 3.3 + install_requires.append("enum34") + setup(name="freeopcua", version="0.10.7", @@ -29,7 +63,7 @@ "License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={'console_scripts': + entry_points={'console_scripts': [ 'uaread = opcua.tools:uaread', 'uals = opcua.tools:uals', @@ -43,4 +77,3 @@ ] } ) - diff --git a/tests/tests.py b/tests/tests.py old mode 100644 new mode 100755 diff --git a/tests/tests_client.py b/tests/tests_client.py index 1493e3103..eca1651ff 100644 --- a/tests/tests_client.py +++ b/tests/tests_client.py @@ -48,7 +48,7 @@ def test_service_fault(self): request = ua.ReadRequest() request.TypeId = ua.FourByteNodeId(999) # bad type! with self.assertRaises(ua.UaStatusCodeError): - self.clt.uaclient._uasocket.send_request(request) + self.clt.uaclient.send_request(request) def test_objects_anonymous(self): objects = self.ro_clt.get_objects_node()