diff --git a/tests/test_build.py b/tests/test_build.py index c764dbd..4c36dd9 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,9 +1,17 @@ +from nose.tools import eq_ from wordfreq.build import load_all_data +from wordfreq.query import wordlist_info from wordfreq.transfer import download_and_extract_raw_data from wordfreq import config import os import tempfile import shutil +import sqlite3 + + +def flatten_list_of_dicts(list_of_dicts): + things = [sorted(d.items()) for d in list_of_dicts] + return sorted(things) def test_build(): @@ -17,7 +25,10 @@ def test_build(): try: db_file = os.path.join(tempdir, 'test.db') load_all_data(config.RAW_DATA_DIR, db_file) + conn = sqlite3.connect(db_file) - assert open(db_file).read() == open(config.DB_FILENAME).read() + # Compare the information we got to the information in the default DB. + eq_(flatten_list_of_dicts(wordlist_info(conn)), + flatten_list_of_dicts(wordlist_info(None))) finally: shutil.rmtree(tempdir) diff --git a/tests/test_queries.py b/tests/test_queries.py index 165ff47..b307e0c 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from nose.tools import eq_, assert_almost_equal, assert_greater from wordfreq.query import (word_frequency, average_frequency, wordlist_size, - get_wordlists, metanl_word_frequency) + wordlist_info, metanl_word_frequency) def test_freq_examples(): @@ -40,7 +40,7 @@ def _check_normalized_frequencies(wordlist, lang): def test_normalized_frequencies(): - for list_info in get_wordlists(): + for list_info in wordlist_info(): wordlist = list_info['wordlist'] lang = list_info['lang'] yield _check_normalized_frequencies, wordlist, lang diff --git a/wordfreq/query.py b/wordfreq/query.py index a33b7cc..d2cffbd 100644 --- a/wordfreq/query.py +++ b/wordfreq/query.py @@ -101,8 +101,18 @@ def iter_wordlist(wordlist='multi', lang=None): return results -def get_wordlists(): - c = CONN.cursor() +def wordlist_info(connection=None): + """ + Get info about all the wordlists in a database, returning their + list name, language, and number of words as 'wordlist', 'lang', + and 'count' respectively. + + The database connection can be given as an argument, in order to get + information about a database other than the default configured one. + """ + if connection is None: + connection = CONN + c = connection.cursor() results = c.execute( "SELECT wordlist, lang, count(*) from words GROUP BY wordlist, lang" )