'FastAPI: reject a WebSocket connection with HTTP response

In a FastAPI based Web app, I have a WebSocket endpoint that should allow connections only if some conditions are met, otherwise it should return an HTTP 404 reply instead of upgrading the connection with HTTP 101.

As far as I understand, this is fully supported by the protocol, But I couldn't find any way to do it with FastAPI or Starlette.

If I have something like:

@router.websocket("/foo")
async def ws_foo(request: WebSocket):
    if _user_is_allowed(request):
        await request.accept()
        _handle_ws_connection(request)
    else:
        raise HTTPException(status_code=404)

The exception isn't converted to a 404 response, as FastAPI's ExceptionMiddleware doesn't seem to handle such cases.

Is there any native / built-in way of supporting this kind of "reject" flow?



Solution 1:[1]

Once a handshake is completed, the protocol changes from HTTP to WebSocket. If you attempted to raise an exception inside the websocket endpoint, you would see that is not possible, or return an HTTP response (e.g., return JSONResponse(...status_code=404)), you would get an internal error ASGI callable returned without sending handshake.

Option 1

Thus, if you would like to have some kind of checking mechanism before the protocol is upgraded, you would need to use a Middleware, as shown below. Inside the middleware, you can't raise an exception, but you can return a response (i.e., Response, JSONResponse, PlainTextResponse, etc), which is actually how FastAPI handles exceptions behind the scenes. As a reference, please have a look at this question, as well as the discussion here.

async def is_user_allowed(request: Request):
    # if conditions are not met, return False
    print(request['headers'])
    print(request.client)
    return False

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    if not await is_user_allowed(request):
        return JSONResponse(content={"message": "User not allowed"}, status_code=404)
    response = await call_next(request)
    return response

or, if you prefer, you can have is_user_allowed() method raising a custom exception that you need to catch with a try-except block:

class UserException(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(message)

async def is_user_allowed(request: Request):
    # if conditions are not met, raise UserException
    raise UserException(message="User not allowed.")
    
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    try:
        await is_user_allowed(request)
    except UserException as e:
        return JSONResponse(content={"message": f'{e.message}'}, status_code=404)
    response = await call_next(request)
    return response

Option 2

If, however, you need to do that using the websocket instance, you could have the same logic as above, but, instead, pass the websocket instance in is_user_allowed() method, and catch the exception inside the websocket endpoint (inspired by this).

@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
    await ws.accept()
    try:
        await is_user_allowed(ws)
        await handle_conn(ws)
    except UserException as e:
        await ws.send_text(e.message) # optionally send a message to the client before closing the connection
        await ws.close()

In the above, however, you would have to accept the connection first, so that you can call the close() method to terminate the connection, if exception is raised. If you prefer, you could use something like the below. However, that return statement insdie the except block would throw an internal server error ASGI callable returned without sending handshake., as described earlier.

@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
    try:
        await is_user_allowed(ws)
    except UserException as e:
        return
    await ws.accept()
    await handle_conn(ws)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1