diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index e09b294dff..84cb2dc7a5 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -151,7 +151,7 @@ # DDL statements follow # https://cloud.google.com/spanner/docs/data-definition-language -RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) +RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP|GRANT|REVOKE)", re.IGNORECASE | re.DOTALL) RE_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 7d2384beed..0d27763432 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -27,9 +27,12 @@ from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted from google.api_core import gapic_v1 +from google.iam.v1 import iam_policy_pb2 +from google.iam.v1 import options_pb2 from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest from google.cloud.spanner_admin_database_v1 import Database as DatabasePB +from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest from google.cloud.spanner_admin_database_v1 import EncryptionConfig from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest @@ -119,7 +122,8 @@ class Database(object): :class:`~google.cloud.spanner_admin_database_v1.types.DatabaseDialect` :param database_dialect: (Optional) database dialect for the database - + :type database_role: str or None + :param database_role: (Optional) user-assigned database_role for the session. """ _spanner_api = None @@ -133,6 +137,7 @@ def __init__( logger=None, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, + database_role=None, ): self.database_id = database_id self._instance = instance @@ -149,9 +154,10 @@ def __init__( self._logger = logger self._encryption_config = encryption_config self._database_dialect = database_dialect + self._database_role = database_role if pool is None: - pool = BurstyPool() + pool = BurstyPool(database_role=database_role) self._pool = pool pool.bind(self) @@ -314,6 +320,14 @@ def database_dialect(self): """ return self._database_dialect + @property + def database_role(self): + """User-assigned database_role for sessions created by the pool. + :rtype: str + :returns: a str with the name of the database role. + """ + return self._database_role + @property def logger(self): """Logger used by the database. @@ -584,16 +598,22 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() - def session(self, labels=None): + def session(self, labels=None, database_role=None): """Factory to create a session for this database. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for the session. + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a session bound to this database. """ - return Session(self, labels=labels) + # If role is specified in param, then that role is used + # instead. + role = database_role or self._database_role + return Session(self, labels=labels, database_role=role) def snapshot(self, **kw): """Return an object which wraps a snapshot. @@ -772,6 +792,29 @@ def list_database_operations(self, filter_="", page_size=None): filter_=database_filter, page_size=page_size ) + def list_database_roles(self, page_size=None): + """Lists Cloud Spanner database roles. + + :type page_size: int + :param page_size: + Optional. The maximum number of database roles in each page of results + from this request. Non-positive values are ignored. Defaults to a + sensible value set by the API. + + :type: Iterable + :returns: + Iterable of :class:`~google.cloud.spanner_admin_database_v1.types.spanner_database_admin.DatabaseRole` + resources within the current database. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = ListDatabaseRolesRequest( + parent=self.name, + page_size=page_size, + ) + return api.list_database_roles(request=request, metadata=metadata) + def table(self, table_id): """Factory to create a table object within this database. @@ -811,6 +854,54 @@ def list_tables(self): for row in results: yield self.table(row[0]) + def get_iam_policy(self, policy_version=None): + """Gets the access control policy for a database resource. + + :type policy_version: int + :param policy_version: + (Optional) the maximum policy version that will be + used to format the policy. Valid values are 0, 1 ,3. + + :rtype: :class:`~google.iam.v1.policy_pb2.Policy` + :returns: + returns an Identity and Access Management (IAM) policy. It is used to + specify access control policies for Cloud Platform + resources. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = iam_policy_pb2.GetIamPolicyRequest( + resource=self.name, + options=options_pb2.GetPolicyOptions( + requested_policy_version=policy_version + ), + ) + response = api.get_iam_policy(request=request, metadata=metadata) + return response + + def set_iam_policy(self, policy): + """Sets the access control policy on a database resource. + Replaces any existing policy. + + :type policy: :class:`~google.iam.v1.policy_pb2.Policy` + :param policy_version: + the complete policy to be applied to the resource. + + :rtype: :class:`~google.iam.v1.policy_pb2.Policy` + :returns: + returns the new Identity and Access Management (IAM) policy. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = iam_policy_pb2.SetIamPolicyRequest( + resource=self.name, + policy=policy, + ) + response = api.set_iam_policy(request=request, metadata=metadata) + return response + class BatchCheckout(object): """Context manager for using a batch from a database. diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 6a9517a0e8..f972f817b3 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -431,6 +431,7 @@ def database( logger=None, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, + database_role=None, ): """Factory to create a database within this instance. @@ -477,6 +478,7 @@ def database( logger=logger, encryption_config=encryption_config, database_dialect=database_dialect, + database_role=database_role, ) def list_databases(self, page_size=None): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 56a78ef672..216ba5aeff 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -18,6 +18,8 @@ import queue from google.cloud.exceptions import NotFound +from google.cloud.spanner_v1 import BatchCreateSessionsRequest +from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import _metadata_with_prefix @@ -30,14 +32,18 @@ class AbstractSessionPool(object): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ _database = None - def __init__(self, labels=None): + def __init__(self, labels=None, database_role=None): if labels is None: labels = {} self._labels = labels + self._database_role = database_role @property def labels(self): @@ -48,6 +54,15 @@ def labels(self): """ return self._labels + @property + def database_role(self): + """User-assigned database_role for sessions created by the pool. + + :rtype: str + :returns: database_role assigned by the user + """ + return self._database_role + def bind(self, database): """Associate the pool with a database. @@ -104,9 +119,9 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - if self.labels: - return self._database.session(labels=self.labels) - return self._database.session() + return self._database.session( + labels=self.labels, database_role=self.database_role + ) def session(self, **kwargs): """Check out a session from the pool. @@ -146,13 +161,22 @@ class FixedSizePool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ DEFAULT_SIZE = 10 DEFAULT_TIMEOUT = 10 - def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT, labels=None): - super(FixedSizePool, self).__init__(labels=labels) + def __init__( + self, + size=DEFAULT_SIZE, + default_timeout=DEFAULT_TIMEOUT, + labels=None, + database_role=None, + ): + super(FixedSizePool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout self._sessions = queue.LifoQueue(size) @@ -167,9 +191,14 @@ def bind(self, database): self._database = database api = database.spanner_api metadata = _metadata_with_prefix(database.name) + self._database_role = self._database_role or self._database.database_role + request = BatchCreateSessionsRequest( + session_template=Session(creator_role=self.database_role), + ) while not self._sessions.full(): resp = api.batch_create_sessions( + request=request, database=database.name, session_count=self.size - self._sessions.qsize(), metadata=metadata, @@ -243,10 +272,13 @@ class BurstyPool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, target_size=10, labels=None): - super(BurstyPool, self).__init__(labels=labels) + def __init__(self, target_size=10, labels=None, database_role=None): + super(BurstyPool, self).__init__(labels=labels, database_role=database_role) self.target_size = target_size self._database = None self._sessions = queue.LifoQueue(target_size) @@ -259,6 +291,7 @@ def bind(self, database): when needed. """ self._database = database + self._database_role = self._database_role or self._database.database_role def get(self): """Check a session out from the pool. @@ -340,10 +373,20 @@ class PingingPool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000, labels=None): - super(PingingPool, self).__init__(labels=labels) + def __init__( + self, + size=10, + default_timeout=10, + ping_interval=3000, + labels=None, + database_role=None, + ): + super(PingingPool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout self._delta = datetime.timedelta(seconds=ping_interval) @@ -360,9 +403,15 @@ def bind(self, database): api = database.spanner_api metadata = _metadata_with_prefix(database.name) created_session_count = 0 + self._database_role = self._database_role or self._database.database_role + + request = BatchCreateSessionsRequest( + session_template=Session(creator_role=self.database_role), + ) while created_session_count < self.size: resp = api.batch_create_sessions( + request=request, database=database.name, session_count=self.size - created_session_count, metadata=metadata, @@ -470,13 +519,27 @@ class TransactionPingingPool(PingingPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000, labels=None): + def __init__( + self, + size=10, + default_timeout=10, + ping_interval=3000, + labels=None, + database_role=None, + ): self._pending_sessions = queue.Queue() super(TransactionPingingPool, self).__init__( - size, default_timeout, ping_interval, labels=labels + size, + default_timeout, + ping_interval, + labels=labels, + database_role=database_role, ) self.begin_pending_transactions() @@ -489,6 +552,7 @@ def bind(self, database): when needed. """ super(TransactionPingingPool, self).bind(database) + self._database_role = self._database_role or self._database.database_role self.begin_pending_transactions() def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 1ab6a93626..c210f8f61d 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -52,16 +52,20 @@ class Session(object): :type labels: dict (str -> str) :param labels: (Optional) User-assigned labels for the session. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ _session_id = None _transaction = None - def __init__(self, database, labels=None): + def __init__(self, database, labels=None, database_role=None): self._database = database if labels is None: labels = {} self._labels = labels + self._database_role = database_role def __lt__(self, other): return self._session_id < other._session_id @@ -71,6 +75,14 @@ def session_id(self): """Read-only ID, set by the back-end during :meth:`create`.""" return self._session_id + @property + def database_role(self): + """User-assigned database-role for the session. + + :rtype: str + :returns: the database role str (None if no database role were assigned).""" + return self._database_role + @property def labels(self): """User-assigned labels for the session. @@ -115,6 +127,8 @@ def create(self): metadata = _metadata_with_prefix(self._database.name) request = CreateSessionRequest(database=self._database.name) + if self._database.database_role is not None: + request.session.creator_role = self._database.database_role if self._labels: request.session.labels = self._labels diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 35f348939e..ad138b3a1c 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -31,6 +31,8 @@ from google.cloud import spanner from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types +from google.type import expr_pb2 +from google.iam.v1 import policy_pb2 from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore OPERATION_TIMEOUT_SECONDS = 240 @@ -2310,6 +2312,122 @@ def list_instance_config_operations(): # [END spanner_list_instance_config_operations] +def add_and_drop_database_roles(instance_id, database_id): + """Showcases how to manage a user defined database role.""" + # [START spanner_add_and_drop_database_roles] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + role_parent = "new_parent" + role_child = "new_child" + + operation = database.update_ddl( + [ + "CREATE ROLE {}".format(role_parent), + "GRANT SELECT ON TABLE Singers TO ROLE {}".format(role_parent), + "CREATE ROLE {}".format(role_child), + "GRANT ROLE {} TO ROLE {}".format(role_parent, role_child), + ] + ) + operation.result(OPERATION_TIMEOUT_SECONDS) + print( + "Created roles {} and {} and granted privileges".format(role_parent, role_child) + ) + + operation = database.update_ddl( + [ + "REVOKE ROLE {} FROM ROLE {}".format(role_parent, role_child), + "DROP ROLE {}".format(role_child), + ] + ) + operation.result(OPERATION_TIMEOUT_SECONDS) + print("Revoked privileges and dropped role {}".format(role_child)) + + # [END spanner_add_and_drop_database_roles] + + +def read_data_with_database_role(instance_id, database_id): + """Showcases how a user defined database role is used by member.""" + # [START spanner_read_data_with_database_role] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + role = "new_parent" + database = instance.database(database_id, database_role=role) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * FROM Singers") + for row in results: + print("SingerId: {}, FirstName: {}, LastName: {}".format(*row)) + + # [END spanner_read_data_with_database_role] + + +def list_database_roles(instance_id, database_id): + """Showcases how to list Database Roles.""" + # [START spanner_list_database_roles] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + # List database roles. + print("Database Roles are:") + for role in database.list_database_roles(): + print(role.name.split("/")[-1]) + # [END spanner_list_database_roles] + + +def enable_fine_grained_access( + instance_id, + database_id, + iam_member="user:alice@example.com", + database_role="new_parent", + title="condition title", +): + """Showcases how to enable fine grained access control.""" + # [START spanner_enable_fine_grained_access] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + # iam_member = "user:alice@example.com" + # database_role = "new_parent" + # title = "condition title" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + # The policy in the response from getDatabaseIAMPolicy might use the policy version + # that you specified, or it might use a lower policy version. For example, if you + # specify version 3, but the policy has no conditional role bindings, the response + # uses version 1. Valid values are 0, 1, and 3. + policy = database.get_iam_policy(3) + if policy.version < 3: + policy.version = 3 + + new_binding = policy_pb2.Binding( + role="roles/spanner.fineGrainedAccessUser", + members=[iam_member], + condition=expr_pb2.Expr( + title=title, + expression=f'resource.name.endsWith("/databaseRoles/{database_role}")', + ), + ) + + policy.version = 3 + policy.bindings.append(new_binding) + database.set_iam_policy(policy) + + new_policy = database.get_iam_policy(3) + print( + f"Enabled fine-grained access in IAM. New policy has version {new_policy.version}" + ) + # [END spanner_enable_fine_grained_access] + + if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -2419,6 +2537,23 @@ def list_instance_config_operations(): "create_client_with_query_options", help=create_client_with_query_options.__doc__, ) + subparsers.add_parser( + "add_and_drop_database_roles", help=add_and_drop_database_roles.__doc__ + ) + subparsers.add_parser( + "read_data_with_database_role", help=read_data_with_database_role.__doc__ + ) + subparsers.add_parser("list_database_roles", help=list_database_roles.__doc__) + enable_fine_grained_access_parser = subparsers.add_parser( + "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ + ) + enable_fine_grained_access_parser.add_argument( + "--iam_member", default="user:alice@example.com" + ) + enable_fine_grained_access_parser.add_argument( + "--database_role", default="new_parent" + ) + enable_fine_grained_access_parser.add_argument("--title", default="condition title") args = parser.parse_args() @@ -2534,3 +2669,17 @@ def list_instance_config_operations(): query_data_with_query_options(args.instance_id, args.database_id) elif args.command == "create_client_with_query_options": create_client_with_query_options(args.instance_id, args.database_id) + elif args.command == "add_and_drop_database_roles": + add_and_drop_database_roles(args.instance_id, args.database_id) + elif args.command == "read_data_with_database_role": + read_data_with_database_role(args.instance_id, args.database_id) + elif args.command == "list_database_roles": + list_database_roles(args.instance_id, args.database_id) + elif args.command == "enable_fine_grained_access": + enable_fine_grained_access( + args.instance_id, + args.database_id, + args.iam_member, + args.database_role, + args.title, + ) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 05cfedfdde..6d5822e37b 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -759,3 +759,25 @@ def test_set_request_tag(capsys, instance_id, sample_database): snippets.set_request_tag(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, AlbumId: 1, AlbumTitle: Total Junk" in out + + +@pytest.mark.dependency(name="add_and_drop_database_roles", depends=["insert_data"]) +def test_add_and_drop_database_roles(capsys, instance_id, sample_database): + snippets.add_and_drop_database_roles(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "Created roles new_parent and new_child and granted privileges" in out + assert "Revoked privileges and dropped role new_child" in out + + +@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) +def test_read_data_with_database_role(capsys, instance_id, sample_database): + snippets.read_data_with_database_role(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "ingerId: 1, FirstName: Marc, LastName: Richards" in out + + +@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) +def test_list_database_roles(capsys, instance_id, sample_database): + snippets.list_database_roles(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "new_parent" in out diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index e9e6c69287..9fac10ed4d 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -18,7 +18,9 @@ import pytest from google.api_core import exceptions +from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 +from google.type import expr_pb2 from . import _helpers from . import _sample_data @@ -164,6 +166,53 @@ def test_create_database_with_default_leader_success( assert result[0] == default_leader +def test_iam_policy( + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, +): + pool = spanner_v1.BurstyPool(labels={"testcase": "iam_policy"}) + temp_db_id = _helpers.unique_id("iam_db", separator="_") + create_table = ( + "CREATE TABLE policy (\n" + + " Id STRING(36) NOT NULL,\n" + + " Field1 STRING(36) NOT NULL\n" + + ") PRIMARY KEY (Id)" + ) + create_role = "CREATE ROLE parent" + + temp_db = shared_instance.database( + temp_db_id, + ddl_statements=[create_table, create_role], + pool=pool, + ) + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) + policy = temp_db.get_iam_policy(3) + + assert policy.version == 0 + assert policy.etag == b"\x00 \x01" + + new_binding = policy_pb2.Binding( + role="roles/spanner.fineGrainedAccessUser", + members=["user:asthamohta@google.com"], + condition=expr_pb2.Expr( + title="condition title", + expression='resource.name.endsWith("/databaseRoles/parent")', + ), + ) + + policy.version = 3 + policy.bindings.append(new_binding) + temp_db.set_iam_policy(policy) + + new_policy = temp_db.get_iam_policy(3) + assert new_policy.version == 3 + assert new_policy.bindings == [new_binding] + + def test_table_not_found(shared_instance): temp_db_id = _helpers.unique_id("tbl_not_found", separator="_") @@ -301,6 +350,87 @@ def test_update_ddl_w_default_leader_success( assert len(temp_db.ddl_statements) == len(ddl_statements) +def test_create_role_grant_access_success( + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, +): + creator_role_parent = _helpers.unique_id("role_parent", separator="_") + creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") + + temp_db_id = _helpers.unique_id("dfl_ldrr_upd_ddl", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + f"CREATE ROLE {creator_role_parent}", + f"CREATE ROLE {creator_role_orphan}", + f"GRANT SELECT ON TABLE contacts TO ROLE {creator_role_parent}", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Perform select with orphan role on table contacts. + # Expect PermissionDenied exception. + temp_db = shared_instance.database(temp_db_id, database_role=creator_role_orphan) + with pytest.raises(exceptions.PermissionDenied): + with temp_db.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * FROM contacts") + for row in results: + pass + + # Perform select with parent role on table contacts. Expect success. + temp_db = shared_instance.database(temp_db_id, database_role=creator_role_parent) + with temp_db.snapshot() as snapshot: + snapshot.execute_sql("SELECT * FROM contacts") + + ddl_remove_roles = [ + f"REVOKE SELECT ON TABLE contacts FROM ROLE {creator_role_parent}", + f"DROP ROLE {creator_role_parent}", + f"DROP ROLE {creator_role_orphan}", + ] + # Revoke permission and Delete roles. + operation = temp_db.update_ddl(ddl_remove_roles) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + +def test_list_database_role_success( + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, +): + creator_role_parent = _helpers.unique_id("role_parent", separator="_") + creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") + + temp_db_id = _helpers.unique_id("dfl_ldrr_upd_ddl", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + f"CREATE ROLE {creator_role_parent}", + f"CREATE ROLE {creator_role_orphan}", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # List database roles. + roles_list = [] + for role in temp_db.list_database_roles(): + roles_list.append(role.name.split("/")[-1]) + assert creator_role_parent in roles_list + assert creator_role_orphan in roles_list + + def test_db_batch_insert_then_db_snapshot_read(shared_database): _helpers.retry_has_all_dll(shared_database.reload)() sd = _sample_data diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 511ad838cf..ddd1d5572a 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -54,6 +54,10 @@ def test_classify_stmt(self): "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", STMT_DDL, ), + ("CREATE ROLE parent", STMT_DDL), + ("GRANT SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), + ("REVOKE SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), + ("GRANT ROLE parent TO ROLE child", STMT_DDL), ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), ) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bd47a2ac31..bff89320c7 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -61,6 +61,7 @@ class _BaseTest(unittest.TestCase): BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID TRANSACTION_TAG = "transaction-tag" + DATABASE_ROLE = "dummy-role" def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -112,6 +113,7 @@ def test_ctor_defaults(self): self.assertIsNone(database._logger) # BurstyPool does not create sessions during 'bind()'. self.assertTrue(database._pool._sessions.empty()) + self.assertIsNone(database.database_role) def test_ctor_w_explicit_pool(self): instance = _Instance(self.INSTANCE_NAME) @@ -123,6 +125,15 @@ def test_ctor_w_explicit_pool(self): self.assertIs(database._pool, pool) self.assertIs(pool._bound, database) + def test_ctor_w_database_role(self): + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one( + self.DATABASE_ID, instance, database_role=self.DATABASE_ROLE + ) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertIs(database.database_role, self.DATABASE_ROLE) + def test_ctor_w_ddl_statements_non_string(self): with self.assertRaises(ValueError): @@ -1527,6 +1538,51 @@ def test_list_database_operations_explicit_filter(self): filter_=expected_filter_, page_size=page_size ) + def test_list_database_roles_grpc_error(self): + from google.api_core.exceptions import Unknown + from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest + + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.list_database_roles.side_effect = Unknown("testing") + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(Unknown): + database.list_database_roles() + + expected_request = ListDatabaseRolesRequest( + parent=database.name, + ) + + api.list_database_roles.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + def test_list_database_roles_defaults(self): + from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest + + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + instance = _Instance(self.INSTANCE_NAME, client=client) + instance.list_database_roles = mock.MagicMock(return_value=[]) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + resp = database.list_database_roles() + + expected_request = ListDatabaseRolesRequest( + parent=database.name, + ) + + api.list_database_roles.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + self.assertIsNotNone(resp) + def test_table_factory_defaults(self): from google.cloud.spanner_v1.table import Table diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index c715fb2ee1..e0a0f663cf 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest - import mock @@ -544,6 +543,7 @@ def test_database_factory_defaults(self): self.assertIsNone(database._logger) pool = database._pool self.assertIs(pool._database, database) + self.assertIsNone(database.database_role) def test_database_factory_explicit(self): from logging import Logger @@ -553,6 +553,7 @@ def test_database_factory_explicit(self): client = _Client(self.PROJECT) instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) DATABASE_ID = "database-id" + DATABASE_ROLE = "dummy-role" pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} @@ -563,6 +564,7 @@ def test_database_factory_explicit(self): pool=pool, logger=logger, encryption_config=encryption_config, + database_role=DATABASE_ROLE, ) self.assertIsInstance(database, Database) @@ -573,6 +575,7 @@ def test_database_factory_explicit(self): self.assertIs(database._logger, logger) self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) + self.assertIs(database.database_role, DATABASE_ROLE) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 593420187d..1a53aa1604 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -44,12 +44,15 @@ def test_ctor_defaults(self): pool = self._make_one() self.assertIsNone(pool._database) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(labels=labels) + database_role = "dummy-role" + pool = self._make_one(labels=labels, database_role=database_role) self.assertIsNone(pool._database) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) def test_bind_abstract(self): pool = self._make_one() @@ -82,7 +85,7 @@ def test__new_session_wo_labels(self): new_session = pool._new_session() self.assertIs(new_session, session) - database.session.assert_called_once_with() + database.session.assert_called_once_with(labels={}, database_role=None) def test__new_session_w_labels(self): labels = {"foo": "bar"} @@ -94,7 +97,19 @@ def test__new_session_w_labels(self): new_session = pool._new_session() self.assertIs(new_session, session) - database.session.assert_called_once_with(labels=labels) + database.session.assert_called_once_with(labels=labels, database_role=None) + + def test__new_session_w_database_role(self): + database_role = "dummy-role" + pool = self._make_one(database_role=database_role) + database = pool._database = _make_database("name") + session = _make_session() + database.session.return_value = session + + new_session = pool._new_session() + + self.assertIs(new_session, session) + database.session.assert_called_once_with(labels={}, database_role=database_role) def test_session_wo_kwargs(self): from google.cloud.spanner_v1.pool import SessionCheckout @@ -133,26 +148,34 @@ def test_ctor_defaults(self): self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(size=4, default_timeout=30, labels=labels) + database_role = "dummy-role" + pool = self._make_one( + size=4, default_timeout=30, labels=labels, database_role=database_role + ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) self.assertEqual(pool.default_timeout, 30) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) def test_bind(self): + database_role = "dummy-role" pool = self._make_one() database = _Database("name") SESSIONS = [_Session(database)] * 10 + database._database_role = database_role database._sessions.extend(SESSIONS) pool.bind(database) self.assertIs(pool._database, database) self.assertEqual(pool.size, 10) + self.assertEqual(pool.database_role, database_role) self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.full()) @@ -272,14 +295,25 @@ def test_ctor_defaults(self): self.assertEqual(pool.target_size, 10) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(target_size=4, labels=labels) + database_role = "dummy-role" + pool = self._make_one(target_size=4, labels=labels, database_role=database_role) self.assertIsNone(pool._database) self.assertEqual(pool.target_size, 4) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_get_empty(self): pool = self._make_one() @@ -392,11 +426,17 @@ def test_ctor_defaults(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} + database_role = "dummy-role" pool = self._make_one( - size=4, default_timeout=30, ping_interval=1800, labels=labels + size=4, + default_timeout=30, + ping_interval=1800, + labels=labels, + database_role=database_role, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -404,6 +444,17 @@ def test_ctor_explicit(self): self.assertEqual(pool._delta.seconds, 1800) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() @@ -624,11 +675,17 @@ def test_ctor_defaults(self): self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} + database_role = "dummy-role" pool = self._make_one( - size=4, default_timeout=30, ping_interval=1800, labels=labels + size=4, + default_timeout=30, + ping_interval=1800, + labels=labels, + database_role=database_role, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -637,6 +694,17 @@ def test_ctor_explicit(self): self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() @@ -794,10 +862,12 @@ def test_ctor_wo_kwargs(self): def test_ctor_w_kwargs(self): pool = _Pool() - checkout = self._make_one(pool, foo="bar") + checkout = self._make_one(pool, foo="bar", database_role="dummy-role") self.assertIs(checkout._pool, pool) self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {"foo": "bar"}) + self.assertEqual( + checkout._kwargs, {"foo": "bar", "database_role": "dummy-role"} + ) def test_context_manager_wo_kwargs(self): session = object() @@ -885,17 +955,31 @@ class _Database(object): def __init__(self, name): self.name = name self._sessions = [] + self._database_role = None def mock_batch_create_sessions( - database=None, session_count=10, timeout=10, metadata=[] + request=None, + database=None, + session_count=10, + timeout=10, + metadata=[], + labels={}, ): from google.cloud.spanner_v1 import BatchCreateSessionsResponse from google.cloud.spanner_v1 import Session + database_role = request.session_template.creator_role if request else None if session_count < 2: - response = BatchCreateSessionsResponse(session=[Session()]) + response = BatchCreateSessionsResponse( + session=[Session(creator_role=database_role, labels=labels)] + ) else: - response = BatchCreateSessionsResponse(session=[Session(), Session()]) + response = BatchCreateSessionsResponse( + session=[ + Session(creator_role=database_role, labels=labels), + Session(creator_role=database_role, labels=labels), + ] + ) return response from google.cloud.spanner_v1 import SpannerClient @@ -903,7 +987,16 @@ def mock_batch_create_sessions( self.spanner_api = mock.create_autospec(SpannerClient, instance=True) self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions - def session(self): + @property + def database_role(self): + """Database role used in sessions to connect to this database. + + :rtype: str + :returns: an str with the name of the database role. + """ + return self._database_role + + def session(self, **kwargs): # always return first session in the list # to avoid reversing the order of putting # sessions into pool (important for order tests) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0f297654bb..005cd0cd1f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -45,6 +45,7 @@ class TestSession(OpenTelemetryBase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session-id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + DATABASE_ROLE = "dummy-role" BASE_ATTRIBUTES = { "db.type": "spanner", "db.url": "spanner.googleapis.com", @@ -61,19 +62,20 @@ def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) @staticmethod - def _make_database(name=DATABASE_NAME): + def _make_database(name=DATABASE_NAME, database_role=None): from google.cloud.spanner_v1.database import Database database = mock.create_autospec(Database, instance=True) database.name = name database.log_commit_stats = False + database.database_role = database_role return database @staticmethod - def _make_session_pb(name, labels=None): + def _make_session_pb(name, labels=None, database_role=None): from google.cloud.spanner_v1 import Session - return Session(name=name, labels=labels) + return Session(name=name, labels=labels, creator_role=database_role) def _make_spanner_api(self): from google.cloud.spanner_v1 import SpannerClient @@ -87,6 +89,20 @@ def test_constructor_wo_labels(self): self.assertIs(session._database, database) self.assertEqual(session.labels, {}) + def test_constructor_w_database_role(self): + database = self._make_database(database_role=self.DATABASE_ROLE) + session = self._make_one(database, database_role=self.DATABASE_ROLE) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + + def test_constructor_wo_database_role(self): + database = self._make_database() + session = self._make_one(database) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertIs(session.database_role, None) + def test_constructor_w_labels(self): database = self._make_database() labels = {"foo": "bar"} @@ -126,6 +142,65 @@ def test_create_w_session_id(self): self.assertNoSpans() + def test_create_w_database_role(self): + from google.cloud.spanner_v1 import CreateSessionRequest + from google.cloud.spanner_v1 import Session as SessionRequestProto + + session_pb = self._make_session_pb( + self.SESSION_NAME, database_role=self.DATABASE_ROLE + ) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database(database_role=self.DATABASE_ROLE) + database.spanner_api = gax_api + session = self._make_one(database, database_role=self.DATABASE_ROLE) + + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + session_template = SessionRequestProto(creator_role=self.DATABASE_ROLE) + + request = CreateSessionRequest( + database=database.name, + session=session_template, + ) + + gax_api.create_session.assert_called_once_with( + request=request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + self.assertSpanAttributes( + "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES + ) + + def test_create_wo_database_role(self): + from google.cloud.spanner_v1 import CreateSessionRequest + + session_pb = self._make_session_pb(self.SESSION_NAME) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database() + database.spanner_api = gax_api + session = self._make_one(database) + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertIsNone(session.database_role) + + request = CreateSessionRequest( + database=database.name, + ) + + gax_api.create_session.assert_called_once_with( + request=request, metadata=[("google-cloud-resource-prefix", database.name)] + ) + + self.assertSpanAttributes( + "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES + ) + def test_create_ok(self): from google.cloud.spanner_v1 import CreateSessionRequest