fix context optimizer to search dowwards when baseline uses offload
This commit is contained in:
@@ -332,37 +332,63 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb:
|
||||
'error': 'Unknown failure during baseline test'
|
||||
}
|
||||
|
||||
# Turn 2: Test a higher context to calculate VRAM/context ratio
|
||||
# Try doubling the context or 32K, whichever is smaller
|
||||
test_ctx_2 = min(baseline_ctx * 2, 32768, max_context)
|
||||
if test_ctx_2 <= baseline_ctx:
|
||||
test_ctx_2 = min(baseline_ctx + 16384, max_context)
|
||||
# Round to multiple of 2048
|
||||
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
||||
# Turn 2: Test a different context to calculate VRAM/context ratio
|
||||
# If baseline already shows offloading OR is at max_context, test LOWER
|
||||
# Otherwise, test HIGHER
|
||||
baseline_has_offload = results[0]['offload_pct'] > 0
|
||||
|
||||
if baseline_has_offload or baseline_ctx >= max_context:
|
||||
# Test lower context to find where it fits
|
||||
test_ctx_2 = max(8192, baseline_ctx // 2)
|
||||
# Round to multiple of 2048
|
||||
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
||||
calibration_label = "lower bound"
|
||||
else:
|
||||
# Try doubling the context or 32K, whichever is smaller
|
||||
test_ctx_2 = min(baseline_ctx * 2, 32768, max_context)
|
||||
if test_ctx_2 <= baseline_ctx:
|
||||
test_ctx_2 = min(baseline_ctx + 16384, max_context)
|
||||
# Round to multiple of 2048
|
||||
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
||||
calibration_label = "upper bound"
|
||||
|
||||
turn_label = f"Turn 2/{max_turns}" if max_turns else "Turn 2"
|
||||
print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} (calibration)...", end=' ', flush=True)
|
||||
print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} ({calibration_label})...", end=' ', flush=True)
|
||||
result = test_context_size(model_name, test_ctx_2)
|
||||
|
||||
if result and 'error' not in result:
|
||||
results.append(result)
|
||||
print(f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: {result['offload_pct']:.1f}% CPU" if result['offload_pct'] > 0 else f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: GPU only")
|
||||
|
||||
# Calculate VRAM per 1K context tokens
|
||||
# Calculate VRAM per 1K context tokens (works for both higher and lower tests)
|
||||
vram_diff = result['vram_gb'] - baseline_vram
|
||||
ctx_diff = test_ctx_2 - baseline_ctx
|
||||
if ctx_diff > 0:
|
||||
vram_per_1k_ctx = (vram_diff / ctx_diff) * 1000
|
||||
|
||||
if ctx_diff != 0: # Can be positive or negative
|
||||
vram_per_1k_ctx = abs(vram_diff / ctx_diff) * 1000
|
||||
print(f" → Estimated VRAM usage: {vram_per_1k_ctx:.4f} GB per 1K context")
|
||||
|
||||
# Predict optimal context size
|
||||
# Predict optimal context size based on available VRAM
|
||||
if target_vram and vram_per_1k_ctx > 0:
|
||||
available_for_ctx = target_vram - baseline_vram
|
||||
estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000
|
||||
predicted_optimal = baseline_ctx + int(estimated_additional_ctx)
|
||||
# Round to multiple of 2048
|
||||
# Find which result has no offload (if any) to use as reference
|
||||
ref_result = result if result['offload_pct'] == 0 else (results[0] if results[0]['offload_pct'] == 0 else None)
|
||||
|
||||
if ref_result:
|
||||
# We have a point that fits - extrapolate from there
|
||||
available_for_ctx = target_vram - ref_result['vram_gb']
|
||||
estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000
|
||||
predicted_optimal = ref_result['num_ctx'] + int(estimated_additional_ctx)
|
||||
else:
|
||||
# Neither fits - need to find what would fit
|
||||
# Start from the smaller test and work backwards
|
||||
smaller_result = results[0] if results[0]['num_ctx'] < result['num_ctx'] else result
|
||||
vram_needed_reduction = smaller_result['vram_gb'] - target_vram
|
||||
ctx_reduction_needed = (vram_needed_reduction / vram_per_1k_ctx) * 1000
|
||||
predicted_optimal = smaller_result['num_ctx'] - int(ctx_reduction_needed)
|
||||
|
||||
# Round to multiple of 2048 and clamp to valid range
|
||||
predicted_optimal = (predicted_optimal // 2048) * 2048
|
||||
predicted_optimal = max(baseline_ctx, min(predicted_optimal, max_context))
|
||||
predicted_optimal = max(2048, min(predicted_optimal, max_context))
|
||||
|
||||
print(f" → Predicted optimal context: {predicted_optimal:,}")
|
||||
else:
|
||||
@@ -384,8 +410,15 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb:
|
||||
predicted_optimal = None
|
||||
|
||||
# Remaining turns: Test predicted optimal or use VRAM-based refinement
|
||||
min_ctx = baseline_ctx
|
||||
max_ctx = max_context
|
||||
# Initialize search bounds based on whether baseline has offload
|
||||
if baseline_has_offload:
|
||||
# Search downward from baseline to find what fits
|
||||
min_ctx = 2048 # Minimum practical context
|
||||
max_ctx = baseline_ctx
|
||||
else:
|
||||
# Search upward from baseline to find max that fits
|
||||
min_ctx = baseline_ctx
|
||||
max_ctx = max_context
|
||||
|
||||
turn = 2
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user