Source code for pg_grant.query

import sys
import typing as t
from enum import Enum
from typing import Any, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union

from sqlalchemy import (
    ARRAY,
    Connection,
    Select,
    Text,
    cast,
    column,
    func,
    select,
    table,
    text,
)
from sqlalchemy.orm import Session

from ._typing_sqlalchemy import ArgTypesInput
from .exc import NoSuchObjectError
from .types import (
    ColumnInfo,
    FunctionInfo,
    ParameterInfo,
    RelationInfo,
    SchemaRelationInfo,
)

if sys.version_info >= (3, 10):
    from typing import TypeAlias
else:
    from typing_extensions import TypeAlias

__all__ = (
    "get_all_table_acls",
    "get_table_acl",
    "get_all_column_acls",
    "get_column_acls",
    "get_all_sequence_acls",
    "get_sequence_acl",
    "get_all_function_acls",
    "get_function_acl",
    "get_all_language_acls",
    "get_language_acl",
    "get_all_schema_acls",
    "get_schema_acl",
    "get_all_database_acls",
    "get_database_acl",
    "get_all_tablespace_acls",
    "get_tablespace_acl",
    "get_all_type_acls",
    "get_type_acl",
    "get_all_parameter_acls",
    "get_parameter_acl",
)

pg_table_is_visible = func.pg_catalog.pg_table_is_visible
pg_function_is_visible = func.pg_catalog.pg_function_is_visible
pg_type_is_visible = func.pg_catalog.pg_type_is_visible
array_agg = func.array_agg
unnest = func.unnest
coalesce = func.coalesce
canonical_type = func.pg_temp.pg_grant_canonical_type

TP = TypeVar("TP", bound=Tuple[Any, ...])
Connectable: TypeAlias = Union[Connection, Session]


class PgRelKind(Enum):
    TABLE = "r"
    INDEX = "i"
    SEQUENCE = "S"
    VIEW = "v"
    MATERIALIZED_VIEW = "m"
    COMPOSITE_TYPE = "c"
    TOAST_TABLE = "t"
    FOREIGN_TABLE = "f"
    PARTITIONED_TABLE = "p"  # PostgresSQL 10+


pg_class = table(
    "pg_class",
    column("oid"),
    column("relname"),
    column("relacl"),
    column("relnamespace"),
    column("relkind"),
    column("relowner"),
)

pg_namespace = table(
    "pg_namespace",
    column("oid"),
    column("nspname"),
    column("nspowner"),
    column("nspacl"),
)

pg_roles = table(
    "pg_roles",
    column("oid"),
    column("rolname"),
)

pg_proc = table(
    "pg_proc",
    column("oid"),
    column("proname"),
    column("proargtypes"),
    column("pronamespace"),
    column("proacl"),
    column("proowner"),
)

pg_type = table(
    "pg_type",
    column("oid"),
    column("typname"),
    column("typnamespace"),
    column("typowner"),
    column("typacl"),
)

pg_parameter_acl = table(
    "pg_parameter_acl",
    column("oid"),
    column("parname"),
    column("paracl"),
)

pg_language = table(
    "pg_language",
    column("oid"),
    column("lanname"),
    column("lanowner"),
    column("lanacl"),
)

pg_database = table(
    "pg_database",
    column("oid"),
    column("datname"),
    column("datdba"),
    column("datacl"),
)

pg_tablespace = table(
    "pg_tablespace",
    column("oid"),
    column("spcname"),
    column("spcowner"),
    column("spcacl"),
)

pg_attribute = table(
    "pg_attribute",
    column("attrelid"),
    column("attname"),
    column("attnum"),
    column("attisdropped"),
    column("attacl"),
)

_pg_class_stmt = (
    select(
        pg_class.c.oid,
        pg_namespace.c.nspname.label("schema"),
        pg_class.c.relname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_class.c.relacl, ARRAY(Text)).label("acl"),
    )
    .outerjoin(pg_namespace, pg_class.c.relnamespace == pg_namespace.c.oid)
    .outerjoin(pg_roles, pg_class.c.relowner == pg_roles.c.oid)
)

_pg_attribute_stmt = (
    select(
        pg_class.c.oid.label("table_oid"),
        pg_namespace.c.nspname.label("schema"),
        pg_class.c.relname.label("table"),
        pg_attribute.c.attname.label("column"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_attribute.c.attacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_attribute)
    .join(pg_class, pg_attribute.c.attrelid == pg_class.c.oid)
    .outerjoin(pg_namespace, pg_class.c.relnamespace == pg_namespace.c.oid)
    .outerjoin(pg_roles, pg_class.c.relowner == pg_roles.c.oid)
    .where(pg_attribute.c.attnum > 0)
    .where(~pg_attribute.c.attisdropped)
    .where(
        # need to cast for PostgreSQL < 13 on psycopg3
        cast(pg_class.c.relkind, Text).in_(
            [
                PgRelKind.TABLE.value,
                PgRelKind.VIEW.value,
                PgRelKind.MATERIALIZED_VIEW.value,
                PgRelKind.PARTITIONED_TABLE.value,
                PgRelKind.FOREIGN_TABLE.value,
            ]
        )
    )
)

_upat = unnest(pg_proc.c.proargtypes).alias("upat")
_pg_proc_argtypes = (
    select(
        coalesce(array_agg(canonical_type(pg_type.c.typname)), cast([], ARRAY(Text)))
    )
    .join(_upat, _upat.column == pg_type.c.oid)
    .scalar_subquery()
)

_pg_proc_stmt = (
    select(
        pg_proc.c.oid,
        pg_namespace.c.nspname.label("schema"),
        pg_proc.c.proname.label("name"),
        _pg_proc_argtypes.label("arg_types"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_proc.c.proacl, ARRAY(Text)).label("acl"),
    )
    .outerjoin(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
    .outerjoin(pg_roles, pg_proc.c.proowner == pg_roles.c.oid)
)

_pg_lang_stmt = (
    select(
        pg_language.c.oid,
        pg_language.c.lanname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_language.c.lanacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_language)
    .outerjoin(pg_roles, pg_language.c.lanowner == pg_roles.c.oid)
)

_pg_schema_stmt = (
    select(
        pg_namespace.c.oid,
        pg_namespace.c.nspname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_namespace.c.nspacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_namespace)
    .outerjoin(pg_roles, pg_namespace.c.nspowner == pg_roles.c.oid)
)

_pg_db_stmt = (
    select(
        pg_database.c.oid,
        pg_database.c.datname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_database.c.datacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_database)
    .outerjoin(pg_roles, pg_database.c.datdba == pg_roles.c.oid)
)

_pg_tablespace_stmt = (
    select(
        pg_tablespace.c.oid,
        pg_tablespace.c.spcname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_tablespace.c.spcacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_tablespace)
    .outerjoin(pg_roles, pg_tablespace.c.spcowner == pg_roles.c.oid)
)

_pg_type_stmt = (
    select(
        pg_type.c.oid,
        pg_namespace.c.nspname.label("schema"),
        pg_type.c.typname.label("name"),
        pg_roles.c.rolname.label("owner"),
        cast(pg_type.c.typacl, ARRAY(Text)).label("acl"),
    )
    .select_from(pg_type)
    .outerjoin(pg_namespace, pg_type.c.typnamespace == pg_namespace.c.oid)
    .outerjoin(pg_roles, pg_type.c.typowner == pg_roles.c.oid)
)

_pg_parameter_stmt = select(
    pg_parameter_acl.c.oid,
    pg_parameter_acl.c.parname.label("name"),
    cast(pg_parameter_acl.c.paracl, ARRAY(Text)).label("acl"),
)


def _filter_pg_class_stmt(
    stmt: Select[TP], schema: Optional[str] = None, rel_name: Optional[str] = None
) -> Select[TP]:
    if schema is not None:
        stmt = stmt.where(pg_namespace.c.nspname == schema)

    if rel_name is not None:
        if schema is None:
            # "pg_table_is_visible can also be used with views, materialized
            # views, indexes, sequences and foreign tables"
            stmt = stmt.where(pg_table_is_visible(pg_class.c.oid))

        stmt = stmt.where(pg_class.c.relname == rel_name)

    return stmt


def _filter_pg_proc_stmt(
    schema: Optional[str] = None,
    function_name: Optional[str] = None,
    arg_types: Optional[ArgTypesInput] = None,
) -> Select[Any]:
    stmt = _pg_proc_stmt

    if (function_name is None) != (arg_types is None):
        raise TypeError(
            "function_name and arg_types must both be specified"
        )  # pragma: no cover

    if schema is not None:
        stmt = stmt.where(pg_namespace.c.nspname == schema)

    if function_name is not None:
        assert arg_types is not None
        arg_types = list(arg_types)

        arg_types_sub: Any
        if arg_types:
            typs = func.unnest(cast(arg_types, ARRAY(Text))).alias("typs")
            arg_types_sub = (
                select(array_agg(canonical_type(typs.column)))
                .select_from(typs)
                .scalar_subquery()
            )
        else:
            arg_types_sub = cast([], ARRAY(Text))

        if schema is None:
            stmt = stmt.where(pg_function_is_visible(pg_proc.c.oid))
        stmt = stmt.where(pg_proc.c.proname == function_name)
        stmt = stmt.where(_pg_proc_argtypes == arg_types_sub)

    return stmt


def _filter_pg_type_stmt(
    schema: Optional[str] = None, type_name: Optional[str] = None
) -> Select[Any]:
    stmt = _pg_type_stmt

    if schema is not None:
        stmt = stmt.where(pg_namespace.c.nspname == schema)

    if type_name is not None:
        if schema is None:
            stmt = stmt.where(pg_type_is_visible(pg_type.c.oid))
        stmt = stmt.where(pg_type.c.typname == type_name)

    return stmt


def _table_stmt(
    schema: Optional[str] = None, table_name: Optional[str] = None
) -> Select[Any]:
    stmt = _filter_pg_class_stmt(_pg_class_stmt, schema=schema, rel_name=table_name)
    return stmt.where(
        # need to cast for PostgreSQL < 13 on psycopg3
        cast(pg_class.c.relkind, Text).in_(
            [
                PgRelKind.TABLE.value,
                PgRelKind.VIEW.value,
                PgRelKind.MATERIALIZED_VIEW.value,
                PgRelKind.PARTITIONED_TABLE.value,
                PgRelKind.FOREIGN_TABLE.value,
            ]
        )
    )


def _sequence_stmt(
    schema: Optional[str] = None, sequence_name: Optional[str] = None
) -> Select[Any]:
    stmt = _filter_pg_class_stmt(_pg_class_stmt, schema=schema, rel_name=sequence_name)
    return stmt.where(pg_class.c.relkind == PgRelKind.SEQUENCE.value)


[docs] def get_all_table_acls( conn: Connectable, schema: Optional[str] = None ) -> List[SchemaRelationInfo]: """Get privileges for all tables, views, materialized views, and foreign tables. Specify `schema` to limit the results to that schema. Returns: List of :class:`~.types.SchemaRelationInfo` objects. """ stmt = _table_stmt(schema=schema) return [ SchemaRelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(stmt).mappings() ]
[docs] def get_table_acl( conn: Connectable, name: str, schema: Optional[str] = None ) -> SchemaRelationInfo: """Get privileges for the table, view, materialized view, or foreign table specified by `name`. If `schema` is not given, the table or view must be visible in the search path. Returns: :class:`~.types.SchemaRelationInfo` """ stmt = _table_stmt(schema=schema, table_name=name) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(name) return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_column_acls( conn: Connectable, schema: Optional[str] = None ) -> List[ColumnInfo]: """Get privileges for all table, view, materialized view, and foreign table columns. Specify `schema` to limit the results to that schema. Returns: List of :class:`~.types.ColumnInfo` objects. """ stmt = _filter_pg_class_stmt(_pg_attribute_stmt, schema=schema) return [ ColumnInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(stmt).mappings() ]
[docs] def get_column_acls( conn: Connectable, table_name: str, schema: Optional[str] = None ) -> List[ColumnInfo]: """Get column privileges for the table, view, materialized view, or foreign table specified by `name`. If `schema` is not given, the table or view must be visible in the search path. Returns: List of :class:`~.types.ColumnInfo` objects. """ stmt = _filter_pg_class_stmt(_pg_attribute_stmt, schema=schema, rel_name=table_name) rows = conn.execute(stmt).mappings().all() if not rows: raise NoSuchObjectError(table_name) return [ColumnInfo(**t.cast("Mapping[str, Any]", row)) for row in rows]
[docs] def get_all_sequence_acls( conn: Connectable, schema: Optional[str] = None ) -> List[SchemaRelationInfo]: """Unless `schema` is given, returns all sequences from all schemas. Returns: List of :class:`~.types.SchemaRelationInfo` objects. """ stmt = _sequence_stmt(schema=schema) return [ SchemaRelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(stmt).mappings() ]
[docs] def get_sequence_acl( conn: Connectable, sequence: str, schema: Optional[str] = None ) -> SchemaRelationInfo: """If `schema` is not given, the sequence must be visible in the search path. Returns: :class:`~.types.SchemaRelationInfo` """ stmt = _sequence_stmt(schema=schema, sequence_name=sequence) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(sequence) return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))
def _make_canonical_type_function(conn: Connectable) -> None: """Create function which canonicalizes a type name, returning the input when casting to REGTYPE fails. E.g. casting the 'any' type fails. Normal examples include 'int4' -> 'integer' """ # pg_temp is per-connection stmt = text( """ CREATE OR REPLACE FUNCTION pg_temp.pg_grant_canonical_type(typname text) RETURNS text AS $$ BEGIN BEGIN typname := typname::regtype::text; EXCEPTION WHEN syntax_error THEN END; RETURN typname; END; $$ LANGUAGE plpgsql STABLE RETURNS NULL ON NULL INPUT; """ ) conn.execute(stmt)
[docs] def get_all_function_acls( conn: Connectable, schema: Optional[str] = None ) -> List[FunctionInfo]: """Unless `schema` is given, returns all functions from all schemas. Returns: List of :class:`~.types.FunctionInfo` objects. """ _make_canonical_type_function(conn) stmt = _filter_pg_proc_stmt(schema=schema) return [ FunctionInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(stmt).mappings() ]
[docs] def get_function_acl( conn: Connectable, function_name: str, arg_types: ArgTypesInput, schema: Optional[str] = None, ) -> FunctionInfo: """If `schema` is not given, the function must be visible in the search path. Returns: :class:`~.types.FunctionInfo` """ # We could ask the user to register an event on their connection pool which # creates this function on checkout, but that isn't a nice API for the # common case. _make_canonical_type_function(conn) if (function_name is None) != (arg_types is None): raise TypeError("function_name and arg_types must both be specified") if not isinstance(arg_types, Sequence) or isinstance(arg_types, str): raise TypeError("arg_types should be a sequence of strings, e.g. ['text']") stmt = _filter_pg_proc_stmt(schema, function_name, arg_types) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(function_name) return FunctionInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_language_acls(conn: Connectable) -> List[RelationInfo]: """ Returns: List of :class:`~.types.RelationInfo` objects. """ return [ RelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(_pg_lang_stmt).mappings() ]
[docs] def get_language_acl(conn: Connectable, language: str) -> RelationInfo: """ Returns: :class:`~.types.RelationInfo` """ stmt = _pg_lang_stmt.where(pg_language.c.lanname == language) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(language) return RelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_schema_acls(conn: Connectable) -> List[RelationInfo]: """ Returns: List of :class:`~.types.RelationInfo` objects. """ return [ RelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(_pg_schema_stmt).mappings() ]
[docs] def get_schema_acl(conn: Connectable, schema: str) -> RelationInfo: """ Returns: :class:`~.types.RelationInfo` """ stmt = _pg_schema_stmt.where(pg_namespace.c.nspname == schema) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(schema) return RelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_database_acls(conn: Connectable) -> List[RelationInfo]: """ Returns: List of :class:`~.types.RelationInfo` objects. """ return [ RelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(_pg_db_stmt).mappings() ]
[docs] def get_database_acl(conn: Connectable, database: str) -> RelationInfo: """ Returns: :class:`~.types.RelationInfo` """ stmt = _pg_db_stmt.where(pg_database.c.datname == database) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(database) return RelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_tablespace_acls(conn: Connectable) -> List[RelationInfo]: """ Returns: List of :class:`~.types.RelationInfo` objects. """ return [ RelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(_pg_tablespace_stmt).mappings() ]
[docs] def get_tablespace_acl(conn: Connectable, tablespace: str) -> RelationInfo: """ Returns: :class:`~.types.RelationInfo` """ stmt = _pg_tablespace_stmt.where(pg_tablespace.c.spcname == tablespace) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(tablespace) return RelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_type_acls( conn: Connectable, schema: Optional[str] = None ) -> List[SchemaRelationInfo]: """Unless `schema` is given, returns all types from all schemas. Returns: List of :class:`~.types.SchemaRelationInfo` objects. """ stmt = _filter_pg_type_stmt(schema=schema) return [ SchemaRelationInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(stmt).mappings() ]
[docs] def get_type_acl( conn: Connectable, type_name: str, schema: Optional[str] = None ) -> SchemaRelationInfo: """If `schema` is not given, the type must be visible in the search path. Returns: :class:`~.types.SchemaRelationInfo` """ stmt = _filter_pg_type_stmt(schema=schema, type_name=type_name) row = conn.execute(stmt).mappings().one_or_none() if row is None: raise NoSuchObjectError(type_name) return SchemaRelationInfo(**t.cast("Mapping[str, Any]", row))
[docs] def get_all_parameter_acls(conn: Connectable) -> List[ParameterInfo]: """Return all parameters which have non-default ACLs. Returns: List of :class:`~.types.ParameterInfo` objects """ return [ ParameterInfo(**t.cast("Mapping[str, Any]", row)) for row in conn.execute(_pg_parameter_stmt).mappings() ]
[docs] def get_parameter_acl(conn: Connectable, parameter: str) -> Optional[ParameterInfo]: """Return information of the given parameter. Returns: :class:`~.types.ParameterInfo` if the parameter exists and has non-default privileges, otherwise ``None``. """ stmt = _pg_parameter_stmt.where(pg_parameter_acl.c.parname == parameter) row = conn.execute(stmt).mappings().one_or_none() if row is None: return None return ParameterInfo(**t.cast("Mapping[str, Any]", row))