diff --git a/django/core/management.py b/django/core/management.py index 5590a92e8c..193cadeef3 100644 --- a/django/core/management.py +++ b/django/core/management.py @@ -156,21 +156,21 @@ def get_sql_create(app): for klass in app_models: opts = klass._meta for f in opts.many_to_many: - table_output = ['CREATE TABLE %s (' % backend.quote_name(f.get_m2m_db_table(opts))] + table_output = ['CREATE TABLE %s (' % backend.quote_name(f.m2m_db_table())] table_output.append(' %s %s NOT NULL PRIMARY KEY,' % (backend.quote_name('id'), data_types['AutoField'])) table_output.append(' %s %s NOT NULL REFERENCES %s (%s),' % \ - (backend.quote_name(opts.object_name.lower() + '_id'), + (backend.quote_name(f.m2m_column_name()), data_types[get_rel_data_type(opts.pk)] % opts.pk.__dict__, backend.quote_name(opts.db_table), backend.quote_name(opts.pk.column))) table_output.append(' %s %s NOT NULL REFERENCES %s (%s),' % \ - (backend.quote_name(f.rel.to._meta.object_name.lower() + '_id'), + (backend.quote_name(f.m2m_reverse_name()), data_types[get_rel_data_type(f.rel.to._meta.pk)] % f.rel.to._meta.pk.__dict__, backend.quote_name(f.rel.to._meta.db_table), backend.quote_name(f.rel.to._meta.pk.column))) table_output.append(' UNIQUE (%s, %s)' % \ - (backend.quote_name(opts.object_name.lower() + '_id'), - backend.quote_name(f.rel.to._meta.object_name.lower() + '_id'))) + (backend.quote_name(f.m2m_column_name()), + backend.quote_name(f.m2m_reverse_name()))) table_output.append(');') final_output.append('\n'.join(table_output)) return final_output @@ -249,11 +249,11 @@ def get_sql_delete(app): for f in opts.many_to_many: try: if cursor is not None: - cursor.execute("SELECT 1 FROM %s LIMIT 1" % backend.quote_name(f.get_m2m_db_table(opts))) + cursor.execute("SELECT 1 FROM %s LIMIT 1" % backend.quote_name(f.m2m_db_table())) except: connection.rollback() else: - output.append("DROP TABLE %s;" % backend.quote_name(f.get_m2m_db_table(opts))) + output.append("DROP TABLE %s;" % backend.quote_name(f.m2m_db_table())) app_label = app_models[0]._meta.app_label @@ -335,7 +335,7 @@ def get_sql_sequence_reset(app): backend.quote_name(klass._meta.db_table))) for f in klass._meta.many_to_many: output.append("SELECT setval('%s_id_seq', (SELECT max(%s) FROM %s));" % \ - (f.get_m2m_db_table(klass._meta), backend.quote_name('id'), f.get_m2m_db_table(klass._meta))) + (f.m2m_db_table(), backend.quote_name('id'), f.m2m_db_table())) return output get_sql_sequence_reset.help_doc = "Prints the SQL statements for resetting PostgreSQL sequences for the given app name(s)." get_sql_sequence_reset.args = APP_ARGS diff --git a/django/db/models/base.py b/django/db/models/base.py index 5bf552059e..2d3071edb2 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -312,16 +312,16 @@ class Model(object): def _set_related_many_to_many(self, rel_class, rel_field, id_list): id_list = map(int, id_list) # normalize to integers rel = rel_field.rel.to - m2m_table = rel_field.get_m2m_db_table(rel_opts) + m2m_table = rel_field.m2m_db_table() this_id = self._get_pk_val() cursor = connection.cursor() cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ (backend.quote_name(m2m_table), - backend.quote_name(rel.object_name.lower() + '_id')), [this_id]) + backend.quote_name(rel_field.m2m_column_name())), [this_id]) sql = "INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ (backend.quote_name(m2m_table), - backend.quote_name(rel.object_name.lower() + '_id'), - backend.quote_name(rel_opts.object_name.lower() + '_id')) + backend.quote_name(rel_field.m2m_column_name()), + backend.quote_name(rel_field.m2m_reverse_name())) cursor.executemany(sql, [(this_id, i) for i in id_list]) connection.commit() diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 38482a65c0..4106660685 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -199,9 +199,9 @@ class ManyRelatedObjectsDescriptor(object): qn = backend.quote_name this_opts = instance.__class__._meta rel_opts = rel_model._meta - join_table = backend.quote_name(self.related.field.get_m2m_db_table(rel_opts)) - this_col_name = this_opts.object_name.lower() + '_id' - rel_col_name = rel_opts.object_name.lower() + '_id' + join_table = qn(self.related.field.m2m_db_table()) + this_col_name = qn(self.related.field.m2m_reverse_name()) + rel_col_name = qn(self.related.field.m2m_column_name()) # Dynamically create a class that subclasses the related # model's default manager. @@ -210,6 +210,7 @@ class ManyRelatedObjectsDescriptor(object): class RelatedManager(superclass): def get_query_set(self): return superclass.get_query_set(self).filter(**(self.core_filters)) + if rel_type == "o2m": def add(self, **kwargs): kwargs.update({rel_field.name: instance}) @@ -257,7 +258,6 @@ class ReverseManyRelatedObjectsDescriptor(object): # ReverseManyRelatedObjectsDescriptor instance. def __init__(self, m2m_field): self.field = m2m_field - self.rel_model = m2m_field.rel.to def __get__(self, instance, instance_type=None): if instance is None: @@ -265,15 +265,15 @@ class ReverseManyRelatedObjectsDescriptor(object): qn = backend.quote_name this_opts = instance.__class__._meta - rel_model = self.rel_model - rel_opts = self.rel_model._meta - join_table = backend.quote_name(self.field.get_m2m_db_table(this_opts)) - this_col_name = this_opts.object_name.lower() + '_id' - rel_col_name = rel_opts.object_name.lower() + '_id' + rel_model = self.field.rel.to + rel_opts = rel_model._meta + join_table = qn(self.field.m2m_db_table()) + this_col_name = qn(self.field.m2m_column_name()) + rel_col_name = qn(self.field.m2m_reverse_name()) # Dynamically create a class that subclasses the related # model's default manager. - superclass = self.rel_model._default_manager.__class__ + superclass = rel_model._default_manager.__class__ class RelatedManager(superclass): def get_query_set(self): @@ -300,7 +300,7 @@ class ReverseManyRelatedObjectsDescriptor(object): clear.alters_data = True manager = RelatedManager() - manager.model = self.rel_model + manager.model = rel_model return manager @@ -464,9 +464,17 @@ class ManyToManyField(RelatedField, Field): def get_choices_default(self): return Field.get_choices(self, include_blank=False) - def get_m2m_db_table(self, original_opts): - "Returns the name of the many-to-many 'join' table." - return '%s_%s' % (original_opts.db_table, self.name) + def _get_m2m_db_table(self, opts): + "Function that can be curried to provide the m2m table name for this relation" + return '%s_%s' % (opts.db_table, self.name) + + def _get_m2m_column_name(self, related): + "Function that can be curried to provide the source column name for the m2m table" + return related.model._meta.object_name.lower() + '_id' + + def _get_m2m_reverse_name(self, related): + "Function that can be curried to provide the related column name for the m2m table" + return related.parent_model._meta.object_name.lower() + '_id' def isValidIDList(self, field_data, all_data): "Validates that the value is a valid list of foreign keys" @@ -504,12 +512,21 @@ class ManyToManyField(RelatedField, Field): def contribute_to_class(self, cls, name): super(ManyToManyField, self).contribute_to_class(cls, name) + # Add the descriptor for the m2m relation setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self)) - + + # Set up the accessor for the m2m table name for the relation + self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) + def contribute_to_related_class(self, cls, related): setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related, 'm2m')) + # Add the descriptor for the m2m relation self.rel.singular = self.rel.singular or self.rel.to._meta.object_name.lower() + # Set up the accessors for the column names on the m2m table + self.m2m_column_name = curry(self._get_m2m_column_name, related) + self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related) + def set_attributes_from_rel(self): pass diff --git a/django/db/models/query.py b/django/db/models/query.py index 7ecfea7b0c..4b40495f5a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -647,9 +647,10 @@ def lookup_inner(path, clause, value, opts, table, column): # This process hijacks current_table/column to point to the # intermediate table. current_table = "m2m_" + new_table - join_column = new_opts.object_name.lower() + '_id' - intermediate_table = field.get_m2m_db_table(current_opts) - + intermediate_table = field.m2m_db_table() + join_column = field.m2m_reverse_name() + intermediate_column = field.m2m_column_name() + raise FieldFound # Does the name belong to a reverse defined many-to-many field? @@ -663,9 +664,10 @@ def lookup_inner(path, clause, value, opts, table, column): # This process hijacks current_table/column to point to the # intermediate table. current_table = "m2m_" + new_table - join_column = new_opts.object_name.lower() + '_id' - intermediate_table = field.field.get_m2m_db_table(new_opts) - + intermediate_table = field.field.m2m_db_table() + join_column = field.field.m2m_column_name() + intermediate_column = field.field.m2m_reverse_name() + raise FieldFound # Does the name belong to a one-to-many field? @@ -709,7 +711,7 @@ def lookup_inner(path, clause, value, opts, table, column): (backend.quote_name(table), backend.quote_name(current_opts.pk.column), backend.quote_name(current_table), - backend.quote_name(current_opts.object_name.lower() + '_id')) + backend.quote_name(intermediate_column)) ) if path: @@ -787,14 +789,14 @@ def delete_objects(seen_objs): pk_list = [pk for pk,instance in seen_objs[cls]] for related in cls._meta.get_all_related_many_to_many_objects(): cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (backend.quote_name(related.field.get_m2m_db_table(related.opts)), - backend.quote_name(cls._meta.object_name.lower() + '_id'), + (backend.quote_name(related.field.m2m_db_table()), + backend.quote_name(related.field.m2m_reverse_name()), ','.join(['%s' for pk in pk_list])), pk_list) for f in cls._meta.many_to_many: cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (backend.quote_name(f.get_m2m_db_table(cls._meta)), - backend.quote_name(cls._meta.object_name.lower() + '_id'), + (backend.quote_name(f.m2m_db_table()), + backend.quote_name(f.m2m_column_name()), ','.join(['%s' for pk in pk_list])), pk_list) for field in cls._meta.fields: