diff --git a/docker-compose.yml b/docker-compose.yml
index 8771aa0..4e97ac8 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -8,6 +8,7 @@ services:
vllm:
image: vllm/vllm-openai:latest
command: "--model Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 --quantization gptq "
+ #command: "--model HuggingFaceTB/SmolVLM-Instruct --max_model_len 4098"
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
ports:
diff --git a/penparse/webui/tasks.py b/penparse/webui/tasks.py
index 122bc04..6c86af7 100644
--- a/penparse/webui/tasks.py
+++ b/penparse/webui/tasks.py
@@ -1,5 +1,6 @@
import base64
import litellm
+import openai
from loguru import logger
from celery import shared_task
@@ -26,7 +27,6 @@ Please include whitespace and formatting for headings too.
"""
-
@shared_task
def process_memo(memo_id: str):
"""Run OCR on a memo and store the output"""
@@ -70,15 +70,24 @@ def process_memo(memo_id: str):
litellm.api_base = settings.OPENAI_API_BASE # os.environ.get("OPENAI_API_BASE")
litellm.api_key = settings.OPENAI_API_KEY
- response = litellm.completion(
- model=settings.OPENAI_MODEL, #os.getenv("MODEL", "openai/gpt-4o"),
- messages=[message],
- temperature=0.01
- )
+ try:
+ response = litellm.completion(
+ model=settings.OPENAI_MODEL, #os.getenv("MODEL", "openai/gpt-4o"),
+ messages=[message],
+ temperature=0.01
+ )
- response.choices[0].message["content"]
+ response.choices[0].message["content"]
- with transaction.atomic():
- memo.content = response.choices[0].message["content"]
- memo.status = MemoStatus.Done
- memo.save()
+ with transaction.atomic():
+ memo.content = response.choices[0].message["content"]
+ memo.status = MemoStatus.Done
+ memo.model_name = settings.OPENAI_MODEL
+ memo.save()
+ except openai.OpenAIError as e:
+
+ with transaction.atomic():
+ memo.status = MemoStatus.Error
+ memo.error_message = e.__repr__()
+ memo.save()
+ logger.error(e)
diff --git a/penparse/webui/templates/dashboard.html b/penparse/webui/templates/dashboard.html
index b6734f4..ba88bcf 100644
--- a/penparse/webui/templates/dashboard.html
+++ b/penparse/webui/templates/dashboard.html
@@ -90,29 +90,91 @@
action="{% url 'upload_document' %}"
method="post"
enctype="multipart/form-data"
+ id="upload-form"
>
{% csrf_token %}
-
+
-
+
+
+
+
{% endblock %}
diff --git a/penparse/webui/test/test_tasks.py b/penparse/webui/test/test_tasks.py
new file mode 100644
index 0000000..0548403
--- /dev/null
+++ b/penparse/webui/test/test_tasks.py
@@ -0,0 +1,104 @@
+import pytest
+from django.core.files.base import ContentFile
+from django.core.files.storage import default_storage
+from unittest.mock import patch, MagicMock
+from penparse.webui.tasks import process_memo
+from penparse.webui.models import ImageMemo, MemoStatus
+
+
+@pytest.fixture
+def sample_image_memo(db):
+ memo = ImageMemo.objects.create(
+ status=MemoStatus.Pending, image_mimetype="image/jpeg"
+ )
+ memo.image.save("test_image.jpg", ContentFile(b"fake image content"))
+ return memo
+
+
+@pytest.mark.django_db
+def test_process_memo_success(sample_image_memo):
+ with patch("penparse.webui.tasks.litellm") as mock_litellm:
+ mock_response = MagicMock()
+ mock_response.choices[0].message = {"content": "Transcribed content"}
+ mock_litellm.completion.return_value = mock_response
+
+ process_memo(sample_image_memo.id)
+
+ processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
+ assert processed_memo.status == MemoStatus.Done
+ assert processed_memo.content == "Transcribed content"
+ assert processed_memo.error_message == ""
+
+
+@pytest.mark.django_db
+def test_process_memo_missing_image(sample_image_memo):
+ default_storage.delete(sample_image_memo.image.name)
+
+ process_memo(sample_image_memo.id)
+
+ processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
+ assert processed_memo.status == MemoStatus.Error
+ assert "Image file" in processed_memo.error_message
+
+
+@pytest.mark.django_db
+def test_process_memo_api_error(sample_image_memo):
+ with patch("penparse.webui.tasks.litellm") as mock_litellm:
+ mock_litellm.completion.side_effect = mock_litellm.APIError("API Error")
+
+ process_memo(sample_image_memo.id)
+
+ processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
+ assert processed_memo.status == MemoStatus.Error
+ assert "API Error" in processed_memo.error_message
+
+
+@pytest.mark.django_db
+def test_process_memo_sets_model_name(sample_image_memo):
+ with (
+ patch("penparse.webui.tasks.litellm") as mock_litellm,
+ patch("penparse.webui.tasks.settings") as mock_settings,
+ ):
+ mock_response = MagicMock()
+ mock_response.choices[0].message = {"content": "Transcribed content"}
+ mock_litellm.completion.return_value = mock_response
+ mock_settings.OPENAI_MODEL = "test-model"
+
+ process_memo(sample_image_memo.id)
+
+ processed_memo = ImageMemo.objects.get(id=sample_image_memo.id)
+ assert processed_memo.model_name == "test-model"
+
+
+@pytest.mark.django_db
+def test_process_memo_uses_correct_api_settings(sample_image_memo):
+ with (
+ patch("penparse.webui.tasks.litellm") as mock_litellm,
+ patch("penparse.webui.tasks.settings") as mock_settings,
+ ):
+ mock_response = MagicMock()
+ mock_response.choices[0].message = {"content": "Transcribed content"}
+ mock_litellm.completion.return_value = mock_response
+ mock_settings.OPENAI_API_BASE = "https://test-api-base.com"
+ mock_settings.OPENAI_API_KEY = "test-api-key"
+ mock_settings.OPENAI_MODEL = "test-model"
+
+ process_memo(sample_image_memo.id)
+
+ assert mock_litellm.api_base == "https://test-api-base.com"
+ assert mock_litellm.api_key == "test-api-key"
+ mock_litellm.completion.assert_called_once_with(
+ model="test-model",
+ messages=pytest.approx(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": pytest.ANY},
+ {"type": "image_url", "image_url": {"url": pytest.ANY}},
+ ],
+ }
+ ]
+ ),
+ temperature=0.01,
+ )