Commit 7b00d902 authored by Claude Paroz's avatar Claude Paroz
Browse files

[py3] Made GeoIP tests pass with Python 3

parent 465a29ab
Loading
Loading
Loading
Loading
+11 −14
Original line number Diff line number Diff line
@@ -137,9 +137,6 @@ class GeoIP(object):
        if not isinstance(query, six.string_types):
            raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__)

        # GeoIP only takes ASCII-encoded strings.
        query = query.encode('ascii')

        # Extra checks for the existence of country and city databases.
        if city_or_country and not (self._country or self._city):
            raise GeoIPException('Invalid GeoIP country and city data files.')
@@ -148,8 +145,8 @@ class GeoIP(object):
        elif city and not self._city:
            raise GeoIPException('Invalid GeoIP city data file: %s' % self._city_file)

        # Return the query string back to the caller.
        return query
        # Return the query string back to the caller. GeoIP only takes bytestrings.
        return force_bytes(query)

    def city(self, query):
        """
@@ -157,33 +154,33 @@ class GeoIP(object):
        Fully Qualified Domain Name (FQDN).  Some information in the dictionary
        may be undefined (None).
        """
        query = self._check_query(query, city=True)
        enc_query = self._check_query(query, city=True)
        if ipv4_re.match(query):
            # If an IP address was passed in
            return GeoIP_record_by_addr(self._city, c_char_p(query))
            return GeoIP_record_by_addr(self._city, c_char_p(enc_query))
        else:
            # If a FQDN was passed in.
            return GeoIP_record_by_name(self._city, c_char_p(query))
            return GeoIP_record_by_name(self._city, c_char_p(enc_query))

    def country_code(self, query):
        "Returns the country code for the given IP Address or FQDN."
        query = self._check_query(query, city_or_country=True)
        enc_query = self._check_query(query, city_or_country=True)
        if self._country:
            if ipv4_re.match(query):
                return GeoIP_country_code_by_addr(self._country, query)
                return GeoIP_country_code_by_addr(self._country, enc_query)
            else:
                return GeoIP_country_code_by_name(self._country, query)
                return GeoIP_country_code_by_name(self._country, enc_query)
        else:
            return self.city(query)['country_code']

    def country_name(self, query):
        "Returns the country name for the given IP Address or FQDN."
        query = self._check_query(query, city_or_country=True)
        enc_query = self._check_query(query, city_or_country=True)
        if self._country:
            if ipv4_re.match(query):
                return GeoIP_country_name_by_addr(self._country, query)
                return GeoIP_country_name_by_addr(self._country, enc_query)
            else:
                return GeoIP_country_name_by_name(self._country, query)
                return GeoIP_country_name_by_name(self._country, enc_query)
        else:
            return self.city(query)['country_name']

+6 −1
Original line number Diff line number Diff line
@@ -92,7 +92,7 @@ def check_string(result, func, cargs):
        free(result)
    else:
        s = ''
    return s
    return s.decode()

GeoIP_database_info = lgeoip.GeoIP_database_info
GeoIP_database_info.restype = geoip_char_p
@@ -100,7 +100,12 @@ GeoIP_database_info.errcheck = check_string

# String output routines.
def string_output(func):
    def _err_check(result, func, cargs):
        if result:
            return result.decode()
        return result
    func.restype = c_char_p
    func.errcheck = _err_check
    return func

GeoIP_country_code_by_addr = string_output(lgeoip.GeoIP_country_code_by_addr)
+0 −6
Original line number Diff line number Diff line
@@ -106,12 +106,6 @@ class GeoIPTest(unittest.TestCase):
        d = g.city("www.osnabrueck.de")
        self.assertEqual('Osnabrück', d['city'])

    def test06_unicode_query(self):
        "Testing that GeoIP accepts unicode string queries, see #17059."
        g = GeoIP()
        d = g.country('whitehouse.gov')
        self.assertEqual('US', d['country_code'])


def suite():
    s = unittest.TestSuite()