MemGPT/letta/server/ws_api/server.py
2024-12-26 19:43:11 -08:00

141 lines
5.7 KiB
Python

import asyncio
import signal
import sys
import traceback
import websockets
import letta.server.ws_api.protocol as protocol
from letta.server.constants import WS_DEFAULT_PORT
from letta.server.server import SyncServer
from letta.server.ws_api.interface import SyncWebSocketInterface
class WebSocketServer:
def __init__(self, host="localhost", port=WS_DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()
self.server = SyncServer(default_interface=self.interface)
def shutdown_server(self):
try:
self.interface.close()
print(f"Closed the WS interface")
except Exception as e:
print(f"Closing the WS interface failed with: {e}")
def initialize_server(self):
print("Server is initializing...")
print(f"Listening on {self.host}:{self.port}...")
async def start_server(self):
self.initialize_server()
# Can play with ping_interval and ping_timeout
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
# and https://github.com/letta-ai/letta/issues/471
async with websockets.serve(self.handle_client, self.host, self.port):
await asyncio.Future() # Run forever
def run(self):
return self.start_server() # Return the coroutine
async def handle_client(self, websocket, path):
self.interface.register_client(websocket)
try:
# async for message in websocket:
while True:
message = await websocket.recv()
# Assuming the message is a JSON string
try:
data = json_loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
continue
if "type" not in data:
print(f"[server] bad data from client (JSON but no type):\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
elif data["type"] == "command":
# Create a new agent
if data["command"] == "create_agent":
try:
# self.agent = self.create_new_agent(data["config"])
self.server.create_agent(user_id="NULL", agent_config=data["config"])
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
except Exception as e:
self.agent = None
print(f"[server] self.create_new_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
else:
print(f"[server] unrecognized client command type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
elif data["type"] == "user_message":
user_message = data["message"]
if "agent_id" not in data or data["agent_id"] is None:
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
continue
await websocket.send(protocol.server_agent_response_start())
try:
# self.run_step(user_message)
self.server.user_message(user_id="NULL", agent_id=data["agent_id"], message=user_message)
except Exception as e:
print(f"[server] self.server.user_message failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_agent_response_error(f"server.user_message failed with: {e}"))
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
await websocket.send(protocol.server_agent_response_end())
# ... handle other message types as needed ...
else:
print(f"[server] unrecognized client package data type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
except websockets.exceptions.ConnectionClosed:
print(f"[server] connection with client was closed")
finally:
self.interface.unregister_client(websocket)
def start_server():
# Check if a port argument is provided
port = WS_DEFAULT_PORT
if len(sys.argv) > 1:
try:
port = int(sys.argv[1])
except ValueError:
print(f"Invalid port number. Using default port {port}.")
server = WebSocketServer(port=port)
def handle_sigterm(*args):
# Perform necessary cleanup
print("SIGTERM received, shutting down...")
# Note: This should be quick and not involve asynchronous calls
print("Shutting down the server...")
server.shutdown_server()
print("Server has been shut down.")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
try:
asyncio.run(server.run())
except KeyboardInterrupt:
print("Shutting down the server...")
finally:
server.shutdown_server()
print("Server has been shut down.")
if __name__ == "__main__":
start_server()