diff --git a/pyproject.toml b/pyproject.toml index 5bbc5cc91..93a4dff32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,10 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] -dependencies = ["cffi; implementation_name == 'pypy'"] +dependencies = [ + "cffi; implementation_name == 'pypy'", + "anyioutils >=0.4.2" +] description = "Python bindings for 0MQ" readme = "README.md" @@ -144,7 +147,7 @@ search = '__version__: str = "{current_version}"' [tool.cibuildwheel] build-verbosity = "1" free-threaded-support = true -test-requires = ["pytest>=6", "importlib_metadata"] +test-requires = ["pytest>=6", "importlib_metadata", "exceptiongroup;python_version<'3.11'"] test-command = "pytest -vsx {package}/tools/test_wheel.py" [tool.cibuildwheel.linux] diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 413239734..33de64fe4 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -11,11 +11,19 @@ from multiprocessing import Process import pytest +from anyio import create_task_group, move_on_after, sleep +from anyioutils import CancelledError, create_task from pytest import mark import zmq import zmq.asyncio as zaio +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + + +pytestmark = pytest.mark.anyio + @pytest.fixture def Context(event_loop): @@ -46,23 +54,17 @@ def test_instance_subclass_second(context): async def test_recv_multipart(context, create_bound_pair): a, b = create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() - assert not f.done() await a.send(b"hi") - recvd = await f - assert recvd == [b"hi"] + assert await f == [b"hi"] async def test_recv(create_bound_pair): a, b = create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() - assert not f1.done() - assert not f2.done() await a.send_multipart([b"hi", b"there"]) - recvd = await f2 - assert f1.done() - assert f1.result() == b"hi" - assert recvd == b"there" + assert await f1 == b"hi" + assert await f2 == b"there" @mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") @@ -72,11 +74,11 @@ async def test_recv_timeout(push_pull): f1 = b.recv() b.rcvtimeo = 1000 f2 = b.recv_multipart() - with pytest.raises(zmq.Again): + with pytest.raises(ExceptionGroup) as excinfo: await f1 + assert excinfo.group_contains(zmq.Again) await a.send_multipart([b"hi", b"there"]) recvd = await f2 - assert f2.done() assert recvd == [b"hi", b"there"] @@ -84,70 +86,58 @@ async def test_recv_timeout(push_pull): async def test_send_timeout(socket): s = socket(zmq.PUSH) s.sndtimeo = 100 - with pytest.raises(zmq.Again): + with pytest.raises(ExceptionGroup) as excinfo: await s.send(b"not going anywhere") + assert excinfo.group_contains(zmq.Again) async def test_recv_string(push_pull): a, b = push_pull f = b.recv_string() - assert not f.done() msg = "πøøπ" await a.send_string(msg) recvd = await f - assert f.done() - assert f.result() == msg assert recvd == msg async def test_recv_json(push_pull): a, b = push_pull f = b.recv_json() - assert not f.done() obj = dict(a=5) await a.send_json(obj) recvd = await f - assert f.done() - assert f.result() == obj assert recvd == obj async def test_recv_json_cancelled(push_pull): - a, b = push_pull - f = b.recv_json() - assert not f.done() - f.cancel() - # cycle eventloop to allow cancel events to fire - await asyncio.sleep(0) - obj = dict(a=5) - await a.send_json(obj) - # CancelledError change in 3.8 https://bugs.python.org/issue32528 - if sys.version_info < (3, 8): - with pytest.raises(CancelledError): + async with create_task_group() as tg: + a, b = push_pull + f = create_task(b.recv_json(), tg) + f.cancel(raise_exception=False) + # cycle eventloop to allow cancel events to fire + await sleep(0) + obj = dict(a=5) + await a.send_json(obj) + recvd = await f.wait() + assert f.cancelled() + assert f.done() + # give it a chance to incorrectly consume the event + events = await b.poll(timeout=5) + assert events + await sleep(0) + # make sure cancelled recv didn't eat up event + f = b.recv_json() + with move_on_after(5): recvd = await f - else: - with pytest.raises(asyncio.exceptions.CancelledError): - recvd = await f - assert f.done() - # give it a chance to incorrectly consume the event - events = await b.poll(timeout=5) - assert events - await asyncio.sleep(0) - # make sure cancelled recv didn't eat up event - f = b.recv_json() - recvd = await asyncio.wait_for(f, timeout=5) - assert recvd == obj + assert recvd == obj async def test_recv_pyobj(push_pull): a, b = push_pull f = b.recv_pyobj() - assert not f.done() obj = dict(a=5) await a.send_pyobj(obj) recvd = await f - assert f.done() - assert f.result() == obj assert recvd == obj @@ -206,85 +196,90 @@ async def test_custom_serialize_error(dealer_router): async def test_recv_dontwait(push_pull): push, pull = push_pull f = pull.recv(zmq.DONTWAIT) - with pytest.raises(zmq.Again): + with pytest.raises(BaseExceptionGroup) as excinfo: await f + assert excinfo.group_contains(zmq.Again) await push.send(b"ping") await pull.poll() # ensure message will be waiting - f = pull.recv(zmq.DONTWAIT) - assert f.done() - msg = await f + msg = await pull.recv(zmq.DONTWAIT) assert msg == b"ping" async def test_recv_cancel(push_pull): - a, b = push_pull - f1 = b.recv() - f2 = b.recv_multipart() - assert f1.cancel() - assert f1.done() - assert not f2.done() - await a.send_multipart([b"hi", b"there"]) - recvd = await f2 - assert f1.cancelled() - assert f2.done() - assert recvd == [b"hi", b"there"] + async with create_task_group() as tg: + a, b = push_pull + f1 = create_task(b.recv(), tg) + f2 = create_task(b.recv_multipart(), tg) + f1.cancel(raise_exception=False) + assert f1.done() + assert not f2.done() + await a.send_multipart([b"hi", b"there"]) + recvd = await f2.wait() + assert f1.cancelled() + assert f2.done() + assert recvd == [b"hi", b"there"] async def test_poll(push_pull): - a, b = push_pull - f = b.poll(timeout=0) - await asyncio.sleep(0) - assert f.result() == 0 + async with create_task_group() as tg: + a, b = push_pull + f = create_task(b.poll(timeout=0), tg) + await sleep(0.01) + assert f.result() == 0 - f = b.poll(timeout=1) - assert not f.done() - evt = await f + f = create_task(b.poll(timeout=1), tg) + assert not f.done() + evt = await f.wait() - assert evt == 0 + assert evt == 0 - f = b.poll(timeout=1000) - assert not f.done() - await a.send_multipart([b"hi", b"there"]) - evt = await f - assert evt == zmq.POLLIN - recvd = await b.recv_multipart() - assert recvd == [b"hi", b"there"] + f = create_task(b.poll(timeout=1000), tg) + assert not f.done() + await a.send_multipart([b"hi", b"there"]) + evt = await f.wait() + assert evt == zmq.POLLIN + recvd = await b.recv_multipart() + assert recvd == [b"hi", b"there"] async def test_poll_base_socket(sockets): - ctx = zmq.Context() - url = "inproc://test" - a = ctx.socket(zmq.PUSH) - b = ctx.socket(zmq.PULL) - sockets.extend([a, b]) - a.bind(url) - b.connect(url) - - poller = zaio.Poller() - poller.register(b, zmq.POLLIN) - - f = poller.poll(timeout=1000) - assert not f.done() - a.send_multipart([b"hi", b"there"]) - evt = await f - assert evt == [(b, zmq.POLLIN)] - recvd = b.recv_multipart() - assert recvd == [b"hi", b"there"] + async with create_task_group() as tg: + ctx = zmq.Context() + url = "inproc://test" + a = ctx.socket(zmq.PUSH) + b = ctx.socket(zmq.PULL) + sockets.extend([a, b]) + a.bind(url) + b.connect(url) + + poller = zaio.Poller() + poller.register(b, zmq.POLLIN) + + f = create_task(poller.poll(timeout=1000), tg) + assert not f.done() + a.send_multipart([b"hi", b"there"]) + evt = await f.wait() + assert evt == [(b, zmq.POLLIN)] + recvd = b.recv_multipart() + assert recvd == [b"hi", b"there"] async def test_poll_on_closed_socket(push_pull): - a, b = push_pull + with pytest.raises(BaseExceptionGroup) as excinfo: + async with create_task_group() as tg: + a, b = push_pull - f = b.poll(timeout=1) - b.close() + f = create_task(b.poll(timeout=1), tg) + b.close() - # The test might stall if we try to await f directly so instead just make a few - # passes through the event loop to schedule and execute all callbacks - for _ in range(5): - await asyncio.sleep(0) - if f.cancelled(): - break - assert f.cancelled() + # The test might stall if we try to await f directly so instead just make a few + # passes through the event loop to schedule and execute all callbacks + for _ in range(5): + await sleep(0) + if f.cancelled(): + break + assert f.done() + assert excinfo.group_contains(zmq.error.ZMQError) @pytest.mark.skipif( @@ -344,16 +339,17 @@ def test_shadow(): async def test_poll_leak(): - ctx = zmq.asyncio.Context() - with ctx, ctx.socket(zmq.PULL) as s: - assert len(s._recv_futures) == 0 - for i in range(10): - f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN)) - f.cancel() - await asyncio.sleep(0) - # one more sleep allows further chained cleanup - await asyncio.sleep(0.1) - assert len(s._recv_futures) == 0 + async with create_task_group() as tg: + ctx = zmq.asyncio.Context() + with ctx, ctx.socket(zmq.PULL) as s: + assert len(s._recv_futures) == 0 + for i in range(10): + f = create_task(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN), tg) + f.cancel(raise_exception=False) + await sleep(0) + # one more sleep allows further chained cleanup + await sleep(0.1) + assert len(s._recv_futures) == 0 class ProcessForTeardownTest(Process): diff --git a/tests/test_ioloop.py b/tests/test_ioloop.py index 39bcc53a0..bd4229e6c 100644 --- a/tests/test_ioloop.py +++ b/tests/test_ioloop.py @@ -12,9 +12,8 @@ _tornado = True -def setup(): - if not _tornado: - pytest.skip("requires tornado") +if not _tornado: + pytest.skip("requires tornado", allow_module_level=True) def test_ioloop(): diff --git a/zmq/_future.py b/zmq/_future.py index 388284e74..51edaf13a 100644 --- a/zmq/_future.py +++ b/zmq/_future.py @@ -4,8 +4,8 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations +import selectors import warnings -from asyncio import Future from collections import deque from functools import partial from itertools import chain @@ -18,6 +18,9 @@ cast, ) +from anyio import create_task_group, sleep +from anyioutils import Future, Task, create_task + import zmq as _zmq from zmq import EVENTS, POLLIN, POLLOUT @@ -36,34 +39,13 @@ class _FutureEvent(NamedTuple): # _Future # _READ # _WRITE -# _default_loop() class _Async: """Mixin for common async logic""" - _current_loop: Any = None _Future: type[Future] - - def _get_loop(self) -> Any: - """Get event loop - - Notice if event loop has changed, - and register init_io_state on activation of a new event loop - """ - if self._current_loop is None: - self._current_loop = self._default_loop() - self._init_io_state(self._current_loop) - return self._current_loop - current_loop = self._default_loop() - if current_loop is not self._current_loop: - # warn? This means a socket is being used in multiple loops! - self._current_loop = current_loop - self._init_io_state(current_loop) - return current_loop - - def _default_loop(self) -> Any: - raise NotImplementedError("Must be implemented in a subclass") + _event_handler_initialized = False def _init_io_state(self, loop=None) -> None: pass @@ -77,120 +59,119 @@ class _AsyncPoller(_Async, _zmq.Poller): _WRITE: int raw_sockets: list[Any] - def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: + def _watch_raw_socket(self, socket: Any, evt: int, f: Callable) -> None: """Schedule callback for a raw socket""" raise NotImplementedError() - def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: + def _unwatch_raw_sockets(self, *sockets: Any) -> None: """Unschedule callback for a raw socket""" raise NotImplementedError() - def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: # type: ignore - """Return a Future for a poll event""" - future = self._Future() - if timeout == 0: - try: - result = super().poll(0) - except Exception as e: - future.set_exception(e) - else: - future.set_result(result) - return future - - loop = self._get_loop() - - # register Future to be called as soon as any event is available on any socket - watcher = self._Future() - - # watch raw sockets: - raw_sockets: list[Any] = [] - - def wake_raw(*args): - if not watcher.done(): - watcher.set_result(None) - - watcher.add_done_callback( - lambda f: self._unwatch_raw_sockets(loop, *raw_sockets) - ) - - wrapped_sockets: list[_AsyncSocket] = [] - - def _clear_wrapper_io(f): - for s in wrapped_sockets: - s._clear_io_state() - - for socket, mask in self.sockets: - if isinstance(socket, _zmq.Socket): - if not isinstance(socket, self._socket_class): - # it's a blocking zmq.Socket, wrap it in async - socket = self._socket_class.from_socket(socket) - wrapped_sockets.append(socket) - if mask & _zmq.POLLIN: - socket._add_recv_event('poll', future=watcher) - if mask & _zmq.POLLOUT: - socket._add_send_event('poll', future=watcher) - else: - raw_sockets.append(socket) - evt = 0 - if mask & _zmq.POLLIN: - evt |= self._READ - if mask & _zmq.POLLOUT: - evt |= self._WRITE - self._watch_raw_socket(loop, socket, evt, wake_raw) - - def on_poll_ready(f): - if future.done(): - return - if watcher.cancelled(): - try: - future.cancel() - except RuntimeError: - # RuntimeError may be called during teardown - pass - return - if watcher.exception(): - future.set_exception(watcher.exception()) - else: + async def poll(self, timeout=-1) -> list[tuple[Any, int]]: # type: ignore + """Return a poll event""" + async with create_task_group() as tg: + future = self._Future() + if timeout == 0: try: - result = super(_AsyncPoller, self).poll(0) + result = super().poll(0) except Exception as e: future.set_exception(e) else: future.set_result(result) + return await future.wait() - watcher.add_done_callback(on_poll_ready) + # register Future to be called as soon as any event is available on any socket + watcher = self._Future() - if wrapped_sockets: - watcher.add_done_callback(_clear_wrapper_io) + # watch raw sockets: + raw_sockets: list[Any] = [] - if timeout is not None and timeout > 0: - # schedule cancel to fire on poll timeout, if any - def trigger_timeout(): + def wake_raw(*args): if not watcher.done(): watcher.set_result(None) - timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout) - - def cancel_timeout(f): - if hasattr(timeout_handle, 'cancel'): - timeout_handle.cancel() + watcher.add_done_callback(lambda f: self._unwatch_raw_sockets(*raw_sockets)) + + wrapped_sockets: list[_AsyncSocket] = [] + + def _clear_wrapper_io(f): + for s in wrapped_sockets: + s._clear_io_state() + + for socket, mask in self.sockets: + if isinstance(socket, _zmq.Socket): + if not isinstance(socket, self._socket_class): + # it's a blocking zmq.Socket, wrap it in async + socket = self._socket_class.from_socket(socket) + wrapped_sockets.append(socket) + if mask & _zmq.POLLIN: + create_task( + socket._add_recv_event(tg, 'poll', future=watcher), tg + ) + if mask & _zmq.POLLOUT: + create_task( + socket._add_send_event(tg, 'poll', future=watcher), tg + ) else: - loop.remove_timeout(timeout_handle) + raw_sockets.append(socket) + evt = 0 + if mask & _zmq.POLLIN: + evt |= self._READ + if mask & _zmq.POLLOUT: + evt |= self._WRITE + self._watch_raw_socket(socket, evt, wake_raw) + + def on_poll_ready(f): + if future.done(): + return + if watcher.cancelled(): + try: + future.cancel(raise_exception=False) + except RuntimeError: + # RuntimeError may be called during teardown + pass + return + if watcher.exception(): + future.set_exception(watcher.exception()) + else: + try: + result = super(_AsyncPoller, self).poll(0) + except Exception as e: + future.set_exception(e) + else: + future.set_result(result) - future.add_done_callback(cancel_timeout) + watcher.add_done_callback(on_poll_ready) - def cancel_watcher(f): - if not watcher.done(): - watcher.cancel() + if wrapped_sockets: + watcher.add_done_callback(_clear_wrapper_io) - future.add_done_callback(cancel_watcher) + if timeout is not None and timeout > 0: + # schedule cancel to fire on poll timeout, if any + async def trigger_timeout(): + await sleep(1e-3 * timeout) + if not watcher.done(): + watcher.set_result(None) - return future + timeout_handle = create_task(trigger_timeout(), tg) + + def cancel_timeout(f): + timeout_handle.cancel(raise_exception=False) + + future.add_done_callback(cancel_timeout) + + def cancel_watcher(f): + if not watcher.done(): + watcher.cancel(raise_exception=False) + + future.add_done_callback(cancel_watcher) + + return await future.wait() class _NoTimer: @staticmethod - def cancel(): + def cancel(raise_exception=True): pass @@ -249,7 +230,7 @@ def close(self, linger: int | None = None) -> None: for event in event_list: if not event.future.done(): try: - event.future.cancel() + event.future.cancel(raise_exception=False) except RuntimeError: # RuntimeError may be called during teardown pass @@ -260,24 +241,25 @@ def close(self, linger: int | None = None) -> None: def get(self, key): result = super().get(key) - if key == EVENTS: - self._schedule_remaining_events(result) + # if key == EVENTS: + # self._schedule_remaining_events(result) return result get.__doc__ = _zmq.Socket.get.__doc__ - def recv_multipart( + async def recv_multipart( self, flags: int = 0, copy: bool = True, track: bool = False - ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: + ) -> list[bytes] | list[_zmq.Frame]: """Receive a complete multipart zmq message. Returns a Future whose result will be a multipart message. """ - return self._add_recv_event( - 'recv_multipart', dict(flags=flags, copy=copy, track=track) - ) + async with create_task_group() as tg: + return await self._add_recv_event( + tg, 'recv_multipart', dict(flags=flags, copy=copy, track=track) + ) - def recv( # type: ignore + async def recv( # type: ignore self, flags: int = 0, copy: bool = True, track: bool = False ) -> Awaitable[bytes | _zmq.Frame]: """Receive a single zmq frame. @@ -286,11 +268,14 @@ def recv( # type: ignore Recommend using recv_multipart instead. """ - return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track)) + async with create_task_group() as tg: + return await self._add_recv_event( + tg, 'recv', dict(flags=flags, copy=copy, track=track) + ) - def send_multipart( # type: ignore + async def send_multipart( # type: ignore self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs - ) -> Awaitable[_zmq.MessageTracker | None]: + ) -> _zmq.MessageTracker | None: """Send a complete multipart zmq message. Returns a Future that resolves when sending is complete. @@ -298,16 +283,20 @@ def send_multipart( # type: ignore kwargs['flags'] = flags kwargs['copy'] = copy kwargs['track'] = track - return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) + async with create_task_group() as tg: + self._init_io_state(tg) + return await self._add_send_event( + tg, 'send_multipart', msg=msg_parts, kwargs=kwargs + ) - def send( # type: ignore + async def send( # type: ignore self, data: Any, flags: int = 0, copy: bool = True, track: bool = False, **kwargs: Any, - ) -> Awaitable[_zmq.MessageTracker | None]: + ) -> _zmq.MessageTracker | None: """Send a single zmq frame. Returns a Future that resolves when sending is complete. @@ -318,51 +307,54 @@ def send( # type: ignore kwargs['copy'] = copy kwargs['track'] = track kwargs.update(dict(flags=flags, copy=copy, track=track)) - return self._add_send_event('send', msg=data, kwargs=kwargs) + async with create_task_group() as tg: + self._init_io_state(tg) + return await self._add_send_event(tg, 'send', msg=data, kwargs=kwargs) def _deserialize(self, recvd, load): """Deserialize with Futures""" - f = self._Future() - - def _chain(_): - """Chain result through serialization to recvd""" - if f.done(): - # chained future may be cancelled, which means nobody is going to get this result - # if it's an error, that's no big deal (probably zmq.Again), - # but if it's a successful recv, this is a dropped message! - if not recvd.cancelled() and recvd.exception() is None: - warnings.warn( - # is there a useful stacklevel? - # ideally, it would point to where `f.cancel()` was called - f"Future {f} completed while awaiting {recvd}. A message has been dropped!", - RuntimeWarning, - ) - return - if recvd.exception(): - f.set_exception(recvd.exception()) - else: - buf = recvd.result() - try: - loaded = load(buf) - except Exception as e: - f.set_exception(e) - else: - f.set_result(loaded) - - recvd.add_done_callback(_chain) - - def _chain_cancel(_): - """Chain cancellation from f to recvd""" - if recvd.done(): - return - if f.cancelled(): - recvd.cancel() - - f.add_done_callback(_chain_cancel) - - return f - - def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore + return load(recvd) + # f = self._Future() + + # def _chain(_): + # """Chain result through serialization to recvd""" + # if f.done(): + # # chained future may be cancelled, which means nobody is going to get this result + # # if it's an error, that's no big deal (probably zmq.Again), + # # but if it's a successful recv, this is a dropped message! + # if not recvd.cancelled() and recvd.exception() is None: + # warnings.warn( + # # is there a useful stacklevel? + # # ideally, it would point to where `f.cancel()` was called + # f"Future {f} completed while awaiting {recvd}. A message has been dropped!", + # RuntimeWarning, + # ) + # return + # if recvd.exception(): + # f.set_exception(recvd.exception()) + # else: + # buf = recvd.result() + # try: + # loaded = load(buf) + # except Exception as e: + # f.set_exception(e) + # else: + # f.set_result(loaded) + + # recvd.add_done_callback(_chain) + + # def _chain_cancel(_): + # """Chain cancellation from f to recvd""" + # if recvd.done(): + # return + # if f.cancelled(): + # recvd.cancel() + + # f.add_done_callback(_chain_cancel) + + # return await f.wait() + + async def poll(self, timeout=None, flags=_zmq.POLLIN) -> int: # type: ignore """poll the socket for events returns a Future for the poll results. @@ -371,48 +363,49 @@ def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: igno if self.closed: raise _zmq.ZMQError(_zmq.ENOTSUP) - p = self._poller_class() - p.register(self, flags) - poll_future = cast(Future, p.poll(timeout)) + async with create_task_group() as tg: + p = self._poller_class() + p.register(self, flags) + poll_future = cast(Task, create_task(p.poll(timeout), tg)) - future = self._Future() + future = self._Future() - def unwrap_result(f): - if future.done(): - return - if poll_future.cancelled(): - try: - future.cancel() - except RuntimeError: - # RuntimeError may be called during teardown - pass - return - if f.exception(): - future.set_exception(poll_future.exception()) - else: - evts = dict(poll_future.result()) - future.set_result(evts.get(self, 0)) + def unwrap_result(f): + if future.done(): + return + if poll_future.cancelled(): + try: + future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass + return + if f.exception(): + future.set_exception(poll_future.exception()) + else: + evts = dict(poll_future.result()) + future.set_result(evts.get(self, 0)) - if poll_future.done(): - # hook up result if already done - unwrap_result(poll_future) - else: - poll_future.add_done_callback(unwrap_result) + if poll_future.done(): + # hook up result if already done + unwrap_result(poll_future) + else: + poll_future.add_done_callback(unwrap_result) - def cancel_poll(future): - """Cancel underlying poll if request has been cancelled""" - if not poll_future.done(): - try: - poll_future.cancel() - except RuntimeError: - # RuntimeError may be called during teardown - pass + def cancel_poll(future): + """Cancel underlying poll if request has been cancelled""" + if not poll_future.done(): + try: + poll_future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass - future.add_done_callback(cancel_poll) + future.add_done_callback(cancel_poll) - return future + return await future.wait() - def _add_timeout(self, future, timeout): + def _add_timeout(self, task_group, future, timeout): """Add a timeout for a send or recv Future""" def future_timeout(): @@ -423,9 +416,9 @@ def future_timeout(): # raise EAGAIN future.set_exception(_zmq.Again()) - return self._call_later(timeout, future_timeout) + return self._call_later(task_group, timeout, future_timeout) - def _call_later(self, delay, callback): + def _call_later(self, task_group, delay, callback): """Schedule a function to be called later Override for different IOLoop implementations @@ -433,7 +426,12 @@ def _call_later(self, delay, callback): Tornado and asyncio happen to both have ioloop.call_later with the same signature. """ - return self._get_loop().call_later(delay, callback) + + async def call_later(): + await sleep(delay) + callback() + + return create_task(call_later(), task_group) @staticmethod def _remove_finished_future(future, event_list, event=None): @@ -453,7 +451,7 @@ def _remove_finished_future(future, event_list, event=None): # usually this will have been removed by being consumed return - def _add_recv_event(self, kind, kwargs=None, future=None): + async def _add_recv_event(self, task_group, kind, kwargs=None, future=None): """Add a recv event, returning the corresponding Future""" f = future or self._Future() if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT: @@ -465,13 +463,13 @@ def _add_recv_event(self, kind, kwargs=None, future=None): f.set_exception(e) else: f.set_result(r) - return f + return await f.wait() timer = _NoTimer if hasattr(_zmq, 'RCVTIMEO'): timeout_ms = self._shadow_sock.rcvtimeo if timeout_ms >= 0: - timer = self._add_timeout(f, timeout_ms * 1e-3) + timer = self._add_timeout(task_group, f, timeout_ms * 1e-3) # we add it to the list of futures before we add the timeout as the # timeout will remove the future from recv_futures to avoid leaks @@ -480,7 +478,7 @@ def _add_recv_event(self, kind, kwargs=None, future=None): if self._shadow_sock.get(EVENTS) & POLLIN: # recv immediately, if we can - self._handle_recv() + self._handle_recv(task_group) if self._recv_futures and _future_event in self._recv_futures: # Don't let the Future sit in _recv_events after it's done # no need to register this if we've already been handled @@ -492,10 +490,12 @@ def _add_recv_event(self, kind, kwargs=None, future=None): event=_future_event, ) ) - self._add_io_state(POLLIN) - return f + self._add_io_state(task_group, POLLIN) + return await f.wait() - def _add_send_event(self, kind, msg=None, kwargs=None, future=None): + async def _add_send_event( + self, task_group, kind, msg=None, kwargs=None, future=None + ): """Add a send event, returning the corresponding Future""" f = future or self._Future() # attempt send with DONTWAIT if no futures are waiting @@ -529,14 +529,14 @@ def _add_send_event(self, kind, msg=None, kwargs=None, future=None): # short-circuit resolved, return finished Future # schedule wake for recv if there are any receivers waiting if self._recv_futures: - self._schedule_remaining_events() - return f + self._schedule_remaining_events(task_group) + return await f.wait() timer = _NoTimer if hasattr(_zmq, 'SNDTIMEO'): timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO) if timeout_ms >= 0: - timer = self._add_timeout(f, timeout_ms * 1e-3) + timer = self._add_timeout(task_group, f, timeout_ms * 1e-3) # we add it to the list of futures before we add the timeout as the # timeout will remove the future from recv_futures to avoid leaks @@ -551,10 +551,10 @@ def _add_send_event(self, kind, msg=None, kwargs=None, future=None): ) ) - self._add_io_state(POLLOUT) - return f + self._add_io_state(task_group, POLLOUT) + return await f.wait() - def _handle_recv(self): + def _handle_recv(self, task_group): """Handle recv events""" if not self._shadow_sock.get(EVENTS) & POLLIN: # event triggered, but state may have been changed between trigger and callback @@ -569,12 +569,12 @@ def _handle_recv(self): break if not self._recv_futures: - self._drop_io_state(POLLIN) + self._drop_io_state(task_group, POLLIN) if f is None: return - timer.cancel() + timer.cancel(raise_exception=False) if kind == 'poll': # on poll event, just signal ready, nothing else. @@ -595,7 +595,7 @@ def _handle_recv(self): else: f.set_result(result) - def _handle_send(self): + def _handle_send(self, task_group): if not self._shadow_sock.get(EVENTS) & POLLOUT: # event triggered, but state may have been changed between trigger and callback return @@ -609,7 +609,7 @@ def _handle_send(self): break if not self._send_futures: - self._drop_io_state(POLLOUT) + self._drop_io_state(task_group, POLLOUT) if f is None: return @@ -636,7 +636,7 @@ def _handle_send(self): f.set_result(result) # event masking from ZMQStream - def _handle_events(self, fd=0, events=0): + async def _handle_events(self, task_group, fd=0, events=0): """Dispatch IO events to _handle_recv, etc.""" if self._shadow_sock.closed: return @@ -645,10 +645,10 @@ def _handle_events(self, fd=0, events=0): if zmq_events & _zmq.POLLIN: self._handle_recv() if zmq_events & _zmq.POLLOUT: - self._handle_send() - self._schedule_remaining_events() + self._handle_send(task_group) + self._schedule_remaining_events(task_group) - def _schedule_remaining_events(self, events=None): + def _schedule_remaining_events(self, task_group, events=None): """Schedule a call to handle_events next loop iteration If there are still events to handle. @@ -662,37 +662,34 @@ def _schedule_remaining_events(self, events=None): if events is None: events = self._shadow_sock.get(EVENTS) if events & self._state: - self._call_later(0, self._handle_events) + create_task(self._handle_events(task_group), task_group) - def _add_io_state(self, state): + def _add_io_state(self, task_group, state): """Add io_state to poller.""" if self._state != state: state = self._state = self._state | state - self._update_handler(self._state) + self._update_handler(task_group, self._state) - def _drop_io_state(self, state): + def _drop_io_state(self, task_group, state): """Stop poller from watching an io_state.""" if self._state & state: self._state = self._state & (~state) - self._update_handler(self._state) + self._update_handler(task_group, self._state) - def _update_handler(self, state): + def _update_handler(self, task_group, state): """Update IOLoop handler with state. zmq FD is always read-only. """ - # ensure loop is registered and init_io has been called - # if there are any events to watch for - if state: - self._get_loop() - self._schedule_remaining_events() + self._schedule_remaining_events(task_group) - def _init_io_state(self, loop=None): + def _init_io_state(self, task_group, loop=None): """initialize the ioloop event handler""" - if loop is None: - loop = self._get_loop() - loop.add_handler(self._shadow_sock, self._handle_events, self._READ) - self._call_later(0, self._handle_events) + if not self._event_handler_initialized: + self._event_handler_initialized = True + sel = selectors.DefaultSelector() + sel.register(self._shadow_sock, self._READ, self._handle_events) + create_task(self._handle_events(task_group), task_group) def _clear_io_state(self): """unregister the ioloop event handler diff --git a/zmq/asyncio.py b/zmq/asyncio.py index 22bbd14d4..58c8f551a 100644 --- a/zmq/asyncio.py +++ b/zmq/asyncio.py @@ -11,9 +11,11 @@ import selectors import sys import warnings -from asyncio import Future, SelectorEventLoop +from asyncio import SelectorEventLoop from weakref import WeakKeyDictionary +from anyioutils import Future + import zmq as _zmq from zmq import _future @@ -119,13 +121,13 @@ def _default_loop(self): class Poller(_AsyncIO, _future._AsyncPoller): """Poller returning asyncio.Future for poll results.""" - def _watch_raw_socket(self, loop, socket, evt, f): + def _watch_raw_socket(self, socket, evt, f): """Schedule callback for a raw socket""" - selector = _get_selector(loop) + selelector = selectors.DefaultSelector() if evt & self._READ: - selector.add_reader(socket, lambda *args: f()) + selelector.register(socket, self._READ, lambda *args: f()) if evt & self._WRITE: - selector.add_writer(socket, lambda *args: f()) + selelector.register(socket, self._WRITE, lambda *args: f()) def _unwatch_raw_sockets(self, loop, *sockets): """Unschedule callback for a raw socket""" @@ -145,20 +147,20 @@ def _get_selector(self, io_loop=None): io_loop = self._get_loop() return _get_selector(io_loop) - def _init_io_state(self, io_loop=None): - """initialize the ioloop event handler""" - self._get_selector(io_loop).add_reader( - self._fd, lambda: self._handle_events(0, 0) - ) + # def _init_io_state(self, io_loop=None): + # """initialize the ioloop event handler""" + # self._get_selector(io_loop).add_reader( + # self._fd, lambda: self._handle_events(0, 0) + # ) def _clear_io_state(self): """clear any ioloop event handler called once at close """ - loop = self._current_loop - if loop and not loop.is_closed() and self._fd != -1: - self._get_selector(loop).remove_reader(self._fd) + # loop = self._current_loop + # if loop and not loop.is_closed() and self._fd != -1: + # self._get_selector(loop).remove_reader(self._fd) Poller._socket_class = Socket diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py index a4a906b53..c8251dd7b 100644 --- a/zmq/sugar/socket.py +++ b/zmq/sugar/socket.py @@ -857,7 +857,7 @@ def send_serialized(self, msg, serialize, flags=0, copy=True, **kwargs): frames = serialize(msg) return self.send_multipart(frames, flags=flags, copy=copy, **kwargs) - def recv_serialized(self, deserialize, flags=0, copy=True): + async def recv_serialized(self, deserialize, flags=0, copy=True): """Receive a message with a custom deserialization function. .. versionadded:: 17 @@ -883,7 +883,7 @@ def recv_serialized(self, deserialize, flags=0, copy=True): ZMQError for any of the reasons :func:`~Socket.recv` might fail """ - frames = self.recv_multipart(flags=flags, copy=copy) + frames = await self.recv_multipart(flags=flags, copy=copy) return self._deserialize(frames, deserialize) def send_string( @@ -914,7 +914,7 @@ def send_string( send_unicode = send_string - def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str: + async def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str: """Receive a unicode string, as sent by send_string. Parameters @@ -934,12 +934,12 @@ def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str: ZMQError for any of the reasons :func:`Socket.recv` might fail """ - msg = self.recv(flags=flags) + msg = await self.recv(flags=flags) return self._deserialize(msg, lambda buf: buf.decode(encoding)) recv_unicode = recv_string - def send_pyobj( + async def send_pyobj( self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs ) -> zmq.Frame | None: """ @@ -965,9 +965,9 @@ def send_pyobj( where defined, and pickle.HIGHEST_PROTOCOL elsewhere. """ msg = pickle.dumps(obj, protocol) - return self.send(msg, flags=flags, **kwargs) + return await self.send(msg, flags=flags, **kwargs) - def recv_pyobj(self, flags: int = 0) -> Any: + async def recv_pyobj(self, flags: int = 0) -> Any: """ Receive a Python object as a message using UNSAFE pickle to serialize. @@ -995,10 +995,10 @@ def recv_pyobj(self, flags: int = 0) -> Any: ZMQError for any of the reasons :func:`~Socket.recv` might fail """ - msg = self.recv(flags) + msg = await self.recv(flags) return self._deserialize(msg, pickle.loads) - def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None: + async def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None: """Send a Python object as a message using json to serialize. Keyword arguments are passed on to json.dumps @@ -1015,9 +1015,11 @@ def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None: if key in kwargs: send_kwargs[key] = kwargs.pop(key) msg = jsonapi.dumps(obj, **kwargs) - return self.send(msg, flags=flags, **send_kwargs) + return await self.send(msg, flags=flags, **send_kwargs) - def recv_json(self, flags: int = 0, **kwargs) -> list | str | int | float | dict: + async def recv_json( + self, flags: int = 0, **kwargs + ) -> list | str | int | float | dict: """Receive a Python object as a message using json to serialize. Keyword arguments are passed on to json.loads @@ -1037,7 +1039,7 @@ def recv_json(self, flags: int = 0, **kwargs) -> list | str | int | float | dict ZMQError for any of the reasons :func:`~Socket.recv` might fail """ - msg = self.recv(flags) + msg = await self.recv(flags) return self._deserialize(msg, lambda buf: jsonapi.loads(buf, **kwargs)) _poller_class = Poller