You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
3.4 KiB

import asyncio
import time
import uuid
from typing import List, Dict
from fastapi.websockets import WebSocket
from starlette.websockets import WebSocketState
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.unacknowledged_messages: Dict[str, Dict[str, any]] = {}
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast_stream(self, data: bytes):
for connection in self.active_connections:
try:
await connection.send_bytes(data)
except (ConnectionClosedError, ConnectionClosedOK):
self.disconnect(connection)
async def broadcast_json(self, event: str, data: str):
for connection in self.active_connections:
try:
await connection.send_json({"event": event, "data": data})
except (ConnectionClosedError, ConnectionClosedOK):
self.disconnect(connection)
async def send_message(self, event: str, data: str, message_id:str = None):
if message_id is None:
message_id = str(uuid.uuid4())
try:
for connection in self.active_connections:
if connection.client_state == WebSocketState.CONNECTED:
await connection.send_json({"event": event, "data": data, "message_id": message_id})
if message_id not in self.unacknowledged_messages:
self.unacknowledged_messages[message_id] = {
"message": {"event": event, "data": data, "message_id": message_id},
"timestamp": time.time(),
"websocket": connection
}
except (ConnectionClosedError, ConnectionClosedOK):
self.disconnect(connection)
async def handle_acknowledgment(self, message_id: str):
if message_id in self.unacknowledged_messages:
del self.unacknowledged_messages[message_id]
async def check_unacknowledged_messages(self):
while True:
for message_id, message_info in list(self.unacknowledged_messages.items()):
current_time = time.time()
if current_time - message_info["timestamp"] > 5: # 设置超时时间为5秒
await self.send_message(
message_info["message"]["event"],
message_info["message"]["data"],
message_info["message"]["message_id"]
)
await asyncio.sleep(3) # 等待一段时间再次检查
async def handle_message(self, message:str):
message_id = message
if message_id is not None:
await self.handle_acknowledgment(message_id)
async def start_message_check(self):
self.message_check_task = asyncio.create_task(self.check_unacknowledged_messages())
async def stop_message_check(self):
if hasattr(self, 'message_check_task'):
self.message_check_task.cancel()
try:
await self.message_check_task
except asyncio.CancelledError:
pass
manager = ConnectionManager()