fix: tests should only use "openai" quickstart if `OPENAI_API_KEY is set (#801)

This commit is contained in:
Sarah Wooders 2024-01-09 19:13:16 -08:00 committed by GitHub
parent c18e10c067
commit 03f768868e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from memgpt import MemGPT
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
import os
from .utils import wipe_config
@ -14,7 +14,10 @@ def create_test_agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
client = MemGPT(quickstart="openai")
if os.getenv("OPENAI_API_KEY"):
client = MemGPT(quickstart="openai")
else:
client = MemGPT(quickstart="memgpt_hosted")
agent_state = client.create_agent(
agent_config={

View File

@ -16,7 +16,10 @@ from memgpt import MemGPT
def test_save_load():
# configure_memgpt() # rely on configure running first^
client = MemGPT(quickstart="openai")
if os.getenv("OPENAI_API_KEY"):
client = MemGPT(quickstart="openai")
else:
client = MemGPT(quickstart="memgpt_hosted")
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip-ui")

View File

@ -1,7 +1,7 @@
from memgpt import MemGPT
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig
import os
from .utils import wipe_config
@ -15,7 +15,10 @@ test_agent_state_post_message = None
def test_create_agent():
wipe_config()
global client
client = MemGPT(quickstart="openai")
if os.getenv("OPENAI_API_KEY"):
client = MemGPT(quickstart="openai")
else:
client = MemGPT(quickstart="memgpt_hosted")
global test_agent_state
test_agent_state = client.create_agent(
@ -51,7 +54,10 @@ def test_save_load():
# Create a new client (not thread safe), and load the same agent
# The agent state inside should correspond to the initial state pre-message
client2 = MemGPT(quickstart="openai")
if os.getenv("OPENAI_API_KEY"):
client2 = MemGPT(quickstart="openai")
else:
client2 = MemGPT(quickstart="memgpt_hosted")
client2_agent_obj = client2.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id)
client2_agent_state = client2_agent_obj.to_agent_state()
@ -71,7 +77,10 @@ def test_save_load():
# This should persist the test message into the agent state
client.save()
client3 = MemGPT(quickstart="openai")
if os.getenv("OPENAI_API_KEY"):
client3 = MemGPT(quickstart="openai")
else:
client3 = MemGPT(quickstart="memgpt_hosted")
client3_agent_obj = client3.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id)
client3_agent_state = client3_agent_obj.to_agent_state()