Files
ollama-utils/plugins/reranking-endpoint/api.py
2026-01-20 20:44:48 +00:00

187 lines
6.1 KiB
Python

import asyncio
import httpx
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import numpy as np
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Ollama BGE Reranker (Working Workaround)")
class RerankRequest(BaseModel):
model: str
query: str
documents: List[str]
top_n: Optional[int] = 3
class RerankResult(BaseModel):
index: int
relevance_score: float
document: Optional[str] = None
class RerankResponse(BaseModel):
results: List[RerankResult]
async def get_embedding(
client: httpx.AsyncClient,
model: str,
text: str
) -> Optional[List[float]]:
"""Get embedding from Ollama."""
url = "http://localhost:11434/api/embeddings"
try:
response = await client.post(
url,
json={"model": model, "prompt": text},
timeout=30.0
)
response.raise_for_status()
return response.json().get("embedding")
except Exception as e:
logger.error(f"Error getting embedding: {e}")
return None
async def score_document_cross_encoder_workaround(
client: httpx.AsyncClient,
model: str,
query: str,
doc: str,
index: int
) -> dict:
"""
Workaround for using BGE-reranker with Ollama.
Based on: https://medium.com/@rosgluk/reranking-documents-with-ollama-and-qwen3-reranker-model-in-go-6dc9c2fb5f0b
Key discovery: When using concatenated query+doc embeddings,
LOWER magnitude = MORE relevant. We invert the scores so that
higher values = more relevant (standard convention).
Steps:
1. Concatenate query and document in cross-encoder format
2. Get embedding of the concatenated text
3. Calculate magnitude (lower = more relevant)
4. Invert and normalize to 0-1 (higher = more relevant)
"""
# Format as cross-encoder input
# The format matters - reranker models expect specific patterns
combined = f"Query: {query}\n\nDocument: {doc}\n\nRelevance:"
# Get embedding
embedding = await get_embedding(client, model, combined)
if embedding is None:
logger.warning(f"Failed to get embedding for document {index}")
return {
"index": index,
"relevance_score": 0.0,
"document": doc
}
# Calculate magnitude (L2 norm) of the embedding vector
vec = np.array(embedding)
magnitude = float(np.linalg.norm(vec))
# CRITICAL DISCOVERY: For BGE-reranker via Ollama embeddings:
# LOWER magnitude = MORE relevant document
# Observed range: ~15-25 (lower = better)
# Invert and normalize to 0-1 where higher score = more relevant
# Adjusted bounds based on empirical observations
typical_good_magnitude = 15.0 # Highly relevant documents
typical_poor_magnitude = 25.0 # Irrelevant documents
# Linear interpolation (inverted)
# magnitude 15 → score ~0.9
# magnitude 25 → score ~0.0
score = (typical_poor_magnitude - magnitude) / (typical_poor_magnitude - typical_good_magnitude)
# Clamp to 0-1 range
score = min(max(score, 0.0), 1.0)
logger.debug(f"Doc {index}: magnitude={magnitude:.2f}, score={score:.4f}")
logger.info(f"Raw magnitude: {magnitude:.2f}")
return {
"index": index,
"relevance_score": score,
"document": doc
}
@app.on_event("startup")
async def check_ollama():
"""Verify Ollama is accessible on startup."""
try:
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:11434/api/tags", timeout=5.0)
response.raise_for_status()
logger.info("✓ Successfully connected to Ollama")
logger.warning("⚠️ Using workaround: concatenation + magnitude")
logger.warning("⚠️ This is less accurate than proper cross-encoder usage")
except Exception as e:
logger.error(f"✗ Cannot connect to Ollama: {e}")
@app.post("/v1/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""
Rerank documents using BGE-reranker via Ollama workaround.
NOTE: This uses a workaround (magnitude of concatenated embeddings)
because Ollama doesn't expose BGE's classification head.
For best accuracy, use sentence-transformers directly.
"""
if not request.documents:
raise HTTPException(status_code=400, detail="No documents provided")
logger.info(f"Reranking {len(request.documents)} documents (workaround method)")
logger.info(f"Query: {request.query[:100]}...")
async with httpx.AsyncClient() as client:
# Score all documents concurrently
tasks = [
score_document_cross_encoder_workaround(
client, request.model, request.query, doc, i
)
for i, doc in enumerate(request.documents)
]
results = await asyncio.gather(*tasks)
# Sort by score DESCENDING (higher score = more relevant)
# Scores are now inverted, so higher = better
results.sort(key=lambda x: x["relevance_score"], reverse=True)
# Log scores
top_scores = [f"{r['relevance_score']:.4f}" for r in results[:request.top_n]]
logger.info(f"Top {len(top_scores)} scores: {top_scores}")
return {"results": results[:request.top_n]}
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"service": "ollama-bge-reranker-workaround",
"note": "Using magnitude workaround - less accurate than native"
}
if __name__ == "__main__":
import uvicorn
logger.info("=" * 60)
logger.info("Ollama BGE Reranker - WORKAROUND Implementation")
logger.info("=" * 60)
logger.info("Using concatenation + magnitude method")
logger.info("This works but is less accurate than proper cross-encoders")
logger.info("Starting on: http://0.0.0.0:8080")
logger.info("=" * 60)
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info")