skip to Main Content

I am using Starlette framework for my website and using websockets. I have a class that derives from WebSocketEndpoint that handles the incoming websocket connections. I have another class that I am using to manage Pub/Sub on Redis.

I am trying to get the callback to get called that I am registering when I subscribe but I am not able to achieve the same. While looking at the documentation for subscribe method, it says:

    Subscribe to channels. Channels supplied as keyword arguments expect
    a channel name as the key and a callable as the value. A channel's
    callable will be invoked automatically when a message is received on
    that channel rather than producing a message via ``listen()`` or
    ``get_message()``.

What am I missing here?

WSEndpoint Class:

from starlette.endpoints import WebSocketEndpoint
from starlette.websockets import WebSocket
import json
from blueprints import redis_pubsub_manager

class WSEndpoint(WebSocketEndpoint):
    encoding = 'json'

    def __init__(self, scope, receive, send):
        super().__init__(scope, receive, send)

        self.global_chatroom_channel_name = "globalchat"
        self.connected_users = []
        self.redis_manager_sub = redis_pubsub_manager.RedisPubSubManager(room=self.global_chatroom_channel_name)
        self.redis_manager_pub = redis_pubsub_manager.RedisPubSubManager(room=self.global_chatroom_channel_name)

        self.global_subscription = None

        print("done")

    async def on_connect(self, websocket: WebSocket):
        await websocket.accept()
        self.connected_users.append(websocket)
        if self.global_subscription is None:
            await self.redis_manager_sub.connect()
            await self.redis_manager_pub.connect()
            self.global_subscription = await self.redis_manager_sub.subscribe(callback = self.publish_message_to_subscribers)

        print(f"Socket Connected: {websocket}")


    async def on_receive(self, websocket: WebSocket, data):
        await self.redis_manager_pub._publish(json.dumps(data))

    async def publish_message_to_subscribers(msg):
        print("inside publish_message_to_subscribers")
        print(msg)

RedisPubSubManager Class:

import asyncio
import redis.asyncio as aioredis
import json

class RedisPubSubManager:
    def __init__(self, host='localhost', port=6379, room="globalchat"):
        self.redis_host = host
        self.redis_port = port
        self.pubsub = None
        self.room = room

    async def _get_redis_connection(self) -> aioredis.Redis:
        return aioredis.Redis(host=self.redis_host,
                              port=self.redis_port,
                              auto_close_connection_pool=False)
    
    async def connect(self) -> None:
        self.redis_connection = await self._get_redis_connection()
        self.pubsub = self.redis_connection.pubsub()


    async def _publish(self, message) -> None:
        await self.redis_connection.publish(self.room, message)

    async def subscribe(self, callback) -> aioredis.Redis:
        await self.pubsub.subscribe(**{"globalchat": callback})

        return self.pubsub

2

Answers


  1. Recently I also faced this exact issue and after checking a lot of resources came to below conclusion.

    If you look at internal implementation of Redis library code, specially implementation of .listen() method ( shared below for aioredis==2.0.1, I guess implementation for the Redis Python SDK would be same/similar but you can cross reference it once ), notice that its fetching redis msg and then doing some processing, finding the handler ( it can be passed at the time of calling .subscribe function, which you have done correctly ) and calling handler with message or yielding the result directly.

    async def listen(self) -> AsyncIterator:
        """Listen for messages on channels this client has been subscribed to"""
        while self.subscribed:
            response = self.handle_message(await self.parse_response(block=True))
            if response is not None:
                yield response
    
    
    
    
    def handle_message(self, response, ignore_subscribe_messages=False):
        """
        Parses a pub/sub message. If the channel or pattern was subscribed to
        with a message handler, the handler is invoked instead of a parsed
        message being returned.
        """
        message_type = str_if_bytes(response[0])
        if message_type == "pmessage":
            message = {
                "type": message_type,
                "pattern": response[1],
                "channel": response[2],
                "data": response[3],
            }
        elif message_type == "pong":
            message = {
                "type": message_type,
                "pattern": None,
                "channel": None,
                "data": response[1],
            }
        else:
            message = {
                "type": message_type,
                "pattern": None,
                "channel": response[1],
                "data": response[2],
            }
    
        # if this is an unsubscribe message, remove it from memory
        if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
            if message_type == "punsubscribe":
                pattern = response[1]
                if pattern in self.pending_unsubscribe_patterns:
                    self.pending_unsubscribe_patterns.remove(pattern)
                    self.patterns.pop(pattern, None)
            else:
                channel = response[1]
                if channel in self.pending_unsubscribe_channels:
                    self.pending_unsubscribe_channels.remove(channel)
                    self.channels.pop(channel, None)
    
        if message_type in self.PUBLISH_MESSAGE_TYPES:
            # if there's a message handler, invoke it
            if message_type == "pmessage":
                handler = self.patterns.get(message["pattern"], None)
            else:
                handler = self.channels.get(message["channel"], None)
            if handler:
                handler(message)
                return None
        elif message_type != "pong":
            # this is a subscribe/unsubscribe message. ignore if we don't
            # want them
            if ignore_subscribe_messages or self.ignore_subscribe_messages:
                return None
    
        return message
    

    So I am afraid you will have to create a co-routine who will call .listen method and either you can handle the Redis msg from within this Co-routine or inside your handler.

    Final code would look something like this:

    class WSEndpoint(WebSocketEndpoint):
        encoding = 'json'
    
        def __init__(self, scope, receive, send):
            asyncio.ensure_future(                        
                self.__pubsub_data_reader(
                    self.redis_manager_sub
                )
            )
    
    
        async def __pubsub_data_reader(self, pubsub):
            """
            Reads messages received from Redis PubSub.
    
            Args:
            pubsub (aioredis.ChannelSubscribe): PubSub object for the subscribed channel.
            """
            print("started __pubsub_data_reader")
            try:
                async for message in pubsub.listen():
                    pass
            except Exception as exc:
                print("Got error: ", exc)
    

    See if this helps, if not, I can share the complete Pub-Sub code with you which can help you understand the flow further.

    PS: In Python it does not feel true Asynchronous, because finally there is a while loop who is constantly checking something ( as compared to counter parts like Node.JS ). If anyone knows better, would love to understand their views. Also if anyone knows any better solution, please do let know because above one seems kind of hacky 🙂

    Login or Signup to reply.
  2. I ran into this same issue while transforming my previous "non-async redis subscriptions running in threads" solution to async based solution.

    I found this while trying to read the aioredis code (specifically the PubSub class):

    async def run(
            self,
            *,
            exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
            poll_timeout: float = 1.0,
        ) -> None:
            """Process pub/sub messages using registered callbacks.
    
            This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
            redis-py, but it is a coroutine. To launch it as a separate task, use
            ``asyncio.create_task``:
    
                >>> task = asyncio.create_task(pubsub.run())
    
            To shut it down, use asyncio cancellation:
    
                >>> task.cancel()
                >>> await task
            """
    

    I followed the above mentioned approach (recommended by the authors of aioredis) and that works just fine, while using uvicorn and fastapi with socketio. You might have to adapt for your own app flows using WebSockets from fastapi, but it should still work.

    Right after calling subscribe with handler passed in as callable, just add the following line:

    asyncio.create_task(pubsub.run())
    
    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search