diff --git a/google/cloud/_http.py b/google/cloud/_http.py index fbc228e..719e986 100644 --- a/google/cloud/_http.py +++ b/google/cloud/_http.py @@ -20,6 +20,7 @@ except ImportError: import collections as collections_abc import json +import os import platform import warnings @@ -176,12 +177,56 @@ class JSONConnection(Connection): API_BASE_URL = None """The base of the API call URL.""" + API_BASE_MTLS_URL = None + """The base of the API call URL for mutual TLS.""" + + ALLOW_AUTO_SWITCH_TO_MTLS_URL = False + """Indicates if auto switch to mTLS url is allowed.""" + API_VERSION = None """The version of the API, used in building the API call's URL.""" API_URL_TEMPLATE = None """A template for the URL of a particular API call.""" + def get_api_base_url_for_mtls(self, api_base_url=None): + """Return the api base url for mutual TLS. + + Typically, you shouldn't need to use this method. + + The logic is as follows: + + If `api_base_url` is provided, just return this value; otherwise, the + return value depends `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable + value. + + If the environment variable value is "always", return `API_BASE_MTLS_URL`. + If the environment variable value is "never", return `API_BASE_URL`. + Otherwise, if `ALLOW_AUTO_SWITCH_TO_MTLS_URL` is True and the underlying + http is mTLS, then return `API_BASE_MTLS_URL`; otherwise return `API_BASE_URL`. + + :type api_base_url: str + :param api_base_url: User provided api base url. It takes precedence over + `API_BASE_URL` and `API_BASE_MTLS_URL`. + + :rtype: str + :returns: The api base url used for mTLS. + """ + if api_base_url: + return api_base_url + + env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if env == "always": + url_to_use = self.API_BASE_MTLS_URL + elif env == "never": + url_to_use = self.API_BASE_URL + else: + if self.ALLOW_AUTO_SWITCH_TO_MTLS_URL: + url_to_use = self.API_BASE_MTLS_URL if self.http.is_mtls else self.API_BASE_URL + else: + url_to_use = self.API_BASE_URL + return url_to_use + def build_api_url( self, path, query_params=None, api_base_url=None, api_version=None ): @@ -210,7 +255,7 @@ def build_api_url( :returns: The URL assembled from the pieces provided. """ url = self.API_URL_TEMPLATE.format( - api_base_url=(api_base_url or self.API_BASE_URL), + api_base_url=self.get_api_base_url_for_mtls(api_base_url), api_version=(api_version or self.API_VERSION), path=path, ) diff --git a/google/cloud/client.py b/google/cloud/client.py index a477717..fdb81a5 100644 --- a/google/cloud/client.py +++ b/google/cloud/client.py @@ -159,6 +159,7 @@ def __init__(self, credentials=None, _http=None, client_options=None): self._credentials = self._credentials.with_quota_project(client_options.quota_project_id) self._http_internal = _http + self._client_cert_source = client_options.client_cert_source def __getstate__(self): """Explicitly state that clients are not pickleable.""" @@ -183,6 +184,7 @@ def _http(self): self._credentials, refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT, ) + self._http_internal.configure_mtls_channel(self._client_cert_source) return self._http_internal diff --git a/setup.py b/setup.py index 6b9530d..66e115c 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ release_status = "Development Status :: 5 - Production/Stable" dependencies = [ "google-api-core >= 1.21.0, < 2.0.0dev", + "google-auth >= 1.24.0, < 2.0dev", # Support six==1.12.0 due to App Engine standard runtime. # https://github.com/googleapis/python-cloud-core/issues/45 "six >=1.12.0", diff --git a/tests/unit/test__http.py b/tests/unit/test__http.py index 069ddc0..32a4965 100644 --- a/tests/unit/test__http.py +++ b/tests/unit/test__http.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import os import unittest import warnings @@ -165,6 +166,7 @@ def _make_mock_one(self, *args, **kw): class MockConnection(self._get_target_class()): API_URL_TEMPLATE = "{api_base_url}/mock/{api_version}{path}" API_BASE_URL = "http://mock" + API_BASE_MTLS_URL = "https://mock.mtls" API_VERSION = "vMOCK" return MockConnection(*args, **kw) @@ -230,6 +232,50 @@ def test_build_api_url_w_extra_query_params_tuples(self): self.assertEqual(parms["qux"], ["quux", "corge"]) self.assertEqual(parms["prettyPrint"], ["false"]) + def test_get_api_base_url_for_mtls_w_api_base_url(self): + client = object() + conn = self._make_mock_one(client) + uri = conn.get_api_base_url_for_mtls(api_base_url="http://foo") + self.assertEqual(uri, "http://foo") + + def test_get_api_base_url_for_mtls_env_always(self): + client = object() + conn = self._make_mock_one(client) + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + uri = conn.get_api_base_url_for_mtls() + self.assertEqual(uri, "https://mock.mtls") + + def test_get_api_base_url_for_mtls_env_never(self): + client = object() + conn = self._make_mock_one(client) + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + uri = conn.get_api_base_url_for_mtls() + self.assertEqual(uri, "http://mock") + + def test_get_api_base_url_for_mtls_env_auto(self): + client = mock.Mock() + client._http = mock.Mock() + client._http.is_mtls = False + conn = self._make_mock_one(client) + + # ALLOW_AUTO_SWITCH_TO_MTLS_URL is False, so use regular endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + uri = conn.get_api_base_url_for_mtls() + self.assertEqual(uri, "http://mock") + + # ALLOW_AUTO_SWITCH_TO_MTLS_URL is True, so now endpoint dependes + # on client._http.is_mtls + conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL = True + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + uri = conn.get_api_base_url_for_mtls() + self.assertEqual(uri, "http://mock") + + client._http.is_mtls = True + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + uri = conn.get_api_base_url_for_mtls() + self.assertEqual(uri, "https://mock.mtls") + def test__make_request_no_data_no_content_type_no_headers(self): from google.cloud._http import CLIENT_INFO_HEADER diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 48e0144..8137826 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -125,20 +125,22 @@ def test_ctor__http_property_new(self): from google.cloud.client import _CREDENTIALS_REFRESH_TIMEOUT credentials = _make_credentials() - client = self._make_one(credentials=credentials) + mock_client_cert_source = mock.Mock() + client_options = {'client_cert_source': mock_client_cert_source} + client = self._make_one(credentials=credentials, client_options=client_options) self.assertIsNone(client._http_internal) - authorized_session_patch = mock.patch( - "google.auth.transport.requests.AuthorizedSession", - return_value=mock.sentinel.http, - ) - with authorized_session_patch as AuthorizedSession: - self.assertIs(client._http, mock.sentinel.http) + with mock.patch('google.auth.transport.requests.AuthorizedSession') as AuthorizedSession: + session = mock.Mock() + session.configure_mtls_channel = mock.Mock() + AuthorizedSession.return_value = session + self.assertIs(client._http, session) # Check the mock. AuthorizedSession.assert_called_once_with(credentials, refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT) + session.configure_mtls_channel.assert_called_once_with(mock_client_cert_source) # Make sure the cached value is used on subsequent access. - self.assertIs(client._http_internal, mock.sentinel.http) - self.assertIs(client._http, mock.sentinel.http) + self.assertIs(client._http_internal, session) + self.assertIs(client._http, session) self.assertEqual(AuthorizedSession.call_count, 1) def test_from_service_account_json(self):