'asyncio.wait not returning on first exception

I have an AMQP publisher class with the following methods. on_response is the callback that is called when a consumer sends back a message to the RPC queue I setup. I.e. the self.callback_queue.name you see in the reply_to of the Message. publish publishes out to a direct exchange with a routing key that has multiple consumers (very similar to a fanout), and multiple responses come back. I create a number of futures equal to the number of responses I expect, and asyncio.wait for those futures to complete. As I get responses back on the queue and consume them, I set the result to the futures.

    async def on_response(self, message: IncomingMessage):
        if message.correlation_id is None:
            logger.error(f"Bad message {message!r}")
            await message.ack()
            return
        body = message.body.decode('UTF-8')
        future = self.futures[message.correlation_id].pop()
        if hasattr(body, 'error'):
            future.set_execption(body)
        else:
            future.set_result(body)
        await message.ack()

    async def publish(self, routing_key, expected_response_count, msg, timeout=None, return_partial=False):
        if not self.connected:
            logger.info("Publisher not connected. Waiting to connect first.")
            await self.connect()
        
        correlation_id = str(uuid.uuid4())
        futures = [self.loop.create_future() for _ in range(expected_response_count)]
        self.futures[correlation_id] = futures
        await self.exchange.publish(
            Message(
                str(msg).encode(),
                content_type="text/plain",
                correlation_id=correlation_id,
                reply_to=self.callback_queue.name,
            ),
            routing_key=routing_key,
        )
        done, pending = await asyncio.wait(futures, timeout=timeout, return_when=asyncio.FIRST_EXCEPTION)
        if not return_partial and pending:
            raise asyncio.TimeoutError(f'Failed to return all results for publish to {routing_key}')

        for f in pending:
            f.cancel()
        del self.futures[correlation_id]

        results = []
        for future in done:
            try:
                results.append(json.loads(future.result()))
            except json.decoder.JSONDecodeError as e:
                logger.error(f'Client did not return JSON!! {e!r}')
                logger.info(future.result())
        return results

My goal is to either wait until all futures are finished, or a timeout occurs. This is all working nicely at the moment. What doesn't work, is when I added return_when=asyncio.FIRST_EXCEPTION, the asyncio.wait.. does not finish after the first call of future.set_exception(...) as I thought it would.

What do I need to do with the future so that when I get a response back and see that an error occurred on the consumer side (before the timeout, or even other responses) the await asyncio.wait will no longer be blocking. I was looking at the documentation and it says:

The function will return when any future finishes by raising an exception

when return_when=asyncio.FIRST_EXCEPTION. My first thought is that I'm not raising an exception in my future correctly, only, I'm having trouble finding out exactly how I should do that then. From the API documentation for the Future class, it looks like I'm doing the right thing.



Solution 1:[1]

When I created a minimum viable example, I realized I was actually doing things MOSTLY right after all, and I glanced over other errors causing this not to work. Here is my minimum example:

The most important change I had to do was actually pass in an Exception object.. (subclass of BaseException) do the set_exception method.

import asyncio

async def set_after(future, t, body, raise_exception):
    await asyncio.sleep(t)
    if raise_exception:
        future.set_exception(Exception("problem"))
    else:
        future.set_result(body)
    print(body)

async def main():
    loop = asyncio.get_event_loop()
    futures = [loop.create_future() for _ in range(2)]
    asyncio.create_task(set_after(futures[0], 3, 'hello', raise_exception=True))
    asyncio.create_task(set_after(futures[1], 7, 'world', raise_exception=False))
    print(futures)
    done, pending = await asyncio.wait(futures, timeout=10, return_when=asyncio.FIRST_EXCEPTION)
    print(done)
    print(pending)


asyncio.run(main())

In this line of code if hasattr(body, 'error'):, body was a string. I thought it was JSON at that point already. Should have been using "error" in body as my condition in any case. whoops!

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1