implemented LLM-based OCR
Run Tests / Run Tests (push) Successful in 38s Details

This commit is contained in:
James Ravenscroft 2024-12-10 12:04:26 +00:00
parent 6be976e376
commit 7b15955bf6
8 changed files with 1536 additions and 78 deletions

View File

@ -4,3 +4,21 @@ services:
ports:
- 5672:5672
- 15672:15672
vllm:
image: vllm/vllm-openai:latest
command: "--model Qwen/Qwen2-VL-2B-Instruct"
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
ports:
- 8002:8000
environment:
- HUGGING_FACE_HUB_TOKEN=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
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["0"]
capabilities: [gpu]

View File

@ -10,7 +10,12 @@ For the full list of settings and their values, see
https://docs.djangoproject.com/en/4.2/ref/settings/
"""
import os
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
@ -20,7 +25,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'django-insecure-t5yq5dlvztd^-oq%*($@lj$$33l_73e05093xw7s0)-ekqhtfn'
SECRET_KEY = "django-insecure-t5yq5dlvztd^-oq%*($@lj$$33l_73e05093xw7s0)-ekqhtfn"
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
@ -31,53 +36,54 @@ ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'webui'
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"webui",
"markdown_deux"
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
"django.middleware.security.SecurityMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
ROOT_URLCONF = 'penparse.urls'
ROOT_URLCONF = "penparse.urls"
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
"django.template.context_processors.debug",
"django.template.context_processors.request",
"django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages",
],
},
},
]
WSGI_APPLICATION = 'penparse.wsgi.application'
WSGI_APPLICATION = "penparse.wsgi.application"
# Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': BASE_DIR / 'db.sqlite3',
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3",
}
}
@ -87,31 +93,31 @@ DATABASES = {
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
]
LOGIN_REDIRECT_URL = '/dashboard'
LOGIN_REDIRECT_URL = "/dashboard"
AUTH_USER_MODEL = 'webui.User'
AUTH_USER_MODEL = "webui.User"
AUTHENTICATION_BACKENDS = ['webui.auth.EmailBackend']
AUTHENTICATION_BACKENDS = ["webui.auth.EmailBackend"]
# Internationalization
# https://docs.djangoproject.com/en/4.2/topics/i18n/
LANGUAGE_CODE = 'en-gb'
LANGUAGE_CODE = "en-gb"
TIME_ZONE = 'UTC'
TIME_ZONE = "UTC"
USE_I18N = True
@ -121,12 +127,17 @@ USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
STATIC_URL = 'static/'
STATIC_URL = "static/"
# Default primary key field type
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
CELERY_BROKER_URL = 'amqp://rabbit:rabbit@localhost//'
CELERY_BROKER_URL = "amqp://guest:guest@localhost/"
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "openai/gpt-4o")

View File

@ -38,6 +38,7 @@ class UserManager(BaseUserManager):
return self._create_user(email, password, **extra_fields)
class MemoStatus(models.TextChoices):
Pending = "pending"
Processing = "processing"
@ -47,23 +48,26 @@ class MemoStatus(models.TextChoices):
class ImageMemo(models.Model):
"""Model definition for ImageMemo."""
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
image_mimetype = models.CharField(max_length=256)
image = models.ImageField(upload_to='uploads/%Y/%m/%d')
image = models.ImageField(upload_to="uploads/%Y/%m/%d")
content = models.TextField()
author = models.ForeignKey(
'User', on_delete=models.CASCADE, related_name='memos')
author = models.ForeignKey("User", on_delete=models.CASCADE, related_name="memos")
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
updated_at = models.DateTimeField(
auto_now=True,
)
status = models.CharField(max_length=10, choices=MemoStatus.choices, default=MemoStatus.Pending)
status = models.CharField(
max_length=10, choices=MemoStatus.choices, default=MemoStatus.Pending
)
error_message = models.TextField(null=True)
class Meta:
ordering = ["-created_at"]
@ -74,8 +78,8 @@ class User(AbstractUser):
first_name = models.CharField(max_length=150, blank=False)
last_name = models.CharField(max_length=150, blank=False)
USERNAME_FIELD = 'email'
REQUIRED_FIELDS = ['full_name']
USERNAME_FIELD = "email"
REQUIRED_FIELDS = ["full_name"]
objects = UserManager() # type: ignore

77
penparse/webui/tasks.py Normal file
View File

@ -0,0 +1,77 @@
import requests
import base64
import litellm
import os
from loguru import logger
from celery import shared_task, Task
from django.db import transaction
from django.core.files.storage import default_storage
from django.conf import settings
from .models import ImageMemo, MemoStatus
from datetime import datetime
TRANSCRIBE_PROMPT = """Transcribe the hand written notes in the attached image and present them as markdown inside a fence like so
```markdown
<Content>
```
If any words or letters are unclear, denote them with a '?<word>?'. For example if you were not sure whether a word is blow or blew you would transcribe it as '?blow?'
"""
@shared_task
def process_memo(memo_id: str):
"""Run OCR on a memo and store the output"""
logger.info(f"Looking up memo with id={memo_id}")
memo = ImageMemo.objects.get(id=memo_id)
with transaction.atomic():
logger.info(f"Set status=processing for memo {memo.id}")
memo.status = MemoStatus.Processing
memo.save()
# check that the image exists
logger.info(f"Checking that image {memo.image.name} exists")
if not default_storage.exists(memo.image.name):
memo.status = MemoStatus.Error
memo.error_message = f"Image file {memo.image.name} does not exist"
memo.save()
return
# read the image into memory
logger.info(f"Reading image {memo.image.name}")
bytearray = default_storage.open(memo.image.name).read()
# call the OCR API
logger.info(f"Calling OCR API for memo {memo.id}")
b64img = base64.b64encode(bytearray).decode("utf-8")
message = {
"role": "user",
"content": [
{"type": "text", "text": TRANSCRIBE_PROMPT},
{
"type": "image_url",
"image_url": {"url": f"data:{memo.image_mimetype};base64,{b64img}"},
},
],
}
litellm.api_base = settings.OPENAI_API_BASE # os.environ.get("OPENAI_API_BASE")
litellm.api_key = settings.OPENAI_API_KEY
response = litellm.completion(
model=settings.OPENAI_MODEL, #os.getenv("MODEL", "openai/gpt-4o"),
messages=[message],
)
response.choices[0].message["content"]
with transaction.atomic():
memo.content = response.choices[0].message["content"]
memo.status = MemoStatus.Done
memo.save()

View File

@ -1,4 +1,6 @@
{% extends "main.html" %} {% block content %}
{% extends "main.html" %}
{% load markdown_deux_tags %}
{% block content %}
<section class="mb-16">
<h1 class="text-4xl font-bold text-gray-800 mb-4">Your Dashboard</h1>
<p class="text-xl text-gray-600 mb-8">
@ -35,6 +37,14 @@
<p class="text-gray-600 mb-4">
Last Updated: {{ document.updated_at }}
</p>
{% if document.content %}
<div class="text-gray-700 mb-4">
<h4 class="font-semibold mb-2">Content Preview:</h4>
<div class="prose prose-sm">
{{ document.content|truncatechars_html:100|markdown }}
</div>
</div>
{% endif %}
<div class="flex justify-between items-center">
<a
href="{% url 'view_document' document.id %}"

View File

@ -9,10 +9,14 @@ from ..models import ImageMemo
from django.http import HttpRequest
from django.db import transaction
from uuid import uuid4
from django.contrib.auth.decorators import login_required
from ..tasks import process_memo
logger = logging.getLogger(__name__)
@ -38,9 +42,13 @@ def upload_document(request: HttpRequest):
# Create an ImageMemo instance
image_memo = ImageMemo(
image=file_name,
image_mimetype=uploaded_file.content_type,
content="", # You can add initial content here if needed
author=request.user, # Assuming the user is authenticated
)
transaction.on_commit(lambda: process_memo.delay(image_memo.id))
image_memo.save()
messages.success(request, "Image uploaded successfully!")

View File

@ -6,11 +6,14 @@ readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"celery>=5.4.0",
"django-markdown-deux>=1.0.6",
"django>=4.2.16",
"litellm>=1.54.1",
"loguru>=0.7.3",
"pillow>=11.0.0",
"pytest-django>=4.9.0",
"pytest-loguru>=0.4.0",
"pytest>=8.3.4",
"python-dotenv>=1.0.1",
"requests>=2.32.3",
]

1327
uv.lock

File diff suppressed because it is too large Load Diff