|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# TODO: Remove below import when minimum supported Python version is 3.10 |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from typing import TYPE_CHECKING, Dict, Optional |
| 19 | + |
| 20 | +import google.auth |
| 21 | +import google.auth.transport.requests |
| 22 | +import requests |
| 23 | +import sqlalchemy |
| 24 | +from google.cloud.sql.connector import Connector |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + import google.auth.credentials |
| 28 | + import pymysql |
| 29 | + |
| 30 | + |
| 31 | +def _get_iam_principal_email( |
| 32 | + credentials: google.auth.credentials.Credentials, |
| 33 | +) -> str: |
| 34 | + """Get email address associated with current authenticated IAM principal. |
| 35 | +
|
| 36 | + Email will be used for automatic IAM database authentication to Cloud SQL. |
| 37 | +
|
| 38 | + Args: |
| 39 | + credentials (google.auth.credentials.Credentials): |
| 40 | + The credentials object to use in finding the associated IAM |
| 41 | + principal email address. |
| 42 | +
|
| 43 | + Returns: |
| 44 | + email (str): |
| 45 | + The email address associated with the current authenticated IAM |
| 46 | + principal. |
| 47 | + """ |
| 48 | + # refresh credentials if they are not valid |
| 49 | + if not credentials.valid: |
| 50 | + request = google.auth.transport.requests.Request() |
| 51 | + credentials.refresh(request) |
| 52 | + # if credentials are associated with a service account email, return early |
| 53 | + if hasattr(credentials, "_service_account_email"): |
| 54 | + return credentials._service_account_email |
| 55 | + # call OAuth2 api to get IAM principal email associated with OAuth2 token |
| 56 | + url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" |
| 57 | + response = requests.get(url) |
| 58 | + response.raise_for_status() |
| 59 | + response_json: Dict = response.json() |
| 60 | + email = response_json.get("email") |
| 61 | + if email is None: |
| 62 | + raise ValueError( |
| 63 | + "Failed to automatically obtain authenticated IAM princpal's " |
| 64 | + "email address using environment's ADC credentials!" |
| 65 | + ) |
| 66 | + return email |
| 67 | + |
| 68 | + |
| 69 | +class MySQLEngine: |
| 70 | + """A class for managing connections to a Cloud SQL for MySQL database.""" |
| 71 | + |
| 72 | + _connector: Optional[Connector] = None |
| 73 | + |
| 74 | + def __init__( |
| 75 | + self, |
| 76 | + engine: sqlalchemy.engine.Engine, |
| 77 | + ) -> None: |
| 78 | + self.engine = engine |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def from_instance( |
| 82 | + cls, |
| 83 | + project_id: str, |
| 84 | + region: str, |
| 85 | + instance: str, |
| 86 | + database: str, |
| 87 | + ) -> MySQLEngine: |
| 88 | + """Create an instance of MySQLEngine from Cloud SQL instance |
| 89 | + details. |
| 90 | +
|
| 91 | + This method uses the Cloud SQL Python Connector to connect to Cloud SQL |
| 92 | + using automatic IAM database authentication with the Google ADC |
| 93 | + credentials sourced from the environment. |
| 94 | +
|
| 95 | + More details can be found at https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials |
| 96 | +
|
| 97 | + Args: |
| 98 | + project_id (str): Project ID of the Google Cloud Project where |
| 99 | + the Cloud SQL instance is located. |
| 100 | + region (str): Region where the Cloud SQL instance is located. |
| 101 | + instance (str): The name of the Cloud SQL instance. |
| 102 | + database (str): The name of the database to connect to on the |
| 103 | + Cloud SQL instance. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + (MySQLEngine): The engine configured to connect to a |
| 107 | + Cloud SQL instance database. |
| 108 | + """ |
| 109 | + engine = cls._create_connector_engine( |
| 110 | + instance_connection_name=f"{project_id}:{region}:{instance}", |
| 111 | + database=database, |
| 112 | + ) |
| 113 | + return cls(engine=engine) |
| 114 | + |
| 115 | + @classmethod |
| 116 | + def _create_connector_engine( |
| 117 | + cls, instance_connection_name: str, database: str |
| 118 | + ) -> sqlalchemy.engine.Engine: |
| 119 | + """Create a SQLAlchemy engine using the Cloud SQL Python Connector. |
| 120 | +
|
| 121 | + Defaults to use "pymysql" driver and to connect using automatic IAM |
| 122 | + database authentication with the IAM principal associated with the |
| 123 | + environment's Google Application Default Credentials. |
| 124 | +
|
| 125 | + Args: |
| 126 | + instance_connection_name (str): The instance connection |
| 127 | + name of the Cloud SQL instance to establish a connection to. |
| 128 | + (ex. "project-id:instance-region:instance-name") |
| 129 | + database (str): The name of the database to connect to on the |
| 130 | + Cloud SQL instance. |
| 131 | + Returns: |
| 132 | + (sqlalchemy.engine.Engine): Engine configured using the Cloud SQL |
| 133 | + Python Connector. |
| 134 | + """ |
| 135 | + # get application default credentials |
| 136 | + credentials, _ = google.auth.default( |
| 137 | + scopes=["https://www.googleapis.com/auth/userinfo.email"] |
| 138 | + ) |
| 139 | + iam_database_user = _get_iam_principal_email(credentials) |
| 140 | + if cls._connector is None: |
| 141 | + cls._connector = Connector() |
| 142 | + |
| 143 | + # anonymous function to be used for SQLAlchemy 'creator' argument |
| 144 | + def getconn() -> pymysql.Connection: |
| 145 | + conn = cls._connector.connect( # type: ignore |
| 146 | + instance_connection_name, |
| 147 | + "pymysql", |
| 148 | + user=iam_database_user, |
| 149 | + db=database, |
| 150 | + enable_iam_auth=True, |
| 151 | + ) |
| 152 | + return conn |
| 153 | + |
| 154 | + return sqlalchemy.create_engine( |
| 155 | + "mysql+pymysql://", |
| 156 | + creator=getconn, |
| 157 | + ) |
| 158 | + |
| 159 | + def connect(self) -> sqlalchemy.engine.Connection: |
| 160 | + """Create a connection from SQLAlchemy connection pool. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + (sqlalchemy.engine.Connection): a single DBAPI connection checked |
| 164 | + out from the connection pool. |
| 165 | + """ |
| 166 | + return self.engine.connect() |
0 commit comments