mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	magic-removal: Created InBulkQuerySet, ValuesQuerySet and DateQuerySet, subclasses of QuerySet that provide custom iterator(). This lets you use iterator() with in_bulk(), values() and dates(). Also added unit tests.
git-svn-id: http://code.djangoproject.com/svn/django/branches/magic-removal@2200 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -85,6 +85,9 @@ class Manager(object): | |||||||
|     def in_bulk(self, *args, **kwargs): |     def in_bulk(self, *args, **kwargs): | ||||||
|         return QuerySet(self.model).in_bulk(*args, **kwargs) |         return QuerySet(self.model).in_bulk(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def iterator(self, *args, **kwargs): | ||||||
|  |         return QuerySet(self.model).iterator(*args, **kwargs) | ||||||
|  |  | ||||||
|     def order_by(self, *args, **kwargs): |     def order_by(self, *args, **kwargs): | ||||||
|         return QuerySet(self.model).order_by(*args, **kwargs) |         return QuerySet(self.model).order_by(*args, **kwargs) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -190,67 +190,33 @@ class QuerySet(object): | |||||||
|         _, sql, params = del_query._get_sql_clause(False) |         _, sql, params = del_query._get_sql_clause(False) | ||||||
|         cursor.execute("DELETE " + sql, params) |         cursor.execute("DELETE " + sql, params) | ||||||
|  |  | ||||||
|  |     ################################################## | ||||||
|  |     # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS # | ||||||
|  |     ################################################## | ||||||
|  |  | ||||||
|     def in_bulk(self, id_list): |     def in_bulk(self, id_list): | ||||||
|         assert isinstance(id_list, list), "in_bulk() must be provided with a list of IDs." |         assert isinstance(id_list, list), "in_bulk() must be provided with a list of IDs." | ||||||
|         assert id_list != [], "in_bulk() cannot be passed an empty ID list." |         assert id_list != [], "in_bulk() cannot be passed an empty ID list." | ||||||
|         bulk_query = self._clone() |         return self._clone(klass=InBulkQuerySet, _id_list=id_list) | ||||||
|         bulk_query._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(id_list)))) |  | ||||||
|         bulk_query._params.extend(id_list) |  | ||||||
|         return dict([(obj._get_pk_val(), obj) for obj in bulk_query.iterator()]) |  | ||||||
|  |  | ||||||
|     def values(self, *fields): |     def values(self, *fields): | ||||||
|         # select_related and select aren't supported in values(). |         return self._clone(klass=ValuesQuerySet, _fields=fields) | ||||||
|         values_query = self._clone(_select_related=False, _select={}) |  | ||||||
|  |  | ||||||
|         # 'fields' is a list of field names to fetch. |  | ||||||
|         if fields: |  | ||||||
|             columns = [self.model._meta.get_field(f, many_to_many=False).column for f in fields] |  | ||||||
|         else: # Default to all fields. |  | ||||||
|             columns = [f.column for f in self.model._meta.fields] |  | ||||||
|  |  | ||||||
|         cursor = connection.cursor() |  | ||||||
|         select, sql, params = values_query._get_sql_clause(True) |  | ||||||
|         select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] |  | ||||||
|         cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) |  | ||||||
|         while 1: |  | ||||||
|             rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) |  | ||||||
|             if not rows: |  | ||||||
|                 raise StopIteration |  | ||||||
|             for row in rows: |  | ||||||
|                 yield dict(zip(columns, row)) |  | ||||||
|  |  | ||||||
|     def dates(self, field_name, kind, order='ASC'): |     def dates(self, field_name, kind, order='ASC'): | ||||||
|         """ |         """ | ||||||
|         Returns a list of datetime objects representing all available dates |         Returns a list of datetime objects representing all available dates | ||||||
|         for the given field_name, scoped to 'kind'. |         for the given field_name, scoped to 'kind'. | ||||||
|         """ |         """ | ||||||
|         from django.db.backends.util import typecast_timestamp |  | ||||||
|  |  | ||||||
|         assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." |         assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." | ||||||
|         assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'." |         assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'." | ||||||
|         # Let the FieldDoesNotExist exception propogate. |         # Let the FieldDoesNotExist exception propogate. | ||||||
|         field = self.model._meta.get_field(field_name, many_to_many=False) |         field = self.model._meta.get_field(field_name, many_to_many=False) | ||||||
|         assert isinstance(field, DateField), "%r isn't a DateField." % field_name |         assert isinstance(field, DateField), "%r isn't a DateField." % field_name | ||||||
|  |         return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order) | ||||||
|  |  | ||||||
|         date_query = self._clone() |     ################################################################## | ||||||
|         date_query._order_by = () # Clear this because it'll mess things up otherwise. |     # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # | ||||||
|         if field.null: |     ################################################################## | ||||||
|             date_query._where.append('%s.%s IS NOT NULL' % \ |  | ||||||
|                 (backend.quote_name(self.model._meta.db_table), backend.quote_name(field.column))) |  | ||||||
|         select, sql, params = date_query._get_sql_clause(True) |  | ||||||
|         sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ |  | ||||||
|             (backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), |  | ||||||
|             backend.quote_name(field.column))), sql, order) |  | ||||||
|         cursor = connection.cursor() |  | ||||||
|         cursor.execute(sql, params) |  | ||||||
|         # We have to manually run typecast_timestamp(str()) on the results, because |  | ||||||
|         # MySQL doesn't automatically cast the result of date functions as datetime |  | ||||||
|         # objects -- MySQL returns the values as strings, instead. |  | ||||||
|         return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()] |  | ||||||
|  |  | ||||||
|     ############################################# |  | ||||||
|     # PUBLIC METHODS THAT RETURN A NEW QUERYSET # |  | ||||||
|     ############################################# |  | ||||||
|  |  | ||||||
|     def filter(self, *args, **kwargs): |     def filter(self, *args, **kwargs): | ||||||
|         "Returns a new QuerySet instance with the args ANDed to the existing set." |         "Returns a new QuerySet instance with the args ANDed to the existing set." | ||||||
| @@ -285,8 +251,10 @@ class QuerySet(object): | |||||||
|     # PRIVATE METHODS # |     # PRIVATE METHODS # | ||||||
|     ################### |     ################### | ||||||
|  |  | ||||||
|     def _clone(self, **kwargs): |     def _clone(self, klass=None, **kwargs): | ||||||
|         c = QuerySet() |         if klass is None: | ||||||
|  |             klass = self.__class__ | ||||||
|  |         c = klass() | ||||||
|         c.model = self.model |         c.model = self.model | ||||||
|         c._filters = self._filters |         c._filters = self._filters | ||||||
|         c._order_by = self._order_by |         c._order_by = self._order_by | ||||||
| @@ -402,6 +370,61 @@ class QuerySet(object): | |||||||
|  |  | ||||||
|         return select, " ".join(sql), params |         return select, " ".join(sql), params | ||||||
|  |  | ||||||
|  | class InBulkQuerySet(QuerySet): | ||||||
|  |     def iterator(self): | ||||||
|  |         self._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(self._id_list)))) | ||||||
|  |         self._params.extend(self._id_list) | ||||||
|  |         yield dict([(obj._get_pk_val(), obj) for obj in QuerySet.iterator(self)]) | ||||||
|  |  | ||||||
|  |     def _get_data(self): | ||||||
|  |         if self._result_cache is None: | ||||||
|  |             for i in self.iterator(): | ||||||
|  |                 self._result_cache = i | ||||||
|  |         return self._result_cache | ||||||
|  |  | ||||||
|  | class ValuesQuerySet(QuerySet): | ||||||
|  |     def iterator(self): | ||||||
|  |         # select_related and select aren't supported in values(). | ||||||
|  |         self._select_related = False | ||||||
|  |         self._select = {} | ||||||
|  |  | ||||||
|  |         # self._fields is a list of field names to fetch. | ||||||
|  |         if self._fields: | ||||||
|  |             columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] | ||||||
|  |             field_names = [f.attname for f in self._fields] | ||||||
|  |         else: # Default to all fields. | ||||||
|  |             columns = [f.column for f in self.model._meta.fields] | ||||||
|  |             field_names = [f.attname for f in self.model._meta.fields] | ||||||
|  |  | ||||||
|  |         cursor = connection.cursor() | ||||||
|  |         select, sql, params = self._get_sql_clause(True) | ||||||
|  |         select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] | ||||||
|  |         cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) | ||||||
|  |         while 1: | ||||||
|  |             rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) | ||||||
|  |             if not rows: | ||||||
|  |                 raise StopIteration | ||||||
|  |             for row in rows: | ||||||
|  |                 yield dict(zip(field_names, row)) | ||||||
|  |  | ||||||
|  | class DateQuerySet(QuerySet): | ||||||
|  |     def iterator(self): | ||||||
|  |         from django.db.backends.util import typecast_timestamp | ||||||
|  |         self._order_by = () # Clear this because it'll mess things up otherwise. | ||||||
|  |         if self._field.null: | ||||||
|  |             date_query._where.append('%s.%s IS NOT NULL' % \ | ||||||
|  |                 (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) | ||||||
|  |         select, sql, params = self._get_sql_clause(True) | ||||||
|  |         sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ | ||||||
|  |             (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), | ||||||
|  |             backend.quote_name(self._field.column))), sql, self._order) | ||||||
|  |         cursor = connection.cursor() | ||||||
|  |         cursor.execute(sql, params) | ||||||
|  |         # We have to manually run typecast_timestamp(str()) on the results, because | ||||||
|  |         # MySQL doesn't automatically cast the result of date functions as datetime | ||||||
|  |         # objects -- MySQL returns the values as strings, instead. | ||||||
|  |         return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()] | ||||||
|  |  | ||||||
| class QOperator: | class QOperator: | ||||||
|     "Base class for QAnd and QOr" |     "Base class for QAnd and QOr" | ||||||
|     def __init__(self, *args): |     def __init__(self, *args): | ||||||
|   | |||||||
| @@ -37,14 +37,13 @@ datetime.datetime(2005, 7, 28, 0, 0) | |||||||
| >>> a.headline = 'Area woman programs in Python' | >>> a.headline = 'Area woman programs in Python' | ||||||
| >>> a.save() | >>> a.save() | ||||||
|  |  | ||||||
| # Listing objects displays all the articles in the database. Note that the article | # Article.objects.all() returns all the articles in the database. Note that | ||||||
| # is represented by "<Article object>", because we haven't given the Article | # the article is represented by "<Article object>", because we haven't given | ||||||
| # model a __repr__() method. | # the Article model a __repr__() method. | ||||||
| >>> Article.objects.all() | >>> Article.objects.all() | ||||||
| [<Article object>] | [<Article object>] | ||||||
|  |  | ||||||
| # Django provides a rich database lookup API that's entirely driven by | # Django provides a rich database lookup API. | ||||||
| # keyword arguments. |  | ||||||
| >>> Article.objects.get(id__exact=1) | >>> Article.objects.get(id__exact=1) | ||||||
| <Article object> | <Article object> | ||||||
| >>> Article.objects.get(headline__startswith='Area woman') | >>> Article.objects.get(headline__startswith='Area woman') | ||||||
| @@ -56,7 +55,7 @@ datetime.datetime(2005, 7, 28, 0, 0) | |||||||
| >>> Article.objects.get(pub_date__year=2005, pub_date__month=7, pub_date__day=28) | >>> Article.objects.get(pub_date__year=2005, pub_date__month=7, pub_date__day=28) | ||||||
| <Article object> | <Article object> | ||||||
|  |  | ||||||
| # You can omit __exact if you want | # The "__exact" lookup type can be omitted, as a shortcut. | ||||||
| >>> Article.objects.get(id=1) | >>> Article.objects.get(id=1) | ||||||
| <Article object> | <Article object> | ||||||
| >>> Article.objects.get(headline='Area woman programs in Python') | >>> Article.objects.get(headline='Area woman programs in Python') | ||||||
| @@ -69,7 +68,8 @@ datetime.datetime(2005, 7, 28, 0, 0) | |||||||
| >>> Article.objects.filter(pub_date__year=2005, pub_date__month=7) | >>> Article.objects.filter(pub_date__year=2005, pub_date__month=7) | ||||||
| [<Article object>] | [<Article object>] | ||||||
|  |  | ||||||
| # Django raises an ArticleDoesNotExist exception for get() | # Django raises an Article.DoesNotExist exception for get() if the parameters | ||||||
|  | # don't match any object. | ||||||
| >>> Article.objects.get(id__exact=2) | >>> Article.objects.get(id__exact=2) | ||||||
| Traceback (most recent call last): | Traceback (most recent call last): | ||||||
|     ... |     ... | ||||||
| @@ -82,7 +82,7 @@ DoesNotExist: Article does not exist for ... | |||||||
|  |  | ||||||
| # Lookup by a primary key is the most common case, so Django provides a | # Lookup by a primary key is the most common case, so Django provides a | ||||||
| # shortcut for primary-key exact lookups. | # shortcut for primary-key exact lookups. | ||||||
| # The following is identical to articles.get(id__exact=1). | # The following is identical to articles.get(id=1). | ||||||
| >>> Article.objects.get(pk=1) | >>> Article.objects.get(pk=1) | ||||||
| <Article object> | <Article object> | ||||||
|  |  | ||||||
| @@ -93,7 +93,7 @@ DoesNotExist: Article does not exist for ... | |||||||
| True | True | ||||||
|  |  | ||||||
| # You can initialize a model instance using positional arguments, which should | # You can initialize a model instance using positional arguments, which should | ||||||
| # match the field order as defined in the model... | # match the field order as defined in the model. | ||||||
| >>> a2 = Article(None, 'Second article', datetime(2005, 7, 29)) | >>> a2 = Article(None, 'Second article', datetime(2005, 7, 29)) | ||||||
| >>> a2.save() | >>> a2.save() | ||||||
| >>> a2.id | >>> a2.id | ||||||
| @@ -126,7 +126,8 @@ Traceback (most recent call last): | |||||||
|     ... |     ... | ||||||
| TypeError: 'foo' is an invalid keyword argument for this function | TypeError: 'foo' is an invalid keyword argument for this function | ||||||
|  |  | ||||||
| # You can leave off the ID. | # You can leave off the value for an AutoField when creating an object, because | ||||||
|  | # it'll get filled in automatically when you save(). | ||||||
| >>> a5 = Article(headline='Article 6', pub_date=datetime(2005, 7, 31)) | >>> a5 = Article(headline='Article 6', pub_date=datetime(2005, 7, 31)) | ||||||
| >>> a5.save() | >>> a5.save() | ||||||
| >>> a5.id | >>> a5.id | ||||||
| @@ -154,7 +155,7 @@ datetime.datetime(2005, 7, 31, 12, 30, 45) | |||||||
| >>> a8.id | >>> a8.id | ||||||
| 8L | 8L | ||||||
|  |  | ||||||
| # Saving an object again shouldn't create a new object -- it just saves the old one. | # Saving an object again doesn't create a new object -- it just saves the old one. | ||||||
| >>> a8.save() | >>> a8.save() | ||||||
| >>> a8.id | >>> a8.id | ||||||
| 8L | 8L | ||||||
| @@ -174,6 +175,7 @@ True | |||||||
| >>> Article.objects.get(id__exact=8) == Article.objects.get(id__exact=7) | >>> Article.objects.get(id__exact=8) == Article.objects.get(id__exact=7) | ||||||
| False | False | ||||||
|  |  | ||||||
|  | # dates() returns a list of available dates of the given scope for the given field. | ||||||
| >>> Article.objects.dates('pub_date', 'year') | >>> Article.objects.dates('pub_date', 'year') | ||||||
| [datetime.datetime(2005, 1, 1, 0, 0)] | [datetime.datetime(2005, 1, 1, 0, 0)] | ||||||
| >>> Article.objects.dates('pub_date', 'month') | >>> Article.objects.dates('pub_date', 'month') | ||||||
| @@ -185,7 +187,7 @@ False | |||||||
| >>> Article.objects.dates('pub_date', 'day', order='DESC') | >>> Article.objects.dates('pub_date', 'day', order='DESC') | ||||||
| [datetime.datetime(2005, 7, 31, 0, 0), datetime.datetime(2005, 7, 30, 0, 0), datetime.datetime(2005, 7, 29, 0, 0), datetime.datetime(2005, 7, 28, 0, 0)] | [datetime.datetime(2005, 7, 31, 0, 0), datetime.datetime(2005, 7, 30, 0, 0), datetime.datetime(2005, 7, 29, 0, 0), datetime.datetime(2005, 7, 28, 0, 0)] | ||||||
|  |  | ||||||
| # Try some bad arguments to dates(). | # dates() requires valid arguments. | ||||||
|  |  | ||||||
| >>> Article.objects.dates() | >>> Article.objects.dates() | ||||||
| Traceback (most recent call last): | Traceback (most recent call last): | ||||||
| @@ -207,7 +209,16 @@ Traceback (most recent call last): | |||||||
|    ... |    ... | ||||||
| AssertionError: 'order' must be either 'ASC' or 'DESC'. | AssertionError: 'order' must be either 'ASC' or 'DESC'. | ||||||
|  |  | ||||||
| # You can combine queries with & and | | # Use iterator() with dates() to return a generator that lazily requests each | ||||||
|  | # result one at a time, to save memory. | ||||||
|  | >>> for a in Article.objects.dates('pub_date', 'day', order='DESC').iterator(): | ||||||
|  | ...     print repr(a) | ||||||
|  | datetime.datetime(2005, 7, 31, 0, 0) | ||||||
|  | datetime.datetime(2005, 7, 30, 0, 0) | ||||||
|  | datetime.datetime(2005, 7, 29, 0, 0) | ||||||
|  | datetime.datetime(2005, 7, 28, 0, 0) | ||||||
|  |  | ||||||
|  | # You can combine queries with & and |. | ||||||
| >>> s1 = Article.objects.filter(id__exact=1) | >>> s1 = Article.objects.filter(id__exact=1) | ||||||
| >>> s2 = Article.objects.filter(id__exact=2) | >>> s2 = Article.objects.filter(id__exact=2) | ||||||
| >>> tmp = [a.id for a in list(s1 | s2)] | >>> tmp = [a.id for a in list(s1 | s2)] | ||||||
| @@ -231,7 +242,7 @@ AssertionError: 'order' must be either 'ASC' or 'DESC'. | |||||||
| [<Article object>, <Article object>] | [<Article object>, <Article object>] | ||||||
|  |  | ||||||
| # An Article instance doesn't have access to the "objects" attribute. | # An Article instance doesn't have access to the "objects" attribute. | ||||||
| # That is only available as a class method. | # That's only available on the class. | ||||||
| >>> a7.objects.all() | >>> a7.objects.all() | ||||||
| Traceback (most recent call last): | Traceback (most recent call last): | ||||||
|     ... |     ... | ||||||
|   | |||||||
| @@ -33,7 +33,8 @@ API_TESTS = """ | |||||||
| >>> a7 = Article(headline='Article 7', pub_date=datetime(2005, 7, 27)) | >>> a7 = Article(headline='Article 7', pub_date=datetime(2005, 7, 27)) | ||||||
| >>> a7.save() | >>> a7.save() | ||||||
|  |  | ||||||
| # iterator() is a generator. | # Each QuerySet gets iterator(), which is a generator that "lazily" returns | ||||||
|  | # results using database-level iteration. | ||||||
| >>> for a in Article.objects.iterator(): | >>> for a in Article.objects.iterator(): | ||||||
| ...     print a.headline | ...     print a.headline | ||||||
| Article 5 | Article 5 | ||||||
| @@ -103,6 +104,20 @@ True | |||||||
| [('headline', 'Article 7'), ('id', 7)] | [('headline', 'Article 7'), ('id', 7)] | ||||||
| [('headline', 'Article 1'), ('id', 1)] | [('headline', 'Article 1'), ('id', 1)] | ||||||
|  |  | ||||||
|  | # You can use values() with iterator() for memory savings, because iterator() | ||||||
|  | # uses database-level iteration. | ||||||
|  | >>> for d in Article.objects.values('id', 'headline').iterator(): | ||||||
|  | ...     i = d.items() | ||||||
|  | ...     i.sort() | ||||||
|  | ...     i | ||||||
|  | [('headline', 'Article 5'), ('id', 5)] | ||||||
|  | [('headline', 'Article 6'), ('id', 6)] | ||||||
|  | [('headline', 'Article 4'), ('id', 4)] | ||||||
|  | [('headline', 'Article 2'), ('id', 2)] | ||||||
|  | [('headline', 'Article 3'), ('id', 3)] | ||||||
|  | [('headline', 'Article 7'), ('id', 7)] | ||||||
|  | [('headline', 'Article 1'), ('id', 1)] | ||||||
|  |  | ||||||
| # if you don't specify which fields, all are returned | # if you don't specify which fields, all are returned | ||||||
| >>> list(Article.objects.filter(id=5).values()) == [{'id': 5, 'headline': 'Article 5', 'pub_date': datetime(2005, 8, 1, 9, 0)}] | >>> list(Article.objects.filter(id=5).values()) == [{'id': 5, 'headline': 'Article 5', 'pub_date': datetime(2005, 8, 1, 9, 0)}] | ||||||
| True | True | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user