mirror of
https://github.com/django/django.git
synced 2025-10-29 00:26:07 +00:00
Fixed #33646 -- Added async-compatible interface to QuerySet.
Thanks Simon Charette for reviews. Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es> Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
committed by
Mariusz Felisiak
parent
27aa7035f5
commit
58b27e0dbb
@@ -7,6 +7,8 @@ import operator
|
||||
import warnings
|
||||
from itertools import chain, islice
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
import django
|
||||
from django.conf import settings
|
||||
from django.core import exceptions
|
||||
@@ -45,6 +47,33 @@ class BaseIterable:
|
||||
self.chunked_fetch = chunked_fetch
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def _async_generator(self):
|
||||
# Generators don't actually start running until the first time you call
|
||||
# next() on them, so make the generator object in the async thread and
|
||||
# then repeatedly dispatch to it in a sync thread.
|
||||
sync_generator = self.__iter__()
|
||||
|
||||
def next_slice(gen):
|
||||
return list(islice(gen, self.chunk_size))
|
||||
|
||||
while True:
|
||||
chunk = await sync_to_async(next_slice)(sync_generator)
|
||||
for item in chunk:
|
||||
yield item
|
||||
if len(chunk) < self.chunk_size:
|
||||
break
|
||||
|
||||
# __aiter__() is a *synchronous* method that has to then return an
|
||||
# *asynchronous* iterator/generator. Thus, nest an async generator inside
|
||||
# it.
|
||||
# This is a generic iterable converter for now, and is going to suffer a
|
||||
# performance penalty on large sets of items due to the cost of crossing
|
||||
# over the sync barrier for each chunk. Custom __aiter__() methods should
|
||||
# be added to each Iterable subclass, but that needs some work in the
|
||||
# Compiler first.
|
||||
def __aiter__(self):
|
||||
return self._async_generator()
|
||||
|
||||
|
||||
class ModelIterable(BaseIterable):
|
||||
"""Iterable that yields a model instance for each row."""
|
||||
@@ -321,6 +350,16 @@ class QuerySet:
|
||||
self._fetch_all()
|
||||
return iter(self._result_cache)
|
||||
|
||||
def __aiter__(self):
|
||||
# Remember, __aiter__ itself is synchronous, it's the thing it returns
|
||||
# that is async!
|
||||
async def generator():
|
||||
await self._async_fetch_all()
|
||||
for item in self._result_cache:
|
||||
yield item
|
||||
|
||||
return generator()
|
||||
|
||||
def __bool__(self):
|
||||
self._fetch_all()
|
||||
return bool(self._result_cache)
|
||||
@@ -460,6 +499,25 @@ class QuerySet:
|
||||
)
|
||||
return self._iterator(use_chunked_fetch, chunk_size)
|
||||
|
||||
async def aiterator(self, chunk_size=2000):
|
||||
"""
|
||||
An asynchronous iterator over the results from applying this QuerySet
|
||||
to the database.
|
||||
"""
|
||||
if self._prefetch_related_lookups:
|
||||
raise NotSupportedError(
|
||||
"Using QuerySet.aiterator() after prefetch_related() is not supported."
|
||||
)
|
||||
if chunk_size <= 0:
|
||||
raise ValueError("Chunk size must be strictly positive.")
|
||||
use_chunked_fetch = not connections[self.db].settings_dict.get(
|
||||
"DISABLE_SERVER_SIDE_CURSORS"
|
||||
)
|
||||
async for item in self._iterable_class(
|
||||
self, chunked_fetch=use_chunked_fetch, chunk_size=chunk_size
|
||||
):
|
||||
yield item
|
||||
|
||||
def aggregate(self, *args, **kwargs):
|
||||
"""
|
||||
Return a dictionary containing the calculations (aggregation)
|
||||
@@ -502,6 +560,9 @@ class QuerySet:
|
||||
)
|
||||
return query.get_aggregation(self.db, kwargs)
|
||||
|
||||
async def aaggregate(self, *args, **kwargs):
|
||||
return await sync_to_async(self.aggregate)(*args, **kwargs)
|
||||
|
||||
def count(self):
|
||||
"""
|
||||
Perform a SELECT COUNT() and return the number of records as an
|
||||
@@ -515,6 +576,9 @@ class QuerySet:
|
||||
|
||||
return self.query.get_count(using=self.db)
|
||||
|
||||
async def acount(self):
|
||||
return await sync_to_async(self.count)()
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
"""
|
||||
Perform the query and return a single object matching the given
|
||||
@@ -550,6 +614,9 @@ class QuerySet:
|
||||
)
|
||||
)
|
||||
|
||||
async def aget(self, *args, **kwargs):
|
||||
return await sync_to_async(self.get)(*args, **kwargs)
|
||||
|
||||
def create(self, **kwargs):
|
||||
"""
|
||||
Create a new object with the given kwargs, saving it to the database
|
||||
@@ -560,6 +627,9 @@ class QuerySet:
|
||||
obj.save(force_insert=True, using=self.db)
|
||||
return obj
|
||||
|
||||
async def acreate(self, **kwargs):
|
||||
return await sync_to_async(self.create)(**kwargs)
|
||||
|
||||
def _prepare_for_bulk_create(self, objs):
|
||||
for obj in objs:
|
||||
if obj.pk is None:
|
||||
@@ -720,6 +790,13 @@ class QuerySet:
|
||||
|
||||
return objs
|
||||
|
||||
async def abulk_create(self, objs, batch_size=None, ignore_conflicts=False):
|
||||
return await sync_to_async(self.bulk_create)(
|
||||
objs=objs,
|
||||
batch_size=batch_size,
|
||||
ignore_conflicts=ignore_conflicts,
|
||||
)
|
||||
|
||||
def bulk_update(self, objs, fields, batch_size=None):
|
||||
"""
|
||||
Update the given fields in each of the given objects in the database.
|
||||
@@ -774,6 +851,15 @@ class QuerySet:
|
||||
|
||||
bulk_update.alters_data = True
|
||||
|
||||
async def abulk_update(self, objs, fields, batch_size=None):
|
||||
return await sync_to_async(self.bulk_update)(
|
||||
objs=objs,
|
||||
fields=fields,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
abulk_update.alters_data = True
|
||||
|
||||
def get_or_create(self, defaults=None, **kwargs):
|
||||
"""
|
||||
Look up an object with the given kwargs, creating one if necessary.
|
||||
@@ -799,6 +885,12 @@ class QuerySet:
|
||||
pass
|
||||
raise
|
||||
|
||||
async def aget_or_create(self, defaults=None, **kwargs):
|
||||
return await sync_to_async(self.get_or_create)(
|
||||
defaults=defaults,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update_or_create(self, defaults=None, **kwargs):
|
||||
"""
|
||||
Look up an object with the given kwargs, updating one with defaults
|
||||
@@ -819,6 +911,12 @@ class QuerySet:
|
||||
obj.save(using=self.db)
|
||||
return obj, False
|
||||
|
||||
async def aupdate_or_create(self, defaults=None, **kwargs):
|
||||
return await sync_to_async(self.update_or_create)(
|
||||
defaults=defaults,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _extract_model_params(self, defaults, **kwargs):
|
||||
"""
|
||||
Prepare `params` for creating a model instance based on the given
|
||||
@@ -873,21 +971,37 @@ class QuerySet:
|
||||
raise TypeError("Cannot change a query once a slice has been taken.")
|
||||
return self._earliest(*fields)
|
||||
|
||||
async def aearliest(self, *fields):
|
||||
return await sync_to_async(self.earliest)(*fields)
|
||||
|
||||
def latest(self, *fields):
|
||||
"""
|
||||
Return the latest object according to fields (if given) or by the
|
||||
model's Meta.get_latest_by.
|
||||
"""
|
||||
if self.query.is_sliced:
|
||||
raise TypeError("Cannot change a query once a slice has been taken.")
|
||||
return self.reverse()._earliest(*fields)
|
||||
|
||||
async def alatest(self, *fields):
|
||||
return await sync_to_async(self.latest)(*fields)
|
||||
|
||||
def first(self):
|
||||
"""Return the first object of a query or None if no match is found."""
|
||||
for obj in (self if self.ordered else self.order_by("pk"))[:1]:
|
||||
return obj
|
||||
|
||||
async def afirst(self):
|
||||
return await sync_to_async(self.first)()
|
||||
|
||||
def last(self):
|
||||
"""Return the last object of a query or None if no match is found."""
|
||||
for obj in (self.reverse() if self.ordered else self.order_by("-pk"))[:1]:
|
||||
return obj
|
||||
|
||||
async def alast(self):
|
||||
return await sync_to_async(self.last)()
|
||||
|
||||
def in_bulk(self, id_list=None, *, field_name="pk"):
|
||||
"""
|
||||
Return a dictionary mapping each of the given IDs to the object with
|
||||
@@ -930,6 +1044,12 @@ class QuerySet:
|
||||
qs = self._chain()
|
||||
return {getattr(obj, field_name): obj for obj in qs}
|
||||
|
||||
async def ain_bulk(self, id_list=None, *, field_name="pk"):
|
||||
return await sync_to_async(self.in_bulk)(
|
||||
id_list=id_list,
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
"""Delete the records in the current QuerySet."""
|
||||
self._not_support_combined_queries("delete")
|
||||
@@ -963,6 +1083,12 @@ class QuerySet:
|
||||
delete.alters_data = True
|
||||
delete.queryset_only = True
|
||||
|
||||
async def adelete(self):
|
||||
return await sync_to_async(self.delete)()
|
||||
|
||||
adelete.alters_data = True
|
||||
adelete.queryset_only = True
|
||||
|
||||
def _raw_delete(self, using):
|
||||
"""
|
||||
Delete objects found from the given queryset in single direct SQL
|
||||
@@ -998,6 +1124,11 @@ class QuerySet:
|
||||
|
||||
update.alters_data = True
|
||||
|
||||
async def aupdate(self, **kwargs):
|
||||
return await sync_to_async(self.update)(**kwargs)
|
||||
|
||||
aupdate.alters_data = True
|
||||
|
||||
def _update(self, values):
|
||||
"""
|
||||
A version of update() that accepts field objects instead of field names.
|
||||
@@ -1018,12 +1149,21 @@ class QuerySet:
|
||||
_update.queryset_only = False
|
||||
|
||||
def exists(self):
|
||||
"""
|
||||
Return True if the QuerySet would have any results, False otherwise.
|
||||
"""
|
||||
if self._result_cache is None:
|
||||
return self.query.has_results(using=self.db)
|
||||
return bool(self._result_cache)
|
||||
|
||||
async def aexists(self):
|
||||
return await sync_to_async(self.exists)()
|
||||
|
||||
def contains(self, obj):
|
||||
"""Return True if the queryset contains an object."""
|
||||
"""
|
||||
Return True if the QuerySet contains the provided obj,
|
||||
False otherwise.
|
||||
"""
|
||||
self._not_support_combined_queries("contains")
|
||||
if self._fields is not None:
|
||||
raise TypeError(
|
||||
@@ -1040,14 +1180,24 @@ class QuerySet:
|
||||
return obj in self._result_cache
|
||||
return self.filter(pk=obj.pk).exists()
|
||||
|
||||
async def acontains(self, obj):
|
||||
return await sync_to_async(self.contains)(obj=obj)
|
||||
|
||||
def _prefetch_related_objects(self):
|
||||
# This method can only be called once the result cache has been filled.
|
||||
prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
|
||||
self._prefetch_done = True
|
||||
|
||||
def explain(self, *, format=None, **options):
|
||||
"""
|
||||
Runs an EXPLAIN on the SQL query this QuerySet would perform, and
|
||||
returns the results.
|
||||
"""
|
||||
return self.query.explain(using=self.db, format=format, **options)
|
||||
|
||||
async def aexplain(self, *, format=None, **options):
|
||||
return await sync_to_async(self.explain)(format=format, **options)
|
||||
|
||||
##################################################
|
||||
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
|
||||
##################################################
|
||||
@@ -1648,6 +1798,12 @@ class QuerySet:
|
||||
if self._prefetch_related_lookups and not self._prefetch_done:
|
||||
self._prefetch_related_objects()
|
||||
|
||||
async def _async_fetch_all(self):
|
||||
if self._result_cache is None:
|
||||
self._result_cache = [result async for result in self._iterable_class(self)]
|
||||
if self._prefetch_related_lookups and not self._prefetch_done:
|
||||
sync_to_async(self._prefetch_related_objects)()
|
||||
|
||||
def _next_is_sticky(self):
|
||||
"""
|
||||
Indicate that the next filter call and the one following that should
|
||||
|
||||
Reference in New Issue
Block a user