mirror of
https://github.com/django/django.git
synced 2025-10-31 09:41:08 +00:00
Fixed #33308 -- Added support for psycopg version 3.
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews. Co-authored-by: Florian Apolloner <florian@apolloner.eu> Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
committed by
Mariusz Felisiak
parent
d44ee518c4
commit
09ffc5c121
@@ -1,8 +1,6 @@
|
||||
"""
|
||||
This object provides quoting for GEOS geometries into PostgreSQL/PostGIS.
|
||||
"""
|
||||
from psycopg2.extensions import ISQLQuote
|
||||
|
||||
from django.contrib.gis.db.backends.postgis.pgraster import to_pgraster
|
||||
from django.contrib.gis.geos import GEOSGeometry
|
||||
from django.db.backends.postgresql.psycopg_any import sql
|
||||
@@ -27,6 +25,8 @@ class PostGISAdapter:
|
||||
|
||||
def __conform__(self, proto):
|
||||
"""Does the given protocol conform to what Psycopg2 expects?"""
|
||||
from psycopg2.extensions import ISQLQuote
|
||||
|
||||
if proto == ISQLQuote:
|
||||
return self
|
||||
else:
|
||||
|
||||
@@ -1,17 +1,93 @@
|
||||
from django.db.backends.base.base import NO_DB_ALIAS
|
||||
from django.db.backends.postgresql.base import (
|
||||
DatabaseWrapper as Psycopg2DatabaseWrapper,
|
||||
)
|
||||
from functools import lru_cache
|
||||
|
||||
from django.db.backends.base.base import NO_DB_ALIAS
|
||||
from django.db.backends.postgresql.base import DatabaseWrapper as PsycopgDatabaseWrapper
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
|
||||
from .adapter import PostGISAdapter
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import PostGISIntrospection
|
||||
from .operations import PostGISOperations
|
||||
from .schema import PostGISSchemaEditor
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg.adapt import Dumper
|
||||
from psycopg.pq import Format
|
||||
from psycopg.types import TypeInfo
|
||||
from psycopg.types.string import TextBinaryLoader, TextLoader
|
||||
|
||||
class DatabaseWrapper(Psycopg2DatabaseWrapper):
|
||||
class GeometryType:
|
||||
pass
|
||||
|
||||
class GeographyType:
|
||||
pass
|
||||
|
||||
class RasterType:
|
||||
pass
|
||||
|
||||
class BaseTextDumper(Dumper):
|
||||
def dump(self, obj):
|
||||
# Return bytes as hex for text formatting
|
||||
return obj.ewkb.hex().encode()
|
||||
|
||||
class BaseBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj):
|
||||
return obj.ewkb
|
||||
|
||||
@lru_cache
|
||||
def postgis_adapters(geo_oid, geog_oid, raster_oid):
|
||||
class BaseDumper(Dumper):
|
||||
def __init_subclass__(cls, base_dumper):
|
||||
super().__init_subclass__()
|
||||
|
||||
cls.GeometryDumper = type(
|
||||
"GeometryDumper", (base_dumper,), {"oid": geo_oid}
|
||||
)
|
||||
cls.GeographyDumper = type(
|
||||
"GeographyDumper", (base_dumper,), {"oid": geog_oid}
|
||||
)
|
||||
cls.RasterDumper = type(
|
||||
"RasterDumper", (BaseTextDumper,), {"oid": raster_oid}
|
||||
)
|
||||
|
||||
def get_key(self, obj, format):
|
||||
if obj.is_geometry:
|
||||
return GeographyType if obj.geography else GeometryType
|
||||
else:
|
||||
return RasterType
|
||||
|
||||
def upgrade(self, obj, format):
|
||||
if obj.is_geometry:
|
||||
if obj.geography:
|
||||
return self.GeographyDumper(GeographyType)
|
||||
else:
|
||||
return self.GeometryDumper(GeometryType)
|
||||
else:
|
||||
return self.RasterDumper(RasterType)
|
||||
|
||||
def dump(self, obj):
|
||||
raise NotImplementedError
|
||||
|
||||
class PostGISTextDumper(BaseDumper, base_dumper=BaseTextDumper):
|
||||
pass
|
||||
|
||||
class PostGISBinaryDumper(BaseDumper, base_dumper=BaseBinaryDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
return PostGISTextDumper, PostGISBinaryDumper
|
||||
|
||||
|
||||
class DatabaseWrapper(PsycopgDatabaseWrapper):
|
||||
SchemaEditorClass = PostGISSchemaEditor
|
||||
|
||||
_type_infos = {
|
||||
"geometry": {},
|
||||
"geography": {},
|
||||
"raster": {},
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if kwargs.get("alias", "") != NO_DB_ALIAS:
|
||||
@@ -27,3 +103,45 @@ class DatabaseWrapper(Psycopg2DatabaseWrapper):
|
||||
if bool(cursor.fetchone()):
|
||||
return
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
|
||||
if is_psycopg3:
|
||||
# Ensure adapters are registers if PostGIS is used within this
|
||||
# connection.
|
||||
self.register_geometry_adapters(self.connection, True)
|
||||
|
||||
def get_new_connection(self, conn_params):
|
||||
connection = super().get_new_connection(conn_params)
|
||||
if is_psycopg3:
|
||||
self.register_geometry_adapters(connection)
|
||||
return connection
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def _register_type(self, pg_connection, typename):
|
||||
registry = self._type_infos[typename]
|
||||
try:
|
||||
info = registry[self.alias]
|
||||
except KeyError:
|
||||
info = TypeInfo.fetch(pg_connection, typename)
|
||||
registry[self.alias] = info
|
||||
|
||||
if info: # Can be None if the type does not exist (yet).
|
||||
info.register(pg_connection)
|
||||
pg_connection.adapters.register_loader(info.oid, TextLoader)
|
||||
pg_connection.adapters.register_loader(info.oid, TextBinaryLoader)
|
||||
|
||||
return info.oid if info else None
|
||||
|
||||
def register_geometry_adapters(self, pg_connection, clear_caches=False):
|
||||
if clear_caches:
|
||||
for typename in self._type_infos:
|
||||
self._type_infos[typename].pop(self.alias, None)
|
||||
|
||||
geo_oid = self._register_type(pg_connection, "geometry")
|
||||
geog_oid = self._register_type(pg_connection, "geography")
|
||||
raster_oid = self._register_type(pg_connection, "raster")
|
||||
|
||||
PostGISTextDumper, PostGISBinaryDumper = postgis_adapters(
|
||||
geo_oid, geog_oid, raster_oid
|
||||
)
|
||||
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISTextDumper)
|
||||
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISBinaryDumper)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from django.contrib.gis.db.backends.base.features import BaseSpatialFeatures
|
||||
from django.db.backends.postgresql.features import (
|
||||
DatabaseFeatures as Psycopg2DatabaseFeatures,
|
||||
DatabaseFeatures as PsycopgDatabaseFeatures,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseSpatialFeatures, Psycopg2DatabaseFeatures):
|
||||
class DatabaseFeatures(BaseSpatialFeatures, PsycopgDatabaseFeatures):
|
||||
supports_geography = True
|
||||
supports_3d_storage = True
|
||||
supports_3d_functions = True
|
||||
|
||||
@@ -11,6 +11,7 @@ from django.contrib.gis.measure import Distance
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import NotSupportedError, ProgrammingError
|
||||
from django.db.backends.postgresql.operations import DatabaseOperations
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
from django.db.models import Func, Value
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.version import get_version_tuple
|
||||
@@ -161,7 +162,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
|
||||
unsupported_functions = set()
|
||||
|
||||
select = "%s::bytea"
|
||||
select = "%s" if is_psycopg3 else "%s::bytea"
|
||||
|
||||
select_extent = None
|
||||
|
||||
@cached_property
|
||||
@@ -407,6 +409,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
geom_class = expression.output_field.geom_class
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if isinstance(value, str): # Coming from hex strings.
|
||||
value = value.encode("ascii")
|
||||
return None if value is None else GEOSGeometryBase(read(value), geom_class)
|
||||
|
||||
return converter
|
||||
|
||||
@@ -237,7 +237,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
|
||||
class ArrayRHSMixin:
|
||||
def __init__(self, lhs, rhs):
|
||||
# Don't wrap arrays that contains only None values, psycopg2 doesn't
|
||||
# Don't wrap arrays that contains only None values, psycopg doesn't
|
||||
# allow this.
|
||||
if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
|
||||
expressions = []
|
||||
|
||||
@@ -9,6 +9,7 @@ from django.db.backends.postgresql.psycopg_any import (
|
||||
NumericRange,
|
||||
Range,
|
||||
)
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.lookups import PostgresOperatorLookup
|
||||
|
||||
from .utils import AttributeSetter
|
||||
@@ -208,7 +209,14 @@ class DateRangeField(RangeField):
|
||||
return "daterange"
|
||||
|
||||
|
||||
RangeField.register_lookup(lookups.DataContains)
|
||||
class RangeContains(lookups.DataContains):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.rhs, (list, tuple, Range)):
|
||||
return Cast(self.rhs, self.lhs.field.base_field)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
|
||||
RangeField.register_lookup(RangeContains)
|
||||
RangeField.register_lookup(lookups.ContainedBy)
|
||||
RangeField.register_lookup(lookups.Overlap)
|
||||
|
||||
|
||||
@@ -35,6 +35,10 @@ class CreateExtension(Operation):
|
||||
# installed, otherwise a subsequent data migration would use the same
|
||||
# connection.
|
||||
register_type_handlers(schema_editor.connection)
|
||||
if hasattr(schema_editor.connection, "register_geometry_adapters"):
|
||||
schema_editor.connection.register_geometry_adapters(
|
||||
schema_editor.connection.connection, True
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if not router.allow_migrate(schema_editor.connection.alias, app_label):
|
||||
|
||||
@@ -39,6 +39,11 @@ class SearchQueryField(Field):
|
||||
return "tsquery"
|
||||
|
||||
|
||||
class _Float4Field(Field):
|
||||
def db_type(self, connection):
|
||||
return "float4"
|
||||
|
||||
|
||||
class SearchConfig(Expression):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -138,7 +143,11 @@ class SearchVector(SearchVectorCombinable, Func):
|
||||
if clone.weight:
|
||||
weight_sql, extra_params = compiler.compile(clone.weight)
|
||||
sql = "setweight({}, {})".format(sql, weight_sql)
|
||||
return sql, config_params + params + extra_params
|
||||
|
||||
# These parameters must be bound on the client side because we may
|
||||
# want to create an index on this expression.
|
||||
sql = connection.ops.compose_sql(sql, config_params + params + extra_params)
|
||||
return sql, []
|
||||
|
||||
|
||||
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
|
||||
@@ -244,6 +253,8 @@ class SearchRank(Func):
|
||||
normalization=None,
|
||||
cover_density=False,
|
||||
):
|
||||
from .fields.array import ArrayField
|
||||
|
||||
if not hasattr(vector, "resolve_expression"):
|
||||
vector = SearchVector(vector)
|
||||
if not hasattr(query, "resolve_expression"):
|
||||
@@ -252,6 +263,7 @@ class SearchRank(Func):
|
||||
if weights is not None:
|
||||
if not hasattr(weights, "resolve_expression"):
|
||||
weights = Value(weights)
|
||||
weights = Cast(weights, ArrayField(_Float4Field()))
|
||||
expressions = (weights,) + expressions
|
||||
if normalization is not None:
|
||||
if not hasattr(normalization, "resolve_expression"):
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import functools
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import register_hstore
|
||||
|
||||
from django.db import connections
|
||||
from django.db.backends.base.base import NO_DB_ALIAS
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
|
||||
|
||||
def get_type_oids(connection_alias, type_name):
|
||||
@@ -32,30 +30,51 @@ def get_citext_oids(connection_alias):
|
||||
return get_type_oids(connection_alias, "citext")
|
||||
|
||||
|
||||
def register_type_handlers(connection, **kwargs):
|
||||
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
|
||||
return
|
||||
if is_psycopg3:
|
||||
from psycopg.types import TypeInfo, hstore
|
||||
|
||||
oids, array_oids = get_hstore_oids(connection.alias)
|
||||
# Don't register handlers when hstore is not available on the database.
|
||||
#
|
||||
# If someone tries to create an hstore field it will error there. This is
|
||||
# necessary as someone may be using PSQL without extensions installed but
|
||||
# be using other features of contrib.postgres.
|
||||
#
|
||||
# This is also needed in order to create the connection in order to install
|
||||
# the hstore extension.
|
||||
if oids:
|
||||
register_hstore(
|
||||
connection.connection, globally=True, oid=oids, array_oid=array_oids
|
||||
)
|
||||
def register_type_handlers(connection, **kwargs):
|
||||
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
|
||||
return
|
||||
|
||||
oids, citext_oids = get_citext_oids(connection.alias)
|
||||
# Don't register handlers when citext is not available on the database.
|
||||
#
|
||||
# The same comments in the above call to register_hstore() also apply here.
|
||||
if oids:
|
||||
array_type = psycopg2.extensions.new_array_type(
|
||||
citext_oids, "citext[]", psycopg2.STRING
|
||||
)
|
||||
psycopg2.extensions.register_type(array_type, None)
|
||||
oids, array_oids = get_hstore_oids(connection.alias)
|
||||
for oid, array_oid in zip(oids, array_oids):
|
||||
ti = TypeInfo("hstore", oid, array_oid)
|
||||
hstore.register_hstore(ti, connection.connection)
|
||||
|
||||
_, citext_oids = get_citext_oids(connection.alias)
|
||||
for array_oid in citext_oids:
|
||||
ti = TypeInfo("citext", 0, array_oid)
|
||||
ti.register(connection.connection)
|
||||
|
||||
else:
|
||||
import psycopg2
|
||||
from psycopg2.extras import register_hstore
|
||||
|
||||
def register_type_handlers(connection, **kwargs):
|
||||
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
|
||||
return
|
||||
|
||||
oids, array_oids = get_hstore_oids(connection.alias)
|
||||
# Don't register handlers when hstore is not available on the database.
|
||||
#
|
||||
# If someone tries to create an hstore field it will error there. This is
|
||||
# necessary as someone may be using PSQL without extensions installed but
|
||||
# be using other features of contrib.postgres.
|
||||
#
|
||||
# This is also needed in order to create the connection in order to install
|
||||
# the hstore extension.
|
||||
if oids:
|
||||
register_hstore(
|
||||
connection.connection, globally=True, oid=oids, array_oid=array_oids
|
||||
)
|
||||
|
||||
oids, citext_oids = get_citext_oids(connection.alias)
|
||||
# Don't register handlers when citext is not available on the database.
|
||||
#
|
||||
# The same comments in the above call to register_hstore() also apply here.
|
||||
if oids:
|
||||
array_type = psycopg2.extensions.new_array_type(
|
||||
citext_oids, "citext[]", psycopg2.STRING
|
||||
)
|
||||
psycopg2.extensions.register_type(array_type, None)
|
||||
|
||||
@@ -207,7 +207,7 @@ class Command(BaseCommand):
|
||||
self.models.add(obj.object.__class__)
|
||||
try:
|
||||
obj.save(using=self.using)
|
||||
# psycopg2 raises ValueError if data contains NUL chars.
|
||||
# psycopg raises ValueError if data contains NUL chars.
|
||||
except (DatabaseError, IntegrityError, ValueError) as e:
|
||||
e.args = (
|
||||
"Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s"
|
||||
|
||||
@@ -164,6 +164,8 @@ class BaseDatabaseFeatures:
|
||||
# Can we roll back DDL in a transaction?
|
||||
can_rollback_ddl = False
|
||||
|
||||
schema_editor_uses_clientside_param_binding = False
|
||||
|
||||
# Does it support operations requiring references rename in a transaction?
|
||||
supports_atomic_references_rename = True
|
||||
|
||||
@@ -335,6 +337,9 @@ class BaseDatabaseFeatures:
|
||||
# Does the backend support the logical XOR operator?
|
||||
supports_logical_xor = False
|
||||
|
||||
# Set to (exception, message) if null characters in text are disallowed.
|
||||
prohibits_null_characters_in_text_exception = None
|
||||
|
||||
# Collation names for use by the Django test suite.
|
||||
test_collations = {
|
||||
"ci": None, # Case-insensitive.
|
||||
|
||||
@@ -525,6 +525,9 @@ class BaseDatabaseOperations:
|
||||
else:
|
||||
return value
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
return value
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
PostgreSQL database backend for Django.
|
||||
|
||||
Requires psycopg 2: https://www.psycopg.org/
|
||||
Requires psycopg2 >= 2.8.4 or psycopg >= 3.1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -21,48 +21,63 @@ from django.utils.safestring import SafeString
|
||||
from django.utils.version import get_version_tuple
|
||||
|
||||
try:
|
||||
import psycopg2 as Database
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
|
||||
try:
|
||||
import psycopg as Database
|
||||
except ImportError:
|
||||
import psycopg2 as Database
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
|
||||
|
||||
|
||||
def psycopg2_version():
|
||||
version = psycopg2.__version__.split(" ", 1)[0]
|
||||
def psycopg_version():
|
||||
version = Database.__version__.split(" ", 1)[0]
|
||||
return get_version_tuple(version)
|
||||
|
||||
|
||||
PSYCOPG2_VERSION = psycopg2_version()
|
||||
|
||||
if PSYCOPG2_VERSION < (2, 8, 4):
|
||||
if psycopg_version() < (2, 8, 4):
|
||||
raise ImproperlyConfigured(
|
||||
"psycopg2 version 2.8.4 or newer is required; you have %s"
|
||||
% psycopg2.__version__
|
||||
f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
if (3,) <= psycopg_version() < (3, 1):
|
||||
raise ImproperlyConfigured(
|
||||
f"psycopg version 3.1 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
|
||||
|
||||
# Some of these import psycopg2, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA
|
||||
from .creation import DatabaseCreation # NOQA
|
||||
from .features import DatabaseFeatures # NOQA
|
||||
from .introspection import DatabaseIntrospection # NOQA
|
||||
from .operations import DatabaseOperations # NOQA
|
||||
from .psycopg_any import IsolationLevel # NOQA
|
||||
from .schema import DatabaseSchemaEditor # NOQA
|
||||
from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
|
||||
|
||||
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
||||
psycopg2.extras.register_uuid()
|
||||
if is_psycopg3:
|
||||
from psycopg import adapters, sql
|
||||
from psycopg.pq import Format
|
||||
|
||||
# Register support for inet[] manually so we don't have to handle the Inet()
|
||||
# object on load all the time.
|
||||
INETARRAY_OID = 1041
|
||||
INETARRAY = psycopg2.extensions.new_array_type(
|
||||
(INETARRAY_OID,),
|
||||
"INETARRAY",
|
||||
psycopg2.extensions.UNICODE,
|
||||
)
|
||||
psycopg2.extensions.register_type(INETARRAY)
|
||||
from .psycopg_any import get_adapters_template, register_tzloader
|
||||
|
||||
TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
|
||||
|
||||
else:
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
|
||||
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
||||
psycopg2.extras.register_uuid()
|
||||
|
||||
# Register support for inet[] manually so we don't have to handle the Inet()
|
||||
# object on load all the time.
|
||||
INETARRAY_OID = 1041
|
||||
INETARRAY = psycopg2.extensions.new_array_type(
|
||||
(INETARRAY_OID,),
|
||||
"INETARRAY",
|
||||
psycopg2.extensions.UNICODE,
|
||||
)
|
||||
psycopg2.extensions.register_type(INETARRAY)
|
||||
|
||||
# Some of these import psycopg, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA isort:skip
|
||||
from .creation import DatabaseCreation # NOQA isort:skip
|
||||
from .features import DatabaseFeatures # NOQA isort:skip
|
||||
from .introspection import DatabaseIntrospection # NOQA isort:skip
|
||||
from .operations import DatabaseOperations # NOQA isort:skip
|
||||
from .schema import DatabaseSchemaEditor # NOQA isort:skip
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
@@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
conn_params["host"] = settings_dict["HOST"]
|
||||
if settings_dict["PORT"]:
|
||||
conn_params["port"] = settings_dict["PORT"]
|
||||
if is_psycopg3:
|
||||
conn_params["context"] = get_adapters_template(
|
||||
settings.USE_TZ, self.timezone
|
||||
)
|
||||
# Disable prepared statements by default to keep connection poolers
|
||||
# working. Can be reenabled via OPTIONS in the settings dict.
|
||||
conn_params["prepare_threshold"] = conn_params.pop(
|
||||
"prepare_threshold", None
|
||||
)
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
@@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
except ValueError:
|
||||
raise ImproperlyConfigured(
|
||||
f"Invalid transaction isolation level {isolation_level_value} "
|
||||
f"specified. Use one of the IsolationLevel values."
|
||||
f"specified. Use one of the psycopg.IsolationLevel values."
|
||||
)
|
||||
connection = Database.connect(**conn_params)
|
||||
connection = self.Database.connect(**conn_params)
|
||||
if set_isolation_level:
|
||||
connection.isolation_level = self.isolation_level
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's decode
|
||||
# to json.dumps() to json.loads(), when using a custom decoder in
|
||||
# JSONField.
|
||||
psycopg2.extras.register_default_jsonb(
|
||||
conn_or_curs=connection, loads=lambda x: x
|
||||
)
|
||||
if not is_psycopg3:
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's
|
||||
# decode to json.dumps() to json.loads(), when using a custom
|
||||
# decoder in JSONField.
|
||||
psycopg2.extras.register_default_jsonb(
|
||||
conn_or_curs=connection, loads=lambda x: x
|
||||
)
|
||||
connection.cursor_factory = Cursor
|
||||
return connection
|
||||
|
||||
def ensure_timezone(self):
|
||||
@@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
)
|
||||
else:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
|
||||
if is_psycopg3:
|
||||
# Register the cursor timezone only if the connection disagrees, to
|
||||
# avoid copying the adapter map.
|
||||
tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
||||
if self.timezone != tzloader.timezone:
|
||||
register_tzloader(self.timezone, cursor)
|
||||
else:
|
||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
return cursor
|
||||
|
||||
def tzinfo_factory(self, offset):
|
||||
@@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
return CursorDebugWrapper(cursor, self)
|
||||
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy_expert(self, sql, file, *args):
|
||||
with self.debug_sql(sql):
|
||||
return self.cursor.copy_expert(sql, file, *args)
|
||||
if is_psycopg3:
|
||||
|
||||
def copy_to(self, file, table, *args, **kwargs):
|
||||
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
||||
return self.cursor.copy_to(file, table, *args, **kwargs)
|
||||
class Cursor(Database.Cursor):
|
||||
"""
|
||||
A subclass of psycopg cursor implementing callproc.
|
||||
"""
|
||||
|
||||
def callproc(self, name, args=None):
|
||||
if not isinstance(name, sql.Identifier):
|
||||
name = sql.Identifier(name)
|
||||
|
||||
qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
|
||||
if args:
|
||||
for item in args:
|
||||
qparts.append(sql.Literal(item))
|
||||
qparts.append(sql.SQL(","))
|
||||
del qparts[-1]
|
||||
|
||||
qparts.append(sql.SQL(")"))
|
||||
stmt = sql.Composed(qparts)
|
||||
self.execute(stmt)
|
||||
return args
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy(self, statement):
|
||||
with self.debug_sql(statement):
|
||||
return self.cursor.copy(statement)
|
||||
|
||||
else:
|
||||
|
||||
Cursor = psycopg2.extensions.cursor
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy_expert(self, sql, file, *args):
|
||||
with self.debug_sql(sql):
|
||||
return self.cursor.copy_expert(sql, file, *args)
|
||||
|
||||
def copy_to(self, file, table, *args, **kwargs):
|
||||
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
||||
return self.cursor.copy_to(file, table, *args, **kwargs)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import operator
|
||||
|
||||
from django.db import InterfaceError
|
||||
from django.db import DataError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
@@ -26,6 +27,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
can_introspect_materialized_views = True
|
||||
can_distinct_on_fields = True
|
||||
can_rollback_ddl = True
|
||||
schema_editor_uses_clientside_param_binding = True
|
||||
supports_combined_alters = True
|
||||
nulls_order_largest = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
@@ -81,6 +83,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def prohibits_null_characters_in_text_exception(self):
|
||||
if is_psycopg3:
|
||||
return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes"
|
||||
else:
|
||||
return ValueError, "A string literal cannot contain NUL (0x00) characters."
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
|
||||
@@ -3,9 +3,16 @@ from functools import lru_cache, partial
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.postgresql.psycopg_any import Inet, Jsonb, mogrify
|
||||
from django.db.backends.postgresql.psycopg_any import (
|
||||
Inet,
|
||||
Jsonb,
|
||||
errors,
|
||||
is_psycopg3,
|
||||
mogrify,
|
||||
)
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
@lru_cache
|
||||
@@ -36,6 +43,18 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
"SmallAutoField": "smallint",
|
||||
}
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg.types import numeric
|
||||
|
||||
integerfield_type_map = {
|
||||
"SmallIntegerField": numeric.Int2,
|
||||
"IntegerField": numeric.Int4,
|
||||
"BigIntegerField": numeric.Int8,
|
||||
"PositiveSmallIntegerField": numeric.Int2,
|
||||
"PositiveIntegerField": numeric.Int4,
|
||||
"PositiveBigIntegerField": numeric.Int8,
|
||||
}
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
internal_type = output_field.get_internal_type()
|
||||
if internal_type in (
|
||||
@@ -56,19 +75,23 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
)
|
||||
return "%s"
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
extract_sql = f"EXTRACT(%s FROM {sql})"
|
||||
extract_param = lookup_type
|
||||
if lookup_type == "week_day":
|
||||
# For consistency across backends, we return Sunday=1, Saturday=7.
|
||||
extract_sql = f"EXTRACT(%s FROM {sql}) + 1"
|
||||
extract_param = "dow"
|
||||
return f"EXTRACT(DOW FROM {sql}) + 1", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
extract_param = "isodow"
|
||||
return f"EXTRACT(ISODOW FROM {sql})", params
|
||||
elif lookup_type == "iso_year":
|
||||
extract_param = "isoyear"
|
||||
return extract_sql, (extract_param, *params)
|
||||
return f"EXTRACT(ISOYEAR FROM {sql})", params
|
||||
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid lookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
@@ -100,10 +123,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return (
|
||||
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
|
||||
("second", "second", *params),
|
||||
)
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
@@ -114,10 +134,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return (
|
||||
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
|
||||
("second", "second", *params),
|
||||
)
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
@@ -137,6 +154,16 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = "%s"
|
||||
|
||||
if lookup_type == "isnull" and internal_type in (
|
||||
"CharField",
|
||||
"EmailField",
|
||||
"TextField",
|
||||
"CICharField",
|
||||
"CIEmailField",
|
||||
"CITextField",
|
||||
):
|
||||
return "%s::text"
|
||||
|
||||
# Cast text lookups to text to allow things like filter(x__contains=4)
|
||||
if lookup_type in (
|
||||
"iexact",
|
||||
@@ -178,7 +205,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
return mogrify(sql, params, self.connection)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
return "SET TIME ZONE %s"
|
||||
return "SELECT set_config('TimeZone', %s, false)"
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
@@ -278,12 +305,22 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
else:
|
||||
return ["DISTINCT"], []
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://www.psycopg.org/docs/cursor.html#cursor.query
|
||||
# The query attribute is a Psycopg extension to the DB API 2.0.
|
||||
if cursor.query is not None:
|
||||
return cursor.query.decode()
|
||||
return None
|
||||
if is_psycopg3:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
try:
|
||||
return self.compose_sql(sql, params)
|
||||
except errors.DataError:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://www.psycopg.org/docs/cursor.html#cursor.query
|
||||
# The query attribute is a Psycopg extension to the DB API 2.0.
|
||||
if cursor.query is not None:
|
||||
return cursor.query.decode()
|
||||
return None
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
@@ -303,6 +340,13 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
|
||||
return "VALUES " + values_sql
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
if value is None or hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
return self.integerfield_type_map[internal_type](value)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
return value
|
||||
|
||||
|
||||
@@ -1,31 +1,102 @@
|
||||
from enum import IntEnum
|
||||
import ipaddress
|
||||
from functools import lru_cache
|
||||
|
||||
from psycopg2 import errors, extensions, sql # NOQA
|
||||
from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA
|
||||
from psycopg2.extras import Json as Jsonb # NOQA
|
||||
from psycopg2.extras import NumericRange, Range # NOQA
|
||||
try:
|
||||
from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
|
||||
from psycopg.postgres import types
|
||||
from psycopg.types.datetime import TimestamptzLoader
|
||||
from psycopg.types.json import Jsonb
|
||||
from psycopg.types.range import Range, RangeDumper
|
||||
from psycopg.types.string import TextLoader
|
||||
|
||||
RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
|
||||
Inet = ipaddress.ip_address
|
||||
|
||||
DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
|
||||
RANGE_TYPES = (Range,)
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
|
||||
READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
|
||||
REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
|
||||
TSRANGE_OID = types["tsrange"].oid
|
||||
TSTZRANGE_OID = types["tstzrange"].oid
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
return ClientCursor(connection.connection).mogrify(sql, params)
|
||||
|
||||
def _quote(value, connection=None):
|
||||
adapted = extensions.adapt(value)
|
||||
if hasattr(adapted, "encoding"):
|
||||
adapted.encoding = "utf8"
|
||||
# getquoted() returns a quoted bytestring of the adapted value.
|
||||
return adapted.getquoted().decode()
|
||||
# Adapters.
|
||||
class BaseTzLoader(TimestamptzLoader):
|
||||
"""
|
||||
Load a PostgreSQL timestamptz using the a specific timezone.
|
||||
The timezone can be None too, in which case it will be chopped.
|
||||
"""
|
||||
|
||||
timezone = None
|
||||
|
||||
sql.quote = _quote
|
||||
def load(self, data):
|
||||
res = super().load(data)
|
||||
return res.replace(tzinfo=self.timezone)
|
||||
|
||||
def register_tzloader(tz, context):
|
||||
class SpecificTzLoader(BaseTzLoader):
|
||||
timezone = tz
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return cursor.mogrify(sql, params).decode()
|
||||
context.adapters.register_loader("timestamptz", SpecificTzLoader)
|
||||
|
||||
class DjangoRangeDumper(RangeDumper):
|
||||
"""A Range dumper customized for Django."""
|
||||
|
||||
def upgrade(self, obj, format):
|
||||
# Dump ranges containing naive datetimes as tstzrange, because
|
||||
# Django doesn't use tz-aware ones.
|
||||
dumper = super().upgrade(obj, format)
|
||||
if dumper is not self and dumper.oid == TSRANGE_OID:
|
||||
dumper.oid = TSTZRANGE_OID
|
||||
return dumper
|
||||
|
||||
@lru_cache
|
||||
def get_adapters_template(use_tz, timezone):
|
||||
# Create at adapters map extending the base one.
|
||||
ctx = adapt.AdaptersMap(adapters)
|
||||
# Register a no-op dumper to avoid a round trip from psycopg version 3
|
||||
# decode to json.dumps() to json.loads(), when using a custom decoder
|
||||
# in JSONField.
|
||||
ctx.register_loader("jsonb", TextLoader)
|
||||
# Don't convert automatically from PostgreSQL network types to Python
|
||||
# ipaddress.
|
||||
ctx.register_loader("inet", TextLoader)
|
||||
ctx.register_loader("cidr", TextLoader)
|
||||
ctx.register_dumper(Range, DjangoRangeDumper)
|
||||
# Register a timestamptz loader configured on self.timezone.
|
||||
# This, however, can be overridden by create_cursor.
|
||||
register_tzloader(timezone, ctx)
|
||||
return ctx
|
||||
|
||||
is_psycopg3 = True
|
||||
|
||||
except ImportError:
|
||||
from enum import IntEnum
|
||||
|
||||
from psycopg2 import errors, extensions, sql # NOQA
|
||||
from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA
|
||||
from psycopg2.extras import Json as Jsonb # NOQA
|
||||
from psycopg2.extras import NumericRange, Range # NOQA
|
||||
|
||||
RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
|
||||
READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
|
||||
REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
|
||||
|
||||
def _quote(value, connection=None):
|
||||
adapted = extensions.adapt(value)
|
||||
if hasattr(adapted, "encoding"):
|
||||
adapted.encoding = "utf8"
|
||||
# getquoted() returns a quoted bytestring of the adapted value.
|
||||
return adapted.getquoted().decode()
|
||||
|
||||
sql.quote = _quote
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return cursor.mogrify(sql, params).decode()
|
||||
|
||||
is_psycopg3 = False
|
||||
|
||||
@@ -40,6 +40,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
)
|
||||
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
|
||||
|
||||
def execute(self, sql, params=()):
|
||||
# Merge the query client-side, as PostgreSQL won't do it server-side.
|
||||
if params is None:
|
||||
return super().execute(sql, params)
|
||||
sql = self.connection.ops.compose_sql(str(sql), params)
|
||||
# Don't let the superclass touch anything.
|
||||
return super().execute(sql, None)
|
||||
|
||||
sql_add_identity = (
|
||||
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
|
||||
"GENERATED BY DEFAULT AS IDENTITY"
|
||||
|
||||
@@ -2019,6 +2019,10 @@ class IntegerField(Field):
|
||||
"Field '%s' expected a number but got %r." % (self.name, value),
|
||||
) from e
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
value = super().get_db_prep_value(value, connection, prepared)
|
||||
return connection.ops.adapt_integerfield_value(value, self.get_internal_type())
|
||||
|
||||
def get_internal_type(self):
|
||||
return "IntegerField"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Database functions that do comparisons or type conversions."""
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
@@ -158,7 +159,14 @@ class JSONObject(Func):
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField()) if index % 2 == 0 else expression
|
||||
for index, expression in enumerate(copy.get_source_expressions())
|
||||
]
|
||||
)
|
||||
return super(JSONObject, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_OBJECT",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import CharField, IntegerField
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.fields import CharField, IntegerField, TextField
|
||||
from django.db.models.functions import Cast, Coalesce
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
@@ -82,6 +82,20 @@ class ConcatPair(Func):
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField())
|
||||
for expression in copy.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(ConcatPair, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
|
||||
@@ -568,7 +568,7 @@ class IsNull(BuiltinLookup):
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
sql, params = self.process_lhs(compiler, connection)
|
||||
if self.rhs:
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user