From 70d2ac8d368fb2ef8ab94d530d6feb3ce062e9d9bc4f65f7c6ae08dcdd5947b8 Mon Sep 17 00:00:00 2001 From: mstoeck3 Date: Mon, 19 Jan 2026 12:23:10 +0100 Subject: [PATCH] fix context optimizer to search dowwards when baseline uses offload --- scripts/context-optimizer.py | 71 ++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/scripts/context-optimizer.py b/scripts/context-optimizer.py index 9866abd..84ddd6f 100755 --- a/scripts/context-optimizer.py +++ b/scripts/context-optimizer.py @@ -332,37 +332,63 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb: 'error': 'Unknown failure during baseline test' } - # Turn 2: Test a higher context to calculate VRAM/context ratio - # Try doubling the context or 32K, whichever is smaller - test_ctx_2 = min(baseline_ctx * 2, 32768, max_context) - if test_ctx_2 <= baseline_ctx: - test_ctx_2 = min(baseline_ctx + 16384, max_context) - # Round to multiple of 2048 - test_ctx_2 = (test_ctx_2 // 2048) * 2048 + # Turn 2: Test a different context to calculate VRAM/context ratio + # If baseline already shows offloading OR is at max_context, test LOWER + # Otherwise, test HIGHER + baseline_has_offload = results[0]['offload_pct'] > 0 + + if baseline_has_offload or baseline_ctx >= max_context: + # Test lower context to find where it fits + test_ctx_2 = max(8192, baseline_ctx // 2) + # Round to multiple of 2048 + test_ctx_2 = (test_ctx_2 // 2048) * 2048 + calibration_label = "lower bound" + else: + # Try doubling the context or 32K, whichever is smaller + test_ctx_2 = min(baseline_ctx * 2, 32768, max_context) + if test_ctx_2 <= baseline_ctx: + test_ctx_2 = min(baseline_ctx + 16384, max_context) + # Round to multiple of 2048 + test_ctx_2 = (test_ctx_2 // 2048) * 2048 + calibration_label = "upper bound" turn_label = f"Turn 2/{max_turns}" if max_turns else "Turn 2" - print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} (calibration)...", end=' ', flush=True) + print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} ({calibration_label})...", end=' ', flush=True) result = test_context_size(model_name, test_ctx_2) if result and 'error' not in result: results.append(result) print(f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: {result['offload_pct']:.1f}% CPU" if result['offload_pct'] > 0 else f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: GPU only") - # Calculate VRAM per 1K context tokens + # Calculate VRAM per 1K context tokens (works for both higher and lower tests) vram_diff = result['vram_gb'] - baseline_vram ctx_diff = test_ctx_2 - baseline_ctx - if ctx_diff > 0: - vram_per_1k_ctx = (vram_diff / ctx_diff) * 1000 + + if ctx_diff != 0: # Can be positive or negative + vram_per_1k_ctx = abs(vram_diff / ctx_diff) * 1000 print(f" → Estimated VRAM usage: {vram_per_1k_ctx:.4f} GB per 1K context") - # Predict optimal context size + # Predict optimal context size based on available VRAM if target_vram and vram_per_1k_ctx > 0: - available_for_ctx = target_vram - baseline_vram - estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000 - predicted_optimal = baseline_ctx + int(estimated_additional_ctx) - # Round to multiple of 2048 + # Find which result has no offload (if any) to use as reference + ref_result = result if result['offload_pct'] == 0 else (results[0] if results[0]['offload_pct'] == 0 else None) + + if ref_result: + # We have a point that fits - extrapolate from there + available_for_ctx = target_vram - ref_result['vram_gb'] + estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000 + predicted_optimal = ref_result['num_ctx'] + int(estimated_additional_ctx) + else: + # Neither fits - need to find what would fit + # Start from the smaller test and work backwards + smaller_result = results[0] if results[0]['num_ctx'] < result['num_ctx'] else result + vram_needed_reduction = smaller_result['vram_gb'] - target_vram + ctx_reduction_needed = (vram_needed_reduction / vram_per_1k_ctx) * 1000 + predicted_optimal = smaller_result['num_ctx'] - int(ctx_reduction_needed) + + # Round to multiple of 2048 and clamp to valid range predicted_optimal = (predicted_optimal // 2048) * 2048 - predicted_optimal = max(baseline_ctx, min(predicted_optimal, max_context)) + predicted_optimal = max(2048, min(predicted_optimal, max_context)) print(f" → Predicted optimal context: {predicted_optimal:,}") else: @@ -384,8 +410,15 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb: predicted_optimal = None # Remaining turns: Test predicted optimal or use VRAM-based refinement - min_ctx = baseline_ctx - max_ctx = max_context + # Initialize search bounds based on whether baseline has offload + if baseline_has_offload: + # Search downward from baseline to find what fits + min_ctx = 2048 # Minimum practical context + max_ctx = baseline_ctx + else: + # Search upward from baseline to find max that fits + min_ctx = baseline_ctx + max_ctx = max_context turn = 2 while True: