Commit f105fbe5 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

[1.5.x] Fixed #18823 -- Ensured m2m.clear() works when using through+to_field

There was a potential data-loss issue involved -- when clearing
instance's m2m assignments it was possible some other instance's
m2m data was deleted instead.

This commit also improved None handling for to_field cases.

Backpatch of 611c4d6f
parent 13b4d448
Loading
Loading
Loading
Loading
+36 −9
Original line number Diff line number Diff line
@@ -573,9 +573,31 @@ def create_many_related_manager(superclass, rel):
            self.reverse = reverse
            self.through = through
            self.prefetch_cache_name = prefetch_cache_name
            self._pk_val = self.instance.pk
            if self._pk_val is None:
                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
            self._fk_val = self._get_fk_val(instance, source_field_name)
            if self._fk_val is None:
                raise ValueError('"%r" needs to have a value for field "%s" before '
                                 'this many-to-many relationship can be used.' %
                                 (instance, source_field_name))
            # Even if this relation is not to pk, we require still pk value.
            # The wish is that the instance has been already saved to DB,
            # although having a pk value isn't a guarantee of that.
            if instance.pk is None:
                raise ValueError("%r instance needs to have a primary key value before "
                                 "a many-to-many relationship can be used." %
                                 instance.__class__.__name__)


        def _get_fk_val(self, obj, field_name):
            """
            Returns the correct value for this relationship's foreign key. This
            might be something else than pk value when to_field is used.
            """
            fk = self.through._meta.get_field(field_name)
            if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
                attname = fk.rel.get_related_field().get_attname()
                return fk.get_prep_lookup('exact', getattr(obj, attname))
            else:
                return obj.pk

        def get_query_set(self):
            try:
@@ -677,7 +699,11 @@ def create_many_related_manager(superclass, rel):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                               (obj, self.instance._state.db, obj._state.db))
                        new_ids.add(obj.pk)
                        fk_val = self._get_fk_val(obj, target_field_name)
                        if fk_val is None:
                            raise ValueError('Cannot add "%r": the value for field "%s" is None' %
                                             (obj, target_field_name))
                        new_ids.add(self._get_fk_val(obj, target_field_name))
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
                    else:
@@ -685,7 +711,7 @@ def create_many_related_manager(superclass, rel):
                db = router.db_for_write(self.through, instance=self.instance)
                vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                vals = vals.filter(**{
                    source_field_name: self._pk_val,
                    source_field_name: self._fk_val,
                    '%s__in' % target_field_name: new_ids,
                })
                new_ids = new_ids - set(vals)
@@ -699,11 +725,12 @@ def create_many_related_manager(superclass, rel):
                # Add the ones that aren't there already
                self.through._default_manager.using(db).bulk_create([
                    self.through(**{
                        '%s_id' % source_field_name: self._pk_val,
                        '%s_id' % source_field_name: self._fk_val,
                        '%s_id' % target_field_name: obj_id,
                    })
                    for obj_id in new_ids
                ])

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
@@ -722,7 +749,7 @@ def create_many_related_manager(superclass, rel):
                old_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        old_ids.add(obj.pk)
                        old_ids.add(self._get_fk_val(obj, target_field_name))
                    else:
                        old_ids.add(obj)
                # Work out what DB we're operating on
@@ -736,7 +763,7 @@ def create_many_related_manager(superclass, rel):
                        model=self.model, pk_set=old_ids, using=db)
                # Remove the specified objects from the join table
                self.through._default_manager.using(db).filter(**{
                    source_field_name: self._pk_val,
                    source_field_name: self._fk_val,
                    '%s__in' % target_field_name: old_ids
                }).delete()
                if self.reverse or source_field_name == self.source_field_name:
@@ -756,7 +783,7 @@ def create_many_related_manager(superclass, rel):
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=None, using=db)
            self.through._default_manager.using(db).filter(**{
                source_field_name: self._pk_val
                source_field_name: self._fk_val
            }).delete()
            if self.reverse or source_field_name == self.source_field_name:
                # Don't send the signal when we are clearing the
+4 −4
Original line number Diff line number Diff line
@@ -62,18 +62,18 @@ class B(models.Model):
# Using to_field on the through model
@python_2_unicode_compatible
class Car(models.Model):
    make = models.CharField(max_length=20, unique=True)
    make = models.CharField(max_length=20, unique=True, null=True)
    drivers = models.ManyToManyField('Driver', through='CarDriver')

    def __str__(self):
        return self.make
        return "%s" % self.make

@python_2_unicode_compatible
class Driver(models.Model):
    name = models.CharField(max_length=20, unique=True)
    name = models.CharField(max_length=20, unique=True, null=True)

    def __str__(self):
        return self.name
        return "%s" % self.name

@python_2_unicode_compatible
class CarDriver(models.Model):
+88 −2
Original line number Diff line number Diff line
@@ -123,6 +123,14 @@ class ToFieldThroughTests(TestCase):
        self.car = Car.objects.create(make="Toyota")
        self.driver = Driver.objects.create(name="Ryan Briscoe")
        CarDriver.objects.create(car=self.car, driver=self.driver)
        # We are testing if wrong objects get deleted due to using wrong
        # field value in m2m queries. So, it is essential that the pk
        # numberings do not match.
        # Create one intentionally unused driver to mix up the autonumbering
        self.unused_driver = Driver.objects.create(name="Barney Gumble")
        # And two intentionally unused cars.
        self.unused_car1 = Car.objects.create(make="Trabant")
        self.unused_car2 = Car.objects.create(make="Wartburg")

    def test_to_field(self):
        self.assertQuerysetEqual(
@@ -136,6 +144,84 @@ class ToFieldThroughTests(TestCase):
            ["<Car: Toyota>"]
        )

    def test_to_field_clear_reverse(self):
        self.driver.car_set.clear()
        self.assertQuerysetEqual(
            self.driver.car_set.all(),[])

    def test_to_field_clear(self):
        self.car.drivers.clear()
        self.assertQuerysetEqual(
            self.car.drivers.all(),[])

    # Low level tests for _add_items and _remove_items. We test these methods
    # because .add/.remove aren't available for m2m fields with through, but
    # through is the only way to set to_field currently. We do want to make
    # sure these methods are ready if the ability to use .add or .remove with
    # to_field relations is added some day.
    def test_add(self):
        self.assertQuerysetEqual(
            self.car.drivers.all(),
            ["<Driver: Ryan Briscoe>"]
        )
        # Yikes - barney is going to drive...
        self.car.drivers._add_items('car', 'driver', self.unused_driver)
        self.assertQuerysetEqual(
            self.car.drivers.all(),
            ["<Driver: Ryan Briscoe>", "<Driver: Barney Gumble>"]
        )

    def test_add_null(self):
        nullcar = Car.objects.create(make=None)
        with self.assertRaises(ValueError):
            nullcar.drivers._add_items('car', 'driver', self.unused_driver)

    def test_add_related_null(self):
        nulldriver = Driver.objects.create(name=None)
        with self.assertRaises(ValueError):
            self.car.drivers._add_items('car', 'driver', nulldriver)

    def test_add_reverse(self):
        car2 = Car.objects.create(make="Honda")
        self.assertQuerysetEqual(
            self.driver.car_set.all(),
            ["<Car: Toyota>"]
        )
        self.driver.car_set._add_items('driver', 'car', car2)
        self.assertQuerysetEqual(
            self.driver.car_set.all(),
            ["<Car: Toyota>", "<Car: Honda>"]
        )

    def test_add_null_reverse(self):
        nullcar = Car.objects.create(make=None)
        with self.assertRaises(ValueError):
            self.driver.car_set._add_items('driver', 'car', nullcar)

    def test_add_null_reverse_related(self):
        nulldriver = Driver.objects.create(name=None)
        with self.assertRaises(ValueError):
            nulldriver.car_set._add_items('driver', 'car', self.car)

    def test_remove(self):
        self.assertQuerysetEqual(
            self.car.drivers.all(),
            ["<Driver: Ryan Briscoe>"]
        )
        self.car.drivers._remove_items('car', 'driver', self.driver)
        self.assertQuerysetEqual(
            self.car.drivers.all(),[])

    def test_remove_reverse(self):
        self.assertQuerysetEqual(
            self.driver.car_set.all(),
            ["<Car: Toyota>"]
        )
        self.driver.car_set._remove_items('driver', 'car', self.car)
        self.assertQuerysetEqual(
            self.driver.car_set.all(),[])


class ThroughLoadDataTestCase(TestCase):
    fixtures = ["m2m_through"]