PenParse/penparse/webui/test/test_tasks.py

105 lines
4.2 KiB
Python
Raw Normal View History

2024-12-18 10:14:16 +00:00
import pytest
2024-12-18 11:00:07 +00:00
import litellm
from unittest.mock import ANY
2024-12-18 10:14:16 +00:00
from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from unittest.mock import patch, MagicMock
2024-12-18 11:00:07 +00:00
from ..tasks import process_memo
from ..models import ImageMemo, MemoStatus, User
2024-12-18 10:14:16 +00:00
@pytest.mark.django_db
2024-12-18 11:00:07 +00:00
class TestProcessMemo:
@pytest.fixture
def sample_image_memo(self, db):
user1 = User.objects.create_user(email="user1@test.com", password="password1")
memo = ImageMemo.objects.create(
author=user1,
status=MemoStatus.Pending,
image_mimetype="image/jpeg",
)
memo.image.save("test_image.jpg", ContentFile(b"fake image content"))
return memo
2024-12-18 10:14:16 +00:00
2024-12-18 11:00:07 +00:00
def test_process_memo_success(self, sample_image_memo):
with patch("webui.tasks.litellm") as mock_litellm:
mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.return_value = mock_response
2024-12-18 10:14:16 +00:00
2024-12-18 11:00:07 +00:00
process_memo(sample_image_memo.id)
2024-12-18 10:14:16 +00:00
2024-12-18 11:00:07 +00:00
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.status == MemoStatus.Done
assert processed_memo.content == "Transcribed content"
assert processed_memo.error_message == None
2024-12-18 10:14:16 +00:00
2024-12-18 11:00:07 +00:00
def test_process_memo_missing_image(self, sample_image_memo):
default_storage.delete(sample_image_memo.image.name)
2024-12-18 10:14:16 +00:00
process_memo(sample_image_memo.id)
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.status == MemoStatus.Error
2024-12-18 11:00:07 +00:00
assert "Image file" in processed_memo.error_message
def test_process_memo_api_error(self, sample_image_memo):
with patch("webui.tasks.litellm") as mock_litellm:
mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.side_effect = litellm.APIError(400, "API Error", "openai", "any")
process_memo(sample_image_memo.id)
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.status == MemoStatus.Error
assert "API Error" in processed_memo.error_message
def test_process_memo_sets_model_name(self, sample_image_memo):
with (
patch("webui.tasks.litellm") as mock_litellm,
patch("webui.tasks.settings") as mock_settings,
):
mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.return_value = mock_response
mock_settings.OPENAI_MODEL = "test-model"
process_memo(sample_image_memo.id)
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.model_name == "test-model"
def test_process_memo_uses_correct_api_settings(self, sample_image_memo):
with (
patch("webui.tasks.litellm") as mock_litellm,
patch("webui.tasks.settings") as mock_settings,
):
mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.return_value = mock_response
mock_settings.OPENAI_API_BASE = "https://test-api-base.com"
mock_settings.OPENAI_API_KEY = "test-api-key"
mock_settings.OPENAI_MODEL = "test-model"
process_memo(sample_image_memo.id)
assert mock_litellm.api_base == "https://test-api-base.com"
assert mock_litellm.api_key == "test-api-key"
mock_litellm.completion.assert_called_once_with(
model="test-model",
messages=pytest.approx(
[
{
"role": "user",
"content": [
{"type": "text", "text": ANY},
{"type": "image_url", "image_url": {"url": ANY}},
],
}
]
),
temperature=0.01,
)