mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
141 lines
5.7 KiB
Python
141 lines
5.7 KiB
Python
import asyncio
|
|
import signal
|
|
import sys
|
|
import traceback
|
|
|
|
import websockets
|
|
|
|
import letta.server.ws_api.protocol as protocol
|
|
from letta.server.constants import WS_DEFAULT_PORT
|
|
from letta.server.server import SyncServer
|
|
from letta.server.ws_api.interface import SyncWebSocketInterface
|
|
|
|
|
|
class WebSocketServer:
|
|
def __init__(self, host="localhost", port=WS_DEFAULT_PORT):
|
|
self.host = host
|
|
self.port = port
|
|
self.interface = SyncWebSocketInterface()
|
|
self.server = SyncServer(default_interface=self.interface)
|
|
|
|
def shutdown_server(self):
|
|
try:
|
|
self.interface.close()
|
|
print(f"Closed the WS interface")
|
|
except Exception as e:
|
|
print(f"Closing the WS interface failed with: {e}")
|
|
|
|
def initialize_server(self):
|
|
print("Server is initializing...")
|
|
print(f"Listening on {self.host}:{self.port}...")
|
|
|
|
async def start_server(self):
|
|
self.initialize_server()
|
|
# Can play with ping_interval and ping_timeout
|
|
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
|
|
# and https://github.com/letta-ai/letta/issues/471
|
|
async with websockets.serve(self.handle_client, self.host, self.port):
|
|
await asyncio.Future() # Run forever
|
|
|
|
def run(self):
|
|
return self.start_server() # Return the coroutine
|
|
|
|
async def handle_client(self, websocket, path):
|
|
self.interface.register_client(websocket)
|
|
try:
|
|
# async for message in websocket:
|
|
while True:
|
|
message = await websocket.recv()
|
|
|
|
# Assuming the message is a JSON string
|
|
try:
|
|
data = json_loads(message)
|
|
except:
|
|
print(f"[server] bad data from client:\n{data}")
|
|
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
|
continue
|
|
|
|
if "type" not in data:
|
|
print(f"[server] bad data from client (JSON but no type):\n{data}")
|
|
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
|
|
|
elif data["type"] == "command":
|
|
# Create a new agent
|
|
if data["command"] == "create_agent":
|
|
try:
|
|
# self.agent = self.create_new_agent(data["config"])
|
|
self.server.create_agent(user_id="NULL", agent_config=data["config"])
|
|
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
|
|
except Exception as e:
|
|
self.agent = None
|
|
print(f"[server] self.create_new_agent failed with:\n{e}")
|
|
print(f"{traceback.format_exc()}")
|
|
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
|
|
|
|
else:
|
|
print(f"[server] unrecognized client command type: {data}")
|
|
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
|
|
|
|
elif data["type"] == "user_message":
|
|
user_message = data["message"]
|
|
|
|
if "agent_id" not in data or data["agent_id"] is None:
|
|
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
|
|
continue
|
|
|
|
await websocket.send(protocol.server_agent_response_start())
|
|
try:
|
|
# self.run_step(user_message)
|
|
self.server.user_message(user_id="NULL", agent_id=data["agent_id"], message=user_message)
|
|
except Exception as e:
|
|
print(f"[server] self.server.user_message failed with:\n{e}")
|
|
print(f"{traceback.format_exc()}")
|
|
await websocket.send(protocol.server_agent_response_error(f"server.user_message failed with: {e}"))
|
|
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
|
|
await websocket.send(protocol.server_agent_response_end())
|
|
|
|
# ... handle other message types as needed ...
|
|
else:
|
|
print(f"[server] unrecognized client package data type: {data}")
|
|
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
print(f"[server] connection with client was closed")
|
|
finally:
|
|
self.interface.unregister_client(websocket)
|
|
|
|
|
|
def start_server():
|
|
# Check if a port argument is provided
|
|
port = WS_DEFAULT_PORT
|
|
if len(sys.argv) > 1:
|
|
try:
|
|
port = int(sys.argv[1])
|
|
except ValueError:
|
|
print(f"Invalid port number. Using default port {port}.")
|
|
|
|
server = WebSocketServer(port=port)
|
|
|
|
def handle_sigterm(*args):
|
|
# Perform necessary cleanup
|
|
print("SIGTERM received, shutting down...")
|
|
# Note: This should be quick and not involve asynchronous calls
|
|
print("Shutting down the server...")
|
|
server.shutdown_server()
|
|
print("Server has been shut down.")
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
|
|
try:
|
|
asyncio.run(server.run())
|
|
except KeyboardInterrupt:
|
|
print("Shutting down the server...")
|
|
finally:
|
|
server.shutdown_server()
|
|
print("Server has been shut down.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
start_server()
|