diff --git a/AUTHORS b/AUTHORS index ba1f4036e9..99422f6926 100644 --- a/AUTHORS +++ b/AUTHORS @@ -530,6 +530,7 @@ answer newbie questions, and generally made Django that much better: Leo Shklovskii jason.sidabras@gmail.com MikoĊ‚aj Siedlarek + Karol Sikora Brenton Simpson Jozko Skrablin Ben Slavin diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 065718249e..b369aedb64 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -154,6 +154,9 @@ class Manager(six.with_metaclass(RenameManagerMethods)): def get_or_create(self, **kwargs): return self.get_queryset().get_or_create(**kwargs) + def update_or_create(self, **kwargs): + return self.get_queryset().update_or_create(**kwargs) + def create(self, **kwargs): return self.get_queryset().create(**kwargs) diff --git a/django/db/models/query.py b/django/db/models/query.py index 086cc6dd71..811e917764 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -364,37 +364,84 @@ class QuerySet(object): return objs - def get_or_create(self, **kwargs): + def get_or_create(self, defaults=None, **kwargs): """ Looks up an object with the given kwargs, creating one if necessary. Returns a tuple of (object, created), where created is a boolean specifying whether an object was created. """ - defaults = kwargs.pop('defaults', {}) - lookup = kwargs.copy() - for f in self.model._meta.fields: - if f.attname in lookup: - lookup[f.name] = lookup.pop(f.attname) + lookup, params, _ = self._extract_model_params(defaults, **kwargs) try: self._for_write = True return self.get(**lookup), False except self.model.DoesNotExist: + return self._create_object_from_params(lookup, params) + + def update_or_create(self, defaults=None, **kwargs): + """ + Looks up an object with the given kwargs, updating one with defaults + if it exists, otherwise creates a new one. + Returns a tuple (object, created), where created is a boolean + specifying whether an object was created. + """ + lookup, params, filtered_defaults = self._extract_model_params(defaults, **kwargs) + try: + self._for_write = True + obj = self.get(**lookup) + except self.model.DoesNotExist: + obj, created = self._create_object_from_params(lookup, params) + if created: + return obj, created + for k, v in six.iteritems(filtered_defaults): + setattr(obj, k, v) + try: + sid = transaction.savepoint(using=self.db) + obj.save(update_fields=filtered_defaults.keys(), using=self.db) + transaction.savepoint_commit(sid, using=self.db) + return obj, False + except DatabaseError: + transaction.savepoint_rollback(sid, using=self.db) + six.reraise(sys.exc_info()) + + def _create_object_from_params(self, lookup, params): + """ + Tries to create an object using passed params. + Used by get_or_create and update_or_create + """ + try: + obj = self.model(**params) + sid = transaction.savepoint(using=self.db) + obj.save(force_insert=True, using=self.db) + transaction.savepoint_commit(sid, using=self.db) + return obj, True + except DatabaseError: + transaction.savepoint_rollback(sid, using=self.db) + exc_info = sys.exc_info() try: - params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k) - params.update(defaults) - obj = self.model(**params) - sid = transaction.savepoint(using=self.db) - obj.save(force_insert=True, using=self.db) - transaction.savepoint_commit(sid, using=self.db) - return obj, True - except DatabaseError: - transaction.savepoint_rollback(sid, using=self.db) - exc_info = sys.exc_info() - try: - return self.get(**lookup), False - except self.model.DoesNotExist: - # Re-raise the DatabaseError with its original traceback. - six.reraise(*exc_info) + return self.get(**lookup), False + except self.model.DoesNotExist: + # Re-raise the DatabaseError with its original traceback. + six.reraise(*exc_info) + + def _extract_model_params(self, defaults, **kwargs): + """ + Prepares `lookup` (kwargs that are valid model attributes), `params` + (for creating a model instance) and `filtered_defaults` (defaults + that are valid model attributes) based on given kwargs; for use by + get_or_create and update_or_create. + """ + defaults = defaults or {} + filtered_defaults = {} + lookup = kwargs.copy() + for f in self.model._meta.fields: + # Filter out fields that don't belongs to the model. + if f.attname in lookup: + lookup[f.name] = lookup.pop(f.attname) + if f.attname in defaults: + filtered_defaults[f.name] = defaults.pop(f.attname) + params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k) + params.update(filtered_defaults) + return lookup, params, filtered_defaults def _earliest_or_latest(self, field_name=None, direction="-"): """ diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index d3fe142d36..3963785733 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1330,7 +1330,7 @@ prepared to handle the exception if you are using manual primary keys. get_or_create ~~~~~~~~~~~~~ -.. method:: get_or_create(**kwargs) +.. method:: get_or_create(defaults=None, **kwargs) A convenience method for looking up an object with the given ``kwargs`` (may be empty if your model has defaults for all fields), creating one if necessary. @@ -1366,7 +1366,6 @@ found, ``get_or_create()`` will instantiate and save a new object, returning a tuple of the new object and ``True``. The new object will be created roughly according to this algorithm:: - defaults = kwargs.pop('defaults', {}) params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params.update(defaults) obj = self.model(**params) @@ -1447,6 +1446,49 @@ in the HTTP spec. chapter because it isn't related to that book, but it can't create it either because ``title`` field should be unique. +update_or_create +~~~~~~~~~~~~~~~~ + +.. method:: update_or_create(defaults=None, **kwargs) + +.. versionadded:: 1.7 + +A convenience method for updating an object with the given ``kwargs``, creating +a new one if necessary. The ``defaults`` is a dictionary of (field, value) +pairs used to update the object. + +Returns a tuple of ``(object, created)``, where ``object`` is the created or +updated object and ``created`` is a boolean specifying whether a new object was +created. + +The ``update_or_create`` method tries to fetch an object from database based on +the given ``kwargs``. If a match is found, it updates the fields passed in the +``defaults`` dictionary. + +This is meant as a shortcut to boilerplatish code. For example:: + + try: + obj = Person.objects.get(first_name='John', last_name='Lennon') + for key, value in updated_values.iteritems(): + setattr(obj, key, value) + obj.save() + except Person.DoesNotExist: + updated_values.update({'first_name': 'John', 'last_name': 'Lennon'}) + obj = Person(**updated_values) + obj.save() + +This pattern gets quite unwieldy as the number of fields in a model goes up. +The above example can be rewritten using ``update_or_create()`` like so:: + + obj, created = Person.objects.update_or_create( + first_name='John', last_name='Lennon', defaults=updated_values) + +For detailed description how names passed in ``kwargs`` are resolved see +:meth:`get_or_create`. + +As described above in :meth:`get_or_create`, this method is prone to a +race-condition which can result in multiple rows being inserted simultaneously +if uniqueness is not enforced at the database level. bulk_create ~~~~~~~~~~~ diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt index 2cf1cd326e..f3defb37a3 100644 --- a/docs/releases/1.7.txt +++ b/docs/releases/1.7.txt @@ -41,6 +41,9 @@ Minor features * The ``enter`` argument was added to the :data:`~django.test.signals.setting_changed` signal. +* The :meth:`QuerySet.update_or_create() + ` method was added. + Backwards incompatible changes in 1.7 ===================================== diff --git a/tests/get_or_create/tests.py b/tests/get_or_create/tests.py index 847a6dec01..0f766ab128 100644 --- a/tests/get_or_create/tests.py +++ b/tests/get_or_create/tests.py @@ -131,3 +131,68 @@ class GetOrCreateThroughManyToMany(TestCase): Tag.objects.create(text='foo') a_thing = Thing.objects.create(name='a') self.assertRaises(IntegrityError, a_thing.tags.get_or_create, text='foo') + + +class UpdateOrCreateTests(TestCase): + + def test_update(self): + Person.objects.create( + first_name='John', last_name='Lennon', birthday=date(1940, 10, 9) + ) + p, created = Person.objects.update_or_create( + first_name='John', last_name='Lennon', defaults={ + 'birthday': date(1940, 10, 10) + } + ) + self.assertFalse(created) + self.assertEqual(p.first_name, 'John') + self.assertEqual(p.last_name, 'Lennon') + self.assertEqual(p.birthday, date(1940, 10, 10)) + + def test_create(self): + p, created = Person.objects.update_or_create( + first_name='John', last_name='Lennon', defaults={ + 'birthday': date(1940, 10, 10) + } + ) + self.assertTrue(created) + self.assertEqual(p.first_name, 'John') + self.assertEqual(p.last_name, 'Lennon') + self.assertEqual(p.birthday, date(1940, 10, 10)) + + def test_create_twice(self): + params = { + 'first_name': 'John', + 'last_name': 'Lennon', + 'birthday': date(1940, 10, 10), + } + Person.objects.update_or_create(**params) + # If we execute the exact same statement, it won't create a Person. + p, created = Person.objects.update_or_create(**params) + self.assertFalse(created) + + def test_integrity(self): + # If you don't specify a value or default value for all required + # fields, you will get an error. + self.assertRaises(IntegrityError, + Person.objects.update_or_create, first_name="Tom", last_name="Smith") + + def test_mananual_primary_key_test(self): + # If you specify an existing primary key, but different other fields, + # then you will get an error and data will not be updated. + ManualPrimaryKeyTest.objects.create(id=1, data="Original") + self.assertRaises(IntegrityError, + ManualPrimaryKeyTest.objects.update_or_create, id=1, data="Different" + ) + self.assertEqual(ManualPrimaryKeyTest.objects.get(id=1).data, "Original") + + def test_error_contains_full_traceback(self): + # update_or_create should raise IntegrityErrors with the full traceback. + # This is tested by checking that a known method call is in the traceback. + # We cannot use assertRaises/assertRaises here because we need to inspect + # the actual traceback. Refs #16340. + try: + ManualPrimaryKeyTest.objects.update_or_create(id=1, data="Different") + except IntegrityError as e: + formatted_traceback = traceback.format_exc() + self.assertIn('obj.save', formatted_traceback)