1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #34280 -- Allowed specifying different field values for create operation in QuerySet.update_or_create().

This commit is contained in:
tschilling 2023-01-30 20:39:15 -06:00 committed by Mariusz Felisiak
parent ecafcaf634
commit c5808470aa
8 changed files with 212 additions and 27 deletions

View File

@ -926,25 +926,32 @@ class QuerySet(AltersData):
**kwargs, **kwargs,
) )
def update_or_create(self, defaults=None, **kwargs): def update_or_create(self, defaults=None, create_defaults=None, **kwargs):
""" """
Look up an object with the given kwargs, updating one with defaults Look up an object with the given kwargs, updating one with defaults
if it exists, otherwise create a new one. if it exists, otherwise create a new one. Optionally, an object can
be created with different values than defaults by using
create_defaults.
Return a tuple (object, created), where created is a boolean Return a tuple (object, created), where created is a boolean
specifying whether an object was created. specifying whether an object was created.
""" """
defaults = defaults or {} if create_defaults is None:
update_defaults = create_defaults = defaults or {}
else:
update_defaults = defaults or {}
self._for_write = True self._for_write = True
with transaction.atomic(using=self.db): with transaction.atomic(using=self.db):
# Lock the row so that a concurrent update is blocked until # Lock the row so that a concurrent update is blocked until
# update_or_create() has performed its save. # update_or_create() has performed its save.
obj, created = self.select_for_update().get_or_create(defaults, **kwargs) obj, created = self.select_for_update().get_or_create(
create_defaults, **kwargs
)
if created: if created:
return obj, created return obj, created
for k, v in resolve_callables(defaults): for k, v in resolve_callables(update_defaults):
setattr(obj, k, v) setattr(obj, k, v)
update_fields = set(defaults) update_fields = set(update_defaults)
concrete_field_names = self.model._meta._non_pk_concrete_field_names concrete_field_names = self.model._meta._non_pk_concrete_field_names
# update_fields does not support non-concrete fields. # update_fields does not support non-concrete fields.
if concrete_field_names.issuperset(update_fields): if concrete_field_names.issuperset(update_fields):
@ -964,9 +971,10 @@ class QuerySet(AltersData):
obj.save(using=self.db) obj.save(using=self.db)
return obj, False return obj, False
async def aupdate_or_create(self, defaults=None, **kwargs): async def aupdate_or_create(self, defaults=None, create_defaults=None, **kwargs):
return await sync_to_async(self.update_or_create)( return await sync_to_async(self.update_or_create)(
defaults=defaults, defaults=defaults,
create_defaults=create_defaults,
**kwargs, **kwargs,
) )

View File

@ -2263,14 +2263,18 @@ whenever a request to a page has a side effect on your data. For more, see
``update_or_create()`` ``update_or_create()``
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
.. method:: update_or_create(defaults=None, **kwargs) .. method:: update_or_create(defaults=None, create_defaults=None, **kwargs)
.. method:: aupdate_or_create(defaults=None, **kwargs) .. method:: aupdate_or_create(defaults=None, create_defaults=None, **kwargs)
*Asynchronous version*: ``aupdate_or_create()`` *Asynchronous version*: ``aupdate_or_create()``
A convenience method for updating an object with the given ``kwargs``, creating A convenience method for updating an object with the given ``kwargs``, creating
a new one if necessary. The ``defaults`` is a dictionary of (field, value) a new one if necessary. Both ``create_defaults`` and ``defaults`` are
pairs used to update the object. The values in ``defaults`` can be callables. dictionaries of (field, value) pairs. The values in both ``create_defaults``
and ``defaults`` can be callables. ``defaults`` is used to update the object
while ``create_defaults`` are used for the create operation. If
``create_defaults`` is not supplied, ``defaults`` will be used for the create
operation.
Returns a tuple of ``(object, created)``, where ``object`` is the created or Returns a tuple of ``(object, created)``, where ``object`` is the created or
updated object and ``created`` is a boolean specifying whether a new object was updated object and ``created`` is a boolean specifying whether a new object was
@ -2283,6 +2287,7 @@ the given ``kwargs``. If a match is found, it updates the fields passed in the
This is meant as a shortcut to boilerplatish code. For example:: This is meant as a shortcut to boilerplatish code. For example::
defaults = {'first_name': 'Bob'} defaults = {'first_name': 'Bob'}
create_defaults = {'first_name': 'Bob', 'birthday': date(1940, 10, 9)}
try: try:
obj = Person.objects.get(first_name='John', last_name='Lennon') obj = Person.objects.get(first_name='John', last_name='Lennon')
for key, value in defaults.items(): for key, value in defaults.items():
@ -2290,7 +2295,7 @@ This is meant as a shortcut to boilerplatish code. For example::
obj.save() obj.save()
except Person.DoesNotExist: except Person.DoesNotExist:
new_values = {'first_name': 'John', 'last_name': 'Lennon'} new_values = {'first_name': 'John', 'last_name': 'Lennon'}
new_values.update(defaults) new_values.update(create_defaults)
obj = Person(**new_values) obj = Person(**new_values)
obj.save() obj.save()
@ -2300,6 +2305,7 @@ The above example can be rewritten using ``update_or_create()`` like so::
obj, created = Person.objects.update_or_create( obj, created = Person.objects.update_or_create(
first_name='John', last_name='Lennon', first_name='John', last_name='Lennon',
defaults={'first_name': 'Bob'}, defaults={'first_name': 'Bob'},
create_defaults={'first_name': 'Bob', 'birthday': date(1940, 10, 9)},
) )
For a detailed description of how names passed in ``kwargs`` are resolved, see For a detailed description of how names passed in ``kwargs`` are resolved, see
@ -2318,6 +2324,10 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
In older versions, ``update_or_create()`` didn't specify ``update_fields`` In older versions, ``update_or_create()`` didn't specify ``update_fields``
when calling :meth:`Model.save() <django.db.models.Model.save>`. when calling :meth:`Model.save() <django.db.models.Model.save>`.
.. versionchanged:: 5.0
The ``create_defaults`` argument was added.
``bulk_create()`` ``bulk_create()``
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~

View File

@ -178,7 +178,9 @@ Migrations
Models Models
~~~~~~ ~~~~~~
* ... * The new ``create_defaults`` argument of :meth:`.QuerySet.update_or_create`
and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different
field values for the create operation.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
@ -238,6 +240,14 @@ backends.
* ... * ...
Using ``create_defaults__exact`` may now be required with ``QuerySet.update_or_create()``
-----------------------------------------------------------------------------------------
:meth:`.QuerySet.update_or_create` now supports the parameter
``create_defaults``. As a consequence, any models that have a field named
``create_defaults`` that are used with an ``update_or_create()`` should specify
the field in the lookup with ``create_defaults__exact``.
Miscellaneous Miscellaneous
------------- -------------

View File

@ -99,10 +99,17 @@ class AsyncQuerySetTest(TestCase):
id=self.s1.id, defaults={"field": 2} id=self.s1.id, defaults={"field": 2}
) )
self.assertEqual(instance, self.s1) self.assertEqual(instance, self.s1)
self.assertEqual(instance.field, 2)
self.assertIs(created, False) self.assertIs(created, False)
instance, created = await SimpleModel.objects.aupdate_or_create(field=4) instance, created = await SimpleModel.objects.aupdate_or_create(field=4)
self.assertEqual(await SimpleModel.objects.acount(), 4) self.assertEqual(await SimpleModel.objects.acount(), 4)
self.assertIs(created, True) self.assertIs(created, True)
instance, created = await SimpleModel.objects.aupdate_or_create(
field=5, defaults={"field": 7}, create_defaults={"field": 6}
)
self.assertEqual(await SimpleModel.objects.acount(), 5)
self.assertIs(created, True)
self.assertEqual(instance.field, 6)
@skipUnlessDBFeature("has_bulk_insert") @skipUnlessDBFeature("has_bulk_insert")
@async_to_sync @async_to_sync

View File

@ -44,12 +44,18 @@ class AsyncRelatedManagersOperationTest(TestCase):
self.assertIs(created, True) self.assertIs(created, True)
self.assertEqual(await self.mtm1.simples.acount(), 1) self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2) self.assertEqual(new_simple.field, 2)
new_simple, created = await self.mtm1.simples.aupdate_or_create( new_simple1, created = await self.mtm1.simples.aupdate_or_create(
id=new_simple.id, defaults={"field": 3} id=new_simple.id, defaults={"field": 3}
) )
self.assertIs(created, False) self.assertIs(created, False)
self.assertEqual(await self.mtm1.simples.acount(), 1) self.assertEqual(new_simple1.field, 3)
self.assertEqual(new_simple.field, 3)
new_simple2, created = await self.mtm1.simples.aupdate_or_create(
field=4, defaults={"field": 6}, create_defaults={"field": 5}
)
self.assertIs(created, True)
self.assertEqual(new_simple2.field, 5)
self.assertEqual(await self.mtm1.simples.acount(), 2)
async def test_aupdate_or_create_reverse(self): async def test_aupdate_or_create_reverse(self):
new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create() new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create()

View File

@ -59,6 +59,19 @@ class GenericRelationsTests(TestCase):
self.assertTrue(created) self.assertTrue(created)
self.assertEqual(count + 1, self.bacon.tags.count()) self.assertEqual(count + 1, self.bacon.tags.count())
def test_generic_update_or_create_when_created_with_create_defaults(self):
count = self.bacon.tags.count()
tag, created = self.bacon.tags.update_or_create(
# Since, the "stinky" tag doesn't exist create
# a "juicy" tag.
create_defaults={"tag": "juicy"},
defaults={"tag": "uncured"},
tag="stinky",
)
self.assertEqual(tag.tag, "juicy")
self.assertIs(created, True)
self.assertEqual(count + 1, self.bacon.tags.count())
def test_generic_update_or_create_when_updated(self): def test_generic_update_or_create_when_updated(self):
""" """
Should be able to use update_or_create from the generic related manager Should be able to use update_or_create from the generic related manager
@ -74,6 +87,17 @@ class GenericRelationsTests(TestCase):
self.assertEqual(count + 1, self.bacon.tags.count()) self.assertEqual(count + 1, self.bacon.tags.count())
self.assertEqual(tag.tag, "juicy") self.assertEqual(tag.tag, "juicy")
def test_generic_update_or_create_when_updated_with_defaults(self):
count = self.bacon.tags.count()
tag = self.bacon.tags.create(tag="stinky")
self.assertEqual(count + 1, self.bacon.tags.count())
tag, created = self.bacon.tags.update_or_create(
create_defaults={"tag": "uncured"}, defaults={"tag": "juicy"}, id=tag.id
)
self.assertIs(created, False)
self.assertEqual(count + 1, self.bacon.tags.count())
self.assertEqual(tag.tag, "juicy")
async def test_generic_async_aupdate_or_create(self): async def test_generic_async_aupdate_or_create(self):
tag, created = await self.bacon.tags.aupdate_or_create( tag, created = await self.bacon.tags.aupdate_or_create(
id=self.fatty.id, defaults={"tag": "orange"} id=self.fatty.id, defaults={"tag": "orange"}
@ -86,6 +110,22 @@ class GenericRelationsTests(TestCase):
self.assertEqual(await self.bacon.tags.acount(), 3) self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "pink") self.assertEqual(tag.tag, "pink")
async def test_generic_async_aupdate_or_create_with_create_defaults(self):
tag, created = await self.bacon.tags.aupdate_or_create(
id=self.fatty.id,
create_defaults={"tag": "pink"},
defaults={"tag": "orange"},
)
self.assertIs(created, False)
self.assertEqual(tag.tag, "orange")
self.assertEqual(await self.bacon.tags.acount(), 2)
tag, created = await self.bacon.tags.aupdate_or_create(
tag="pink", create_defaults={"tag": "brown"}
)
self.assertIs(created, True)
self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "brown")
def test_generic_get_or_create_when_created(self): def test_generic_get_or_create_when_created(self):
""" """
Should be able to use get_or_create from the generic related manager Should be able to use get_or_create from the generic related manager
@ -550,6 +590,26 @@ class GenericRelationsTests(TestCase):
self.assertFalse(created) self.assertFalse(created)
self.assertEqual(tag.content_object.id, diamond.id) self.assertEqual(tag.content_object.id, diamond.id)
def test_update_or_create_defaults_with_create_defaults(self):
# update_or_create() should work with virtual fields (content_object).
quartz = Mineral.objects.create(name="Quartz", hardness=7)
diamond = Mineral.objects.create(name="Diamond", hardness=7)
tag, created = TaggedItem.objects.update_or_create(
tag="shiny",
create_defaults={"content_object": quartz},
defaults={"content_object": diamond},
)
self.assertIs(created, True)
self.assertEqual(tag.content_object.id, quartz.id)
tag, created = TaggedItem.objects.update_or_create(
tag="shiny",
create_defaults={"content_object": quartz},
defaults={"content_object": diamond},
)
self.assertIs(created, False)
self.assertEqual(tag.content_object.id, diamond.id)
def test_query_content_type(self): def test_query_content_type(self):
msg = "Field 'content_object' does not generate an automatic reverse relation" msg = "Field 'content_object' does not generate an automatic reverse relation"
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):

View File

@ -6,6 +6,7 @@ class Person(models.Model):
last_name = models.CharField(max_length=100) last_name = models.CharField(max_length=100)
birthday = models.DateField() birthday = models.DateField()
defaults = models.TextField() defaults = models.TextField()
create_defaults = models.TextField()
class DefaultPerson(models.Model): class DefaultPerson(models.Model):

View File

@ -330,15 +330,24 @@ class UpdateOrCreateTests(TestCase):
self.assertEqual(p.birthday, date(1940, 10, 10)) self.assertEqual(p.birthday, date(1940, 10, 10))
def test_create_twice(self): def test_create_twice(self):
params = { p, created = Person.objects.update_or_create(
"first_name": "John", first_name="John",
"last_name": "Lennon", last_name="Lennon",
"birthday": date(1940, 10, 10), create_defaults={"birthday": date(1940, 10, 10)},
} defaults={"birthday": date(1950, 2, 2)},
Person.objects.update_or_create(**params) )
# If we execute the exact same statement, it won't create a Person. self.assertIs(created, True)
p, created = Person.objects.update_or_create(**params) self.assertEqual(p.birthday, date(1940, 10, 10))
self.assertFalse(created) # If we execute the exact same statement, it won't create a Person, but
# will update the birthday.
p, created = Person.objects.update_or_create(
first_name="John",
last_name="Lennon",
create_defaults={"birthday": date(1940, 10, 10)},
defaults={"birthday": date(1950, 2, 2)},
)
self.assertIs(created, False)
self.assertEqual(p.birthday, date(1950, 2, 2))
def test_integrity(self): def test_integrity(self):
""" """
@ -391,8 +400,14 @@ class UpdateOrCreateTests(TestCase):
""" """
p = Publisher.objects.create(name="Acme Publishing") p = Publisher.objects.create(name="Acme Publishing")
book, created = p.books.update_or_create(name="The Book of Ed & Fred") book, created = p.books.update_or_create(name="The Book of Ed & Fred")
self.assertTrue(created) self.assertIs(created, True)
self.assertEqual(p.books.count(), 1) self.assertEqual(p.books.count(), 1)
book, created = p.books.update_or_create(
name="Basics of Django", create_defaults={"name": "Advanced Django"}
)
self.assertIs(created, True)
self.assertEqual(book.name, "Advanced Django")
self.assertEqual(p.books.count(), 2)
def test_update_with_related_manager(self): def test_update_with_related_manager(self):
""" """
@ -406,6 +421,14 @@ class UpdateOrCreateTests(TestCase):
book, created = p.books.update_or_create(defaults={"name": name}, id=book.id) book, created = p.books.update_or_create(defaults={"name": name}, id=book.id)
self.assertFalse(created) self.assertFalse(created)
self.assertEqual(book.name, name) self.assertEqual(book.name, name)
# create_defaults should be ignored.
book, created = p.books.update_or_create(
create_defaults={"name": "Basics of Django"},
defaults={"name": name},
id=book.id,
)
self.assertIs(created, False)
self.assertEqual(book.name, name)
self.assertEqual(p.books.count(), 1) self.assertEqual(p.books.count(), 1)
def test_create_with_many(self): def test_create_with_many(self):
@ -418,8 +441,16 @@ class UpdateOrCreateTests(TestCase):
book, created = author.books.update_or_create( book, created = author.books.update_or_create(
name="The Book of Ed & Fred", publisher=p name="The Book of Ed & Fred", publisher=p
) )
self.assertTrue(created) self.assertIs(created, True)
self.assertEqual(author.books.count(), 1) self.assertEqual(author.books.count(), 1)
book, created = author.books.update_or_create(
name="Basics of Django",
publisher=p,
create_defaults={"name": "Advanced Django"},
)
self.assertIs(created, True)
self.assertEqual(book.name, "Advanced Django")
self.assertEqual(author.books.count(), 2)
def test_update_with_many(self): def test_update_with_many(self):
""" """
@ -437,6 +468,14 @@ class UpdateOrCreateTests(TestCase):
) )
self.assertFalse(created) self.assertFalse(created)
self.assertEqual(book.name, name) self.assertEqual(book.name, name)
# create_defaults should be ignored.
book, created = author.books.update_or_create(
create_defaults={"name": "Basics of Django"},
defaults={"name": name},
id=book.id,
)
self.assertIs(created, False)
self.assertEqual(book.name, name)
self.assertEqual(author.books.count(), 1) self.assertEqual(author.books.count(), 1)
def test_defaults_exact(self): def test_defaults_exact(self):
@ -467,6 +506,34 @@ class UpdateOrCreateTests(TestCase):
self.assertFalse(created) self.assertFalse(created)
self.assertEqual(obj.defaults, "another testing") self.assertEqual(obj.defaults, "another testing")
def test_create_defaults_exact(self):
"""
If you have a field named create_defaults and want to use it as an
exact lookup, you need to use 'create_defaults__exact'.
"""
obj, created = Person.objects.update_or_create(
first_name="George",
last_name="Harrison",
create_defaults__exact="testing",
create_defaults={
"birthday": date(1943, 2, 25),
"create_defaults": "testing",
},
)
self.assertIs(created, True)
self.assertEqual(obj.create_defaults, "testing")
obj, created = Person.objects.update_or_create(
first_name="George",
last_name="Harrison",
create_defaults__exact="testing",
create_defaults={
"birthday": date(1943, 2, 25),
"create_defaults": "another testing",
},
)
self.assertIs(created, False)
self.assertEqual(obj.create_defaults, "testing")
def test_create_callable_default(self): def test_create_callable_default(self):
obj, created = Person.objects.update_or_create( obj, created = Person.objects.update_or_create(
first_name="George", first_name="George",
@ -476,6 +543,16 @@ class UpdateOrCreateTests(TestCase):
self.assertIs(created, True) self.assertIs(created, True)
self.assertEqual(obj.birthday, date(1943, 2, 25)) self.assertEqual(obj.birthday, date(1943, 2, 25))
def test_create_callable_create_defaults(self):
obj, created = Person.objects.update_or_create(
first_name="George",
last_name="Harrison",
defaults={},
create_defaults={"birthday": lambda: date(1943, 2, 25)},
)
self.assertIs(created, True)
self.assertEqual(obj.birthday, date(1943, 2, 25))
def test_update_callable_default(self): def test_update_callable_default(self):
Person.objects.update_or_create( Person.objects.update_or_create(
first_name="George", first_name="George",
@ -694,6 +771,12 @@ class InvalidCreateArgumentsTests(TransactionTestCase):
with self.assertRaisesMessage(FieldError, self.msg): with self.assertRaisesMessage(FieldError, self.msg):
Thing.objects.update_or_create(name="a", defaults={"nonexistent": "b"}) Thing.objects.update_or_create(name="a", defaults={"nonexistent": "b"})
def test_update_or_create_with_invalid_create_defaults(self):
with self.assertRaisesMessage(FieldError, self.msg):
Thing.objects.update_or_create(
name="a", create_defaults={"nonexistent": "b"}
)
def test_update_or_create_with_invalid_kwargs(self): def test_update_or_create_with_invalid_kwargs(self):
with self.assertRaisesMessage(FieldError, self.bad_field_msg): with self.assertRaisesMessage(FieldError, self.bad_field_msg):
Thing.objects.update_or_create(name="a", nonexistent="b") Thing.objects.update_or_create(name="a", nonexistent="b")