Source code for tests.chat.llm.test_openai

"""
Test the OpenAIModel class
"""
import unittest
from unittest.mock import MagicMock, patch

from pykoi.chat.llm.openai import OpenAIModel


[docs]class TestOpenAIModel(unittest.TestCase): """ Test the OpenAIModel class """
[docs] def test_predict(self): """ Test the predict method of the OpenAIModel class """ # Test predicting the next word based on a given message message = "What is the meaning of life?" predicted_word = "42" # Mock the OpenAI.Completion.create behavior mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].text = f"Answer: {predicted_word}" openai_completion_create_mock = MagicMock(return_value=mock_response) # Patch the OpenAI.Completion.create method to use the mocked version with patch( "pykoi.chat.llm.openai.openai.Completion.create", openai_completion_create_mock ): openai_model = OpenAIModel( api_key="fake_api_key", engine="davinci", max_tokens=100, temperature=0.5, ) result = openai_model.predict(message, 1) # Check if the OpenAI.Completion.create method was called with the correct arguments openai_completion_create_mock.assert_called_once_with( engine="davinci", prompt=f"Question: {message}\nAnswer:", max_tokens=100, n=1, stop="\n", temperature=0.5, ) self.assertEqual(result[0], f"Answer: {predicted_word}")
if __name__ == "__main__": unittest.main()