diff options
author | Florian Bruhin <me@the-compiler.org> | 2021-07-09 17:06:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-09 17:06:23 +0200 |
commit | ae6d9009716c85c679b490ab4df92e80b77b3fa5 (patch) | |
tree | 3c03eb22447c3233533dc91679979ee3de8f8991 /tests | |
parent | 8bdab79011680c13b3b2072a3b18a4471c168131 (diff) | |
parent | 71a7674a706b73dfaac5958e9c3bca414c4e8665 (diff) | |
download | qutebrowser-ae6d9009716c85c679b490ab4df92e80b77b3fa5.tar.gz qutebrowser-ae6d9009716c85c679b490ab4df92e80b77b3fa5.zip |
Merge pull request #6567 from lufte/issue6039
Database class
Diffstat (limited to 'tests')
-rw-r--r-- | tests/helpers/fixtures.py | 21 | ||||
-rw-r--r-- | tests/unit/browser/test_history.py | 97 | ||||
-rw-r--r-- | tests/unit/completion/test_histcategory.py | 28 | ||||
-rw-r--r-- | tests/unit/misc/test_sql.py | 156 |
4 files changed, 166 insertions, 136 deletions
diff --git a/tests/helpers/fixtures.py b/tests/helpers/fixtures.py index 7106698be..cd3778b8a 100644 --- a/tests/helpers/fixtures.py +++ b/tests/helpers/fixtures.py @@ -639,15 +639,6 @@ def short_tmpdir(): yield py.path.local(tdir) # pylint: disable=no-member -@pytest.fixture -def init_sql(data_tmpdir): - """Initialize the SQL module, and shut it down after the test.""" - path = str(data_tmpdir / 'test.db') - sql.init(path) - yield - sql.close() - - class ModelValidator: """Validates completion models.""" @@ -682,12 +673,20 @@ def download_stub(win_registry, tmpdir, stubs): @pytest.fixture -def web_history(fake_save_manager, tmpdir, init_sql, config_stub, stubs, +def database(data_tmpdir): + """Create a Database object.""" + db = sql.Database(str(data_tmpdir / 'test.db')) + yield db + db.close() + + +@pytest.fixture +def web_history(fake_save_manager, tmpdir, database, config_stub, stubs, monkeypatch): """Create a WebHistory object.""" config_stub.val.completion.timestamp_format = '%Y-%m-%d' config_stub.val.completion.web_history.max_items = -1 - web_history = history.WebHistory(stubs.FakeHistoryProgress()) + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) monkeypatch.setattr(history, 'web_history', web_history) return web_history diff --git a/tests/unit/browser/test_history.py b/tests/unit/browser/test_history.py index 1a46c5be0..7906d385c 100644 --- a/tests/unit/browser/test_history.py +++ b/tests/unit/browser/test_history.py @@ -31,7 +31,7 @@ from qutebrowser.misc import sql, objects @pytest.fixture(autouse=True) -def prerequisites(config_stub, fake_save_manager, init_sql, fake_args): +def prerequisites(config_stub, fake_save_manager, fake_args): """Make sure everything is ready to initialize a WebHistory.""" config_stub.data = {'general': {'private-browsing': False}} @@ -311,14 +311,14 @@ class TestInit: @pytest.mark.parametrize('backend', [usertypes.Backend.QtWebEngine, usertypes.Backend.QtWebKit]) - def test_init(self, backend, qapp, tmpdir, monkeypatch, cleanup_init): + def test_init(self, backend, qapp, tmpdir, data_tmpdir, monkeypatch, cleanup_init): if backend == usertypes.Backend.QtWebKit: pytest.importorskip('PyQt5.QtWebKitWidgets') else: assert backend == usertypes.Backend.QtWebEngine monkeypatch.setattr(history.objects, 'backend', backend) - history.init(qapp) + history.init(data_tmpdir / f'test_init_{backend}', qapp) assert history.web_history.parent() is qapp try: @@ -368,44 +368,40 @@ class TestDump: class TestRebuild: - # FIXME: Some of those tests might be a bit misleading, as creating a new - # history.WebHistory will regenerate the completion either way with the SQL changes - # in v2.0.0 (because the user version changed from 0 -> 3). - # - # They should be revisited once we can actually create two independent sqlite - # databases and copy the data over, for a "real" test. - - def test_user_version(self, web_history, stubs, monkeypatch): + def test_user_version(self, database, stubs, monkeypatch): """Ensure that completion is regenerated if user_version changes.""" + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) web_history.add_url(QUrl('example.com/1'), redirect=False, atime=1) web_history.add_url(QUrl('example.com/2'), redirect=False, atime=2) web_history.completion.delete('url', 'example.com/2') - # User version always changes, so this won't work - # hist2 = history.WebHistory(progress=stubs.FakeHistoryProgress()) - # assert list(hist2.completion) == [('example.com/1', '', 1)] + hist2 = history.WebHistory(database, progress=stubs.FakeHistoryProgress()) + assert list(hist2.completion) == [('example.com/1', '', 1)] - monkeypatch.setattr(sql, 'user_version_changed', lambda: True) + monkeypatch.setattr(web_history.database, 'user_version_changed', lambda: True) - hist3 = history.WebHistory(progress=stubs.FakeHistoryProgress()) + hist3 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) assert list(hist3.completion) == [ ('example.com/1', '', 1), ('example.com/2', '', 2), ] assert not hist3.metainfo['force_rebuild'] - def test_force_rebuild(self, web_history, stubs): + def test_force_rebuild(self, database, stubs): """Ensure that completion is regenerated if we force a rebuild.""" + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) web_history.add_url(QUrl('example.com/1'), redirect=False, atime=1) web_history.add_url(QUrl('example.com/2'), redirect=False, atime=2) web_history.completion.delete('url', 'example.com/2') - hist2 = history.WebHistory(progress=stubs.FakeHistoryProgress()) - # User version always changes, so this won't work - # assert list(hist2.completion) == [('example.com/1', '', 1)] + hist2 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) + assert list(hist2.completion) == [('example.com/1', '', 1)] hist2.metainfo['force_rebuild'] = True - hist3 = history.WebHistory(progress=stubs.FakeHistoryProgress()) + hist3 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) assert list(hist3.completion) == [ ('example.com/1', '', 1), ('example.com/2', '', 2), @@ -424,7 +420,8 @@ class TestRebuild: web_history.add_url(QUrl('http://example.org'), redirect=False, atime=2) - hist2 = history.WebHistory(progress=stubs.FakeHistoryProgress()) + hist2 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) assert list(hist2.completion) == [('http://example.com', '', 1)] def test_pattern_change_rebuild(self, config_stub, web_history, stubs): @@ -436,14 +433,16 @@ class TestRebuild: web_history.add_url(QUrl('http://example.org'), redirect=False, atime=2) - hist2 = history.WebHistory(progress=stubs.FakeHistoryProgress()) + hist2 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) assert list(hist2.completion) == [ ('http://example.com', '', 1), ] config_stub.val.completion.web_history.exclude = [] - hist3 = history.WebHistory(progress=stubs.FakeHistoryProgress()) + hist3 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) assert list(hist3.completion) == [ ('http://example.com', '', 1), ('http://example.org', '', 2) @@ -454,37 +453,39 @@ class TestRebuild: web_history.add_url(QUrl('example.com/2'), redirect=False, atime=2) # Trigger a completion rebuild - monkeypatch.setattr(sql, 'user_version_changed', lambda: True) + monkeypatch.setattr(web_history.database, 'user_version_changed', lambda: True) progress = stubs.FakeHistoryProgress() - history.WebHistory(progress=progress) + history.WebHistory(web_history.database, progress=progress) assert progress._value == 2 assert progress._started assert progress._finished - def test_interrupted(self, stubs, web_history, monkeypatch): + def test_interrupted(self, stubs, database, monkeypatch): """If we interrupt the rebuilding process, force_rebuild should still be set.""" + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) web_history.add_url(QUrl('example.com/1'), redirect=False, atime=1) + web_history.completion.delete('url', 'example.com/1') progress = stubs.FakeHistoryProgress(raise_on_tick=True) # Trigger a completion rebuild - monkeypatch.setattr(sql, 'user_version_changed', lambda: True) + monkeypatch.setattr(web_history.database, 'user_version_changed', lambda: True) with pytest.raises(Exception, match='tick-tock'): - history.WebHistory(progress=progress) + history.WebHistory(web_history.database, progress=progress) assert web_history.metainfo['force_rebuild'] - # If we now try again, we should get another rebuild. But due to user_version - # always changing, we can't test this at the moment (see the FIXME in the - # docstring for details) + hist2 = history.WebHistory(web_history.database, + progress=stubs.FakeHistoryProgress()) + assert list(hist2.completion) == [('example.com/1', '', 1)] class TestCompletionMetaInfo: @pytest.fixture - def metainfo(self): - return history.CompletionMetaInfo() + def metainfo(self, database): + return history.CompletionMetaInfo(database) def test_contains_keyerror(self, metainfo): with pytest.raises(KeyError): @@ -507,27 +508,27 @@ class TestCompletionMetaInfo: metainfo['excluded_patterns'] = value assert metainfo['excluded_patterns'] == value - # FIXME: It'd be good to test those two things via WebHistory (and not just - # CompletionMetaInfo in isolation), but we can't do that right now - see the - # docstring of TestRebuild for details. - - def test_recovery_no_key(self, metainfo): - metainfo.delete('key', 'force_rebuild') + def test_recovery_no_key(self, caplog, database, stubs): + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) + web_history.metainfo.delete('key', 'force_rebuild') with pytest.raises(sql.BugError, match='No result for single-result query'): - metainfo['force_rebuild'] + web_history.metainfo['force_rebuild'] - metainfo.try_recover() - assert not metainfo['force_rebuild'] + with caplog.at_level(logging.WARNING): + web_history2 = history.WebHistory(database, stubs.FakeHistoryProgress()) + assert not web_history2.metainfo['force_rebuild'] - def test_recovery_no_table(self, metainfo): - sql.Query("DROP TABLE CompletionMetaInfo").run() + def test_recovery_no_table(self, caplog, database, stubs): + web_history = history.WebHistory(database, stubs.FakeHistoryProgress()) + web_history.metainfo.database.query("DROP TABLE CompletionMetaInfo").run() with pytest.raises(sql.BugError, match='no such table: CompletionMetaInfo'): - metainfo['force_rebuild'] + web_history.metainfo['force_rebuild'] - metainfo.try_recover() - assert not metainfo['force_rebuild'] + with caplog.at_level(logging.WARNING): + web_history2 = history.WebHistory(database, stubs.FakeHistoryProgress()) + assert not web_history2.metainfo['force_rebuild'] class TestHistoryProgress: diff --git a/tests/unit/completion/test_histcategory.py b/tests/unit/completion/test_histcategory.py index e0a12943b..cb37fb784 100644 --- a/tests/unit/completion/test_histcategory.py +++ b/tests/unit/completion/test_histcategory.py @@ -32,10 +32,11 @@ from qutebrowser.utils import usertypes @pytest.fixture -def hist(init_sql, config_stub): +def hist(data_tmpdir, config_stub): + db = sql.Database(str(data_tmpdir / 'test_histcategory.db')) config_stub.val.completion.timestamp_format = '%Y-%m-%d' config_stub.val.completion.web_history.max_items = -1 - return sql.SqlTable('CompletionHistory', ['url', 'title', 'last_atime']) + return sql.SqlTable(db, 'CompletionHistory', ['url', 'title', 'last_atime']) @pytest.mark.parametrize('pattern, before, after', [ @@ -99,7 +100,7 @@ def test_set_pattern(pattern, before, after, model_validator, hist): """Validate the filtering and sorting results of set_pattern.""" for row in before: hist.insert({'url': row[0], 'title': row[1], 'last_atime': 1}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) model_validator.set_model(cat) cat.set_pattern(pattern) model_validator.validate(after) @@ -110,7 +111,7 @@ def test_set_pattern_repeated(model_validator, hist): hist.insert({'url': 'example.com/foo', 'title': 'title1', 'last_atime': 1}) hist.insert({'url': 'example.com/bar', 'title': 'title2', 'last_atime': 1}) hist.insert({'url': 'example.com/baz', 'title': 'title3', 'last_atime': 1}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) model_validator.set_model(cat) cat.set_pattern('b') @@ -143,7 +144,7 @@ def test_set_pattern_repeated(model_validator, hist): ], ids=['numbers', 'characters']) def test_set_pattern_long(hist, message_mock, caplog, pattern): hist.insert({'url': 'example.com/foo', 'title': 'title1', 'last_atime': 1}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) with caplog.at_level(logging.ERROR): cat.set_pattern(pattern) msg = message_mock.getmsg(usertypes.MessageLevel.error) @@ -153,7 +154,7 @@ def test_set_pattern_long(hist, message_mock, caplog, pattern): @hypothesis.given(pat=strategies.text()) def test_set_pattern_hypothesis(hist, pat, caplog): hist.insert({'url': 'example.com/foo', 'title': 'title1', 'last_atime': 1}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) with caplog.at_level(logging.ERROR): cat.set_pattern(pat) @@ -202,7 +203,7 @@ def test_sorting(max_items, before, after, model_validator, hist, config_stub): for url, title, atime in before: timestamp = datetime.datetime.strptime(atime, '%Y-%m-%d').timestamp() hist.insert({'url': url, 'title': title, 'last_atime': timestamp}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) model_validator.set_model(cat) cat.set_pattern('') model_validator.validate(after) @@ -211,7 +212,7 @@ def test_sorting(max_items, before, after, model_validator, hist, config_stub): def test_remove_rows(hist, model_validator): hist.insert({'url': 'foo', 'title': 'Foo', 'last_atime': 0}) hist.insert({'url': 'bar', 'title': 'Bar', 'last_atime': 0}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) model_validator.set_model(cat) cat.set_pattern('') hist.delete('url', 'foo') @@ -227,7 +228,7 @@ def test_remove_rows_fetch(hist): 'title': [str(i) for i in range(300)], 'last_atime': [0] * 300, }) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) cat.set_pattern('') # sanity check that we didn't fetch everything up front @@ -245,20 +246,21 @@ def test_remove_rows_fetch(hist): ('%m/%d/%Y %H:%M', '02/27/2018 08:30'), ('', ''), ]) -def test_timestamp_fmt(fmt, expected, model_validator, config_stub, init_sql): +def test_timestamp_fmt(fmt, expected, model_validator, config_stub, data_tmpdir): """Validate the filtering and sorting results of set_pattern.""" config_stub.val.completion.timestamp_format = fmt - hist = sql.SqlTable('CompletionHistory', ['url', 'title', 'last_atime']) + db = sql.Database(str(data_tmpdir / 'test_timestamp_fmt.db')) + hist = sql.SqlTable(db, 'CompletionHistory', ['url', 'title', 'last_atime']) atime = datetime.datetime(2018, 2, 27, 8, 30) hist.insert({'url': 'foo', 'title': '', 'last_atime': atime.timestamp()}) - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) model_validator.set_model(cat) cat.set_pattern('') model_validator.validate([('foo', '', expected)]) def test_skip_duplicate_set(message_mock, caplog, hist): - cat = histcategory.HistoryCategory() + cat = histcategory.HistoryCategory(database=hist.database) cat.set_pattern('foo') cat.set_pattern('foobarbaz') msg = caplog.messages[-1] diff --git a/tests/unit/misc/test_sql.py b/tests/unit/misc/test_sql.py index f6fa68869..80ab7513c 100644 --- a/tests/unit/misc/test_sql.py +++ b/tests/unit/misc/test_sql.py @@ -23,12 +23,12 @@ import pytest import hypothesis from hypothesis import strategies -from PyQt5.QtSql import QSqlError +from PyQt5.QtSql import QSqlDatabase, QSqlError, QSqlQuery from qutebrowser.misc import sql -pytestmark = pytest.mark.usefixtures('init_sql') +pytestmark = pytest.mark.usefixtures('data_tmpdir') class TestUserVersion: @@ -120,23 +120,23 @@ class TestSqlError: assert err.text() == "db text" -def test_init(): - sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_init_table(database): + database.table('Foo', ['name', 'val', 'lucky']) # should not error if table already exists - sql.SqlTable('Foo', ['name', 'val', 'lucky']) + database.table('Foo', ['name', 'val', 'lucky']) -def test_insert(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_insert(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky']) with qtbot.wait_signal(table.changed): table.insert({'name': 'one', 'val': 1, 'lucky': False}) with qtbot.wait_signal(table.changed): table.insert({'name': 'wan', 'val': 1, 'lucky': False}) -def test_insert_replace(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky'], - constraints={'name': 'PRIMARY KEY'}) +def test_insert_replace(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky'], + constraints={'name': 'PRIMARY KEY'}) with qtbot.wait_signal(table.changed): table.insert({'name': 'one', 'val': 1, 'lucky': False}, replace=True) with qtbot.wait_signal(table.changed): @@ -147,8 +147,8 @@ def test_insert_replace(qtbot): table.insert({'name': 'one', 'val': 11, 'lucky': True}, replace=False) -def test_insert_batch(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_insert_batch(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky']) with qtbot.wait_signal(table.changed): table.insert_batch({'name': ['one', 'nine', 'thirteen'], @@ -160,9 +160,9 @@ def test_insert_batch(qtbot): ('thirteen', 13, True)] -def test_insert_batch_replace(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky'], - constraints={'name': 'PRIMARY KEY'}) +def test_insert_batch_replace(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky'], + constraints={'name': 'PRIMARY KEY'}) with qtbot.wait_signal(table.changed): table.insert_batch({'name': ['one', 'nine', 'thirteen'], @@ -185,8 +185,8 @@ def test_insert_batch_replace(qtbot): 'lucky': [True, True]}) -def test_iter(): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_iter(database): + table = database.table('Foo', ['name', 'val', 'lucky']) table.insert({'name': 'one', 'val': 1, 'lucky': False}) table.insert({'name': 'nine', 'val': 9, 'lucky': False}) table.insert({'name': 'thirteen', 'val': 13, 'lucky': True}) @@ -205,15 +205,15 @@ def test_iter(): ([{"a": 2, "b": 5}, {"a": 1, "b": 6}, {"a": 3, "b": 4}], 'a', 'asc', -1, [(1, 6), (2, 5), (3, 4)]), ]) -def test_select(rows, sort_by, sort_order, limit, result): - table = sql.SqlTable('Foo', ['a', 'b']) +def test_select(rows, sort_by, sort_order, limit, result, database): + table = database.table('Foo', ['a', 'b']) for row in rows: table.insert(row) assert list(table.select(sort_by, sort_order, limit)) == result -def test_delete(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_delete(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky']) table.insert({'name': 'one', 'val': 1, 'lucky': False}) table.insert({'name': 'nine', 'val': 9, 'lucky': False}) table.insert({'name': 'thirteen', 'val': 13, 'lucky': True}) @@ -227,8 +227,8 @@ def test_delete(qtbot): assert not list(table) -def test_len(): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_len(database): + table = database.table('Foo', ['name', 'val', 'lucky']) assert len(table) == 0 table.insert({'name': 'one', 'val': 1, 'lucky': False}) assert len(table) == 1 @@ -238,15 +238,15 @@ def test_len(): assert len(table) == 3 -def test_bool(): - table = sql.SqlTable('Foo', ['name']) +def test_bool(database): + table = database.table('Foo', ['name']) assert not table table.insert({'name': 'one'}) assert table -def test_bool_benchmark(benchmark): - table = sql.SqlTable('Foo', ['number']) +def test_bool_benchmark(benchmark, database): + table = database.table('Foo', ['number']) # Simulate a history table table.create_index('NumberIndex', 'number') @@ -258,8 +258,8 @@ def test_bool_benchmark(benchmark): benchmark(run) -def test_contains(): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_contains(database): + table = database.table('Foo', ['name', 'val', 'lucky']) table.insert({'name': 'one', 'val': 1, 'lucky': False}) table.insert({'name': 'nine', 'val': 9, 'lucky': False}) table.insert({'name': 'thirteen', 'val': 13, 'lucky': True}) @@ -279,8 +279,8 @@ def test_contains(): assert not val_query.run(val=10).value() -def test_delete_all(qtbot): - table = sql.SqlTable('Foo', ['name', 'val', 'lucky']) +def test_delete_all(qtbot, database): + table = database.table('Foo', ['name', 'val', 'lucky']) table.insert({'name': 'one', 'val': 1, 'lucky': False}) table.insert({'name': 'nine', 'val': 9, 'lucky': False}) table.insert({'name': 'thirteen', 'val': 13, 'lucky': True}) @@ -295,90 +295,118 @@ def test_version(): class TestSqlQuery: - def test_prepare_error(self): + def test_prepare_error(self, database): with pytest.raises(sql.BugError) as excinfo: - sql.Query('invalid') + database.query('invalid') expected = ('Failed to prepare query "invalid": "near "invalid": ' 'syntax error Unable to execute statement"') assert str(excinfo.value) == expected @pytest.mark.parametrize('forward_only', [True, False]) - def test_forward_only(self, forward_only): - q = sql.Query('SELECT 0 WHERE 0', forward_only=forward_only) + def test_forward_only(self, forward_only, database): + q = database.query('SELECT 0 WHERE 0', forward_only=forward_only) assert q.query.isForwardOnly() == forward_only - def test_iter_inactive(self): - q = sql.Query('SELECT 0') + def test_iter_inactive(self, database): + q = database.query('SELECT 0') with pytest.raises(sql.BugError, match='Cannot iterate inactive query'): next(iter(q)) - def test_iter_empty(self): - q = sql.Query('SELECT 0 AS col WHERE 0') + def test_iter_empty(self, database): + q = database.query('SELECT 0 AS col WHERE 0') q.run() with pytest.raises(StopIteration): next(iter(q)) - def test_iter(self): - q = sql.Query('SELECT 0 AS col') + def test_iter(self, database): + q = database.query('SELECT 0 AS col') q.run() result = next(iter(q)) assert result.col == 0 - def test_iter_multiple(self): - q = sql.Query('VALUES (1), (2), (3);') + def test_iter_multiple(self, database): + q = database.query('VALUES (1), (2), (3);') res = list(q.run()) assert len(res) == 3 assert res[0].column1 == 1 - def test_run_binding(self): - q = sql.Query('SELECT :answer') + def test_run_binding(self, database): + q = database.query('SELECT :answer') q.run(answer=42) assert q.value() == 42 - def test_run_missing_binding(self): - q = sql.Query('SELECT :answer') + def test_run_missing_binding(self, database): + q = database.query('SELECT :answer') with pytest.raises(sql.BugError, match='Missing bound values!'): q.run() - def test_run_batch(self): - q = sql.Query('SELECT :answer') + def test_run_batch(self, database): + q = database.query('SELECT :answer') q.run_batch(values={'answer': [42]}) assert q.value() == 42 - def test_run_batch_missing_binding(self): - q = sql.Query('SELECT :answer') + def test_run_batch_missing_binding(self, database): + q = database.query('SELECT :answer') with pytest.raises(sql.BugError, match='Missing bound values!'): q.run_batch(values={}) - def test_value_missing(self): - q = sql.Query('SELECT 0 WHERE 0') + def test_value_missing(self, database): + q = database.query('SELECT 0 WHERE 0') q.run() - with pytest.raises(sql.BugError, - match='No result for single-result query'): + with pytest.raises(sql.BugError, match='No result for single-result query'): q.value() - def test_num_rows_affected_not_active(self): + def test_num_rows_affected_not_active(self, database): with pytest.raises(AssertionError): - q = sql.Query('SELECT 0') + q = database.query('SELECT 0') q.rows_affected() - def test_num_rows_affected_select(self): + def test_num_rows_affected_select(self, database): with pytest.raises(AssertionError): - q = sql.Query('SELECT 0') + q = database.query('SELECT 0') q.run() q.rows_affected() @pytest.mark.parametrize('condition', [0, 1]) - def test_num_rows_affected(self, condition): - table = sql.SqlTable('Foo', ['name']) + def test_num_rows_affected(self, condition, database): + table = database.table('Foo', ['name']) table.insert({'name': 'helloworld'}) - q = sql.Query(f'DELETE FROM Foo WHERE {condition}') + q = database.query(f'DELETE FROM Foo WHERE {condition}') q.run() assert q.rows_affected() == condition - def test_bound_values(self): - q = sql.Query('SELECT :answer') + def test_bound_values(self, database): + q = database.query('SELECT :answer') q.run(answer=42) assert q.bound_values() == {':answer': 42} + + +class TestTransaction: + + def test_successful_transaction(self, database): + my_table = database.table('my_table', ['column']) + with database.transaction(): + my_table.insert({'column': 1}) + my_table.insert({'column': 2}) + + db2 = QSqlDatabase.addDatabase('QSQLITE', 'db2') + db2.setDatabaseName(database.qt_database().databaseName()) + db2.open() + query = QSqlQuery(db2) + query.exec('select count(*) from my_table') + query.next() + assert query.record().value(0) == 0 + assert database.query('select count(*) from my_table').run().value() == 2 + + def test_failed_transaction(self, database): + my_table = database.table('my_table', ['column']) + try: + with database.transaction(): + my_table.insert({'column': 1}) + my_table.insert({'column': 2}) + raise Exception('something went horribly wrong') + except Exception: + pass + assert database.query('select count(*) from my_table').run().value() == 0 |