fix context optimizer to search dowwards when baseline uses offload
This commit is contained in:
@@ -332,37 +332,63 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb:
|
|||||||
'error': 'Unknown failure during baseline test'
|
'error': 'Unknown failure during baseline test'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Turn 2: Test a higher context to calculate VRAM/context ratio
|
# Turn 2: Test a different context to calculate VRAM/context ratio
|
||||||
# Try doubling the context or 32K, whichever is smaller
|
# If baseline already shows offloading OR is at max_context, test LOWER
|
||||||
test_ctx_2 = min(baseline_ctx * 2, 32768, max_context)
|
# Otherwise, test HIGHER
|
||||||
if test_ctx_2 <= baseline_ctx:
|
baseline_has_offload = results[0]['offload_pct'] > 0
|
||||||
test_ctx_2 = min(baseline_ctx + 16384, max_context)
|
|
||||||
# Round to multiple of 2048
|
if baseline_has_offload or baseline_ctx >= max_context:
|
||||||
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
# Test lower context to find where it fits
|
||||||
|
test_ctx_2 = max(8192, baseline_ctx // 2)
|
||||||
|
# Round to multiple of 2048
|
||||||
|
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
||||||
|
calibration_label = "lower bound"
|
||||||
|
else:
|
||||||
|
# Try doubling the context or 32K, whichever is smaller
|
||||||
|
test_ctx_2 = min(baseline_ctx * 2, 32768, max_context)
|
||||||
|
if test_ctx_2 <= baseline_ctx:
|
||||||
|
test_ctx_2 = min(baseline_ctx + 16384, max_context)
|
||||||
|
# Round to multiple of 2048
|
||||||
|
test_ctx_2 = (test_ctx_2 // 2048) * 2048
|
||||||
|
calibration_label = "upper bound"
|
||||||
|
|
||||||
turn_label = f"Turn 2/{max_turns}" if max_turns else "Turn 2"
|
turn_label = f"Turn 2/{max_turns}" if max_turns else "Turn 2"
|
||||||
print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} (calibration)...", end=' ', flush=True)
|
print(f"{turn_label}: Testing num_ctx={test_ctx_2:,} ({calibration_label})...", end=' ', flush=True)
|
||||||
result = test_context_size(model_name, test_ctx_2)
|
result = test_context_size(model_name, test_ctx_2)
|
||||||
|
|
||||||
if result and 'error' not in result:
|
if result and 'error' not in result:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
print(f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: {result['offload_pct']:.1f}% CPU" if result['offload_pct'] > 0 else f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: GPU only")
|
print(f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: {result['offload_pct']:.1f}% CPU" if result['offload_pct'] > 0 else f"✓ VRAM: {result['vram_gb']:.2f} GB, Offload: GPU only")
|
||||||
|
|
||||||
# Calculate VRAM per 1K context tokens
|
# Calculate VRAM per 1K context tokens (works for both higher and lower tests)
|
||||||
vram_diff = result['vram_gb'] - baseline_vram
|
vram_diff = result['vram_gb'] - baseline_vram
|
||||||
ctx_diff = test_ctx_2 - baseline_ctx
|
ctx_diff = test_ctx_2 - baseline_ctx
|
||||||
if ctx_diff > 0:
|
|
||||||
vram_per_1k_ctx = (vram_diff / ctx_diff) * 1000
|
if ctx_diff != 0: # Can be positive or negative
|
||||||
|
vram_per_1k_ctx = abs(vram_diff / ctx_diff) * 1000
|
||||||
print(f" → Estimated VRAM usage: {vram_per_1k_ctx:.4f} GB per 1K context")
|
print(f" → Estimated VRAM usage: {vram_per_1k_ctx:.4f} GB per 1K context")
|
||||||
|
|
||||||
# Predict optimal context size
|
# Predict optimal context size based on available VRAM
|
||||||
if target_vram and vram_per_1k_ctx > 0:
|
if target_vram and vram_per_1k_ctx > 0:
|
||||||
available_for_ctx = target_vram - baseline_vram
|
# Find which result has no offload (if any) to use as reference
|
||||||
estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000
|
ref_result = result if result['offload_pct'] == 0 else (results[0] if results[0]['offload_pct'] == 0 else None)
|
||||||
predicted_optimal = baseline_ctx + int(estimated_additional_ctx)
|
|
||||||
# Round to multiple of 2048
|
if ref_result:
|
||||||
|
# We have a point that fits - extrapolate from there
|
||||||
|
available_for_ctx = target_vram - ref_result['vram_gb']
|
||||||
|
estimated_additional_ctx = (available_for_ctx / vram_per_1k_ctx) * 1000
|
||||||
|
predicted_optimal = ref_result['num_ctx'] + int(estimated_additional_ctx)
|
||||||
|
else:
|
||||||
|
# Neither fits - need to find what would fit
|
||||||
|
# Start from the smaller test and work backwards
|
||||||
|
smaller_result = results[0] if results[0]['num_ctx'] < result['num_ctx'] else result
|
||||||
|
vram_needed_reduction = smaller_result['vram_gb'] - target_vram
|
||||||
|
ctx_reduction_needed = (vram_needed_reduction / vram_per_1k_ctx) * 1000
|
||||||
|
predicted_optimal = smaller_result['num_ctx'] - int(ctx_reduction_needed)
|
||||||
|
|
||||||
|
# Round to multiple of 2048 and clamp to valid range
|
||||||
predicted_optimal = (predicted_optimal // 2048) * 2048
|
predicted_optimal = (predicted_optimal // 2048) * 2048
|
||||||
predicted_optimal = max(baseline_ctx, min(predicted_optimal, max_context))
|
predicted_optimal = max(2048, min(predicted_optimal, max_context))
|
||||||
|
|
||||||
print(f" → Predicted optimal context: {predicted_optimal:,}")
|
print(f" → Predicted optimal context: {predicted_optimal:,}")
|
||||||
else:
|
else:
|
||||||
@@ -384,8 +410,15 @@ def find_optimal_context(model_name: str, max_turns: Optional[int], overhead_gb:
|
|||||||
predicted_optimal = None
|
predicted_optimal = None
|
||||||
|
|
||||||
# Remaining turns: Test predicted optimal or use VRAM-based refinement
|
# Remaining turns: Test predicted optimal or use VRAM-based refinement
|
||||||
min_ctx = baseline_ctx
|
# Initialize search bounds based on whether baseline has offload
|
||||||
max_ctx = max_context
|
if baseline_has_offload:
|
||||||
|
# Search downward from baseline to find what fits
|
||||||
|
min_ctx = 2048 # Minimum practical context
|
||||||
|
max_ctx = baseline_ctx
|
||||||
|
else:
|
||||||
|
# Search upward from baseline to find max that fits
|
||||||
|
min_ctx = baseline_ctx
|
||||||
|
max_ctx = max_context
|
||||||
|
|
||||||
turn = 2
|
turn = 2
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
Reference in New Issue
Block a user