Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion google/cloud/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
except ImportError:
import collections as collections_abc
import json
import os
import platform
import warnings

Expand Down Expand Up @@ -176,12 +177,56 @@ class JSONConnection(Connection):
API_BASE_URL = None
"""The base of the API call URL."""

API_BASE_MTLS_URL = None
"""The base of the API call URL for mutual TLS."""

ALLOW_AUTO_SWITCH_TO_MTLS_URL = False
"""Indicates if auto switch to mTLS url is allowed."""

API_VERSION = None
"""The version of the API, used in building the API call's URL."""

API_URL_TEMPLATE = None
"""A template for the URL of a particular API call."""

def get_api_base_url_for_mtls(self, api_base_url=None):
"""Return the api base url for mutual TLS.

Typically, you shouldn't need to use this method.

The logic is as follows:

If `api_base_url` is provided, just return this value; otherwise, the
return value depends `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable
value.

If the environment variable value is "always", return `API_BASE_MTLS_URL`.
If the environment variable value is "never", return `API_BASE_URL`.
Otherwise, if `ALLOW_AUTO_SWITCH_TO_MTLS_URL` is True and the underlying
http is mTLS, then return `API_BASE_MTLS_URL`; otherwise return `API_BASE_URL`.

:type api_base_url: str
:param api_base_url: User provided api base url. It takes precedence over
`API_BASE_URL` and `API_BASE_MTLS_URL`.

:rtype: str
:returns: The api base url used for mTLS.
"""
if api_base_url:
return api_base_url

env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
if env == "always":
url_to_use = self.API_BASE_MTLS_URL
elif env == "never":
url_to_use = self.API_BASE_URL
else:
if self.ALLOW_AUTO_SWITCH_TO_MTLS_URL:
url_to_use = self.API_BASE_MTLS_URL if self.http.is_mtls else self.API_BASE_URL
else:
url_to_use = self.API_BASE_URL
return url_to_use

def build_api_url(
self, path, query_params=None, api_base_url=None, api_version=None
):
Expand Down Expand Up @@ -210,7 +255,7 @@ def build_api_url(
:returns: The URL assembled from the pieces provided.
"""
url = self.API_URL_TEMPLATE.format(
api_base_url=(api_base_url or self.API_BASE_URL),
api_base_url=self.get_api_base_url_for_mtls(api_base_url),
api_version=(api_version or self.API_VERSION),
path=path,
)
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self, credentials=None, _http=None, client_options=None):
self._credentials = self._credentials.with_quota_project(client_options.quota_project_id)

self._http_internal = _http
self._client_cert_source = client_options.client_cert_source

def __getstate__(self):
"""Explicitly state that clients are not pickleable."""
Expand All @@ -183,6 +184,7 @@ def _http(self):
self._credentials,
refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT,
)
self._http_internal.configure_mtls_channel(self._client_cert_source)
return self._http_internal


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"google-api-core >= 1.21.0, < 2.0.0dev",
"google-auth >= 1.24.0, < 2.0dev",
# Support six==1.12.0 due to App Engine standard runtime.
# https://github.com/googleapis/python-cloud-core/issues/45
"six >=1.12.0",
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import os
import unittest
import warnings

Expand Down Expand Up @@ -165,6 +166,7 @@ def _make_mock_one(self, *args, **kw):
class MockConnection(self._get_target_class()):
API_URL_TEMPLATE = "{api_base_url}/mock/{api_version}{path}"
API_BASE_URL = "http://mock"
API_BASE_MTLS_URL = "https://mock.mtls"
API_VERSION = "vMOCK"

return MockConnection(*args, **kw)
Expand Down Expand Up @@ -230,6 +232,50 @@ def test_build_api_url_w_extra_query_params_tuples(self):
self.assertEqual(parms["qux"], ["quux", "corge"])
self.assertEqual(parms["prettyPrint"], ["false"])

def test_get_api_base_url_for_mtls_w_api_base_url(self):
client = object()
conn = self._make_mock_one(client)
uri = conn.get_api_base_url_for_mtls(api_base_url="http://foo")
self.assertEqual(uri, "http://foo")

def test_get_api_base_url_for_mtls_env_always(self):
client = object()
conn = self._make_mock_one(client)
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "https://mock.mtls")

def test_get_api_base_url_for_mtls_env_never(self):
client = object()
conn = self._make_mock_one(client)
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

def test_get_api_base_url_for_mtls_env_auto(self):
client = mock.Mock()
client._http = mock.Mock()
client._http.is_mtls = False
conn = self._make_mock_one(client)

# ALLOW_AUTO_SWITCH_TO_MTLS_URL is False, so use regular endpoint.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

# ALLOW_AUTO_SWITCH_TO_MTLS_URL is True, so now endpoint dependes
# on client._http.is_mtls
conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL = True

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

client._http.is_mtls = True
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "https://mock.mtls")

def test__make_request_no_data_no_content_type_no_headers(self):
from google.cloud._http import CLIENT_INFO_HEADER

Expand Down
20 changes: 11 additions & 9 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,22 @@ def test_ctor__http_property_new(self):
from google.cloud.client import _CREDENTIALS_REFRESH_TIMEOUT

credentials = _make_credentials()
client = self._make_one(credentials=credentials)
mock_client_cert_source = mock.Mock()
client_options = {'client_cert_source': mock_client_cert_source}
client = self._make_one(credentials=credentials, client_options=client_options)
self.assertIsNone(client._http_internal)

authorized_session_patch = mock.patch(
"google.auth.transport.requests.AuthorizedSession",
return_value=mock.sentinel.http,
)
with authorized_session_patch as AuthorizedSession:
self.assertIs(client._http, mock.sentinel.http)
with mock.patch('google.auth.transport.requests.AuthorizedSession') as AuthorizedSession:
session = mock.Mock()
session.configure_mtls_channel = mock.Mock()
AuthorizedSession.return_value = session
self.assertIs(client._http, session)
# Check the mock.
AuthorizedSession.assert_called_once_with(credentials, refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT)
session.configure_mtls_channel.assert_called_once_with(mock_client_cert_source)
# Make sure the cached value is used on subsequent access.
self.assertIs(client._http_internal, mock.sentinel.http)
self.assertIs(client._http, mock.sentinel.http)
self.assertIs(client._http_internal, session)
self.assertIs(client._http, session)
self.assertEqual(AuthorizedSession.call_count, 1)

def test_from_service_account_json(self):
Expand Down