diff --git a/spanner/google/cloud/spanner_v1/pool.py b/spanner/google/cloud/spanner_v1/pool.py index 823681fbc864..4ef5aee9baab 100644 --- a/spanner/google/cloud/spanner_v1/pool.py +++ b/spanner/google/cloud/spanner_v1/pool.py @@ -17,9 +17,9 @@ import datetime from six.moves import queue -from six.moves import xrange from google.cloud.exceptions import NotFound +from google.cloud.spanner_v1._helpers import _metadata_with_prefix _NOW = datetime.datetime.utcnow # unit tests may replace @@ -166,11 +166,20 @@ def bind(self, database): when needed. """ self._database = database + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) while not self._sessions.full(): - session = self._new_session() - session.create() - self._sessions.put(session) + resp = api.batch_create_sessions( + database.name, + self.size - self._sessions.qsize(), + timeout=self.default_timeout, + metadata=metadata, + ) + for session_pb in resp.session: + session = self._new_session() + session._session_id = session_pb.name.split("/")[-1] + self._sessions.put(session) def get(self, timeout=None): # pylint: disable=arguments-differ """Check a session out from the pool. @@ -350,11 +359,22 @@ def bind(self, database): when needed. """ self._database = database - - for _ in xrange(self.size): - session = self._new_session() - session.create() - self.put(session) + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + created_session_count = 0 + + while created_session_count < self.size: + resp = api.batch_create_sessions( + database.name, + self.size - created_session_count, + timeout=self.default_timeout, + metadata=metadata, + ) + for session_pb in resp.session: + session = self._new_session() + session._session_id = session_pb.name.split("/")[-1] + self.put(session) + created_session_count += len(resp.session) def get(self, timeout=None): # pylint: disable=arguments-differ """Check a session out from the pool. diff --git a/spanner/tests/unit/test_pool.py b/spanner/tests/unit/test_pool.py index 549044b1f423..eded02ea4e6d 100644 --- a/spanner/tests/unit/test_pool.py +++ b/spanner/tests/unit/test_pool.py @@ -156,8 +156,10 @@ def test_bind(self): self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() def test_get_non_expired(self): pool = self._make_one(size=4) @@ -183,7 +185,7 @@ def test_get_expired(self): session = pool.get() self.assertIs(session, SESSIONS[4]) - self.assertTrue(session._created) + session.create.assert_called() self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) @@ -243,8 +245,10 @@ def test_clear(self): pool.bind(database) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() pool.clear() @@ -286,7 +290,7 @@ def test_get_empty(self): self.assertIsInstance(session, _Session) self.assertIs(session._database, database) - self.assertTrue(session._created) + session.create.assert_called() self.assertTrue(pool._sessions.empty()) def test_get_non_empty_session_exists(self): @@ -299,7 +303,7 @@ def test_get_non_empty_session_exists(self): session = pool.get() self.assertIs(session, previous) - self.assertFalse(session._created) + session.create.assert_not_called() self.assertTrue(session._exists_checked) self.assertTrue(pool._sessions.empty()) @@ -316,7 +320,7 @@ def test_get_non_empty_session_expired(self): self.assertTrue(previous._exists_checked) self.assertIs(session, newborn) - self.assertTrue(session._created) + session.create.assert_called() self.assertFalse(session._exists_checked) self.assertTrue(pool._sessions.empty()) @@ -405,7 +409,6 @@ def test_bind(self): database = _Database("name") SESSIONS = [_Session(database)] * 10 database._sessions.extend(SESSIONS) - pool.bind(database) self.assertIs(pool._database, database) @@ -414,8 +417,10 @@ def test_bind(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() def test_get_hit_no_ping(self): pool = self._make_one(size=4) @@ -470,7 +475,7 @@ def test_get_hit_w_ping_expired(self): session = pool.get() self.assertIs(session, SESSIONS[4]) - self.assertTrue(session._created) + session.create.assert_called() self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) @@ -538,8 +543,10 @@ def test_clear(self): pool.bind(database) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() pool.clear() @@ -595,7 +602,7 @@ def test_ping_oldest_stale_and_not_exists(self): pool.ping() self.assertTrue(SESSIONS[0]._exists_checked) - self.assertTrue(SESSIONS[1]._created) + SESSIONS[1].create.assert_called() class TestTransactionPingingPool(unittest.TestCase): @@ -635,7 +642,6 @@ def test_bind(self): database = _Database("name") SESSIONS = [_Session(database) for _ in range(10)] database._sessions.extend(SESSIONS) - pool.bind(database) self.assertIs(pool._database, database) @@ -644,8 +650,10 @@ def test_bind(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() txn = session._transaction self.assertTrue(txn._begun) @@ -671,8 +679,10 @@ def test_bind_w_timestamp_race(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.full()) + api = database.spanner_api + self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: - self.assertTrue(session._created) + session.create.assert_not_called() txn = session._transaction self.assertTrue(txn._begun) @@ -843,16 +853,13 @@ def __init__(self, database, exists=True, transaction=None): self._database = database self._exists = exists self._exists_checked = False - self._created = False + self.create = mock.Mock() self._deleted = False self._transaction = transaction def __lt__(self, other): return id(self) < id(other) - def create(self): - self._created = True - def exists(self): self._exists_checked = True return self._exists @@ -874,6 +881,22 @@ def __init__(self, name): self.name = name self._sessions = [] + def mock_batch_create_sessions(db, session_count=10, timeout=10, metadata=[]): + from google.cloud.spanner_v1.proto import spanner_pb2 + + response = spanner_pb2.BatchCreateSessionsResponse() + if session_count < 2: + response.session.add() + else: + response.session.add() + response.session.add() + return response + + from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient + + self.spanner_api = mock.create_autospec(SpannerClient, instance=True) + self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions + def session(self): return self._sessions.pop()