Commit 244e2b71 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen Committed by Andrew Godwin
Browse files

Fixed #20946 -- model inheritance + m2m failure

Cleaned up the internal implementation of m2m fields by removing
related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing
if asked for the foreign key value on foreign key to parent model's
primary key when child model had different primary key field.
parent 7775ced9
Loading
Loading
Loading
Loading
+15 −23
Original line number Diff line number Diff line
@@ -501,8 +501,6 @@ def create_many_related_manager(superclass, rel):
            self.through = through
            self.prefetch_cache_name = prefetch_cache_name
            self.related_val = source_field.get_foreign_related_value(instance)
            # Used for single column related auto created models
            self._fk_val = self.related_val[0]
            if None in self.related_val:
                raise ValueError('"%r" needs to have a value for field "%s" before '
                                 'this many-to-many relationship can be used.' %
@@ -515,18 +513,6 @@ def create_many_related_manager(superclass, rel):
                                 "a many-to-many relationship can be used." %
                                 instance.__class__.__name__)

        def _get_fk_val(self, obj, field_name):
            """
            Returns the correct value for this relationship's foreign key. This
            might be something else than pk value when to_field is used.
            """
            fk = self.through._meta.get_field(field_name)
            if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
                attname = fk.rel.get_related_field().get_attname()
                return fk.get_prep_lookup('exact', getattr(obj, attname))
            else:
                return obj.pk

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
@@ -624,11 +610,12 @@ def create_many_related_manager(superclass, rel):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                               (obj, self.instance._state.db, obj._state.db))
                        fk_val = self._get_fk_val(obj, target_field_name)
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(obj)[0]
                        if fk_val is None:
                            raise ValueError('Cannot add "%r": the value for field "%s" is None' %
                                             (obj, target_field_name))
                        new_ids.add(self._get_fk_val(obj, target_field_name))
                        new_ids.add(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
                    else:
@@ -636,7 +623,7 @@ def create_many_related_manager(superclass, rel):
                db = router.db_for_write(self.through, instance=self.instance)
                vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                vals = vals.filter(**{
                    source_field_name: self._fk_val,
                    source_field_name: self.related_val[0],
                    '%s__in' % target_field_name: new_ids,
                })
                new_ids = new_ids - set(vals)
@@ -650,7 +637,7 @@ def create_many_related_manager(superclass, rel):
                # Add the ones that aren't there already
                self.through._default_manager.using(db).bulk_create([
                    self.through(**{
                        '%s_id' % source_field_name: self._fk_val,
                        '%s_id' % source_field_name: self.related_val[0],
                        '%s_id' % target_field_name: obj_id,
                    })
                    for obj_id in new_ids
@@ -674,7 +661,9 @@ def create_many_related_manager(superclass, rel):
                old_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        old_ids.add(self._get_fk_val(obj, target_field_name))
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(obj)[0]
                        old_ids.add(fk_val)
                    else:
                        old_ids.add(obj)
                # Work out what DB we're operating on
@@ -688,7 +677,7 @@ def create_many_related_manager(superclass, rel):
                        model=self.model, pk_set=old_ids, using=db)
                # Remove the specified objects from the join table
                self.through._default_manager.using(db).filter(**{
                    source_field_name: self._fk_val,
                    source_field_name: self.related_val[0],
                    '%s__in' % target_field_name: old_ids
                }).delete()
                if self.reverse or source_field_name == self.source_field_name:
@@ -994,9 +983,12 @@ class ForeignObject(RelatedField):
            # Gotcha: in some cases (like fixture loading) a model can have
            # different values in parent_ptr_id and parent's id. So, use
            # instance.pk (that is, parent_ptr_id) when asked for instance.id.
            opts = instance._meta
            if field.primary_key:
                possible_parent_link = opts.get_ancestor_link(field.model)
                if not possible_parent_link or possible_parent_link.primary_key:
                    ret.append(instance.pk)
            else:
                    continue
            ret.append(getattr(instance, field.attname))
        return tuple(ret)

+6 −0
Original line number Diff line number Diff line
@@ -162,3 +162,9 @@ class Mixin(object):

class MixinModel(models.Model, Mixin):
    pass

class Base(models.Model):
    titles = models.ManyToManyField(Title)

class SubBase(Base):
    sub_id = models.IntegerField(primary_key=True)
+15 −1
Original line number Diff line number Diff line
@@ -10,7 +10,8 @@ from django.utils import six

from .models import (
    Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
    Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel)
    Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
    Title, Base, SubBase)


class ModelInheritanceTests(TestCase):
@@ -357,3 +358,16 @@ class ModelInheritanceTests(TestCase):
            [Place.objects.get(pk=s.pk)],
            lambda x: x
        )

    def test_custompk_m2m(self):
        b = Base.objects.create()
        b.titles.add(Title.objects.create(title="foof"))
        s = SubBase.objects.create(sub_id=b.id)
        b = Base.objects.get(pk=s.id)
        self.assertNotEqual(b.pk, s.pk)
        # Low-level test for related_val
        self.assertEqual(s.titles.related_val, (s.id,))
        # Higher level test for correct query values (title foof not
        # accidentally found).
        self.assertQuerysetEqual(
            s.titles.all(), [])