mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Refactor out the testing functions so we can use this for benchmarking
This commit is contained in:
parent
fade94b66b
commit
80b3946eba
@ -2,9 +2,11 @@ import json
|
||||
import uuid
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from letta import LocalClient, RESTClient
|
||||
from letta import LocalClient, RESTClient, create_client
|
||||
from letta.agent import Agent
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.errors import (
|
||||
InvalidFunctionCallError,
|
||||
InvalidInnerMonologueError,
|
||||
@ -12,7 +14,7 @@ from letta.errors import (
|
||||
MissingFunctionCallError,
|
||||
MissingInnerMonologueError,
|
||||
)
|
||||
from letta.llm_api.llm_api_tools import unpack_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.llm_api_tools import create, unpack_inner_thoughts_from_kwargs
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@ -26,11 +28,16 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
|
||||
from letta.utils import get_human_text, get_persona_text
|
||||
from tests.helpers.utils import cleanup
|
||||
|
||||
# Generate uuid for agent name for this example
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent"))
|
||||
|
||||
# defaults (letta hosted)
|
||||
embedding_config_path = "configs/embedding_model_configs/letta-hosted.json"
|
||||
llm_config_path = "configs/llm_model_configs/letta-hosted.json"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Section: Test Setup
|
||||
@ -41,7 +48,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent"))
|
||||
def setup_agent(
|
||||
client: Union[LocalClient, RESTClient],
|
||||
filename: str,
|
||||
embedding_config_path: str,
|
||||
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
|
||||
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
|
||||
tools: Optional[List[str]] = None,
|
||||
@ -62,6 +68,181 @@ def setup_agent(
|
||||
return agent_state
|
||||
|
||||
|
||||
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False):
|
||||
"""
|
||||
Checks that the first response is valid:
|
||||
|
||||
1. Contains either send_message or archival_memory_search
|
||||
2. Contains valid usage of the function
|
||||
3. Contains inner monologue
|
||||
|
||||
Note: This is acting on the raw LLM response, note the usage of `create`
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
|
||||
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools]
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
tools=tools,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
|
||||
response = create(
|
||||
llm_config=agent_state.llm_config,
|
||||
user_id=str(uuid.UUID(int=1)), # dummy user_id
|
||||
messages=agent._messages,
|
||||
functions=agent.functions,
|
||||
functions_python=agent.functions_python,
|
||||
)
|
||||
|
||||
# Basic check
|
||||
assert response is not None
|
||||
|
||||
# Select first choice
|
||||
choice = response.choices[0]
|
||||
|
||||
# Ensure that the first message returns a "send_message"
|
||||
validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search"
|
||||
assert_contains_valid_function_call(choice.message, validator_func)
|
||||
|
||||
# Assert that the message has an inner monologue
|
||||
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
|
||||
|
||||
|
||||
def check_response_contains_keyword(filename: str):
|
||||
"""
|
||||
Checks that the prompted response from the LLM contains a chosen keyword
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
|
||||
keyword = "banana"
|
||||
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
|
||||
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure the message was sent
|
||||
assert_invoked_send_message_with_keyword(response.messages, keyword)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_uses_external_tool(filename: str):
|
||||
"""
|
||||
Checks that the LLM will use external tools if instructed
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
tool = Tool.from_crewai(crewai_tool)
|
||||
tool_name = tool.name
|
||||
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
client.add_tool(tool)
|
||||
|
||||
# Set up persona for tool usage
|
||||
persona = f"""
|
||||
|
||||
My name is Letta.
|
||||
|
||||
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question.
|
||||
|
||||
Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
|
||||
"""
|
||||
|
||||
agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name])
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure the tool was called
|
||||
assert_invoked_function_call(response.messages, tool_name)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_recall_chat_memory(filename: str):
|
||||
"""
|
||||
Checks that the LLM will recall the chat memory, specifically the human persona.
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
human_name = "BananaBoy"
|
||||
agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure my name was repeated back to me
|
||||
assert_invoked_send_message_with_keyword(response.messages, human_name)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_archival_memory_retrieval(filename: str):
|
||||
"""
|
||||
Checks that the LLM will execute an archival memory retrieval.
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
secret_word = "banana"
|
||||
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure archival_memory_search was called
|
||||
assert_invoked_function_call(response.messages, "archival_memory_search")
|
||||
|
||||
# Make sure secret was repeated back to me
|
||||
assert_invoked_send_message_with_keyword(response.messages, secret_word)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def run_embedding_endpoint(filename):
|
||||
# load JSON file
|
||||
config_data = json.load(open(filename, "r"))
|
||||
print(config_data)
|
||||
embedding_config = EmbeddingConfig(**config_data)
|
||||
model = embedding_model(embedding_config)
|
||||
query_text = "hello"
|
||||
query_vec = model.get_text_embedding(query_text)
|
||||
print("vector dim", len(query_vec))
|
||||
assert query_vec is not None
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Section: Letta Message Assertions
|
||||
# These functions are validating elements of parsed Letta Messsage
|
||||
|
@ -1,212 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from letta import create_client
|
||||
from letta.agent import Agent
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.message import Message
|
||||
from tests.helpers.endpoints_helper import (
|
||||
agent_uuid,
|
||||
assert_contains_correct_inner_monologue,
|
||||
assert_contains_valid_function_call,
|
||||
assert_inner_monologue_is_present_and_valid,
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
assert_sanity_checks,
|
||||
setup_agent,
|
||||
check_agent_archival_memory_retrieval,
|
||||
check_agent_recall_chat_memory,
|
||||
check_agent_uses_external_tool,
|
||||
check_first_response_is_valid_for_llm_endpoint,
|
||||
check_response_contains_keyword,
|
||||
run_embedding_endpoint,
|
||||
)
|
||||
from tests.helpers.utils import cleanup
|
||||
|
||||
messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")]
|
||||
|
||||
# defaults (letta hosted)
|
||||
embedding_config_path = "configs/embedding_model_configs/letta-hosted.json"
|
||||
llm_config_path = "configs/llm_model_configs/letta-hosted.json"
|
||||
|
||||
# directories
|
||||
embedding_config_dir = "configs/embedding_model_configs"
|
||||
llm_config_dir = "configs/llm_model_configs"
|
||||
|
||||
|
||||
def check_first_response_is_valid_for_llm_endpoint(filename: str, inner_thoughts_in_kwargs: bool = False):
|
||||
"""
|
||||
Checks that the first response is valid:
|
||||
|
||||
1. Contains either send_message or archival_memory_search
|
||||
2. Contains valid usage of the function
|
||||
3. Contains inner monologue
|
||||
|
||||
Note: This is acting on the raw LLM response, note the usage of `create`
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
|
||||
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools]
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
tools=tools,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
|
||||
response = create(
|
||||
llm_config=agent_state.llm_config,
|
||||
user_id=str(uuid.UUID(int=1)), # dummy user_id
|
||||
messages=agent._messages,
|
||||
functions=agent.functions,
|
||||
functions_python=agent.functions_python,
|
||||
)
|
||||
|
||||
# Basic check
|
||||
assert response is not None
|
||||
|
||||
# Select first choice
|
||||
choice = response.choices[0]
|
||||
|
||||
# Ensure that the first message returns a "send_message"
|
||||
validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search"
|
||||
assert_contains_valid_function_call(choice.message, validator_func)
|
||||
|
||||
# Assert that the message has an inner monologue
|
||||
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
|
||||
|
||||
|
||||
def check_response_contains_keyword(filename: str):
|
||||
"""
|
||||
Checks that the prompted response from the LLM contains a chosen keyword
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
|
||||
keyword = "banana"
|
||||
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
|
||||
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure the message was sent
|
||||
assert_invoked_send_message_with_keyword(response.messages, keyword)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_uses_external_tool(filename: str):
|
||||
"""
|
||||
Checks that the LLM will use external tools if instructed
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
tool = Tool.from_crewai(crewai_tool)
|
||||
tool_name = tool.name
|
||||
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
client.add_tool(tool)
|
||||
|
||||
# Set up persona for tool usage
|
||||
persona = f"""
|
||||
|
||||
My name is Letta.
|
||||
|
||||
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question.
|
||||
|
||||
Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
|
||||
"""
|
||||
|
||||
agent_state = setup_agent(client, filename, embedding_config_path, memory_persona_str=persona, tools=[tool_name])
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure the tool was called
|
||||
assert_invoked_function_call(response.messages, tool_name)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_recall_chat_memory(filename: str):
|
||||
"""
|
||||
Checks that the LLM will recall the chat memory, specifically the human persona.
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
human_name = "BananaBoy"
|
||||
agent_state = setup_agent(client, filename, embedding_config_path, memory_human_str=f"My name is {human_name}")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure my name was repeated back to me
|
||||
assert_invoked_send_message_with_keyword(response.messages, human_name)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def check_agent_archival_memory_retrieval(filename: str):
|
||||
"""
|
||||
Checks that the LLM will execute an archival memory retrieval.
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
# Set up client
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename, embedding_config_path)
|
||||
secret_word = "banana"
|
||||
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Search archival memory for the secret word and repeat it back to me.")
|
||||
|
||||
# Basic checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure archival_memory_search was called
|
||||
assert_invoked_function_call(response.messages, "archival_memory_search")
|
||||
|
||||
# Make sure secret was repeated back to me
|
||||
assert_invoked_send_message_with_keyword(response.messages, secret_word)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
|
||||
|
||||
def run_embedding_endpoint(filename):
|
||||
# load JSON file
|
||||
config_data = json.load(open(filename, "r"))
|
||||
print(config_data)
|
||||
embedding_config = EmbeddingConfig(**config_data)
|
||||
model = embedding_model(embedding_config)
|
||||
query_text = "hello"
|
||||
query_vec = model.get_text_embedding(query_text)
|
||||
print("vector dim", len(query_vec))
|
||||
assert query_vec is not None
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# OPENAI TESTS
|
||||
# ======================================================================================================================
|
||||
|
Loading…
Reference in New Issue
Block a user