fix: patch core memory edit bug/regression (#1695)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer 2024-08-29 13:12:24 -07:00 committed by GitHub
parent 619f990d48
commit ffa21cc623
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 12 deletions

View File

@ -61,7 +61,7 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
text: Optional[str] = Field(None, description="The text of the message.")
user_id: str = Field(None, description="The unique identifier of the user.")
agent_id: str = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")

View File

@ -6,8 +6,6 @@ import traceback
import warnings
from abc import abstractmethod
from datetime import datetime
from functools import wraps
from threading import Lock
from typing import Callable, List, Optional, Tuple, Union
from fastapi import HTTPException
@ -1584,10 +1582,11 @@ class SyncServer(Server):
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
"""Create a new tool"""
if request.tags and "memory" in request.tags:
# special modifications to memory functions
# self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
request.source_code = request.source_code.replace("self.memory", "self.memory.memory")
# NOTE: deprecated code that existed when we were trying to pretend that `self` was the memory object
# if request.tags and "memory" in request.tags:
# # special modifications to memory functions
# # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
# request.source_code = request.source_code.replace("self.memory", "self.memory.memory")
if not request.json_schema:
# auto-generate openai schema
@ -1605,9 +1604,7 @@ class SyncServer(Server):
# TODO: not sure if this always works
func = env[functions[-1]]
print("FUNCTION", func)
json_schema = generate_schema(func, request.name)
print(json_schema)
else:
# provided by client
json_schema = request.json_schema

View File

@ -8,7 +8,7 @@ from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message
from memgpt.schemas.memgpt_message import InternalMonologue
from memgpt.schemas.usage import MemGPTUsageStatistics
# from tests.utils import create_config
@ -127,7 +127,7 @@ def test_agent_interactions(client, agent):
assert response.usage.step_count == 1
assert response.usage.total_tokens > 0
assert response.usage.completion_tokens > 0
assert isinstance(response.messages[0], Message)
assert isinstance(response.messages[0], InternalMonologue)
print(response.messages)
# TODO: add streaming tests
@ -167,6 +167,14 @@ def test_archival_memory(client, agent):
client.get_archival_memory(agent.id)
def test_core_memory(client, agent):
response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user")
print("Response", response)
memory = client.get_in_context_memory(agent_id=agent.id)
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
def test_messages(client, agent):
# _reset_config()

View File

@ -1,6 +1,9 @@
from typing import Union
import pytest
from memgpt import create_client
from memgpt.client.client import LocalClient, RESTClient
from memgpt.schemas.block import Block
from memgpt.schemas.memory import BlockChatMemory, ChatMemory, Memory
@ -19,7 +22,7 @@ def agent(client):
assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}"
def test_agent(client):
def test_agent(client: Union[LocalClient, RESTClient]):
tools = client.list_tools()