skip to Main Content

I use an object that needs a startup and a teardown process (load from/save to cache, for example) in a FastAPI endpoints. I used a asynccontextmanager to manage the context of an object, but I also want to process the object in a later background task.

Now in my environments (fastapi==0.115.5) the context of this object ends before responding the request, but this is typically earlier than the end of background task, so some of the background task is executed out of the context. For example, if there is a "save to cache" process in the teardown part of the context manager, the later changes in the background task would not be saved, because it runs after the teardown process.

There is a minimal (but still ~150 lines) working example on this gist. I’ll also paste it here.

from fastapi import FastAPI, Depends, BackgroundTasks, Request
from typing import Annotated, AsyncIterator
from pydantic import BaseModel, Field
from uuid import uuid4
from contextlib import asynccontextmanager
import random
import asyncio

app = FastAPI()


class Chat(BaseModel):
    """
    This is a over-simplified Chat History Manager, that can be used in e.g. LangChain-like system
    There is an additional `total` field because history are serialized and cached on their own, and we don't want to load all histories when unserialize them from cache/database.
    """

    id: str = Field(default_factory=lambda: uuid4().hex)
    meta: str = "some meta information"
    history: list[str] = []
    total: int = 0
    uncached: int = 0

    def add_message(self, msg: str):
        self.history.append(msg)
        self.total += 1
        self.uncached += 1

    async def save(self, cache: dict):
        # cache history that are not cached
        for imsg in range(-self.uncached, 0):
            cache[f"msg:{self.id}:{self.total + imsg}"] = self.history[-self.uncached]
        self.uncached = 0
        # cache everything except history
        cache[f"sess:{self.id}"] = self.model_dump(exclude={"history"})

        print(f"saved: {self}")

    @classmethod
    async def load(cls, sess_id: str, cache: dict, max_read: int = 30):
        sess_key = f"sess:{sess_id}"
        obj = cls.model_validate(cache.get(sess_key))
        for imsg in range(max(0, obj.total - max_read), obj.total):
            obj.history.append(cache.get(f"msg:{obj.id}:{imsg}"))

        print(f"loaded: {obj}")
        return obj

    async def chat(self, msg: str, cache: dict):
        """So this"""
        self.add_message(msg)

        async def get_chat():
            resp = []
            for i in range(random.randint(3, 5)):
                # simulate long network IO
                await asyncio.sleep(0.5)
                chunk = f"resp{i}:{random.randbytes(2).hex()};"

                resp.append(chunk)
                yield chunk

            self.add_message("".join(resp))

            # NOTE to make the message cache work properly, we have to manually save this:
            # await self.save(cache)

        return get_chat()


# use a simple dict to mimic an actual cache, e.g. Redis
cache = {}


async def get_cache():
    return cache


# didn't figure out how to make Chat a dependable
# I have read https://fastapi.tiangolo.com/advanced/advanced-dependencies/#parameterized-dependencies but still no clue
# the problem is: `sess_id` is passed from user, not something we can fix just like this tutorial shows.
# As an alternative, I used this async context manager.
# Theoretically this would automatically save the Chat object after exiting the `async with` block
@asynccontextmanager
async def get_chat_from_cache(sess_id: str, cache: dict):
    """
    get object from cache (possibly create one), yield it, then save it back to cache
    """
    sess_key = f"sess:{sess_id}"
    if sess_key not in cache:
        obj = Chat()
        obj.id = sess_id
        await obj.save(cache)
    else:
        obj = await Chat.load(sess_id, cache)

    yield obj

    await obj.save(cache)


async def task(sess_id: str, task_id: int, resp_gen: AsyncIterator[str], cache: dict):
    """ """
    async for chunk in resp_gen:
        # do something with chunk, e.g. stream it to the client via a websocket
        await asyncio.sleep(0.5)
        cache[f"chunk:{sess_id}:{task_id}"] = chunk
        task_id += 1


@app.get("/{sess_id}/{task_id}/{prompt}")
async def get_chat(
    req: Request,
    sess_id: str,
    task_id: int,
    prompt: str,
    background_task: BackgroundTasks,
    cache: Annotated[dict, Depends(get_cache)],
):
    print(f"req incoming: {req.url}")
    async with get_chat_from_cache(sess_id=sess_id, cache=cache) as chat:
        resp_gen = await chat.chat(f"prompt:{prompt}", cache=cache)

        background_task.add_task(
            task, sess_id=sess_id, task_id=task_id, resp_gen=resp_gen, cache=cache
        )

    return "success"


@app.get("/{sess_id}")
async def get_sess(
    req: Request, sess_id: str, cache: Annotated[dict, Depends(get_cache)]
):
    print(f"req incoming: {req.url}")
    return (await Chat.load(sess_id=sess_id, cache=cache)).model_dump()

I found a close (but not identical) discussion that talks about the lifespan of dependables. It seems the lifespan of dependable could be relayed/extended to into the background tasks, though they think this is a wield behavior. I did have the thought of making the get_chat_from_cache a yield based dependable, though I didn’t figure out how to do it correctly. But anyway, this approach seems not recommended by FastAPI devs, because the actual timing of teardown of dependables are undocumented behaviors and might change in future versions.

I know I could probably manually repeat a teardown process in the background task, but this seems like a hack. I’m asking if there are more elegant ways to do this. Perhaps there are better design patterns that can avoid this issue completely, please let me know.

2

Answers


  1. Chosen as BEST ANSWER

    Thanks to Yurii for letting me know the change of behavior at FastAPI 0.106.0. It seems impossible to extend the lifespan of object (or at least cannot be done in a clean way).

    After some research I realize the best practice is still to serialize and cache the object before the end of endpoint, and start a new context in a background task. So there has to be a different context manager that doesn't load object from cache, but instead use an existing object, and still save it to the cache.

    Here I give my modification for reference, just added an async context manager method autosave, and used it in the async generator. In this example it's equivalent to manually save, but this pattern can extend to more complicated scenarios.

        @asynccontextmanager
        async def autosave(self, cache: dict):
    
            yield 
            await self.save(cache)
    
        async def chat(self, msg: str, cache: dict):
            self.add_message(msg)
    
            async def get_chat():
                async with self.autosave(cache=cache):
                    resp = []
                    for i in range(random.randint(3, 5)):
                        # simulate long network IO
                        await asyncio.sleep(0.1)
                        chunk = f"resp{i}:{random.randbytes(2).hex()};"
    
                        resp.append(chunk)
                        yield chunk
    
                    self.add_message("".join(resp))
    
    

  2. Background tasks are executed after your endpoint has finished execution.
    Thus, you cannot keep the context manager open until the background task is completed.

    Turning get_chat_from_cache into a dependency will not help you (it worked before FastAPI 0.106.0 but the behavior was changed and now you can not use dependencies with yield in background tasks).

    You need to re-design your app considering this..

    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search