Commit 93b3ef9b authored by Lukas Klein's avatar Lukas Klein Committed by Claude Paroz
Browse files

Fixed #24321 -- Improved `utils.http.same_origin` compliance with RFC6454

parent e2d6e146
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -33,6 +33,11 @@ ASCTIME_DATE = re.compile(r'^\w{3} %s %s %s %s$' % (__M, __D2, __T, __Y))
RFC3986_GENDELIMS = str(":/?#[]@")
RFC3986_SUBDELIMS = str("!$&'()*+,;=")

PROTOCOL_TO_PORT = {
    'http': 80,
    'https': 443,
}


def urlquote(url, safe='/'):
    """
@@ -253,8 +258,10 @@ def same_origin(url1, url2):
    """
    p1, p2 = urlparse(url1), urlparse(url2)
    try:
        return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port)
    except ValueError:
        o1 = (p1.scheme, p1.hostname, p1.port or PROTOCOL_TO_PORT[p1.scheme])
        o2 = (p2.scheme, p2.hostname, p2.port or PROTOCOL_TO_PORT[p2.scheme])
        return o1 == o2
    except (ValueError, KeyError):
        return False


+6 −0
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@ class TestUtilsHttp(unittest.TestCase):
        self.assertTrue(http.same_origin('http://foo.com/', 'http://foo.com'))
        # With port
        self.assertTrue(http.same_origin('https://foo.com:8000', 'https://foo.com:8000/'))
        # No port given but according to RFC6454 still the same origin
        self.assertTrue(http.same_origin('http://foo.com', 'http://foo.com:80/'))
        self.assertTrue(http.same_origin('https://foo.com', 'https://foo.com:443/'))

    def test_same_origin_false(self):
        # Different scheme
@@ -28,6 +31,9 @@ class TestUtilsHttp(unittest.TestCase):
        self.assertFalse(http.same_origin('http://foo.com', 'http://foo.com.evil.com'))
        # Different port
        self.assertFalse(http.same_origin('http://foo.com:8000', 'http://foo.com:8001'))
        # No port given
        self.assertFalse(http.same_origin('http://foo.com', 'http://foo.com:8000/'))
        self.assertFalse(http.same_origin('https://foo.com', 'https://foo.com:8000/'))

    def test_urlencode(self):
        # 2-tuples (the norm)