Skip to content

Commit b107a43

Browse files
feat: add MySQLChatMessageHistory class (#13)
1 parent 6c8af85 commit b107a43

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

src/langchain_google_cloud_sql_mysql/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from langchain_google_cloud_sql_mysql.mysql_chat_message_history import (
16+
MySQLChatMessageHistory,
17+
)
1518
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
1619
from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader
1720

18-
__all__ = ["MySQLEngine", "MySQLLoader"]
21+
__all__ = ["MySQLChatMessageHistory", "MySQLEngine", "MySQLLoader"]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
from typing import List
16+
17+
import sqlalchemy
18+
from langchain_core.chat_history import BaseChatMessageHistory
19+
from langchain_core.messages import BaseMessage, messages_from_dict
20+
21+
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
22+
23+
24+
class MySQLChatMessageHistory(BaseChatMessageHistory):
25+
"""Chat message history stored in a Cloud SQL MySQL database."""
26+
27+
def __init__(
28+
self,
29+
engine: MySQLEngine,
30+
session_id: str,
31+
table_name: str = "message_store",
32+
) -> None:
33+
self.engine = engine
34+
self.session_id = session_id
35+
self.table_name = table_name
36+
self._create_table_if_not_exists()
37+
38+
def _create_table_if_not_exists(self) -> None:
39+
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
40+
id INT AUTO_INCREMENT PRIMARY KEY,
41+
session_id TEXT NOT NULL,
42+
data JSON NOT NULL,
43+
type TEXT NOT NULL
44+
);"""
45+
46+
with self.engine.connect() as conn:
47+
conn.execute(sqlalchemy.text(create_table_query))
48+
conn.commit()
49+
50+
@property
51+
def messages(self) -> List[BaseMessage]: # type: ignore
52+
"""Retrieve the messages from Cloud SQL"""
53+
query = f"SELECT data, type FROM {self.table_name} WHERE session_id = '{self.session_id}' ORDER BY id;"
54+
with self.engine.connect() as conn:
55+
results = conn.execute(sqlalchemy.text(query)).fetchall()
56+
# load SQLAlchemy row objects into dicts
57+
items = [
58+
{"data": json.loads(result[0]), "type": result[1]} for result in results
59+
]
60+
messages = messages_from_dict(items)
61+
return messages
62+
63+
def add_message(self, message: BaseMessage) -> None:
64+
"""Append the message to the record in Cloud SQL"""
65+
query = f"INSERT INTO {self.table_name} (session_id, data, type) VALUES (:session_id, :data, :type);"
66+
with self.engine.connect() as conn:
67+
conn.execute(
68+
sqlalchemy.text(query),
69+
{
70+
"session_id": self.session_id,
71+
"data": json.dumps(message.dict()),
72+
"type": message.type,
73+
},
74+
)
75+
conn.commit()
76+
77+
def clear(self) -> None:
78+
"""Clear session memory from Cloud SQL"""
79+
query = f"DELETE FROM {self.table_name} WHERE session_id = :session_id;"
80+
with self.engine.connect() as conn:
81+
conn.execute(sqlalchemy.text(query), {"session_id": self.session_id})
82+
conn.commit()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from typing import Generator
16+
17+
import pytest
18+
import sqlalchemy
19+
from langchain_core.messages.ai import AIMessage
20+
from langchain_core.messages.human import HumanMessage
21+
22+
from langchain_google_cloud_sql_mysql import MySQLChatMessageHistory, MySQLEngine
23+
24+
project_id = os.environ["PROJECT_ID"]
25+
region = os.environ["REGION"]
26+
instance_id = os.environ["INSTANCE_ID"]
27+
db_name = os.environ["DB_NAME"]
28+
29+
30+
@pytest.fixture(name="memory_engine")
31+
def setup() -> Generator:
32+
engine = MySQLEngine.from_instance(
33+
project_id=project_id, region=region, instance=instance_id, database=db_name
34+
)
35+
36+
yield engine
37+
# use default table for MySQLChatMessageHistory
38+
table_name = "message_store"
39+
with engine.connect() as conn:
40+
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`"))
41+
conn.commit()
42+
43+
44+
def test_chat_message_history(memory_engine: MySQLEngine) -> None:
45+
history = MySQLChatMessageHistory(engine=memory_engine, session_id="test")
46+
history.add_user_message("hi!")
47+
history.add_ai_message("whats up?")
48+
messages = history.messages
49+
50+
# verify messages are correct
51+
assert messages[0].content == "hi!"
52+
assert type(messages[0]) is HumanMessage
53+
assert messages[1].content == "whats up?"
54+
assert type(messages[1]) is AIMessage
55+
56+
# verify clear() clears message history
57+
history.clear()
58+
assert len(history.messages) == 0

0 commit comments

Comments
 (0)