from enum import Enum
from typing import Sequence
from sqlalchemy import ARRAY, Text, cast, column, func, select, table, text
from sqlalchemy.engine import Connectable
from .exc import NoSuchObjectError
from .types import ColumnInfo, FunctionInfo, RelationInfo, SchemaRelationInfo
__all__ = (
'get_all_table_acls',
'get_table_acl',
'get_all_column_acls',
'get_column_acls',
'get_all_sequence_acls',
'get_sequence_acl',
'get_all_function_acls',
'get_function_acl',
'get_all_language_acls',
'get_language_acl',
'get_all_schema_acls',
'get_schema_acl',
'get_all_database_acls',
'get_database_acl',
'get_all_tablespace_acls',
'get_tablespace_acl',
'get_all_type_acls',
'get_type_acl',
)
pg_table_is_visible = func.pg_catalog.pg_table_is_visible
pg_function_is_visible = func.pg_catalog.pg_function_is_visible
pg_type_is_visible = func.pg_catalog.pg_type_is_visible
array_agg = func.array_agg
unnest = func.unnest
coalesce = func.coalesce
canonical_type = func.pg_temp.pg_grant_canonical_type
class PgRelKind(Enum):
TABLE = 'r'
INDEX = 'i'
SEQUENCE = 'S'
VIEW = 'v'
MATERIALIZED_VIEW = 'm'
COMPOSITE_TYPE = 'c'
TOAST_TABLE = 't'
FOREIGN_TABLE = 'f'
PARTITIONED_TABLE = 'p' # PostgresSQL 10+
pg_class = table(
'pg_class',
column('oid'),
column('relname'),
column('relacl'),
column('relnamespace'),
column('relkind'),
column('relowner'),
)
pg_namespace = table(
'pg_namespace',
column('oid'),
column('nspname'),
column('nspowner'),
column('nspacl'),
)
pg_roles = table(
'pg_roles',
column('oid'),
column('rolname'),
)
pg_proc = table(
'pg_proc',
column('oid'),
column('proname'),
column('proargtypes'),
column('pronamespace'),
column('proacl'),
column('proowner'),
)
pg_type = table(
'pg_type',
column('oid'),
column('typname'),
column('typnamespace'),
column('typowner'),
column('typacl'),
)
pg_language = table(
'pg_language',
column('oid'),
column('lanname'),
column('lanowner'),
column('lanacl'),
)
pg_database = table(
'pg_database',
column('oid'),
column('datname'),
column('datdba'),
column('datacl'),
)
pg_tablespace = table(
'pg_tablespace',
column('oid'),
column('spcname'),
column('spcowner'),
column('spcacl'),
)
pg_attribute = table(
'pg_attribute',
column('attrelid'),
column('attname'),
column('attnum'),
column('attisdropped'),
column('attacl'),
)
_pg_class_stmt = (
select(
pg_class.c.oid,
pg_namespace.c.nspname.label('schema'),
pg_class.c.relname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_class.c.relacl, ARRAY(Text)).label('acl'),
)
.outerjoin(pg_namespace, pg_class.c.relnamespace == pg_namespace.c.oid)
.outerjoin(pg_roles, pg_class.c.relowner == pg_roles.c.oid)
)
_pg_attribute_stmt = (
select(
pg_class.c.oid.label('table_oid'),
pg_namespace.c.nspname.label('schema'),
pg_class.c.relname.label('table'),
pg_attribute.c.attname.label('column'),
pg_roles.c.rolname.label('owner'),
cast(pg_attribute.c.attacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_attribute)
.join(pg_class, pg_attribute.c.attrelid == pg_class.c.oid)
.outerjoin(pg_namespace, pg_class.c.relnamespace == pg_namespace.c.oid)
.outerjoin(pg_roles, pg_class.c.relowner == pg_roles.c.oid)
.where(pg_attribute.c.attnum > 0)
.where(~pg_attribute.c.attisdropped)
.where(pg_class.c.relkind.in_([
PgRelKind.TABLE.value,
PgRelKind.VIEW.value,
PgRelKind.MATERIALIZED_VIEW.value,
PgRelKind.PARTITIONED_TABLE.value,
PgRelKind.FOREIGN_TABLE.value,
]))
)
_upat = unnest(pg_proc.c.proargtypes).alias('upat')
_pg_proc_argtypes = (
select(coalesce(array_agg(canonical_type(pg_type.c.typname)), cast([], ARRAY(Text))))
.join(_upat, _upat.column == pg_type.c.oid)
.scalar_subquery()
)
_pg_proc_stmt = (
select(
pg_proc.c.oid,
pg_namespace.c.nspname.label('schema'),
pg_proc.c.proname.label('name'),
_pg_proc_argtypes.label('arg_types'),
pg_roles.c.rolname.label('owner'),
cast(pg_proc.c.proacl, ARRAY(Text)).label('acl'),
)
.outerjoin(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
.outerjoin(pg_roles, pg_proc.c.proowner == pg_roles.c.oid)
)
_pg_lang_stmt = (
select(
pg_language.c.oid,
pg_language.c.lanname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_language.c.lanacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_language)
.outerjoin(pg_roles, pg_language.c.lanowner == pg_roles.c.oid)
)
_pg_schema_stmt = (
select(
pg_namespace.c.oid,
pg_namespace.c.nspname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_namespace.c.nspacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_namespace)
.outerjoin(pg_roles, pg_namespace.c.nspowner == pg_roles.c.oid)
)
_pg_db_stmt = (
select(
pg_database.c.oid,
pg_database.c.datname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_database.c.datacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_database)
.outerjoin(pg_roles, pg_database.c.datdba == pg_roles.c.oid)
)
_pg_tablespace_stmt = (
select(
pg_tablespace.c.oid,
pg_tablespace.c.spcname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_tablespace.c.spcacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_tablespace)
.outerjoin(pg_roles, pg_tablespace.c.spcowner == pg_roles.c.oid)
)
_pg_type_stmt = (
select(
pg_type.c.oid,
pg_namespace.c.nspname.label('schema'),
pg_type.c.typname.label('name'),
pg_roles.c.rolname.label('owner'),
cast(pg_type.c.typacl, ARRAY(Text)).label('acl'),
)
.select_from(pg_type)
.outerjoin(pg_namespace, pg_type.c.typnamespace == pg_namespace.c.oid)
.outerjoin(pg_roles, pg_type.c.typowner == pg_roles.c.oid)
)
def _filter_pg_class_stmt(stmt, schema=None, rel_name=None):
if schema is not None:
stmt = stmt.where(pg_namespace.c.nspname == schema)
if rel_name is not None:
if schema is None:
# "pg_table_is_visible can also be used with views, materialized
# views, indexes, sequences and foreign tables"
stmt = stmt.where(pg_table_is_visible(pg_class.c.oid))
stmt = stmt.where(pg_class.c.relname == rel_name)
return stmt
def _filter_pg_proc_stmt(schema=None, function_name=None, arg_types=None):
stmt = _pg_proc_stmt
if (function_name is None) != (arg_types is None):
raise TypeError(
'function_name and arg_types must both be specified') # pragma: no cover
if schema is not None:
stmt = stmt.where(pg_namespace.c.nspname == schema)
if function_name is not None:
arg_types = list(arg_types)
if arg_types:
arg_types_sub = (
select(array_agg(canonical_type(column('typs'))))
.select_from(func.unnest(arg_types).alias('typs'))
.scalar_subquery()
)
else:
arg_types_sub = cast([], ARRAY(Text))
if schema is None:
stmt = stmt.where(pg_function_is_visible(pg_proc.c.oid))
stmt = stmt.where(pg_proc.c.proname == function_name)
stmt = stmt.where(_pg_proc_argtypes == arg_types_sub)
return stmt
def _filter_pg_type_stmt(schema=None, type_name=None):
stmt = _pg_type_stmt
if schema is not None:
stmt = stmt.where(pg_namespace.c.nspname == schema)
if type_name is not None:
if schema is None:
stmt = stmt.where(pg_type_is_visible(pg_type.c.oid))
stmt = stmt.where(pg_type.c.typname == type_name)
return stmt
def _table_stmt(schema=None, table_name=None):
stmt = _filter_pg_class_stmt(
_pg_class_stmt, schema=schema, rel_name=table_name)
return stmt.where(pg_class.c.relkind.in_([
PgRelKind.TABLE.value,
PgRelKind.VIEW.value,
PgRelKind.MATERIALIZED_VIEW.value,
PgRelKind.PARTITIONED_TABLE.value,
PgRelKind.FOREIGN_TABLE.value,
]))
def _sequence_stmt(schema=None, sequence_name=None):
stmt = _filter_pg_class_stmt(
_pg_class_stmt, schema=schema, rel_name=sequence_name)
return stmt.where(pg_class.c.relkind == PgRelKind.SEQUENCE.value)
[docs]def get_all_table_acls(conn, schema=None):
"""Get privileges for all tables, views, materialized views, and foreign
tables.
Specify `schema` to limit the results to that schema.
Returns:
List of :class:`~.types.SchemaRelationInfo` objects.
"""
stmt = _table_stmt(schema=schema)
return [SchemaRelationInfo(**row) for row in conn.execute(stmt).mappings()]
[docs]def get_table_acl(conn, name, schema=None):
"""Get privileges for the table, view, materialized view, or foreign table
specified by `name`.
If `schema` is not given, the table or view must be visible in the search
path.
Returns:
:class:`~.types.SchemaRelationInfo`
"""
stmt = _table_stmt(schema=schema, table_name=name)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(name)
return SchemaRelationInfo(**row)
[docs]def get_all_column_acls(conn, schema=None):
"""Get privileges for all table, view, materialized view, and foreign
table columns.
Specify `schema` to limit the results to that schema.
Returns:
List of :class:`~.types.ColumnInfo` objects.
"""
stmt = _filter_pg_class_stmt(_pg_attribute_stmt, schema=schema)
return [ColumnInfo(**row) for row in conn.execute(stmt).mappings()]
[docs]def get_column_acls(conn, table_name, schema=None):
"""Get column privileges for the table, view, materialized view, or foreign
table specified by `name`.
If `schema` is not given, the table or view must be visible in the search
path.
Returns:
List of :class:`~.types.ColumnInfo` objects.
"""
stmt = _filter_pg_class_stmt(
_pg_attribute_stmt, schema=schema, rel_name=table_name)
rows = conn.execute(stmt).mappings().all()
if not rows:
raise NoSuchObjectError(table_name)
return [ColumnInfo(**row) for row in rows]
[docs]def get_all_sequence_acls(conn, schema=None):
"""Unless `schema` is given, returns all sequences from all schemas.
Returns:
List of :class:`~.types.SchemaRelationInfo` objects.
"""
stmt = _sequence_stmt(schema=schema)
return [SchemaRelationInfo(**row) for row in conn.execute(stmt).mappings()]
[docs]def get_sequence_acl(conn, sequence, schema=None):
"""If `schema` is not given, the sequence must be visible in the search
path.
Returns:
:class:`~.types.SchemaRelationInfo`
"""
stmt = _sequence_stmt(schema=schema, sequence_name=sequence)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(sequence)
return SchemaRelationInfo(**row)
def _make_canonical_type_function(conn):
"""Create function which canonicalizes a type name, returning the input
when casting to REGTYPE fails.
E.g. casting the 'any' type fails. Normal examples include 'int4' -> 'integer'
"""
# pg_temp is per-connection
stmt = text("""
CREATE OR REPLACE FUNCTION pg_temp.pg_grant_canonical_type(typname text)
RETURNS text AS $$
BEGIN
BEGIN
typname := typname::regtype::text;
EXCEPTION WHEN syntax_error THEN
END;
RETURN typname;
END;
$$
LANGUAGE plpgsql
STABLE
RETURNS NULL ON NULL INPUT;
""")
conn.execute(stmt)
[docs]def get_all_function_acls(conn, schema=None):
"""Unless `schema` is given, returns all functions from all schemas.
Returns:
List of :class:`~.types.FunctionInfo` objects.
"""
_make_canonical_type_function(conn)
stmt = _filter_pg_proc_stmt(schema=schema)
return [FunctionInfo(**row) for row in conn.execute(stmt).mappings()]
[docs]def get_function_acl(conn, function_name, arg_types: Sequence[str], schema=None):
"""If `schema` is not given, the function must be visible in the search path.
Returns:
:class:`~.types.FunctionInfo`
"""
# We could ask the user to register an event on their connection pool which
# creates this function on checkout, but that isn't a nice API for the
# common case.
_make_canonical_type_function(conn)
if (function_name is None) != (arg_types is None):
raise TypeError('function_name and arg_types must both be specified')
if not isinstance(arg_types, Sequence) or isinstance(arg_types, str):
raise TypeError("arg_types should be a sequence of strings, e.g. ['text']")
stmt = _filter_pg_proc_stmt(schema, function_name, arg_types)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(function_name)
return FunctionInfo(**row)
[docs]def get_all_language_acls(conn):
"""
Returns:
List of :class:`~.types.RelationInfo` objects.
"""
return [RelationInfo(**row) for row in conn.execute(_pg_lang_stmt).mappings()]
[docs]def get_language_acl(conn, language):
"""
Returns:
:class:`~.types.RelationInfo`
"""
stmt = _pg_lang_stmt.where(pg_language.c.lanname == language)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(language)
return RelationInfo(**row)
[docs]def get_all_schema_acls(conn):
"""
Returns:
List of :class:`~.types.RelationInfo` objects.
"""
return [RelationInfo(**row) for row in conn.execute(_pg_schema_stmt).mappings()]
[docs]def get_schema_acl(conn, schema):
"""
Returns:
:class:`~.types.RelationInfo`
"""
stmt = _pg_schema_stmt.where(pg_namespace.c.nspname == schema)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(schema)
return RelationInfo(**row)
[docs]def get_all_database_acls(conn):
"""
Returns:
List of :class:`~.types.RelationInfo` objects.
"""
return [RelationInfo(**row) for row in conn.execute(_pg_db_stmt).mappings()]
[docs]def get_database_acl(conn, database):
"""
Returns:
:class:`~.types.RelationInfo`
"""
stmt = _pg_db_stmt.where(pg_database.c.datname == database)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(database)
return RelationInfo(**row)
[docs]def get_all_tablespace_acls(conn):
"""
Returns:
List of :class:`~.types.RelationInfo` objects.
"""
return [RelationInfo(**row) for row in conn.execute(_pg_tablespace_stmt).mappings()]
[docs]def get_tablespace_acl(conn, tablespace):
"""
Returns:
:class:`~.types.RelationInfo`
"""
stmt = _pg_tablespace_stmt.where(pg_tablespace.c.spcname == tablespace)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(tablespace)
return RelationInfo(**row)
[docs]def get_all_type_acls(conn, schema=None):
"""Unless `schema` is given, returns all types from all schemas.
Returns:
List of :class:`~.types.SchemaRelationInfo` objects.
"""
stmt = _filter_pg_type_stmt(schema=schema)
return [SchemaRelationInfo(**row) for row in conn.execute(stmt).mappings()]
[docs]def get_type_acl(conn, type_name, schema=None):
"""If `schema` is not given, the type must be visible in the search path.
Returns:
:class:`~.types.SchemaRelationInfo`
"""
stmt = _filter_pg_type_stmt(schema=schema, type_name=type_name)
row = conn.execute(stmt).mappings().first()
if row is None:
raise NoSuchObjectError(type_name)
return SchemaRelationInfo(**row)