Commit 0cc059cd authored by Ion Scerbatiuc's avatar Ion Scerbatiuc Committed by Tim Graham
Browse files

Fixed #25172 -- Fixed check framework to work with multiple databases.

parent d0bd5330
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ from django.core import checks, exceptions, validators
# django.core.exceptions. It is retained here for backwards compatibility
# purposes.
from django.core.exceptions import FieldDoesNotExist  # NOQA
from django.db import connection
from django.db import connection, connections, router
from django.db.models.lookups import (
    Lookup, RegisterLookupMixin, Transform, default_lookups,
)
@@ -315,7 +315,11 @@ class Field(RegisterLookupMixin):
            return []

    def _check_backend_specific_checks(self, **kwargs):
        return connection.validation.check_field(self, **kwargs)
        app_label = self.model._meta.app_label
        for db in connections:
            if router.allow_migrate(db, app_label, model=self.model):
                return connections[db].validation.check_field(self, **kwargs)
        return []

    def _check_deprecation_details(self):
        if self.system_check_removed_details is not None:
+43 −0
Original line number Diff line number Diff line
from django.db import connections, models
from django.test import TestCase, mock
from django.test.utils import override_settings

from .tests import IsolateModelsMixin


class TestRouter(object):
    """
    Routes to the 'other' database if the model name starts with 'Other'.
    """
    def allow_migrate(self, db, app_label, model=None, **hints):
        return db == ('other' if model._meta.verbose_name.startswith('other') else 'default')


@override_settings(DATABASE_ROUTERS=[TestRouter()])
class TestMultiDBChecks(IsolateModelsMixin, TestCase):
    multi_db = True

    def _patch_check_field_on(self, db):
        return mock.patch.object(connections[db].validation, 'check_field')

    def test_checks_called_on_the_default_database(self):
        class Model(models.Model):
            field = models.CharField(max_length=100)

        model = Model()
        with self._patch_check_field_on('default') as mock_check_field_default:
            with self._patch_check_field_on('other') as mock_check_field_other:
                model.check()
                self.assertTrue(mock_check_field_default.called)
                self.assertFalse(mock_check_field_other.called)

    def test_checks_called_on_the_other_database(self):
        class OtherModel(models.Model):
            field = models.CharField(max_length=100)

        model = OtherModel()
        with self._patch_check_field_on('other') as mock_check_field_other:
            with self._patch_check_field_on('default') as mock_check_field_default:
                model.check()
                self.assertTrue(mock_check_field_other.called)
                self.assertFalse(mock_check_field_default.called)
+10 −16
Original line number Diff line number Diff line
# -*- encoding: utf-8 -*-
from __future__ import unicode_literals

from types import MethodType

from django.core.checks import Error
from django.db import connection, models
from django.db import connections, models
from django.test import mock

from .base import IsolatedModelsTestCase


def dummy_allow_migrate(db, app_label, **hints):
    # Prevent checks from being run on the 'other' database, which doesn't have
    # its check_field() method mocked in the test.
    return db == 'default'


class BackendSpecificChecksTests(IsolatedModelsTestCase):

    @mock.patch('django.db.models.fields.router.allow_migrate', new=dummy_allow_migrate)
    def test_check_field(self):
        """ Test if backend specific checks are performed. """

        error = Error('an error', hint=None)

        def mock(self, field, **kwargs):
            return [error]

        class Model(models.Model):
            field = models.IntegerField()

        field = Model._meta.get_field('field')

        # Mock connection.validation.check_field method.
        v = connection.validation
        old_check_field = v.check_field
        v.check_field = MethodType(mock, v)
        try:
        with mock.patch.object(connections['default'].validation, 'check_field', return_value=[error]):
            errors = field.check()
        finally:
            # Unmock connection.validation.check_field method.
            v.check_field = old_check_field

        self.assertEqual(errors, [error])
+11 −6
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from . import PostgreSQLTestCase
from .models import (
    ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
    NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
    PostgreSQLModel,
)

try:
@@ -246,16 +247,20 @@ class TestQuerying(PostgreSQLTestCase):
class TestChecks(PostgreSQLTestCase):

    def test_field_checks(self):
        class MyModel(PostgreSQLModel):
            field = ArrayField(models.CharField())
        field.set_attributes_from_name('field')
        errors = field.check()

        model = MyModel()
        errors = model.check()
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].id, 'postgres.E001')

    def test_invalid_base_fields(self):
        class MyModel(PostgreSQLModel):
            field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel'))
        field.set_attributes_from_name('field')
        errors = field.check()

        model = MyModel()
        errors = model.check()
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].id, 'postgres.E002')