implement tests for worker
Run Tests / Run Tests (push) Successful in 39s Details

This commit is contained in:
James Ravenscroft 2024-12-18 11:00:07 +00:00
parent 9461a9ce7c
commit 9496c637ce
2 changed files with 84 additions and 80 deletions

View File

@ -9,7 +9,7 @@ https://docs.djangoproject.com/en/4.2/topics/settings/
For the full list of settings and their values, see For the full list of settings and their values, see
https://docs.djangoproject.com/en/4.2/ref/settings/ https://docs.djangoproject.com/en/4.2/ref/settings/
""" """
import os
from pathlib import Path from pathlib import Path
# Build paths inside the project like this: BASE_DIR / 'subdir'. # Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -130,3 +130,7 @@ DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
CELERY_BROKER_URL = "amqp://rabbit:rabbit@localhost//" CELERY_BROKER_URL = "amqp://rabbit:rabbit@localhost//"
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "openai/gpt-4o")

View File

@ -1,23 +1,29 @@
import pytest import pytest
import litellm
from unittest.mock import ANY
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.core.files.storage import default_storage from django.core.files.storage import default_storage
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from penparse.webui.tasks import process_memo from ..tasks import process_memo
from penparse.webui.models import ImageMemo, MemoStatus from ..models import ImageMemo, MemoStatus, User
@pytest.mark.django_db
class TestProcessMemo:
@pytest.fixture @pytest.fixture
def sample_image_memo(db): def sample_image_memo(self, db):
user1 = User.objects.create_user(email="user1@test.com", password="password1")
memo = ImageMemo.objects.create( memo = ImageMemo.objects.create(
status=MemoStatus.Pending, image_mimetype="image/jpeg" author=user1,
status=MemoStatus.Pending,
image_mimetype="image/jpeg",
) )
memo.image.save("test_image.jpg", ContentFile(b"fake image content")) memo.image.save("test_image.jpg", ContentFile(b"fake image content"))
return memo return memo
def test_process_memo_success(self, sample_image_memo):
@pytest.mark.django_db with patch("webui.tasks.litellm") as mock_litellm:
def test_process_memo_success(sample_image_memo):
with patch("penparse.webui.tasks.litellm") as mock_litellm:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"} mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.return_value = mock_response mock_litellm.completion.return_value = mock_response
@ -27,11 +33,9 @@ def test_process_memo_success(sample_image_memo):
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.status == MemoStatus.Done assert processed_memo.status == MemoStatus.Done
assert processed_memo.content == "Transcribed content" assert processed_memo.content == "Transcribed content"
assert processed_memo.error_message == "" assert processed_memo.error_message == None
def test_process_memo_missing_image(self, sample_image_memo):
@pytest.mark.django_db
def test_process_memo_missing_image(sample_image_memo):
default_storage.delete(sample_image_memo.image.name) default_storage.delete(sample_image_memo.image.name)
process_memo(sample_image_memo.id) process_memo(sample_image_memo.id)
@ -40,11 +44,11 @@ def test_process_memo_missing_image(sample_image_memo):
assert processed_memo.status == MemoStatus.Error assert processed_memo.status == MemoStatus.Error
assert "Image file" in processed_memo.error_message assert "Image file" in processed_memo.error_message
def test_process_memo_api_error(self, sample_image_memo):
@pytest.mark.django_db with patch("webui.tasks.litellm") as mock_litellm:
def test_process_memo_api_error(sample_image_memo): mock_response = MagicMock()
with patch("penparse.webui.tasks.litellm") as mock_litellm: mock_response.choices[0].message = {"content": "Transcribed content"}
mock_litellm.completion.side_effect = mock_litellm.APIError("API Error") mock_litellm.completion.side_effect = litellm.APIError(400, "API Error", "openai", "any")
process_memo(sample_image_memo.id) process_memo(sample_image_memo.id)
@ -52,12 +56,10 @@ def test_process_memo_api_error(sample_image_memo):
assert processed_memo.status == MemoStatus.Error assert processed_memo.status == MemoStatus.Error
assert "API Error" in processed_memo.error_message assert "API Error" in processed_memo.error_message
def test_process_memo_sets_model_name(self, sample_image_memo):
@pytest.mark.django_db
def test_process_memo_sets_model_name(sample_image_memo):
with ( with (
patch("penparse.webui.tasks.litellm") as mock_litellm, patch("webui.tasks.litellm") as mock_litellm,
patch("penparse.webui.tasks.settings") as mock_settings, patch("webui.tasks.settings") as mock_settings,
): ):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"} mock_response.choices[0].message = {"content": "Transcribed content"}
@ -69,12 +71,10 @@ def test_process_memo_sets_model_name(sample_image_memo):
processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
assert processed_memo.model_name == "test-model" assert processed_memo.model_name == "test-model"
def test_process_memo_uses_correct_api_settings(self, sample_image_memo):
@pytest.mark.django_db
def test_process_memo_uses_correct_api_settings(sample_image_memo):
with ( with (
patch("penparse.webui.tasks.litellm") as mock_litellm, patch("webui.tasks.litellm") as mock_litellm,
patch("penparse.webui.tasks.settings") as mock_settings, patch("webui.tasks.settings") as mock_settings,
): ):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.choices[0].message = {"content": "Transcribed content"} mock_response.choices[0].message = {"content": "Transcribed content"}
@ -94,8 +94,8 @@ def test_process_memo_uses_correct_api_settings(sample_image_memo):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": pytest.ANY}, {"type": "text", "text": ANY},
{"type": "image_url", "image_url": {"url": pytest.ANY}}, {"type": "image_url", "image_url": {"url": ANY}},
], ],
} }
] ]