Source code for tests.chat.db.test_qa_database

"""Test the QuestionAnswerDatabase class"""
import datetime
import os
import sqlite3
import unittest

from pykoi.chat.db.qa_database import QuestionAnswerDatabase

# Define a temporary database file for testing
TEST_DB_FILE = "test_qd.db"


[docs]class TestQuestionAnswerDatabase(unittest.TestCase): """ Test the QuestionAnswerDatabase class. """
[docs] def setUp(self): # Create a temporary database for testing self.qadb = QuestionAnswerDatabase(db_file=TEST_DB_FILE, debug=False)
[docs] def tearDown(self): # Remove the temporary database and close connections after each test self.qadb.close_connection() os.remove(TEST_DB_FILE)
[docs] def test_create_table(self): """ Test whether the table is created correctly. """ # Test whether the table is created correctly conn = sqlite3.connect(TEST_DB_FILE) cursor = conn.cursor() cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='question_answer'" ) table_exists = cursor.fetchone() self.assertTrue(table_exists) # Clean up cursor.close() conn.close()
[docs] def test_insert_and_retrieve_question_answer(self): """ Test inserting and retrieving a question-answer pair """ question = "What is the meaning of life?" answer = "42" # Insert data and get the ID qa_id = self.qadb.insert_question_answer(question, answer) # Retrieve the data rows = self.qadb.retrieve_all_question_answers() # Check if the data was inserted correctly self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], qa_id) self.assertEqual(rows[0][1], question) self.assertEqual(rows[0][2], answer) self.assertEqual(rows[0][3], "n/a") # Default vote status
[docs] def test_update_vote_status(self): """ Test updating the vote status of a question-answer pair. """ question = "What is the meaning of life?" answer = "42" # Insert data and get the ID qa_id = self.qadb.insert_question_answer(question, answer) # Update the vote status new_vote_status = "up" self.qadb.update_vote_status(qa_id, new_vote_status) # Retrieve the data rows = self.qadb.retrieve_all_question_answers() # Check if the vote status was updated correctly self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], qa_id) self.assertEqual(rows[0][3], new_vote_status)
[docs] def test_save_to_csv(self): """ Test saving data to a CSV file """ question1 = "What is the meaning of life?" answer1 = "42" question2 = "What is the best programming language?" answer2 = "Python" # Insert data timestamp = datetime.datetime.now() self.qadb.insert_question_answer(question1, answer1) self.qadb.insert_question_answer(question2, answer2) # Save to CSV self.qadb.save_to_csv("test_csv_file.csv") # Check if the CSV file was created and contains the correct data self.assertTrue(os.path.exists("test_csv_file.csv")) with open("test_csv_file.csv", "r") as file: lines = file.readlines() # Verify the CSV file content timestamp_trim = 10 # Trim 10 characters from the timestamp self.assertEqual(len(lines), 3) # Header + 2 rows self.assertEqual(lines[0].strip(), "ID,Question,Answer,Vote Status,Timestamp") self.assertEqual( lines[1].strip()[:-timestamp_trim], f"1,{question1},{answer1},n/a,{timestamp}"[:-timestamp_trim], ) # Default vote status self.assertEqual( lines[2].strip()[:-timestamp_trim], f"2,{question2},{answer2},n/a,{timestamp}"[:-timestamp_trim], ) # Default vote status # Clean up os.remove("test_csv_file.csv")
if __name__ == "__main__": unittest.main()