897 lines
30 KiB
JavaScript
897 lines
30 KiB
JavaScript
#!/usr/bin/env node
|
|
|
|
// efficient-embedding-comparison.js
|
|
// Proper embedding model evaluation with batch processing and vector search
|
|
// Run with: node efficient-embedding-comparison.js --config=config.json
|
|
|
|
import fs from 'fs/promises';
|
|
import yaml from 'js-yaml';
|
|
import path from 'path';
|
|
import crypto from 'crypto';
|
|
|
|
class EmbeddingCache {
|
|
constructor(cacheDir = './embedding-cache') {
|
|
this.cacheDir = cacheDir;
|
|
}
|
|
|
|
async ensureCacheDir() {
|
|
try {
|
|
await fs.access(this.cacheDir);
|
|
} catch {
|
|
await fs.mkdir(this.cacheDir, { recursive: true });
|
|
}
|
|
}
|
|
|
|
getCacheKey(model, text) {
|
|
const content = `${model.name}:${text}`;
|
|
return crypto.createHash('md5').update(content).digest('hex');
|
|
}
|
|
|
|
async getCachedEmbedding(model, text) {
|
|
await this.ensureCacheDir();
|
|
const key = this.getCacheKey(model, text);
|
|
const cachePath = path.join(this.cacheDir, `${key}.json`);
|
|
|
|
try {
|
|
const data = await fs.readFile(cachePath, 'utf8');
|
|
return JSON.parse(data);
|
|
} catch {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
async setCachedEmbedding(model, text, embedding) {
|
|
await this.ensureCacheDir();
|
|
const key = this.getCacheKey(model, text);
|
|
const cachePath = path.join(this.cacheDir, `${key}.json`);
|
|
|
|
await fs.writeFile(cachePath, JSON.stringify(embedding));
|
|
}
|
|
|
|
async getCacheStats(model) {
|
|
await this.ensureCacheDir();
|
|
const files = await fs.readdir(this.cacheDir);
|
|
const modelFiles = files.filter(f => f.includes(model.name.replace(/[^a-zA-Z0-9]/g, '_')));
|
|
return { cached: modelFiles.length, total: files.length };
|
|
}
|
|
}
|
|
|
|
class SearchEvaluator {
|
|
constructor() {
|
|
this.cache = new EmbeddingCache();
|
|
}
|
|
|
|
async rateLimitedDelay(model) {
|
|
if (model.rateLimit && model.rateLimitDelayMs) {
|
|
await new Promise(resolve => setTimeout(resolve, model.rateLimitDelayMs));
|
|
}
|
|
}
|
|
|
|
async getEmbedding(text, model) {
|
|
// Check cache first
|
|
const cached = await this.cache.getCachedEmbedding(model, text);
|
|
if (cached) return cached;
|
|
|
|
const headers = { 'Content-Type': 'application/json' };
|
|
let body, endpoint;
|
|
|
|
if (model.type === 'mistral') {
|
|
if (model.apiKey) {
|
|
headers['Authorization'] = `Bearer ${model.apiKey.replace('${AI_EMBEDDINGS_API_KEY}', process.env.AI_EMBEDDINGS_API_KEY || '')}`;
|
|
}
|
|
body = { model: model.name, input: [text] };
|
|
endpoint = model.endpoint;
|
|
} else {
|
|
body = { model: model.name, prompt: text };
|
|
endpoint = model.endpoint;
|
|
}
|
|
|
|
try {
|
|
const response = await fetch(endpoint, {
|
|
method: 'POST',
|
|
headers,
|
|
body: JSON.stringify(body)
|
|
});
|
|
|
|
if (!response.ok) {
|
|
if (response.status === 429 && model.rateLimit) {
|
|
console.log(` ⚠️ Rate limited, waiting...`);
|
|
await new Promise(resolve => setTimeout(resolve, 10000));
|
|
return this.getEmbedding(text, model);
|
|
}
|
|
throw new Error(`API error ${response.status}: ${await response.text()}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
const embedding = model.type === 'mistral' ? data.data[0].embedding : data.embedding;
|
|
|
|
// Cache the result
|
|
await this.cache.setCachedEmbedding(model, text, embedding);
|
|
return embedding;
|
|
|
|
} catch (error) {
|
|
console.error(`❌ Failed to get embedding: ${error.message}`);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
constructToolText(item, maxLength = null) {
|
|
if (typeof item === 'string') {
|
|
// Even for string inputs, don't truncate to match real app behavior
|
|
return item.toLowerCase();
|
|
}
|
|
|
|
// EXACT match to embeddings.ts createContentString() - NO TRUNCATION
|
|
const parts = [
|
|
item.name,
|
|
item.description || '',
|
|
...(item.tags || []),
|
|
...(item.domains || []),
|
|
...(item.phases || [])
|
|
];
|
|
|
|
const contentString = parts.filter(Boolean).join(' ').toLowerCase();
|
|
|
|
// CRITICAL: No truncation! Return full content like real app
|
|
return contentString;
|
|
}
|
|
|
|
calculateOptimalBatchSize(model) {
|
|
// Factors that ACTUALLY matter for batching individual API calls:
|
|
|
|
// 1. Rate limiting aggressiveness
|
|
if (model.rateLimit && model.rateLimitDelayMs > 2000) {
|
|
return 5; // Conservative batching for heavily rate-limited APIs
|
|
}
|
|
|
|
// 2. API latency expectations
|
|
if (model.type === 'ollama') {
|
|
return 15; // Local APIs are fast, can handle larger batches
|
|
} else if (model.type === 'mistral') {
|
|
return 10; // Remote APIs might be slower, medium batches
|
|
}
|
|
|
|
// 3. Progress reporting frequency preference
|
|
// For 185 tools:
|
|
// - Batch size 10 = 19 progress updates
|
|
// - Batch size 15 = 13 progress updates
|
|
// - Batch size 20 = 10 progress updates
|
|
|
|
return 15; // Good balance for ~13 progress updates
|
|
}
|
|
|
|
async createBatchEmbeddings(items, model) {
|
|
const batchSize = this.calculateOptimalBatchSize(model);
|
|
const contextSize = model.contextSize || 2000; // Only for display/info
|
|
|
|
console.log(` 📦 Creating embeddings for ${items.length} items`);
|
|
console.log(` 📏 Model context: ${contextSize} chars (for reference - NOT truncating)`);
|
|
console.log(` 📋 Batch size: ${batchSize} (for progress reporting)`);
|
|
|
|
const embeddings = new Map();
|
|
let apiCalls = 0;
|
|
let cacheHits = 0;
|
|
const totalBatches = Math.ceil(items.length / batchSize);
|
|
|
|
for (let i = 0; i < items.length; i += batchSize) {
|
|
const batch = items.slice(i, i + batchSize);
|
|
const batchNum = Math.floor(i/batchSize) + 1;
|
|
|
|
console.log(` 📋 Processing batch ${batchNum}/${totalBatches} (${batch.length} tools)`);
|
|
|
|
for (const item of batch) {
|
|
// Get FULL content (no truncation)
|
|
const text = this.constructToolText(item);
|
|
|
|
// Show actual text length for first few tools (full length!)
|
|
if (i < batchSize && batch.indexOf(item) < 3) {
|
|
const truncatedDisplay = text.length > 100 ? text.slice(0, 100) + '...' : text;
|
|
console.log(` 📝 ${item.name}: ${text.length} chars (full) - "${truncatedDisplay}"`);
|
|
}
|
|
|
|
try {
|
|
const embedding = await this.getEmbedding(text, model);
|
|
embeddings.set(item.id || item.name || text, {
|
|
text,
|
|
embedding,
|
|
metadata: item
|
|
});
|
|
|
|
const cached = await this.cache.getCachedEmbedding(model, text);
|
|
if (cached) cacheHits++; else apiCalls++;
|
|
|
|
await this.rateLimitedDelay(model);
|
|
} catch (error) {
|
|
console.warn(` ⚠️ Failed to embed: ${item.name || text.slice(0, 50)}...`);
|
|
// Log the error for debugging
|
|
if (text.length > 8000) {
|
|
console.warn(` 📏 Text was ${text.length} chars - may exceed model limits`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Show content length statistics
|
|
const lengths = Array.from(embeddings.values()).map(e => e.text.length);
|
|
const avgLength = lengths.reduce((a, b) => a + b, 0) / lengths.length;
|
|
const maxLength = Math.max(...lengths);
|
|
const minLength = Math.min(...lengths);
|
|
|
|
console.log(` 📊 Content stats: avg ${avgLength.toFixed(0)} chars, range ${minLength}-${maxLength} chars`);
|
|
console.log(` ✅ Created ${embeddings.size} embeddings (${apiCalls} API calls, ${cacheHits} cache hits)`);
|
|
|
|
return embeddings;
|
|
}
|
|
|
|
cosineSimilarity(a, b) {
|
|
if (!a || !b || a.length === 0 || b.length === 0) return 0;
|
|
|
|
let dotProduct = 0;
|
|
let normA = 0;
|
|
let normB = 0;
|
|
const minLength = Math.min(a.length, b.length);
|
|
|
|
for (let i = 0; i < minLength; i++) {
|
|
dotProduct += a[i] * b[i];
|
|
normA += a[i] * a[i];
|
|
normB += b[i] * b[i];
|
|
}
|
|
|
|
if (normA === 0 || normB === 0) return 0;
|
|
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
|
}
|
|
|
|
searchSimilar(queryEmbedding, toolEmbeddings, topK = 10) {
|
|
const similarities = [];
|
|
|
|
for (const [id, data] of toolEmbeddings) {
|
|
const similarity = this.cosineSimilarity(queryEmbedding, data.embedding);
|
|
similarities.push({
|
|
id,
|
|
similarity,
|
|
metadata: data.metadata,
|
|
text: data.text
|
|
});
|
|
}
|
|
|
|
return similarities
|
|
.sort((a, b) => b.similarity - a.similarity)
|
|
.slice(0, topK);
|
|
}
|
|
|
|
calculateRetrievalMetrics(results, relevantIds, k = 10) {
|
|
const topK = results.slice(0, k);
|
|
const retrievedIds = new Set(topK.map(r => r.id));
|
|
const relevantSet = new Set(relevantIds);
|
|
|
|
// Precision@K
|
|
const relevantRetrieved = topK.filter(r => relevantSet.has(r.id)).length;
|
|
const precisionAtK = topK.length > 0 ? relevantRetrieved / topK.length : 0;
|
|
|
|
// Recall@K
|
|
const recallAtK = relevantIds.length > 0 ? relevantRetrieved / relevantIds.length : 0;
|
|
|
|
// F1@K
|
|
const f1AtK = (precisionAtK + recallAtK) > 0 ?
|
|
2 * (precisionAtK * recallAtK) / (precisionAtK + recallAtK) : 0;
|
|
|
|
// Mean Reciprocal Rank (MRR)
|
|
let mrr = 0;
|
|
for (let i = 0; i < topK.length; i++) {
|
|
if (relevantSet.has(topK[i].id)) {
|
|
mrr = 1 / (i + 1);
|
|
break;
|
|
}
|
|
}
|
|
|
|
// NDCG@K (simplified binary relevance)
|
|
let dcg = 0;
|
|
let idcg = 0;
|
|
|
|
for (let i = 0; i < k; i++) {
|
|
const rank = i + 1;
|
|
const discount = Math.log2(rank + 1);
|
|
|
|
// DCG
|
|
if (i < topK.length && relevantSet.has(topK[i].id)) {
|
|
dcg += 1 / discount;
|
|
}
|
|
|
|
// IDCG (ideal ranking)
|
|
if (i < relevantIds.length) {
|
|
idcg += 1 / discount;
|
|
}
|
|
}
|
|
|
|
const ndcgAtK = idcg > 0 ? dcg / idcg : 0;
|
|
|
|
return {
|
|
precisionAtK,
|
|
recallAtK,
|
|
f1AtK,
|
|
mrr,
|
|
ndcgAtK,
|
|
relevantRetrieved,
|
|
totalRelevant: relevantIds.length
|
|
};
|
|
}
|
|
}
|
|
|
|
class EfficientEmbeddingComparison {
|
|
constructor(configPath = './embedding-test-config.json') {
|
|
this.configPath = configPath;
|
|
this.config = null;
|
|
this.tools = [];
|
|
this.evaluator = new SearchEvaluator();
|
|
|
|
// Test queries tailored to the actual tools.yaml content
|
|
this.testQueries = [
|
|
{
|
|
query: "memory forensics RAM analysis",
|
|
keywords: ["memory", "forensics", "volatility", "ram", "dump", "analysis"],
|
|
category: "memory_analysis"
|
|
},
|
|
{
|
|
query: "network packet capture traffic analysis",
|
|
keywords: ["network", "packet", "pcap", "wireshark", "traffic", "capture"],
|
|
category: "network_analysis"
|
|
},
|
|
{
|
|
query: "malware reverse engineering binary analysis",
|
|
keywords: ["malware", "reverse", "engineering", "ghidra", "binary", "disassemble"],
|
|
category: "malware_analysis"
|
|
},
|
|
{
|
|
query: "digital forensics disk imaging",
|
|
keywords: ["forensics", "disk", "imaging", "autopsy", "investigation", "evidence"],
|
|
category: "disk_forensics"
|
|
},
|
|
{
|
|
query: "incident response threat hunting",
|
|
keywords: ["incident", "response", "threat", "hunting", "investigation", "compromise"],
|
|
category: "incident_response"
|
|
},
|
|
{
|
|
query: "mobile device smartphone forensics",
|
|
keywords: ["mobile", "smartphone", "android", "ios", "device", "cellebrite"],
|
|
category: "mobile_forensics"
|
|
},
|
|
{
|
|
query: "timeline analysis event correlation",
|
|
keywords: ["timeline", "analysis", "correlation", "events", "plaso", "timesketch"],
|
|
category: "timeline_analysis"
|
|
},
|
|
{
|
|
query: "registry analysis windows artifacts",
|
|
keywords: ["registry", "windows", "artifacts", "regripper", "hives", "keys"],
|
|
category: "registry_analysis"
|
|
},
|
|
{
|
|
query: "cloud forensics container analysis",
|
|
keywords: ["cloud", "container", "docker", "virtualization", "aws", "azure"],
|
|
category: "cloud_forensics"
|
|
},
|
|
{
|
|
query: "blockchain cryptocurrency investigation",
|
|
keywords: ["blockchain", "cryptocurrency", "bitcoin", "chainalysis", "transaction"],
|
|
category: "blockchain_analysis"
|
|
}
|
|
];
|
|
|
|
console.log('[INIT] Efficient embedding comparison initialized');
|
|
}
|
|
|
|
async loadConfig() {
|
|
try {
|
|
const configData = await fs.readFile(this.configPath, 'utf8');
|
|
this.config = JSON.parse(configData);
|
|
console.log(`[CONFIG] Loaded ${this.config.models.length} models`);
|
|
} catch (error) {
|
|
console.error('[CONFIG] Failed to load configuration:', error.message);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
async loadTools() {
|
|
try {
|
|
const yamlContent = await fs.readFile(this.config.toolsYamlPath, 'utf8');
|
|
const data = yaml.load(yamlContent);
|
|
|
|
// Extract tools (flexible - handle different YAML structures)
|
|
this.tools = data.tools || data.entries || data.applications || data;
|
|
if (!Array.isArray(this.tools)) {
|
|
this.tools = Object.values(this.tools);
|
|
}
|
|
|
|
// Filter out concepts and ensure required fields
|
|
this.tools = this.tools.filter(tool =>
|
|
tool &&
|
|
tool.type !== 'concept' &&
|
|
(tool.name || tool.title) &&
|
|
(tool.description || tool.summary)
|
|
);
|
|
|
|
// Normalize tool structure
|
|
this.tools = this.tools.map((tool, index) => ({
|
|
id: tool.id || tool.name || tool.title || `tool_${index}`,
|
|
name: tool.name || tool.title,
|
|
description: tool.description || tool.summary || '',
|
|
tags: tool.tags || [],
|
|
domains: tool.domains || tool.categories || [],
|
|
phases: tool.phases || [],
|
|
platforms: tool.platforms || [],
|
|
type: tool.type || 'tool',
|
|
skillLevel: tool.skillLevel,
|
|
license: tool.license
|
|
}));
|
|
|
|
console.log(`[DATA] Loaded ${this.tools.length} tools from ${this.config.toolsYamlPath}`);
|
|
|
|
// Show some statistics
|
|
const domainCounts = {};
|
|
const tagCounts = {};
|
|
|
|
this.tools.forEach(tool => {
|
|
(tool.domains || []).forEach(domain => {
|
|
domainCounts[domain] = (domainCounts[domain] || 0) + 1;
|
|
});
|
|
(tool.tags || []).forEach(tag => {
|
|
tagCounts[tag] = (tagCounts[tag] || 0) + 1;
|
|
});
|
|
});
|
|
|
|
const topDomains = Object.entries(domainCounts)
|
|
.sort(([,a], [,b]) => b - a)
|
|
.slice(0, 5)
|
|
.map(([domain, count]) => `${domain}(${count})`)
|
|
.join(', ');
|
|
|
|
console.log(`[DATA] Top domains: ${topDomains}`);
|
|
console.log(`[DATA] Sample tools: ${this.tools.slice(0, 3).map(t => t.name).join(', ')}`);
|
|
|
|
if (this.tools.length === 0) {
|
|
throw new Error('No valid tools found in YAML file');
|
|
}
|
|
|
|
} catch (error) {
|
|
console.error('[DATA] Failed to load tools:', error.message);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
findRelevantTools(query) {
|
|
const queryLower = query.query.toLowerCase();
|
|
const keywords = query.keywords.map(k => k.toLowerCase());
|
|
|
|
const relevantTools = this.tools.filter(tool => {
|
|
// Build searchable text from all tool metadata
|
|
const searchableFields = [
|
|
tool.name || '',
|
|
tool.description || '',
|
|
(tool.tags || []).join(' '),
|
|
(tool.domains || []).join(' '),
|
|
(tool.phases || []).join(' '),
|
|
(tool.platforms || []).join(' ')
|
|
];
|
|
|
|
const toolText = searchableFields.join(' ').toLowerCase();
|
|
|
|
// Check for keyword matches
|
|
const hasKeywordMatch = keywords.some(keyword => toolText.includes(keyword));
|
|
|
|
// Check for query word matches (words longer than 3 chars)
|
|
const queryWords = queryLower.split(' ').filter(word => word.length > 3);
|
|
const hasQueryWordMatch = queryWords.some(word => toolText.includes(word));
|
|
|
|
// Check for domain-specific matches
|
|
const isDomainRelevant = query.category && tool.domains &&
|
|
tool.domains.some(domain => domain.includes(query.category.replace('_', '-')));
|
|
|
|
return hasKeywordMatch || hasQueryWordMatch || isDomainRelevant;
|
|
});
|
|
|
|
console.log(` 🎯 Found ${relevantTools.length} relevant tools for "${query.query}"`);
|
|
|
|
// Log some examples for debugging
|
|
if (relevantTools.length > 0) {
|
|
console.log(` 📋 Examples: ${relevantTools.slice(0, 3).map(t => t.name).join(', ')}`);
|
|
}
|
|
|
|
return relevantTools.map(tool => tool.id || tool.name);
|
|
}
|
|
|
|
async testSearchPerformance(model) {
|
|
console.log(` 🔍 Testing search performance...`);
|
|
|
|
// Create embeddings for all tools
|
|
const toolEmbeddings = await this.evaluator.createBatchEmbeddings(this.tools, model);
|
|
|
|
const results = [];
|
|
let totalApiCalls = 0;
|
|
|
|
for (const testQuery of this.testQueries) {
|
|
console.log(` 📋 Query: "${testQuery.query}"`);
|
|
|
|
// Get query embedding
|
|
const queryEmbedding = await this.evaluator.getEmbedding(testQuery.query, model);
|
|
totalApiCalls++;
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
|
|
// Find relevant tools for this query
|
|
const relevantIds = this.findRelevantTools(testQuery);
|
|
console.log(` 📊 Found ${relevantIds.length} relevant tools`);
|
|
|
|
if (relevantIds.length === 0) {
|
|
console.log(` ⚠️ No relevant tools found, skipping metrics calculation`);
|
|
continue;
|
|
}
|
|
|
|
// Perform search
|
|
const searchResults = this.evaluator.searchSimilar(queryEmbedding, toolEmbeddings, 20);
|
|
|
|
// Calculate metrics for different k values
|
|
const metrics = {};
|
|
for (const k of [1, 3, 5, 10]) {
|
|
metrics[`k${k}`] = this.evaluator.calculateRetrievalMetrics(searchResults, relevantIds, k);
|
|
}
|
|
|
|
results.push({
|
|
query: testQuery.query,
|
|
category: testQuery.category,
|
|
relevantCount: relevantIds.length,
|
|
searchResults: searchResults.slice(0, 5), // Top 5 for display
|
|
metrics
|
|
});
|
|
|
|
// Display results
|
|
console.log(` 🎯 Top results:`);
|
|
searchResults.slice(0, 3).forEach((result, i) => {
|
|
const isRelevant = relevantIds.includes(result.id) ? '✓' : '✗';
|
|
console.log(` ${i+1}. ${isRelevant} ${result.metadata.name} (${(result.similarity*100).toFixed(1)}%)`);
|
|
});
|
|
|
|
console.log(` 📈 P@5: ${(metrics.k5.precisionAtK*100).toFixed(1)}% | R@5: ${(metrics.k5.recallAtK*100).toFixed(1)}% | NDCG@5: ${(metrics.k5.ndcgAtK*100).toFixed(1)}%`);
|
|
}
|
|
|
|
return { results, totalApiCalls };
|
|
}
|
|
|
|
async testSemanticUnderstanding(model) {
|
|
console.log(` 🧠 Testing semantic understanding...`);
|
|
|
|
const semanticTests = [
|
|
{
|
|
primary: "memory forensics",
|
|
synonyms: ["RAM analysis", "volatile memory examination", "memory dump investigation"],
|
|
unrelated: ["file compression", "web browser", "text editor"]
|
|
},
|
|
{
|
|
primary: "network analysis",
|
|
synonyms: ["packet inspection", "traffic monitoring", "protocol analysis"],
|
|
unrelated: ["image editing", "music player", "calculator"]
|
|
},
|
|
{
|
|
primary: "malware detection",
|
|
synonyms: ["virus scanning", "threat identification", "malicious code analysis"],
|
|
unrelated: ["video converter", "password manager", "calendar app"]
|
|
}
|
|
];
|
|
|
|
let totalCorrect = 0;
|
|
let totalTests = 0;
|
|
let apiCalls = 0;
|
|
|
|
for (const test of semanticTests) {
|
|
console.log(` 🔤 Testing: "${test.primary}"`);
|
|
|
|
const primaryEmbedding = await this.evaluator.getEmbedding(test.primary, model);
|
|
apiCalls++;
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
|
|
// Test synonyms (should be similar)
|
|
for (const synonym of test.synonyms) {
|
|
const synonymEmbedding = await this.evaluator.getEmbedding(synonym, model);
|
|
apiCalls++;
|
|
|
|
const synonymSim = this.evaluator.cosineSimilarity(primaryEmbedding, synonymEmbedding);
|
|
console.log(` ✓ "${synonym}": ${(synonymSim*100).toFixed(1)}%`);
|
|
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
}
|
|
|
|
// Test unrelated terms (should be dissimilar)
|
|
for (const unrelated of test.unrelated) {
|
|
const unrelatedEmbedding = await this.evaluator.getEmbedding(unrelated, model);
|
|
apiCalls++;
|
|
|
|
const unrelatedSim = this.evaluator.cosineSimilarity(primaryEmbedding, unrelatedEmbedding);
|
|
console.log(` ✗ "${unrelated}": ${(unrelatedSim*100).toFixed(1)}%`);
|
|
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
}
|
|
|
|
// Calculate semantic coherence
|
|
const avgSynonymSim = await this.calculateAvgSimilarity(primaryEmbedding, test.synonyms, model);
|
|
const avgUnrelatedSim = await this.calculateAvgSimilarity(primaryEmbedding, test.unrelated, model);
|
|
|
|
const isCorrect = avgSynonymSim > avgUnrelatedSim;
|
|
if (isCorrect) totalCorrect++;
|
|
totalTests++;
|
|
|
|
console.log(` 📊 Synonyms: ${(avgSynonymSim*100).toFixed(1)}% | Unrelated: ${(avgUnrelatedSim*100).toFixed(1)}% ${isCorrect ? '✓' : '✗'}`);
|
|
}
|
|
|
|
return {
|
|
accuracy: totalCorrect / totalTests,
|
|
correctTests: totalCorrect,
|
|
totalTests,
|
|
apiCalls
|
|
};
|
|
}
|
|
|
|
async calculateAvgSimilarity(baseEmbedding, terms, model) {
|
|
let totalSim = 0;
|
|
|
|
for (const term of terms) {
|
|
const embedding = await this.evaluator.getEmbedding(term, model);
|
|
const sim = this.evaluator.cosineSimilarity(baseEmbedding, embedding);
|
|
totalSim += sim;
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
}
|
|
|
|
return totalSim / terms.length;
|
|
}
|
|
|
|
async benchmarkPerformance(model) {
|
|
console.log(` ⚡ Benchmarking performance...`);
|
|
|
|
const testTexts = this.tools.slice(0, 10).map(tool => `${tool.name} ${tool.description}`.slice(0, 500));
|
|
const times = [];
|
|
let apiCalls = 0;
|
|
|
|
console.log(` 🏃 Processing ${testTexts.length} texts...`);
|
|
|
|
for (const text of testTexts) {
|
|
const start = Date.now();
|
|
await this.evaluator.getEmbedding(text, model);
|
|
const time = Date.now() - start;
|
|
times.push(time);
|
|
apiCalls++;
|
|
|
|
await this.evaluator.rateLimitedDelay(model);
|
|
}
|
|
|
|
const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
|
|
const minTime = Math.min(...times);
|
|
const maxTime = Math.max(...times);
|
|
|
|
console.log(` 📊 Avg: ${avgTime.toFixed(0)}ms | Min: ${minTime}ms | Max: ${maxTime}ms`);
|
|
|
|
return {
|
|
avgLatency: avgTime,
|
|
minLatency: minTime,
|
|
maxLatency: maxTime,
|
|
throughput: 1000 / avgTime, // requests per second
|
|
apiCalls
|
|
};
|
|
}
|
|
|
|
async testModel(model) {
|
|
console.log(`\n🧪 Testing ${model.name} (${model.type})...`);
|
|
|
|
const startTime = Date.now();
|
|
let totalApiCalls = 0;
|
|
|
|
try {
|
|
// 1. Search Performance Testing
|
|
const searchResults = await this.testSearchPerformance(model);
|
|
totalApiCalls += searchResults.totalApiCalls;
|
|
|
|
// 2. Semantic Understanding Testing
|
|
const semanticResults = await this.testSemanticUnderstanding(model);
|
|
totalApiCalls += semanticResults.apiCalls;
|
|
|
|
// 3. Performance Benchmarking
|
|
const perfResults = await this.benchmarkPerformance(model);
|
|
totalApiCalls += perfResults.apiCalls;
|
|
|
|
const totalTime = Date.now() - startTime;
|
|
|
|
console.log(` ✅ ${model.name} completed in ${(totalTime/1000).toFixed(1)}s (${totalApiCalls} API calls)`);
|
|
|
|
return {
|
|
searchPerformance: searchResults.results,
|
|
semanticUnderstanding: semanticResults,
|
|
performance: perfResults,
|
|
totalTime,
|
|
totalApiCalls
|
|
};
|
|
|
|
} catch (error) {
|
|
console.error(` ❌ ${model.name} failed:`, error.message);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
calculateOverallScore(results) {
|
|
// Calculate average metrics across all queries
|
|
const searchMetrics = results.searchPerformance.filter(r => r.metrics && Object.keys(r.metrics).length > 0);
|
|
|
|
if (searchMetrics.length === 0) {
|
|
console.warn('⚠️ No search metrics available for scoring - may indicate relevance matching issues');
|
|
return {
|
|
overall: 0,
|
|
components: {
|
|
precision5: 0,
|
|
recall5: 0,
|
|
ndcg5: 0,
|
|
mrr: 0,
|
|
semanticAccuracy: results.semanticUnderstanding?.accuracy || 0,
|
|
throughput: results.performance?.throughput || 0
|
|
},
|
|
warning: 'No search metrics available'
|
|
};
|
|
}
|
|
|
|
console.log(`📊 Calculating score from ${searchMetrics.length} valid search results`);
|
|
|
|
const avgPrecision5 = searchMetrics.reduce((sum, r) => sum + (r.metrics.k5?.precisionAtK || 0), 0) / searchMetrics.length;
|
|
const avgRecall5 = searchMetrics.reduce((sum, r) => sum + (r.metrics.k5?.recallAtK || 0), 0) / searchMetrics.length;
|
|
const avgNDCG5 = searchMetrics.reduce((sum, r) => sum + (r.metrics.k5?.ndcgAtK || 0), 0) / searchMetrics.length;
|
|
const avgMRR = searchMetrics.reduce((sum, r) => sum + (r.metrics.k5?.mrr || 0), 0) / searchMetrics.length;
|
|
|
|
const semanticAccuracy = results.semanticUnderstanding?.accuracy || 0;
|
|
const throughput = results.performance?.throughput || 0;
|
|
|
|
// Weighted overall score
|
|
const weights = {
|
|
precision: 0.25,
|
|
recall: 0.25,
|
|
ndcg: 0.20,
|
|
semantic: 0.20,
|
|
speed: 0.10
|
|
};
|
|
|
|
const normalizedThroughput = Math.min(throughput / 10, 1); // Normalize to 0-1 (10 req/s = 1.0)
|
|
|
|
const overall = (
|
|
avgPrecision5 * weights.precision +
|
|
avgRecall5 * weights.recall +
|
|
avgNDCG5 * weights.ndcg +
|
|
semanticAccuracy * weights.semantic +
|
|
normalizedThroughput * weights.speed
|
|
);
|
|
|
|
return {
|
|
overall,
|
|
components: {
|
|
precision5: avgPrecision5,
|
|
recall5: avgRecall5,
|
|
ndcg5: avgNDCG5,
|
|
mrr: avgMRR,
|
|
semanticAccuracy,
|
|
throughput
|
|
}
|
|
};
|
|
}
|
|
|
|
printResults(modelResults) {
|
|
console.log(`\n${'='.repeat(80)}`);
|
|
console.log("🏆 EFFICIENT EMBEDDING MODEL COMPARISON RESULTS");
|
|
console.log(`${'='.repeat(80)}`);
|
|
|
|
const scores = modelResults.map(mr => ({
|
|
model: mr.model,
|
|
score: this.calculateOverallScore(mr.results),
|
|
results: mr.results
|
|
})).sort((a, b) => b.score.overall - a.score.overall);
|
|
|
|
console.log(`\n🥇 OVERALL RANKINGS:`);
|
|
scores.forEach((score, index) => {
|
|
console.log(` ${index + 1}. ${score.model.name}: ${(score.score.overall * 100).toFixed(1)}% overall`);
|
|
});
|
|
|
|
console.log(`\n📊 DETAILED METRICS:`);
|
|
|
|
console.log(`\n 🎯 Search Performance (Precision@5):`);
|
|
scores.forEach(score => {
|
|
console.log(` ${score.model.name}: ${(score.score.components.precision5 * 100).toFixed(1)}%`);
|
|
});
|
|
|
|
console.log(`\n 🔍 Search Performance (Recall@5):`);
|
|
scores.forEach(score => {
|
|
console.log(` ${score.model.name}: ${(score.score.components.recall5 * 100).toFixed(1)}%`);
|
|
});
|
|
|
|
console.log(`\n 📈 Search Quality (NDCG@5):`);
|
|
scores.forEach(score => {
|
|
console.log(` ${score.model.name}: ${(score.score.components.ndcg5 * 100).toFixed(1)}%`);
|
|
});
|
|
|
|
console.log(`\n 🧠 Semantic Understanding:`);
|
|
scores.forEach(score => {
|
|
console.log(` ${score.model.name}: ${(score.score.components.semanticAccuracy * 100).toFixed(1)}%`);
|
|
});
|
|
|
|
console.log(`\n ⚡ Performance (req/s):`);
|
|
scores.forEach(score => {
|
|
console.log(` ${score.model.name}: ${score.score.components.throughput.toFixed(1)} req/s`);
|
|
});
|
|
|
|
// Winner analysis
|
|
const winner = scores[0];
|
|
console.log(`\n🏆 WINNER: ${winner.model.name}`);
|
|
console.log(` Overall Score: ${(winner.score.overall * 100).toFixed(1)}%`);
|
|
console.log(` Best for: ${this.getBestUseCase(winner.score.components)}`);
|
|
|
|
// Summary stats
|
|
const totalQueries = modelResults[0]?.results.searchPerformance.length || 0;
|
|
const totalTools = this.tools.length;
|
|
|
|
console.log(`\n📋 Test Summary:`);
|
|
console.log(` Tools tested: ${totalTools}`);
|
|
console.log(` Search queries: ${totalQueries}`);
|
|
console.log(` Models compared: ${scores.length}`);
|
|
console.log(` Total API calls: ${modelResults.reduce((sum, mr) => sum + mr.results.totalApiCalls, 0)}`);
|
|
}
|
|
|
|
getBestUseCase(components) {
|
|
const strengths = [];
|
|
if (components.precision5 > 0.7) strengths.push("High precision");
|
|
if (components.recall5 > 0.7) strengths.push("High recall");
|
|
if (components.semanticAccuracy > 0.8) strengths.push("Semantic understanding");
|
|
if (components.throughput > 5) strengths.push("High performance");
|
|
|
|
return strengths.length > 0 ? strengths.join(", ") : "General purpose";
|
|
}
|
|
|
|
async run() {
|
|
try {
|
|
console.log("🚀 EFFICIENT EMBEDDING MODEL COMPARISON");
|
|
console.log("=====================================");
|
|
|
|
await this.loadConfig();
|
|
await this.loadTools();
|
|
|
|
console.log(`\n📋 Test Overview:`);
|
|
console.log(` Models: ${this.config.models.length}`);
|
|
console.log(` Tools: ${this.tools.length}`);
|
|
console.log(` Search queries: ${this.testQueries.length}`);
|
|
console.log(` Cache: ${this.evaluator.cache.cacheDir}`);
|
|
|
|
const modelResults = [];
|
|
|
|
for (const model of this.config.models) {
|
|
try {
|
|
const results = await this.testModel(model);
|
|
modelResults.push({ model, results });
|
|
} catch (error) {
|
|
console.error(`❌ Skipping ${model.name}: ${error.message}`);
|
|
}
|
|
}
|
|
|
|
if (modelResults.length === 0) {
|
|
throw new Error('No models completed testing successfully');
|
|
}
|
|
|
|
this.printResults(modelResults);
|
|
|
|
} catch (error) {
|
|
console.error('\n❌ Test failed:', error.message);
|
|
console.log('\nDebugging steps:');
|
|
console.log('1. Verify tools.yaml exists and contains valid tool data');
|
|
console.log('2. Check model endpoints are accessible');
|
|
console.log('3. For Ollama: ensure models are pulled and ollama serve is running');
|
|
console.log('4. For Mistral: verify AI_EMBEDDINGS_API_KEY environment variable');
|
|
}
|
|
}
|
|
}
|
|
|
|
// Execute
|
|
const configArg = process.argv.find(arg => arg.startsWith('--config='));
|
|
const configPath = configArg ? configArg.split('=')[1] : './embedding-test-config.json';
|
|
|
|
(async () => {
|
|
const comparison = new EfficientEmbeddingComparison(configPath);
|
|
await comparison.run();
|
|
})().catch(console.error); |