Skip to content

Commit 6c8af85

Browse files
feat: add MySQLEngine and Loader load functionality (#9)
1 parent 1c4f5a8 commit 6c8af85

File tree

6 files changed

+553
-0
lines changed

6 files changed

+553
-0
lines changed

integration.cloudbuild.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ steps:
2222
name: python:3.11
2323
entrypoint: python
2424
args: ["-m", "pytest"]
25+
env:
26+
- 'PROJECT_ID=$PROJECT_ID'
27+
- 'INSTANCE_ID=$_INSTANCE_ID'
28+
- 'DB_NAME=$_DB_NAME'
29+
- 'TABLE_NAME=test-$BUILD_ID'
30+
- 'REGION=$_REGION'
31+
32+
substitutions:
33+
_INSTANCE_ID: test-instance
34+
_REGION: us-central1
35+
_DB_NAME: test

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ warn_unused_configs = true
4040
exclude = [
4141
"owlbot.py"
4242
]
43+
44+
[[tool.mypy.overrides]]
45+
module="google.auth.*"
46+
ignore_missing_imports = true

src/langchain_google_cloud_sql_mysql/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
16+
from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader
17+
18+
__all__ = ["MySQLEngine", "MySQLLoader"]
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# TODO: Remove below import when minimum supported Python version is 3.10
16+
from __future__ import annotations
17+
18+
from typing import TYPE_CHECKING, Dict, Optional
19+
20+
import google.auth
21+
import google.auth.transport.requests
22+
import requests
23+
import sqlalchemy
24+
from google.cloud.sql.connector import Connector
25+
26+
if TYPE_CHECKING:
27+
import google.auth.credentials
28+
import pymysql
29+
30+
31+
def _get_iam_principal_email(
32+
credentials: google.auth.credentials.Credentials,
33+
) -> str:
34+
"""Get email address associated with current authenticated IAM principal.
35+
36+
Email will be used for automatic IAM database authentication to Cloud SQL.
37+
38+
Args:
39+
credentials (google.auth.credentials.Credentials):
40+
The credentials object to use in finding the associated IAM
41+
principal email address.
42+
43+
Returns:
44+
email (str):
45+
The email address associated with the current authenticated IAM
46+
principal.
47+
"""
48+
# refresh credentials if they are not valid
49+
if not credentials.valid:
50+
request = google.auth.transport.requests.Request()
51+
credentials.refresh(request)
52+
# if credentials are associated with a service account email, return early
53+
if hasattr(credentials, "_service_account_email"):
54+
return credentials._service_account_email
55+
# call OAuth2 api to get IAM principal email associated with OAuth2 token
56+
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
57+
response = requests.get(url)
58+
response.raise_for_status()
59+
response_json: Dict = response.json()
60+
email = response_json.get("email")
61+
if email is None:
62+
raise ValueError(
63+
"Failed to automatically obtain authenticated IAM princpal's "
64+
"email address using environment's ADC credentials!"
65+
)
66+
return email
67+
68+
69+
class MySQLEngine:
70+
"""A class for managing connections to a Cloud SQL for MySQL database."""
71+
72+
_connector: Optional[Connector] = None
73+
74+
def __init__(
75+
self,
76+
engine: sqlalchemy.engine.Engine,
77+
) -> None:
78+
self.engine = engine
79+
80+
@classmethod
81+
def from_instance(
82+
cls,
83+
project_id: str,
84+
region: str,
85+
instance: str,
86+
database: str,
87+
) -> MySQLEngine:
88+
"""Create an instance of MySQLEngine from Cloud SQL instance
89+
details.
90+
91+
This method uses the Cloud SQL Python Connector to connect to Cloud SQL
92+
using automatic IAM database authentication with the Google ADC
93+
credentials sourced from the environment.
94+
95+
More details can be found at https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials
96+
97+
Args:
98+
project_id (str): Project ID of the Google Cloud Project where
99+
the Cloud SQL instance is located.
100+
region (str): Region where the Cloud SQL instance is located.
101+
instance (str): The name of the Cloud SQL instance.
102+
database (str): The name of the database to connect to on the
103+
Cloud SQL instance.
104+
105+
Returns:
106+
(MySQLEngine): The engine configured to connect to a
107+
Cloud SQL instance database.
108+
"""
109+
engine = cls._create_connector_engine(
110+
instance_connection_name=f"{project_id}:{region}:{instance}",
111+
database=database,
112+
)
113+
return cls(engine=engine)
114+
115+
@classmethod
116+
def _create_connector_engine(
117+
cls, instance_connection_name: str, database: str
118+
) -> sqlalchemy.engine.Engine:
119+
"""Create a SQLAlchemy engine using the Cloud SQL Python Connector.
120+
121+
Defaults to use "pymysql" driver and to connect using automatic IAM
122+
database authentication with the IAM principal associated with the
123+
environment's Google Application Default Credentials.
124+
125+
Args:
126+
instance_connection_name (str): The instance connection
127+
name of the Cloud SQL instance to establish a connection to.
128+
(ex. "project-id:instance-region:instance-name")
129+
database (str): The name of the database to connect to on the
130+
Cloud SQL instance.
131+
Returns:
132+
(sqlalchemy.engine.Engine): Engine configured using the Cloud SQL
133+
Python Connector.
134+
"""
135+
# get application default credentials
136+
credentials, _ = google.auth.default(
137+
scopes=["https://www.googleapis.com/auth/userinfo.email"]
138+
)
139+
iam_database_user = _get_iam_principal_email(credentials)
140+
if cls._connector is None:
141+
cls._connector = Connector()
142+
143+
# anonymous function to be used for SQLAlchemy 'creator' argument
144+
def getconn() -> pymysql.Connection:
145+
conn = cls._connector.connect( # type: ignore
146+
instance_connection_name,
147+
"pymysql",
148+
user=iam_database_user,
149+
db=database,
150+
enable_iam_auth=True,
151+
)
152+
return conn
153+
154+
return sqlalchemy.create_engine(
155+
"mysql+pymysql://",
156+
creator=getconn,
157+
)
158+
159+
def connect(self) -> sqlalchemy.engine.Connection:
160+
"""Create a connection from SQLAlchemy connection pool.
161+
162+
Returns:
163+
(sqlalchemy.engine.Connection): a single DBAPI connection checked
164+
out from the connection pool.
165+
"""
166+
return self.engine.connect()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
from collections.abc import Iterable
16+
from typing import Any, Dict, List, Optional, Sequence, cast
17+
18+
import sqlalchemy
19+
from langchain_community.document_loaders.base import BaseLoader
20+
from langchain_core.documents import Document
21+
22+
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
23+
24+
DEFAULT_METADATA_COL = "langchain_metadata"
25+
26+
27+
def _parse_doc_from_table(
28+
content_columns: Iterable[str],
29+
metadata_columns: Iterable[str],
30+
column_names: Iterable[str],
31+
rows: Sequence[Any],
32+
) -> List[Document]:
33+
docs = []
34+
for row in rows:
35+
page_content = " ".join(
36+
str(getattr(row, column))
37+
for column in content_columns
38+
if column in column_names
39+
)
40+
metadata = {
41+
column: getattr(row, column)
42+
for column in metadata_columns
43+
if column in column_names
44+
}
45+
if DEFAULT_METADATA_COL in metadata:
46+
extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL])
47+
del metadata[DEFAULT_METADATA_COL]
48+
metadata |= extra_metadata
49+
doc = Document(page_content=page_content, metadata=metadata)
50+
docs.append(doc)
51+
return docs
52+
53+
54+
class MySQLLoader(BaseLoader):
55+
"""A class for loading langchain documents from a Cloud SQL MySQL database."""
56+
57+
def __init__(
58+
self,
59+
engine: MySQLEngine,
60+
query: str,
61+
content_columns: Optional[List[str]] = None,
62+
metadata_columns: Optional[List[str]] = None,
63+
):
64+
"""
65+
Args:
66+
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
67+
query (str): The query to execute in MySQL format.
68+
content_columns (List[str]): The columns to write into the `page_content`
69+
of the document. Optional.
70+
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
71+
Optional.
72+
"""
73+
self.engine = engine
74+
self.query = query
75+
self.content_columns = content_columns
76+
self.metadata_columns = metadata_columns
77+
78+
def load(self) -> List[Document]:
79+
"""
80+
Load langchain documents from a Cloud SQL MySQL database.
81+
82+
Document page content defaults to the first columns present in the query or table and
83+
metadata defaults to all other columns. Use with content_columns to overwrite the column
84+
used for page content. Use metadata_columns to select specific metadata columns rather
85+
than using all remaining columns.
86+
87+
If multiple content columns are specified, page_content’s string format will default to
88+
space-separated string concatenation.
89+
90+
Returns:
91+
(List[langchain_core.documents.Document]): a list of Documents with metadata from
92+
specific columns.
93+
"""
94+
with self.engine.connect() as connection:
95+
result_proxy = connection.execute(sqlalchemy.text(self.query))
96+
column_names = list(result_proxy.keys())
97+
results = result_proxy.fetchall()
98+
content_columns = self.content_columns or [column_names[0]]
99+
metadata_columns = self.metadata_columns or [
100+
col for col in column_names if col not in content_columns
101+
]
102+
return _parse_doc_from_table(
103+
content_columns,
104+
metadata_columns,
105+
column_names,
106+
results,
107+
)

0 commit comments

Comments
 (0)