From 9461a9ce7cc7cbead4d4a7161940e83183078e1b Mon Sep 17 00:00:00 2001 From: James Ravenscroft Date: Wed, 18 Dec 2024 10:14:16 +0000 Subject: [PATCH] nicer image upload --- docker-compose.yml | 1 + penparse/webui/tasks.py | 31 ++++--- penparse/webui/templates/dashboard.html | 78 ++++++++++++++++-- penparse/webui/test/test_tasks.py | 104 ++++++++++++++++++++++++ 4 files changed, 195 insertions(+), 19 deletions(-) create mode 100644 penparse/webui/test/test_tasks.py diff --git a/docker-compose.yml b/docker-compose.yml index 8771aa0..4e97ac8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,7 @@ services: vllm: image: vllm/vllm-openai:latest command: "--model Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 --quantization gptq " + #command: "--model HuggingFaceTB/SmolVLM-Instruct --max_model_len 4098" volumes: - ~/.cache/huggingface:/root/.cache/huggingface ports: diff --git a/penparse/webui/tasks.py b/penparse/webui/tasks.py index 122bc04..6c86af7 100644 --- a/penparse/webui/tasks.py +++ b/penparse/webui/tasks.py @@ -1,5 +1,6 @@ import base64 import litellm +import openai from loguru import logger from celery import shared_task @@ -26,7 +27,6 @@ Please include whitespace and formatting for headings too. """ - @shared_task def process_memo(memo_id: str): """Run OCR on a memo and store the output""" @@ -70,15 +70,24 @@ def process_memo(memo_id: str): litellm.api_base = settings.OPENAI_API_BASE # os.environ.get("OPENAI_API_BASE") litellm.api_key = settings.OPENAI_API_KEY - response = litellm.completion( - model=settings.OPENAI_MODEL, #os.getenv("MODEL", "openai/gpt-4o"), - messages=[message], - temperature=0.01 - ) + try: + response = litellm.completion( + model=settings.OPENAI_MODEL, #os.getenv("MODEL", "openai/gpt-4o"), + messages=[message], + temperature=0.01 + ) - response.choices[0].message["content"] + response.choices[0].message["content"] - with transaction.atomic(): - memo.content = response.choices[0].message["content"] - memo.status = MemoStatus.Done - memo.save() + with transaction.atomic(): + memo.content = response.choices[0].message["content"] + memo.status = MemoStatus.Done + memo.model_name = settings.OPENAI_MODEL + memo.save() + except openai.OpenAIError as e: + + with transaction.atomic(): + memo.status = MemoStatus.Error + memo.error_message = e.__repr__() + memo.save() + logger.error(e) diff --git a/penparse/webui/templates/dashboard.html b/penparse/webui/templates/dashboard.html index b6734f4..ba88bcf 100644 --- a/penparse/webui/templates/dashboard.html +++ b/penparse/webui/templates/dashboard.html @@ -90,29 +90,91 @@ action="{% url 'upload_document' %}" method="post" enctype="multipart/form-data" + id="upload-form" > {% csrf_token %} -
+
-
+ + + + {% endblock %} diff --git a/penparse/webui/test/test_tasks.py b/penparse/webui/test/test_tasks.py new file mode 100644 index 0000000..0548403 --- /dev/null +++ b/penparse/webui/test/test_tasks.py @@ -0,0 +1,104 @@ +import pytest +from django.core.files.base import ContentFile +from django.core.files.storage import default_storage +from unittest.mock import patch, MagicMock +from penparse.webui.tasks import process_memo +from penparse.webui.models import ImageMemo, MemoStatus + + +@pytest.fixture +def sample_image_memo(db): + memo = ImageMemo.objects.create( + status=MemoStatus.Pending, image_mimetype="image/jpeg" + ) + memo.image.save("test_image.jpg", ContentFile(b"fake image content")) + return memo + + +@pytest.mark.django_db +def test_process_memo_success(sample_image_memo): + with patch("penparse.webui.tasks.litellm") as mock_litellm: + mock_response = MagicMock() + mock_response.choices[0].message = {"content": "Transcribed content"} + mock_litellm.completion.return_value = mock_response + + process_memo(sample_image_memo.id) + + processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) + assert processed_memo.status == MemoStatus.Done + assert processed_memo.content == "Transcribed content" + assert processed_memo.error_message == "" + + +@pytest.mark.django_db +def test_process_memo_missing_image(sample_image_memo): + default_storage.delete(sample_image_memo.image.name) + + process_memo(sample_image_memo.id) + + processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) + assert processed_memo.status == MemoStatus.Error + assert "Image file" in processed_memo.error_message + + +@pytest.mark.django_db +def test_process_memo_api_error(sample_image_memo): + with patch("penparse.webui.tasks.litellm") as mock_litellm: + mock_litellm.completion.side_effect = mock_litellm.APIError("API Error") + + process_memo(sample_image_memo.id) + + processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) + assert processed_memo.status == MemoStatus.Error + assert "API Error" in processed_memo.error_message + + +@pytest.mark.django_db +def test_process_memo_sets_model_name(sample_image_memo): + with ( + patch("penparse.webui.tasks.litellm") as mock_litellm, + patch("penparse.webui.tasks.settings") as mock_settings, + ): + mock_response = MagicMock() + mock_response.choices[0].message = {"content": "Transcribed content"} + mock_litellm.completion.return_value = mock_response + mock_settings.OPENAI_MODEL = "test-model" + + process_memo(sample_image_memo.id) + + processed_memo = ImageMemo.objects.get(id=sample_image_memo.id) + assert processed_memo.model_name == "test-model" + + +@pytest.mark.django_db +def test_process_memo_uses_correct_api_settings(sample_image_memo): + with ( + patch("penparse.webui.tasks.litellm") as mock_litellm, + patch("penparse.webui.tasks.settings") as mock_settings, + ): + mock_response = MagicMock() + mock_response.choices[0].message = {"content": "Transcribed content"} + mock_litellm.completion.return_value = mock_response + mock_settings.OPENAI_API_BASE = "https://test-api-base.com" + mock_settings.OPENAI_API_KEY = "test-api-key" + mock_settings.OPENAI_MODEL = "test-model" + + process_memo(sample_image_memo.id) + + assert mock_litellm.api_base == "https://test-api-base.com" + assert mock_litellm.api_key == "test-api-key" + mock_litellm.completion.assert_called_once_with( + model="test-model", + messages=pytest.approx( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": pytest.ANY}, + {"type": "image_url", "image_url": {"url": pytest.ANY}}, + ], + } + ] + ), + temperature=0.01, + )