Commit 05656f23 authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Add equality support for Project/ModelState

parent 9027da65
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -59,6 +59,14 @@ class ProjectState(object):
            models[(model_state.app_label, model_state.name.lower())] = model_state
        return cls(models)

    def __eq__(self, other):
        if set(self.models.keys()) != set(other.models.keys()):
            return False
        return all(model == other.models[key] for key, model in self.models.items())

    def __ne__(self, other):
        return not (self == other)


class ModelState(object):
    """
@@ -167,3 +175,16 @@ class ModelState(object):
            if fname == name:
                return field
        raise ValueError("No field called %s on model %s" % (name, self.name))

    def __eq__(self, other):
        return (
            (self.app_label == other.app_label) and
            (self.name == other.name) and
            (len(self.fields) == len(other.fields)) and
            all((k1 == k2 and (f1.deconstruct()[1:] == f2.deconstruct()[1:])) for (k1, f1), (k2, f2) in zip(self.fields, other.fields)) and
            (self.options == other.options) and
            (self.bases == other.bases)
        )

    def __ne__(self, other):
        return not (self == other)
+40 −0
Original line number Diff line number Diff line
@@ -175,3 +175,43 @@ class StateTests(TestCase):
        project_state.add_model_state(ModelState.from_model(F))
        with self.assertRaises(InvalidBasesError):
            project_state.render()

    def test_equality(self):
        """
        Tests that == and != are implemented correctly.
        """

        # Test two things that should be equal
        project_state = ProjectState()
        project_state.add_model_state(ModelState(
            "migrations",
            "Tag",
            [
                ("id", models.AutoField(primary_key=True)),
                ("name", models.CharField(max_length=100)),
                ("hidden", models.BooleanField()),
            ],
            {},
            None,
        ))
        other_state = project_state.clone()
        self.assertEqual(project_state, project_state)
        self.assertEqual(project_state, other_state)
        self.assertEqual(project_state != project_state, False)
        self.assertEqual(project_state != other_state, False)

        # Make a very small change (max_len 99) and see if that affects it
        project_state = ProjectState()
        project_state.add_model_state(ModelState(
            "migrations",
            "Tag",
            [
                ("id", models.AutoField(primary_key=True)),
                ("name", models.CharField(max_length=99)),
                ("hidden", models.BooleanField()),
            ],
            {},
            None,
        ))
        self.assertNotEqual(project_state, other_state)
        self.assertEqual(project_state == other_state, False)