mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Refs #373 -- Added tuple lookups.
This commit is contained in:
		
				
					committed by
					
						 Sarah Boyce
						Sarah Boyce
					
				
			
			
				
	
			
			
			
						parent
						
							3dac3271d2
						
					
				
				
					commit
					1eac690d25
				
			
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -152,6 +152,7 @@ answer newbie questions, and generally made Django that much better: | |||||||
|     Ben Lomax <lomax.on.the.run@gmail.com> |     Ben Lomax <lomax.on.the.run@gmail.com> | ||||||
|     Ben Slavin <benjamin.slavin@gmail.com> |     Ben Slavin <benjamin.slavin@gmail.com> | ||||||
|     Ben Sturmfels <ben@sturm.com.au> |     Ben Sturmfels <ben@sturm.com.au> | ||||||
|  |     Bendegúz Csirmaz <csirmazbendeguz@gmail.com> | ||||||
|     Berker Peksag <berker.peksag@gmail.com> |     Berker Peksag <berker.peksag@gmail.com> | ||||||
|     Bernd Schlapsi |     Bernd Schlapsi | ||||||
|     Bernhard Essl <me@bernhardessl.com> |     Bernhard Essl <me@bernhardessl.com> | ||||||
|   | |||||||
| @@ -1295,6 +1295,52 @@ class Col(Expression): | |||||||
|         ) + self.target.get_db_converters(connection) |         ) + self.target.get_db_converters(connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ColPairs(Expression): | ||||||
|  |     def __init__(self, alias, targets, sources, output_field): | ||||||
|  |         super().__init__(output_field=output_field) | ||||||
|  |         self.alias, self.targets, self.sources = alias, targets, sources | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.targets) | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|  |         return iter(self.get_cols()) | ||||||
|  |  | ||||||
|  |     def get_cols(self): | ||||||
|  |         return [ | ||||||
|  |             Col(self.alias, target, source) | ||||||
|  |             for target, source in zip(self.targets, self.sources) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     def get_source_expressions(self): | ||||||
|  |         return self.get_cols() | ||||||
|  |  | ||||||
|  |     def set_source_expressions(self, exprs): | ||||||
|  |         assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs) | ||||||
|  |         self.targets = [col.target for col in exprs] | ||||||
|  |         self.sources = [col.field for col in exprs] | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         cols_sql = [] | ||||||
|  |         cols_params = [] | ||||||
|  |         cols = self.get_cols() | ||||||
|  |  | ||||||
|  |         for col in cols: | ||||||
|  |             sql, params = col.as_sql(compiler, connection) | ||||||
|  |             cols_sql.append(sql) | ||||||
|  |             cols_params.extend(params) | ||||||
|  |  | ||||||
|  |         return ", ".join(cols_sql), cols_params | ||||||
|  |  | ||||||
|  |     def relabeled_clone(self, relabels): | ||||||
|  |         return self.__class__( | ||||||
|  |             relabels.get(self.alias, self.alias), self.targets, self.sources, self.field | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def resolve_expression(self, *args, **kwargs): | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |  | ||||||
| class Ref(Expression): | class Ref(Expression): | ||||||
|     """ |     """ | ||||||
|     Reference to column alias of the query. For example, Ref('sum_cost') in |     Reference to column alias of the query. For example, Ref('sum_cost') in | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | from django.db import NotSupportedError | ||||||
|  | from django.db.models.expressions import ColPairs | ||||||
|  | from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups | ||||||
| from django.db.models.lookups import ( | from django.db.models.lookups import ( | ||||||
|     Exact, |     Exact, | ||||||
|     GreaterThan, |     GreaterThan, | ||||||
| @@ -9,34 +12,6 @@ from django.db.models.lookups import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiColSource: |  | ||||||
|     contains_aggregate = False |  | ||||||
|     contains_over_clause = False |  | ||||||
|  |  | ||||||
|     def __init__(self, alias, targets, sources, field): |  | ||||||
|         self.targets, self.sources, self.field, self.alias = ( |  | ||||||
|             targets, |  | ||||||
|             sources, |  | ||||||
|             field, |  | ||||||
|             alias, |  | ||||||
|         ) |  | ||||||
|         self.output_field = self.field |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field) |  | ||||||
|  |  | ||||||
|     def relabeled_clone(self, relabels): |  | ||||||
|         return self.__class__( |  | ||||||
|             relabels.get(self.alias, self.alias), self.targets, self.sources, self.field |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_lookup(self, lookup): |  | ||||||
|         return self.output_field.get_lookup(lookup) |  | ||||||
|  |  | ||||||
|     def resolve_expression(self, *args, **kwargs): |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_normalized_value(value, lhs): | def get_normalized_value(value, lhs): | ||||||
|     from django.db.models import Model |     from django.db.models import Model | ||||||
|  |  | ||||||
| @@ -64,7 +39,7 @@ def get_normalized_value(value, lhs): | |||||||
|  |  | ||||||
| class RelatedIn(In): | class RelatedIn(In): | ||||||
|     def get_prep_lookup(self): |     def get_prep_lookup(self): | ||||||
|         if not isinstance(self.lhs, MultiColSource): |         if not isinstance(self.lhs, ColPairs): | ||||||
|             if self.rhs_is_direct_value(): |             if self.rhs_is_direct_value(): | ||||||
|                 # If we get here, we are dealing with single-column relations. |                 # If we get here, we are dealing with single-column relations. | ||||||
|                 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] |                 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] | ||||||
| @@ -98,49 +73,33 @@ class RelatedIn(In): | |||||||
|         return super().get_prep_lookup() |         return super().get_prep_lookup() | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         if isinstance(self.lhs, MultiColSource): |         if isinstance(self.lhs, ColPairs): | ||||||
|             # For multicolumn lookups we need to build a multicolumn where clause. |             # For multicolumn lookups we need to build a multicolumn where clause. | ||||||
|             # This clause is either a SubqueryConstraint (for values that need |             # This clause is either a SubqueryConstraint (for values that need | ||||||
|             # to be compiled to SQL) or an OR-combined list of |             # to be compiled to SQL) or an OR-combined list of | ||||||
|             # (col1 = val1 AND col2 = val2 AND ...) clauses. |             # (col1 = val1 AND col2 = val2 AND ...) clauses. | ||||||
|             from django.db.models.sql.where import ( |             from django.db.models.sql.where import SubqueryConstraint | ||||||
|                 AND, |  | ||||||
|                 OR, |  | ||||||
|                 SubqueryConstraint, |  | ||||||
|                 WhereNode, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             root_constraint = WhereNode(connector=OR) |  | ||||||
|             if self.rhs_is_direct_value(): |             if self.rhs_is_direct_value(): | ||||||
|                 values = [get_normalized_value(value, self.lhs) for value in self.rhs] |                 values = [get_normalized_value(value, self.lhs) for value in self.rhs] | ||||||
|                 for value in values: |                 lookup = TupleIn(self.lhs, values) | ||||||
|                     value_constraint = WhereNode() |                 return compiler.compile(lookup) | ||||||
|                     for source, target, val in zip( |  | ||||||
|                         self.lhs.sources, self.lhs.targets, value |  | ||||||
|                     ): |  | ||||||
|                         lookup_class = target.get_lookup("exact") |  | ||||||
|                         lookup = lookup_class( |  | ||||||
|                             target.get_col(self.lhs.alias, source), val |  | ||||||
|                         ) |  | ||||||
|                         value_constraint.add(lookup, AND) |  | ||||||
|                     root_constraint.add(value_constraint, OR) |  | ||||||
|             else: |             else: | ||||||
|                 root_constraint.add( |                 return compiler.compile( | ||||||
|                     SubqueryConstraint( |                     SubqueryConstraint( | ||||||
|                         self.lhs.alias, |                         self.lhs.alias, | ||||||
|                         [target.column for target in self.lhs.targets], |                         [target.column for target in self.lhs.targets], | ||||||
|                         [source.name for source in self.lhs.sources], |                         [source.name for source in self.lhs.sources], | ||||||
|                         self.rhs, |                         self.rhs, | ||||||
|                     ), |                     ), | ||||||
|                     AND, |  | ||||||
|                 ) |                 ) | ||||||
|             return root_constraint.as_sql(compiler, connection) |  | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RelatedLookupMixin: | class RelatedLookupMixin: | ||||||
|     def get_prep_lookup(self): |     def get_prep_lookup(self): | ||||||
|         if not isinstance(self.lhs, MultiColSource) and not hasattr( |         if not isinstance(self.lhs, ColPairs) and not hasattr( | ||||||
|             self.rhs, "resolve_expression" |             self.rhs, "resolve_expression" | ||||||
|         ): |         ): | ||||||
|             # If we get here, we are dealing with single-column relations. |             # If we get here, we are dealing with single-column relations. | ||||||
| @@ -158,20 +117,16 @@ class RelatedLookupMixin: | |||||||
|         return super().get_prep_lookup() |         return super().get_prep_lookup() | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         if isinstance(self.lhs, MultiColSource): |         if isinstance(self.lhs, ColPairs): | ||||||
|             assert self.rhs_is_direct_value() |             if not self.rhs_is_direct_value(): | ||||||
|             self.rhs = get_normalized_value(self.rhs, self.lhs) |                 raise NotSupportedError( | ||||||
|             from django.db.models.sql.where import AND, WhereNode |                     f"'{self.lookup_name}' doesn't support multi-column subqueries." | ||||||
|  |  | ||||||
|             root_constraint = WhereNode() |  | ||||||
|             for target, source, val in zip( |  | ||||||
|                 self.lhs.targets, self.lhs.sources, self.rhs |  | ||||||
|             ): |  | ||||||
|                 lookup_class = target.get_lookup(self.lookup_name) |  | ||||||
|                 root_constraint.add( |  | ||||||
|                     lookup_class(target.get_col(self.lhs.alias, source), val), AND |  | ||||||
|                 ) |                 ) | ||||||
|             return root_constraint.as_sql(compiler, connection) |             self.rhs = get_normalized_value(self.rhs, self.lhs) | ||||||
|  |             lookup_class = tuple_lookups[self.lookup_name] | ||||||
|  |             lookup = lookup_class(self.lhs, self.rhs) | ||||||
|  |             return compiler.compile(lookup) | ||||||
|  |  | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										244
									
								
								django/db/models/fields/tuple_lookups.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								django/db/models/fields/tuple_lookups.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,244 @@ | |||||||
|  | import itertools | ||||||
|  |  | ||||||
|  | from django.core.exceptions import EmptyResultSet | ||||||
|  | from django.db.models.expressions import ColPairs, Func, Value | ||||||
|  | from django.db.models.lookups import ( | ||||||
|  |     Exact, | ||||||
|  |     GreaterThan, | ||||||
|  |     GreaterThanOrEqual, | ||||||
|  |     In, | ||||||
|  |     IsNull, | ||||||
|  |     LessThan, | ||||||
|  |     LessThanOrEqual, | ||||||
|  | ) | ||||||
|  | from django.db.models.sql.where import AND, OR, WhereNode | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Tuple(Func): | ||||||
|  |     function = "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleLookupMixin: | ||||||
|  |     def get_prep_lookup(self): | ||||||
|  |         self.check_tuple_lookup() | ||||||
|  |         return super().get_prep_lookup() | ||||||
|  |  | ||||||
|  |     def check_tuple_lookup(self): | ||||||
|  |         assert isinstance(self.lhs, ColPairs) | ||||||
|  |         self.check_rhs_is_tuple_or_list() | ||||||
|  |         self.check_rhs_length_equals_lhs_length() | ||||||
|  |  | ||||||
|  |     def check_rhs_is_tuple_or_list(self): | ||||||
|  |         if not isinstance(self.rhs, (tuple, list)): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|  |                 "must be a tuple or a list" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def check_rhs_length_equals_lhs_length(self): | ||||||
|  |         if len(self.lhs) != len(self.rhs): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|  |                 f"must have {len(self.lhs)} elements" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def check_rhs_is_collection_of_tuples_or_lists(self): | ||||||
|  |         if not all(isinstance(vals, (tuple, list)) for vals in self.rhs): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|  |                 f"must be a collection of tuples or lists" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def check_rhs_elements_length_equals_lhs_length(self): | ||||||
|  |         if not all(len(self.lhs) == len(vals) for vals in self.rhs): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|  |                 f"must have {len(self.lhs)} elements each" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) == (x, y, z) as SQL: | ||||||
|  |         # WHERE (a, b, c) = (x, y, z) | ||||||
|  |         vals = [ | ||||||
|  |             Value(val, output_field=col.output_field) | ||||||
|  |             for col, val in zip(self.lhs, self.rhs) | ||||||
|  |         ] | ||||||
|  |         lookup_class = self.__class__.__bases__[-1] | ||||||
|  |         lookup = lookup_class(Tuple(self.lhs), Tuple(*vals)) | ||||||
|  |         return lookup.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleExact(TupleLookupMixin, Exact): | ||||||
|  |     def as_oracle(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) == (x, y, z) as SQL: | ||||||
|  |         # WHERE a = x AND b = y AND c = z | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)] | ||||||
|  |         root = WhereNode(lookups, connector=AND) | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleIsNull(IsNull): | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) is None as SQL: | ||||||
|  |         # WHERE a IS NULL AND b IS NULL AND c IS NULL | ||||||
|  |         vals = self.rhs | ||||||
|  |         if isinstance(vals, bool): | ||||||
|  |             vals = [vals] * len(self.lhs) | ||||||
|  |  | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = [IsNull(col, val) for col, val in zip(cols, vals)] | ||||||
|  |         root = WhereNode(lookups, connector=AND) | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleGreaterThan(TupleLookupMixin, GreaterThan): | ||||||
|  |     def as_oracle(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) > (x, y, z) as SQL: | ||||||
|  |         # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z))) | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = itertools.cycle([GreaterThan, Exact]) | ||||||
|  |         connectors = itertools.cycle([OR, AND]) | ||||||
|  |         cols_list = [col for col in cols for _ in range(2)] | ||||||
|  |         vals_list = [val for val in self.rhs for _ in range(2)] | ||||||
|  |         cols_iter = iter(cols_list[:-1]) | ||||||
|  |         vals_iter = iter(vals_list[:-1]) | ||||||
|  |         col, val = next(cols_iter), next(vals_iter) | ||||||
|  |         lookup, connector = next(lookups), next(connectors) | ||||||
|  |         root = node = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |  | ||||||
|  |         for col, val in zip(cols_iter, vals_iter): | ||||||
|  |             lookup, connector = next(lookups), next(connectors) | ||||||
|  |             child = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |             node.children.append(child) | ||||||
|  |             node = child | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual): | ||||||
|  |     def as_oracle(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) >= (x, y, z) as SQL: | ||||||
|  |         # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z)))) | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = itertools.cycle([GreaterThan, Exact]) | ||||||
|  |         connectors = itertools.cycle([OR, AND]) | ||||||
|  |         cols_list = [col for col in cols for _ in range(2)] | ||||||
|  |         vals_list = [val for val in self.rhs for _ in range(2)] | ||||||
|  |         cols_iter = iter(cols_list) | ||||||
|  |         vals_iter = iter(vals_list) | ||||||
|  |         col, val = next(cols_iter), next(vals_iter) | ||||||
|  |         lookup, connector = next(lookups), next(connectors) | ||||||
|  |         root = node = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |  | ||||||
|  |         for col, val in zip(cols_iter, vals_iter): | ||||||
|  |             lookup, connector = next(lookups), next(connectors) | ||||||
|  |             child = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |             node.children.append(child) | ||||||
|  |             node = child | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleLessThan(TupleLookupMixin, LessThan): | ||||||
|  |     def as_oracle(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) < (x, y, z) as SQL: | ||||||
|  |         # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z))) | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = itertools.cycle([LessThan, Exact]) | ||||||
|  |         connectors = itertools.cycle([OR, AND]) | ||||||
|  |         cols_list = [col for col in cols for _ in range(2)] | ||||||
|  |         vals_list = [val for val in self.rhs for _ in range(2)] | ||||||
|  |         cols_iter = iter(cols_list[:-1]) | ||||||
|  |         vals_iter = iter(vals_list[:-1]) | ||||||
|  |         col, val = next(cols_iter), next(vals_iter) | ||||||
|  |         lookup, connector = next(lookups), next(connectors) | ||||||
|  |         root = node = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |  | ||||||
|  |         for col, val in zip(cols_iter, vals_iter): | ||||||
|  |             lookup, connector = next(lookups), next(connectors) | ||||||
|  |             child = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |             node.children.append(child) | ||||||
|  |             node = child | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): | ||||||
|  |     def as_oracle(self, compiler, connection): | ||||||
|  |         # e.g.: (a, b, c) <= (x, y, z) as SQL: | ||||||
|  |         # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z)))) | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |         lookups = itertools.cycle([LessThan, Exact]) | ||||||
|  |         connectors = itertools.cycle([OR, AND]) | ||||||
|  |         cols_list = [col for col in cols for _ in range(2)] | ||||||
|  |         vals_list = [val for val in self.rhs for _ in range(2)] | ||||||
|  |         cols_iter = iter(cols_list) | ||||||
|  |         vals_iter = iter(vals_list) | ||||||
|  |         col, val = next(cols_iter), next(vals_iter) | ||||||
|  |         lookup, connector = next(lookups), next(connectors) | ||||||
|  |         root = node = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |  | ||||||
|  |         for col, val in zip(cols_iter, vals_iter): | ||||||
|  |             lookup, connector = next(lookups), next(connectors) | ||||||
|  |             child = WhereNode([lookup(col, val)], connector=connector) | ||||||
|  |             node.children.append(child) | ||||||
|  |             node = child | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleIn(TupleLookupMixin, In): | ||||||
|  |     def check_tuple_lookup(self): | ||||||
|  |         assert isinstance(self.lhs, ColPairs) | ||||||
|  |         self.check_rhs_is_tuple_or_list() | ||||||
|  |         self.check_rhs_is_collection_of_tuples_or_lists() | ||||||
|  |         self.check_rhs_elements_length_equals_lhs_length() | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         if not self.rhs: | ||||||
|  |             raise EmptyResultSet | ||||||
|  |  | ||||||
|  |         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: | ||||||
|  |         # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2)) | ||||||
|  |         rhs = [] | ||||||
|  |         for vals in self.rhs: | ||||||
|  |             rhs.append( | ||||||
|  |                 Tuple( | ||||||
|  |                     *[ | ||||||
|  |                         Value(val, output_field=col.output_field) | ||||||
|  |                         for col, val in zip(self.lhs, vals) | ||||||
|  |                     ] | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         lookup = In(Tuple(self.lhs), Tuple(*rhs)) | ||||||
|  |         return lookup.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |     def as_sqlite(self, compiler, connection): | ||||||
|  |         if not self.rhs: | ||||||
|  |             raise EmptyResultSet | ||||||
|  |  | ||||||
|  |         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: | ||||||
|  |         # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) | ||||||
|  |         root = WhereNode([], connector=OR) | ||||||
|  |         cols = self.lhs.get_cols() | ||||||
|  |  | ||||||
|  |         for vals in self.rhs: | ||||||
|  |             lookups = [Exact(col, val) for col, val in zip(cols, vals)] | ||||||
|  |             root.children.append(WhereNode(lookups, connector=AND)) | ||||||
|  |  | ||||||
|  |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | tuple_lookups = { | ||||||
|  |     "exact": TupleExact, | ||||||
|  |     "gt": TupleGreaterThan, | ||||||
|  |     "gte": TupleGreaterThanOrEqual, | ||||||
|  |     "lt": TupleLessThan, | ||||||
|  |     "lte": TupleLessThanOrEqual, | ||||||
|  |     "in": TupleIn, | ||||||
|  |     "isnull": TupleIsNull, | ||||||
|  | } | ||||||
| @@ -23,6 +23,7 @@ from django.db.models.constants import LOOKUP_SEP | |||||||
| from django.db.models.expressions import ( | from django.db.models.expressions import ( | ||||||
|     BaseExpression, |     BaseExpression, | ||||||
|     Col, |     Col, | ||||||
|  |     ColPairs, | ||||||
|     Exists, |     Exists, | ||||||
|     F, |     F, | ||||||
|     OuterRef, |     OuterRef, | ||||||
| @@ -32,7 +33,6 @@ from django.db.models.expressions import ( | |||||||
|     Value, |     Value, | ||||||
| ) | ) | ||||||
| from django.db.models.fields import Field | from django.db.models.fields import Field | ||||||
| from django.db.models.fields.related_lookups import MultiColSource |  | ||||||
| from django.db.models.lookups import Lookup | from django.db.models.lookups import Lookup | ||||||
| from django.db.models.query_utils import ( | from django.db.models.query_utils import ( | ||||||
|     Q, |     Q, | ||||||
| @@ -1549,9 +1549,7 @@ class Query(BaseExpression): | |||||||
|             if len(targets) == 1: |             if len(targets) == 1: | ||||||
|                 col = self._get_col(targets[0], join_info.final_field, alias) |                 col = self._get_col(targets[0], join_info.final_field, alias) | ||||||
|             else: |             else: | ||||||
|                 col = MultiColSource( |                 col = ColPairs(alias, targets, join_info.targets, join_info.final_field) | ||||||
|                     alias, targets, join_info.targets, join_info.final_field |  | ||||||
|                 ) |  | ||||||
|         else: |         else: | ||||||
|             col = self._get_col(targets[0], join_info.final_field, alias) |             col = self._get_col(targets[0], join_info.final_field, alias) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										242
									
								
								tests/foreign_object/test_tuple_lookups.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								tests/foreign_object/test_tuple_lookups.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,242 @@ | |||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from django.db import NotSupportedError, connection | ||||||
|  | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from .models import Contact, Customer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TupleLookupsTests(TestCase): | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         super().setUpTestData() | ||||||
|  |         cls.customer_1 = Customer.objects.create(customer_id=1, company="a") | ||||||
|  |         cls.customer_2 = Customer.objects.create(customer_id=1, company="b") | ||||||
|  |         cls.customer_3 = Customer.objects.create(customer_id=2, company="c") | ||||||
|  |         cls.customer_4 = Customer.objects.create(customer_id=3, company="d") | ||||||
|  |         cls.customer_5 = Customer.objects.create(customer_id=1, company="e") | ||||||
|  |         cls.contact_1 = Contact.objects.create(customer=cls.customer_1) | ||||||
|  |         cls.contact_2 = Contact.objects.create(customer=cls.customer_1) | ||||||
|  |         cls.contact_3 = Contact.objects.create(customer=cls.customer_2) | ||||||
|  |         cls.contact_4 = Contact.objects.create(customer=cls.customer_3) | ||||||
|  |         cls.contact_5 = Contact.objects.create(customer=cls.customer_1) | ||||||
|  |         cls.contact_6 = Contact.objects.create(customer=cls.customer_5) | ||||||
|  |  | ||||||
|  |     def test_exact(self): | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)), | ||||||
|  |             (self.customer_2, (self.contact_3,)), | ||||||
|  |             (self.customer_3, (self.contact_4,)), | ||||||
|  |             (self.customer_4, ()), | ||||||
|  |             (self.customer_5, (self.contact_6,)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             with self.subTest(customer=customer, contacts=contacts): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer=customer).order_by("id"), contacts | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def test_exact_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'exact' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_in(self): | ||||||
|  |         cust_1, cust_2, cust_3, cust_4, cust_5 = ( | ||||||
|  |             self.customer_1, | ||||||
|  |             self.customer_2, | ||||||
|  |             self.customer_3, | ||||||
|  |             self.customer_4, | ||||||
|  |             self.customer_5, | ||||||
|  |         ) | ||||||
|  |         c1, c2, c3, c4, c5, c6 = ( | ||||||
|  |             self.contact_1, | ||||||
|  |             self.contact_2, | ||||||
|  |             self.contact_3, | ||||||
|  |             self.contact_4, | ||||||
|  |             self.contact_5, | ||||||
|  |             self.contact_6, | ||||||
|  |         ) | ||||||
|  |         test_cases = ( | ||||||
|  |             ((), ()), | ||||||
|  |             ((cust_1,), (c1, c2, c5)), | ||||||
|  |             ((cust_1, cust_2), (c1, c2, c3, c5)), | ||||||
|  |             ((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)), | ||||||
|  |             ((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)), | ||||||
|  |             ((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for contacts, customers in test_cases: | ||||||
|  |             with self.subTest(contacts=contacts, customers=customers): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer__in=contacts).order_by("id"), | ||||||
|  |                     customers, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     @unittest.skipIf( | ||||||
|  |         connection.vendor == "mysql", | ||||||
|  |         "MySQL doesn't support LIMIT & IN/ALL/ANY/SOME subquery", | ||||||
|  |     ) | ||||||
|  |     def test_in_subquery(self): | ||||||
|  |         subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Contact.objects.filter(customer__in=subquery).order_by("id"), | ||||||
|  |             (self.contact_1, self.contact_2, self.contact_5), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_lt(self): | ||||||
|  |         c1, c2, c3, c4, c5, c6 = ( | ||||||
|  |             self.contact_1, | ||||||
|  |             self.contact_2, | ||||||
|  |             self.contact_3, | ||||||
|  |             self.contact_4, | ||||||
|  |             self.contact_5, | ||||||
|  |             self.contact_6, | ||||||
|  |         ) | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, ()), | ||||||
|  |             (self.customer_2, (c1, c2, c5)), | ||||||
|  |             (self.customer_5, (c1, c2, c3, c5)), | ||||||
|  |             (self.customer_3, (c1, c2, c3, c5, c6)), | ||||||
|  |             (self.customer_4, (c1, c2, c3, c4, c5, c6)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             with self.subTest(customer=customer, contacts=contacts): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer__lt=customer).order_by("id"), | ||||||
|  |                     contacts, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def test_lt_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'lt' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__lt=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_lte(self): | ||||||
|  |         c1, c2, c3, c4, c5, c6 = ( | ||||||
|  |             self.contact_1, | ||||||
|  |             self.contact_2, | ||||||
|  |             self.contact_3, | ||||||
|  |             self.contact_4, | ||||||
|  |             self.contact_5, | ||||||
|  |             self.contact_6, | ||||||
|  |         ) | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, (c1, c2, c5)), | ||||||
|  |             (self.customer_2, (c1, c2, c3, c5)), | ||||||
|  |             (self.customer_5, (c1, c2, c3, c5, c6)), | ||||||
|  |             (self.customer_3, (c1, c2, c3, c4, c5, c6)), | ||||||
|  |             (self.customer_4, (c1, c2, c3, c4, c5, c6)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             with self.subTest(customer=customer, contacts=contacts): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer__lte=customer).order_by("id"), | ||||||
|  |                     contacts, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def test_lte_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'lte' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__lte=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_gt(self): | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, (self.contact_3, self.contact_4, self.contact_6)), | ||||||
|  |             (self.customer_2, (self.contact_4, self.contact_6)), | ||||||
|  |             (self.customer_5, (self.contact_4,)), | ||||||
|  |             (self.customer_3, ()), | ||||||
|  |             (self.customer_4, ()), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             with self.subTest(customer=customer, contacts=contacts): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer__gt=customer).order_by("id"), | ||||||
|  |                     contacts, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def test_gt_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'gt' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__gt=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_gte(self): | ||||||
|  |         c1, c2, c3, c4, c5, c6 = ( | ||||||
|  |             self.contact_1, | ||||||
|  |             self.contact_2, | ||||||
|  |             self.contact_3, | ||||||
|  |             self.contact_4, | ||||||
|  |             self.contact_5, | ||||||
|  |             self.contact_6, | ||||||
|  |         ) | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, (c1, c2, c3, c4, c5, c6)), | ||||||
|  |             (self.customer_2, (c3, c4, c6)), | ||||||
|  |             (self.customer_5, (c4, c6)), | ||||||
|  |             (self.customer_3, (c4,)), | ||||||
|  |             (self.customer_4, ()), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             with self.subTest(customer=customer, contacts=contacts): | ||||||
|  |                 self.assertSequenceEqual( | ||||||
|  |                     Contact.objects.filter(customer__gte=customer).order_by("pk"), | ||||||
|  |                     contacts, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def test_gte_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'gte' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=self.customer_1.id)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__gte=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_isnull(self): | ||||||
|  |         with self.subTest("customer__isnull=True"): | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__isnull=True).order_by("id"), | ||||||
|  |                 (), | ||||||
|  |             ) | ||||||
|  |         with self.subTest("customer__isnull=False"): | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__isnull=False).order_by("id"), | ||||||
|  |                 ( | ||||||
|  |                     self.contact_1, | ||||||
|  |                     self.contact_2, | ||||||
|  |                     self.contact_3, | ||||||
|  |                     self.contact_4, | ||||||
|  |                     self.contact_5, | ||||||
|  |                     self.contact_6, | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_isnull_subquery(self): | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             NotSupportedError, "'isnull' doesn't support multi-column subqueries." | ||||||
|  |         ): | ||||||
|  |             subquery = Customer.objects.filter(id=0)[:1] | ||||||
|  |             self.assertSequenceEqual( | ||||||
|  |                 Contact.objects.filter(customer__isnull=subquery).order_by("id"), () | ||||||
|  |             ) | ||||||
		Reference in New Issue
	
	Block a user