1
0
mirror of https://github.com/django/django.git synced 2025-10-29 00:26:07 +00:00

Fixed #33646 -- Added async-compatible interface to QuerySet.

Thanks Simon Charette for reviews.

Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Andrew Godwin
2021-09-08 17:01:53 +01:00
committed by Mariusz Felisiak
parent 27aa7035f5
commit 58b27e0dbb
9 changed files with 748 additions and 23 deletions

View File

@@ -7,6 +7,8 @@ import operator
import warnings
from itertools import chain, islice
from asgiref.sync import sync_to_async
import django
from django.conf import settings
from django.core import exceptions
@@ -45,6 +47,33 @@ class BaseIterable:
self.chunked_fetch = chunked_fetch
self.chunk_size = chunk_size
async def _async_generator(self):
# Generators don't actually start running until the first time you call
# next() on them, so make the generator object in the async thread and
# then repeatedly dispatch to it in a sync thread.
sync_generator = self.__iter__()
def next_slice(gen):
return list(islice(gen, self.chunk_size))
while True:
chunk = await sync_to_async(next_slice)(sync_generator)
for item in chunk:
yield item
if len(chunk) < self.chunk_size:
break
# __aiter__() is a *synchronous* method that has to then return an
# *asynchronous* iterator/generator. Thus, nest an async generator inside
# it.
# This is a generic iterable converter for now, and is going to suffer a
# performance penalty on large sets of items due to the cost of crossing
# over the sync barrier for each chunk. Custom __aiter__() methods should
# be added to each Iterable subclass, but that needs some work in the
# Compiler first.
def __aiter__(self):
return self._async_generator()
class ModelIterable(BaseIterable):
"""Iterable that yields a model instance for each row."""
@@ -321,6 +350,16 @@ class QuerySet:
self._fetch_all()
return iter(self._result_cache)
def __aiter__(self):
# Remember, __aiter__ itself is synchronous, it's the thing it returns
# that is async!
async def generator():
await self._async_fetch_all()
for item in self._result_cache:
yield item
return generator()
def __bool__(self):
self._fetch_all()
return bool(self._result_cache)
@@ -460,6 +499,25 @@ class QuerySet:
)
return self._iterator(use_chunked_fetch, chunk_size)
async def aiterator(self, chunk_size=2000):
"""
An asynchronous iterator over the results from applying this QuerySet
to the database.
"""
if self._prefetch_related_lookups:
raise NotSupportedError(
"Using QuerySet.aiterator() after prefetch_related() is not supported."
)
if chunk_size <= 0:
raise ValueError("Chunk size must be strictly positive.")
use_chunked_fetch = not connections[self.db].settings_dict.get(
"DISABLE_SERVER_SIDE_CURSORS"
)
async for item in self._iterable_class(
self, chunked_fetch=use_chunked_fetch, chunk_size=chunk_size
):
yield item
def aggregate(self, *args, **kwargs):
"""
Return a dictionary containing the calculations (aggregation)
@@ -502,6 +560,9 @@ class QuerySet:
)
return query.get_aggregation(self.db, kwargs)
async def aaggregate(self, *args, **kwargs):
return await sync_to_async(self.aggregate)(*args, **kwargs)
def count(self):
"""
Perform a SELECT COUNT() and return the number of records as an
@@ -515,6 +576,9 @@ class QuerySet:
return self.query.get_count(using=self.db)
async def acount(self):
return await sync_to_async(self.count)()
def get(self, *args, **kwargs):
"""
Perform the query and return a single object matching the given
@@ -550,6 +614,9 @@ class QuerySet:
)
)
async def aget(self, *args, **kwargs):
return await sync_to_async(self.get)(*args, **kwargs)
def create(self, **kwargs):
"""
Create a new object with the given kwargs, saving it to the database
@@ -560,6 +627,9 @@ class QuerySet:
obj.save(force_insert=True, using=self.db)
return obj
async def acreate(self, **kwargs):
return await sync_to_async(self.create)(**kwargs)
def _prepare_for_bulk_create(self, objs):
for obj in objs:
if obj.pk is None:
@@ -720,6 +790,13 @@ class QuerySet:
return objs
async def abulk_create(self, objs, batch_size=None, ignore_conflicts=False):
return await sync_to_async(self.bulk_create)(
objs=objs,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
)
def bulk_update(self, objs, fields, batch_size=None):
"""
Update the given fields in each of the given objects in the database.
@@ -774,6 +851,15 @@ class QuerySet:
bulk_update.alters_data = True
async def abulk_update(self, objs, fields, batch_size=None):
return await sync_to_async(self.bulk_update)(
objs=objs,
fields=fields,
batch_size=batch_size,
)
abulk_update.alters_data = True
def get_or_create(self, defaults=None, **kwargs):
"""
Look up an object with the given kwargs, creating one if necessary.
@@ -799,6 +885,12 @@ class QuerySet:
pass
raise
async def aget_or_create(self, defaults=None, **kwargs):
return await sync_to_async(self.get_or_create)(
defaults=defaults,
**kwargs,
)
def update_or_create(self, defaults=None, **kwargs):
"""
Look up an object with the given kwargs, updating one with defaults
@@ -819,6 +911,12 @@ class QuerySet:
obj.save(using=self.db)
return obj, False
async def aupdate_or_create(self, defaults=None, **kwargs):
return await sync_to_async(self.update_or_create)(
defaults=defaults,
**kwargs,
)
def _extract_model_params(self, defaults, **kwargs):
"""
Prepare `params` for creating a model instance based on the given
@@ -873,21 +971,37 @@ class QuerySet:
raise TypeError("Cannot change a query once a slice has been taken.")
return self._earliest(*fields)
async def aearliest(self, *fields):
return await sync_to_async(self.earliest)(*fields)
def latest(self, *fields):
"""
Return the latest object according to fields (if given) or by the
model's Meta.get_latest_by.
"""
if self.query.is_sliced:
raise TypeError("Cannot change a query once a slice has been taken.")
return self.reverse()._earliest(*fields)
async def alatest(self, *fields):
return await sync_to_async(self.latest)(*fields)
def first(self):
"""Return the first object of a query or None if no match is found."""
for obj in (self if self.ordered else self.order_by("pk"))[:1]:
return obj
async def afirst(self):
return await sync_to_async(self.first)()
def last(self):
"""Return the last object of a query or None if no match is found."""
for obj in (self.reverse() if self.ordered else self.order_by("-pk"))[:1]:
return obj
async def alast(self):
return await sync_to_async(self.last)()
def in_bulk(self, id_list=None, *, field_name="pk"):
"""
Return a dictionary mapping each of the given IDs to the object with
@@ -930,6 +1044,12 @@ class QuerySet:
qs = self._chain()
return {getattr(obj, field_name): obj for obj in qs}
async def ain_bulk(self, id_list=None, *, field_name="pk"):
return await sync_to_async(self.in_bulk)(
id_list=id_list,
field_name=field_name,
)
def delete(self):
"""Delete the records in the current QuerySet."""
self._not_support_combined_queries("delete")
@@ -963,6 +1083,12 @@ class QuerySet:
delete.alters_data = True
delete.queryset_only = True
async def adelete(self):
return await sync_to_async(self.delete)()
adelete.alters_data = True
adelete.queryset_only = True
def _raw_delete(self, using):
"""
Delete objects found from the given queryset in single direct SQL
@@ -998,6 +1124,11 @@ class QuerySet:
update.alters_data = True
async def aupdate(self, **kwargs):
return await sync_to_async(self.update)(**kwargs)
aupdate.alters_data = True
def _update(self, values):
"""
A version of update() that accepts field objects instead of field names.
@@ -1018,12 +1149,21 @@ class QuerySet:
_update.queryset_only = False
def exists(self):
"""
Return True if the QuerySet would have any results, False otherwise.
"""
if self._result_cache is None:
return self.query.has_results(using=self.db)
return bool(self._result_cache)
async def aexists(self):
return await sync_to_async(self.exists)()
def contains(self, obj):
"""Return True if the queryset contains an object."""
"""
Return True if the QuerySet contains the provided obj,
False otherwise.
"""
self._not_support_combined_queries("contains")
if self._fields is not None:
raise TypeError(
@@ -1040,14 +1180,24 @@ class QuerySet:
return obj in self._result_cache
return self.filter(pk=obj.pk).exists()
async def acontains(self, obj):
return await sync_to_async(self.contains)(obj=obj)
def _prefetch_related_objects(self):
# This method can only be called once the result cache has been filled.
prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
self._prefetch_done = True
def explain(self, *, format=None, **options):
"""
Runs an EXPLAIN on the SQL query this QuerySet would perform, and
returns the results.
"""
return self.query.explain(using=self.db, format=format, **options)
async def aexplain(self, *, format=None, **options):
return await sync_to_async(self.explain)(format=format, **options)
##################################################
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
##################################################
@@ -1648,6 +1798,12 @@ class QuerySet:
if self._prefetch_related_lookups and not self._prefetch_done:
self._prefetch_related_objects()
async def _async_fetch_all(self):
if self._result_cache is None:
self._result_cache = [result async for result in self._iterable_class(self)]
if self._prefetch_related_lookups and not self._prefetch_done:
sync_to_async(self._prefetch_related_objects)()
def _next_is_sticky(self):
"""
Indicate that the next filter call and the one following that should