187 lines
6.1 KiB
Python
187 lines
6.1 KiB
Python
import asyncio
|
|
import httpx
|
|
import logging
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
import numpy as np
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(title="Ollama BGE Reranker (Working Workaround)")
|
|
|
|
class RerankRequest(BaseModel):
|
|
model: str
|
|
query: str
|
|
documents: List[str]
|
|
top_n: Optional[int] = 3
|
|
|
|
class RerankResult(BaseModel):
|
|
index: int
|
|
relevance_score: float
|
|
document: Optional[str] = None
|
|
|
|
class RerankResponse(BaseModel):
|
|
results: List[RerankResult]
|
|
|
|
async def get_embedding(
|
|
client: httpx.AsyncClient,
|
|
model: str,
|
|
text: str
|
|
) -> Optional[List[float]]:
|
|
"""Get embedding from Ollama."""
|
|
url = "http://localhost:11434/api/embeddings"
|
|
|
|
try:
|
|
response = await client.post(
|
|
url,
|
|
json={"model": model, "prompt": text},
|
|
timeout=30.0
|
|
)
|
|
response.raise_for_status()
|
|
return response.json().get("embedding")
|
|
except Exception as e:
|
|
logger.error(f"Error getting embedding: {e}")
|
|
return None
|
|
|
|
async def score_document_cross_encoder_workaround(
|
|
client: httpx.AsyncClient,
|
|
model: str,
|
|
query: str,
|
|
doc: str,
|
|
index: int
|
|
) -> dict:
|
|
"""
|
|
Workaround for using BGE-reranker with Ollama.
|
|
Based on: https://medium.com/@rosgluk/reranking-documents-with-ollama-and-qwen3-reranker-model-in-go-6dc9c2fb5f0b
|
|
|
|
Key discovery: When using concatenated query+doc embeddings,
|
|
LOWER magnitude = MORE relevant. We invert the scores so that
|
|
higher values = more relevant (standard convention).
|
|
|
|
Steps:
|
|
1. Concatenate query and document in cross-encoder format
|
|
2. Get embedding of the concatenated text
|
|
3. Calculate magnitude (lower = more relevant)
|
|
4. Invert and normalize to 0-1 (higher = more relevant)
|
|
"""
|
|
|
|
# Format as cross-encoder input
|
|
# The format matters - reranker models expect specific patterns
|
|
combined = f"Query: {query}\n\nDocument: {doc}\n\nRelevance:"
|
|
|
|
# Get embedding
|
|
embedding = await get_embedding(client, model, combined)
|
|
|
|
if embedding is None:
|
|
logger.warning(f"Failed to get embedding for document {index}")
|
|
return {
|
|
"index": index,
|
|
"relevance_score": 0.0,
|
|
"document": doc
|
|
}
|
|
|
|
# Calculate magnitude (L2 norm) of the embedding vector
|
|
vec = np.array(embedding)
|
|
magnitude = float(np.linalg.norm(vec))
|
|
|
|
# CRITICAL DISCOVERY: For BGE-reranker via Ollama embeddings:
|
|
# LOWER magnitude = MORE relevant document
|
|
# Observed range: ~15-25 (lower = better)
|
|
|
|
# Invert and normalize to 0-1 where higher score = more relevant
|
|
# Adjusted bounds based on empirical observations
|
|
typical_good_magnitude = 15.0 # Highly relevant documents
|
|
typical_poor_magnitude = 25.0 # Irrelevant documents
|
|
|
|
# Linear interpolation (inverted)
|
|
# magnitude 15 → score ~0.9
|
|
# magnitude 25 → score ~0.0
|
|
score = (typical_poor_magnitude - magnitude) / (typical_poor_magnitude - typical_good_magnitude)
|
|
|
|
# Clamp to 0-1 range
|
|
score = min(max(score, 0.0), 1.0)
|
|
|
|
logger.debug(f"Doc {index}: magnitude={magnitude:.2f}, score={score:.4f}")
|
|
logger.info(f"Raw magnitude: {magnitude:.2f}")
|
|
|
|
return {
|
|
"index": index,
|
|
"relevance_score": score,
|
|
"document": doc
|
|
}
|
|
|
|
@app.on_event("startup")
|
|
async def check_ollama():
|
|
"""Verify Ollama is accessible on startup."""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get("http://localhost:11434/api/tags", timeout=5.0)
|
|
response.raise_for_status()
|
|
logger.info("✓ Successfully connected to Ollama")
|
|
logger.warning("⚠️ Using workaround: concatenation + magnitude")
|
|
logger.warning("⚠️ This is less accurate than proper cross-encoder usage")
|
|
except Exception as e:
|
|
logger.error(f"✗ Cannot connect to Ollama: {e}")
|
|
|
|
@app.post("/v1/rerank", response_model=RerankResponse)
|
|
async def rerank(request: RerankRequest):
|
|
"""
|
|
Rerank documents using BGE-reranker via Ollama workaround.
|
|
|
|
NOTE: This uses a workaround (magnitude of concatenated embeddings)
|
|
because Ollama doesn't expose BGE's classification head.
|
|
For best accuracy, use sentence-transformers directly.
|
|
"""
|
|
if not request.documents:
|
|
raise HTTPException(status_code=400, detail="No documents provided")
|
|
|
|
logger.info(f"Reranking {len(request.documents)} documents (workaround method)")
|
|
logger.info(f"Query: {request.query[:100]}...")
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
# Score all documents concurrently
|
|
tasks = [
|
|
score_document_cross_encoder_workaround(
|
|
client, request.model, request.query, doc, i
|
|
)
|
|
for i, doc in enumerate(request.documents)
|
|
]
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
# Sort by score DESCENDING (higher score = more relevant)
|
|
# Scores are now inverted, so higher = better
|
|
results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
|
|
|
# Log scores
|
|
top_scores = [f"{r['relevance_score']:.4f}" for r in results[:request.top_n]]
|
|
logger.info(f"Top {len(top_scores)} scores: {top_scores}")
|
|
|
|
return {"results": results[:request.top_n]}
|
|
|
|
@app.get("/health")
|
|
def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "ollama-bge-reranker-workaround",
|
|
"note": "Using magnitude workaround - less accurate than native"
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Ollama BGE Reranker - WORKAROUND Implementation")
|
|
logger.info("=" * 60)
|
|
logger.info("Using concatenation + magnitude method")
|
|
logger.info("This works but is less accurate than proper cross-encoders")
|
|
logger.info("Starting on: http://0.0.0.0:8080")
|
|
logger.info("=" * 60)
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info")
|