data-model #2

Merged
mstoeck3 merged 20 commits from data-model into main 2025-09-17 21:56:18 +00:00
3 changed files with 149 additions and 493 deletions
Showing only changes of commit d0ee415f0d - Show all commits

311
app.py
View File

@ -18,38 +18,29 @@ from utils.helpers import is_valid_target
app = Flask(__name__) app = Flask(__name__)
# Use centralized configuration for Flask settings
app.config['SECRET_KEY'] = config.flask_secret_key app.config['SECRET_KEY'] = config.flask_secret_key
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours) app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours)
def get_user_scanner(): def get_user_scanner():
""" """
Retrieves the scanner for the current session, or creates a new session only if none exists. Retrieves the scanner for the current session, or creates a new one if none exists.
""" """
current_flask_session_id = session.get('dnsrecon_session_id') current_flask_session_id = session.get('dnsrecon_session_id')
# Try to get existing session first
if current_flask_session_id: if current_flask_session_id:
existing_scanner = session_manager.get_session(current_flask_session_id) existing_scanner = session_manager.get_session(current_flask_session_id)
if existing_scanner: if existing_scanner:
#print(f"Reusing existing session: {current_flask_session_id}")
return current_flask_session_id, existing_scanner return current_flask_session_id, existing_scanner
else:
print(f"Session {current_flask_session_id} expired, will create new one")
# Only create new session if we absolutely don't have one
print("Creating new session (no valid session found)")
new_session_id = session_manager.create_session() new_session_id = session_manager.create_session()
new_scanner = session_manager.get_session(new_session_id) new_scanner = session_manager.get_session(new_session_id)
if not new_scanner: if not new_scanner:
raise Exception("Failed to create new scanner session") raise Exception("Failed to create new scanner session")
# Store in Flask session
session['dnsrecon_session_id'] = new_session_id session['dnsrecon_session_id'] = new_session_id
session.permanent = True session.permanent = True
print(f"Created new session: {new_session_id}")
return new_session_id, new_scanner return new_session_id, new_scanner
@app.route('/') @app.route('/')
@ -61,11 +52,8 @@ def index():
@app.route('/api/scan/start', methods=['POST']) @app.route('/api/scan/start', methods=['POST'])
def start_scan(): def start_scan():
""" """
FIXED: Start a new reconnaissance scan while preserving session configuration. Starts a new reconnaissance scan.
Only clears graph data, not the entire session with API keys.
""" """
print("=== API: /api/scan/start called ===")
try: try:
data = request.get_json() data = request.get_json()
if not data or 'target' not in data: if not data or 'target' not in data:
@ -76,25 +64,18 @@ def start_scan():
clear_graph = data.get('clear_graph', True) clear_graph = data.get('clear_graph', True)
force_rescan_target = data.get('force_rescan_target', None) force_rescan_target = data.get('force_rescan_target', None)
print(f"Parsed - target: '{target}', max_depth: {max_depth}, clear_graph: {clear_graph}, force_rescan: {force_rescan_target}")
# Validation
if not target: if not target:
return jsonify({'success': False, 'error': 'Target cannot be empty'}), 400 return jsonify({'success': False, 'error': 'Target cannot be empty'}), 400
if not is_valid_target(target): if not is_valid_target(target):
return jsonify({'success': False, 'error': 'Invalid target format. Please enter a valid domain or IP address.'}), 400 return jsonify({'success': False, 'error': 'Invalid target format.'}), 400
if not isinstance(max_depth, int) or not 1 <= max_depth <= 5: if not isinstance(max_depth, int) or not 1 <= max_depth <= 5:
return jsonify({'success': False, 'error': 'Max depth must be an integer between 1 and 5'}), 400 return jsonify({'success': False, 'error': 'Max depth must be an integer between 1 and 5'}), 400
# FIXED: Always reuse existing session, preserve API keys
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500 return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500
print(f"Using scanner {id(scanner)} in session {user_session_id}")
# FIXED: Pass clear_graph flag to scanner, let it handle graph clearing internally
success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target) success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target)
if success: if success:
@ -102,8 +83,7 @@ def start_scan():
'success': True, 'success': True,
'message': 'Scan started successfully', 'message': 'Scan started successfully',
'scan_id': scanner.logger.session_id, 'scan_id': scanner.logger.session_id,
'user_session_id': user_session_id, 'user_session_id': user_session_id
'available_providers': [p.get_name() for p in scanner.providers] # Show which providers are active
}) })
else: else:
return jsonify({ return jsonify({
@ -112,206 +92,98 @@ def start_scan():
}), 409 }), 409
except Exception as e: except Exception as e:
print(f"ERROR: Exception in start_scan endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/scan/stop', methods=['POST']) @app.route('/api/scan/stop', methods=['POST'])
def stop_scan(): def stop_scan():
"""Stop the current scan with immediate GUI feedback.""" """Stop the current scan."""
print("=== API: /api/scan/stop called ===")
try: try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
print(f"Stopping scan for session: {user_session_id}")
if not scanner: if not scanner:
return jsonify({ return jsonify({'success': False, 'error': 'No scanner found for session'}), 404
'success': False,
'error': 'No scanner found for session'
}), 404
# Ensure session ID is set
if not scanner.session_id: if not scanner.session_id:
scanner.session_id = user_session_id scanner.session_id = user_session_id
# Use the stop mechanism scanner.stop_scan()
success = scanner.stop_scan()
# Also set the Redis stop signal directly for extra reliability
session_manager.set_stop_signal(user_session_id) session_manager.set_stop_signal(user_session_id)
# Force immediate status update
session_manager.update_scanner_status(user_session_id, 'stopped') session_manager.update_scanner_status(user_session_id, 'stopped')
# Update the full scanner state
session_manager.update_session_scanner(user_session_id, scanner) session_manager.update_session_scanner(user_session_id, scanner)
print(f"Stop scan completed. Success: {success}, Scanner status: {scanner.status}")
return jsonify({ return jsonify({
'success': True, 'success': True,
'message': 'Scan stop requested - termination initiated', 'message': 'Scan stop requested',
'user_session_id': user_session_id, 'user_session_id': user_session_id
'scanner_status': scanner.status,
'stop_method': 'cross_process'
}) })
except Exception as e: except Exception as e:
print(f"ERROR: Exception in stop_scan endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/scan/status', methods=['GET']) @app.route('/api/scan/status', methods=['GET'])
def get_scan_status(): def get_scan_status():
"""Get current scan status with error handling.""" """Get current scan status."""
try: try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner: if not scanner:
# Return default idle status if no scanner
return jsonify({ return jsonify({
'success': True, 'success': True,
'status': { 'status': {
'status': 'idle', 'status': 'idle', 'target_domain': None, 'current_depth': 0,
'target_domain': None, 'max_depth': 0, 'progress_percentage': 0.0,
'current_depth': 0,
'max_depth': 0,
'current_indicator': '',
'total_indicators_found': 0,
'indicators_processed': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {},
'user_session_id': user_session_id 'user_session_id': user_session_id
} }
}) })
# Ensure session ID is set
if not scanner.session_id: if not scanner.session_id:
scanner.session_id = user_session_id scanner.session_id = user_session_id
status = scanner.get_scan_status() status = scanner.get_scan_status()
status['user_session_id'] = user_session_id status['user_session_id'] = user_session_id
# Additional debug info return jsonify({'success': True, 'status': status})
status['debug_info'] = {
'scanner_object_id': id(scanner),
'session_id_set': bool(scanner.session_id),
'has_scan_thread': bool(scanner.scan_thread and scanner.scan_thread.is_alive()),
'provider_count': len(scanner.providers),
'provider_names': [p.get_name() for p in scanner.providers]
}
return jsonify({
'success': True,
'status': status
})
except Exception as e: except Exception as e:
print(f"ERROR: Exception in get_scan_status endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({
'success': False, 'success': False, 'error': f'Internal server error: {str(e)}',
'error': f'Internal server error: {str(e)}', 'fallback_status': {'status': 'error', 'progress_percentage': 0.0}
'fallback_status': {
'status': 'error',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'progress_percentage': 0.0
}
}), 500 }), 500
@app.route('/api/graph', methods=['GET']) @app.route('/api/graph', methods=['GET'])
def get_graph_data(): def get_graph_data():
"""Get current graph data with error handling and proper empty graph structure.""" """Get current graph data."""
try: try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
empty_graph = {
'nodes': [], 'edges': [],
'statistics': {'node_count': 0, 'edge_count': 0}
}
if not scanner: if not scanner:
# FIXED: Return proper empty graph structure instead of None return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id})
empty_graph = {
'nodes': [],
'edges': [],
'statistics': {
'node_count': 0,
'edge_count': 0,
'creation_time': datetime.now(timezone.utc).isoformat(),
'last_modified': datetime.now(timezone.utc).isoformat()
}
}
return jsonify({
'success': True,
'graph': empty_graph,
'user_session_id': user_session_id
})
graph_data = scanner.get_graph_data() graph_data = scanner.get_graph_data() or empty_graph
if not graph_data: return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id})
graph_data = {
'nodes': [],
'edges': [],
'statistics': {
'node_count': 0,
'edge_count': 0,
'creation_time': datetime.now(timezone.utc).isoformat(),
'last_modified': datetime.now(timezone.utc).isoformat()
}
}
# FIXED: Ensure required fields exist
if 'nodes' not in graph_data:
graph_data['nodes'] = []
if 'edges' not in graph_data:
graph_data['edges'] = []
if 'statistics' not in graph_data:
graph_data['statistics'] = {
'node_count': len(graph_data['nodes']),
'edge_count': len(graph_data['edges']),
'creation_time': datetime.now(timezone.utc).isoformat(),
'last_modified': datetime.now(timezone.utc).isoformat()
}
return jsonify({
'success': True,
'graph': graph_data,
'user_session_id': user_session_id
})
except Exception as e: except Exception as e:
print(f"ERROR: Exception in get_graph_data endpoint: {e}")
traceback.print_exc() traceback.print_exc()
# FIXED: Return proper error structure with empty graph fallback
return jsonify({ return jsonify({
'success': False, 'success': False, 'error': f'Internal server error: {str(e)}',
'error': f'Internal server error: {str(e)}', 'fallback_graph': {'nodes': [], 'edges': [], 'statistics': {}}
'fallback_graph': {
'nodes': [],
'edges': [],
'statistics': {
'node_count': 0,
'edge_count': 0,
'creation_time': datetime.now(timezone.utc).isoformat(),
'last_modified': datetime.now(timezone.utc).isoformat()
}
}
}), 500 }), 500
@app.route('/api/graph/large-entity/extract', methods=['POST']) @app.route('/api/graph/large-entity/extract', methods=['POST'])
def extract_from_large_entity(): def extract_from_large_entity():
"""Extract a node from a large entity, making it a standalone node.""" """Extract a node from a large entity."""
try: try:
data = request.get_json() data = request.get_json()
large_entity_id = data.get('large_entity_id') large_entity_id = data.get('large_entity_id')
@ -333,13 +205,12 @@ def extract_from_large_entity():
return jsonify({'success': False, 'error': f'Failed to extract node {node_id}.'}), 500 return jsonify({'success': False, 'error': f'Failed to extract node {node_id}.'}), 500
except Exception as e: except Exception as e:
print(f"ERROR: Exception in extract_from_large_entity endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/graph/node/<node_id>', methods=['DELETE']) @app.route('/api/graph/node/<node_id>', methods=['DELETE'])
def delete_graph_node(node_id): def delete_graph_node(node_id):
"""Delete a node from the graph for the current user session.""" """Delete a node from the graph."""
try: try:
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner: if not scanner:
@ -348,14 +219,12 @@ def delete_graph_node(node_id):
success = scanner.graph.remove_node(node_id) success = scanner.graph.remove_node(node_id)
if success: if success:
# Persist the change
session_manager.update_session_scanner(user_session_id, scanner) session_manager.update_session_scanner(user_session_id, scanner)
return jsonify({'success': True, 'message': f'Node {node_id} deleted successfully.'}) return jsonify({'success': True, 'message': f'Node {node_id} deleted successfully.'})
else: else:
return jsonify({'success': False, 'error': f'Node {node_id} not found in graph.'}), 404 return jsonify({'success': False, 'error': f'Node {node_id} not found.'}), 404
except Exception as e: except Exception as e:
print(f"ERROR: Exception in delete_graph_node endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@ -376,7 +245,6 @@ def revert_graph_action():
action_data = data['data'] action_data = data['data']
if action_type == 'delete': if action_type == 'delete':
# Re-add the node
node_to_add = action_data.get('node') node_to_add = action_data.get('node')
if node_to_add: if node_to_add:
scanner.graph.add_node( scanner.graph.add_node(
@ -387,204 +255,121 @@ def revert_graph_action():
metadata=node_to_add.get('metadata') metadata=node_to_add.get('metadata')
) )
# Re-add the edges
edges_to_add = action_data.get('edges', []) edges_to_add = action_data.get('edges', [])
for edge in edges_to_add: for edge in edges_to_add:
# Add edge only if both nodes exist to prevent errors
if scanner.graph.graph.has_node(edge['from']) and scanner.graph.graph.has_node(edge['to']): if scanner.graph.graph.has_node(edge['from']) and scanner.graph.graph.has_node(edge['to']):
scanner.graph.add_edge( scanner.graph.add_edge(
source_id=edge['from'], source_id=edge['from'], target_id=edge['to'],
target_id=edge['to'],
relationship_type=edge['metadata']['relationship_type'], relationship_type=edge['metadata']['relationship_type'],
confidence_score=edge['metadata']['confidence_score'], confidence_score=edge['metadata']['confidence_score'],
source_provider=edge['metadata']['source_provider'], source_provider=edge['metadata']['source_provider'],
raw_data=edge.get('raw_data', {}) raw_data=edge.get('raw_data', {})
) )
# Persist the change
session_manager.update_session_scanner(user_session_id, scanner) session_manager.update_session_scanner(user_session_id, scanner)
return jsonify({'success': True, 'message': 'Delete action reverted successfully.'}) return jsonify({'success': True, 'message': 'Delete action reverted successfully.'})
return jsonify({'success': False, 'error': f'Unknown revert action type: {action_type}'}), 400 return jsonify({'success': False, 'error': f'Unknown revert action type: {action_type}'}), 400
except Exception as e: except Exception as e:
print(f"ERROR: Exception in revert_graph_action endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/export', methods=['GET']) @app.route('/api/export', methods=['GET'])
def export_results(): def export_results():
"""Export complete scan results as downloadable JSON for the user session.""" """Export scan results as a JSON file."""
try: try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
# Get complete results
results = scanner.export_results() results = scanner.export_results()
# Add session information to export
results['export_metadata'] = { results['export_metadata'] = {
'user_session_id': user_session_id, 'user_session_id': user_session_id,
'export_timestamp': datetime.now(timezone.utc).isoformat(), 'export_timestamp': datetime.now(timezone.utc).isoformat(),
'export_type': 'user_session_results'
} }
# Create filename with timestamp
timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S') timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
target = scanner.current_target or 'unknown' target = scanner.current_target or 'unknown'
filename = f"dnsrecon_{target}_{timestamp}_{user_session_id[:8]}.json" filename = f"dnsrecon_{target}_{timestamp}.json"
# Create in-memory file json_data = json.dumps(results, indent=2)
json_data = json.dumps(results, indent=2, ensure_ascii=False)
file_obj = io.BytesIO(json_data.encode('utf-8')) file_obj = io.BytesIO(json_data.encode('utf-8'))
return send_file( return send_file(
file_obj, file_obj, as_attachment=True,
as_attachment=True, download_name=filename, mimetype='application/json'
download_name=filename,
mimetype='application/json'
) )
except Exception as e: except Exception as e:
print(f"ERROR: Exception in export_results endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({'success': False, 'error': f'Export failed: {str(e)}'}), 500
'success': False,
'error': f'Export failed: {str(e)}'
}), 500
@app.route('/api/providers', methods=['GET']) @app.route('/api/providers', methods=['GET'])
def get_providers(): def get_providers():
"""Get information about available providers for the user session.""" """Get information about available providers."""
try: try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if scanner and scanner.status == 'running':
status = scanner.get_scan_status()
currently_processing = status.get('currently_processing')
if currently_processing:
provider_name, target_item = currently_processing[0]
print(f"DEBUG: RUNNING Task - Provider: {provider_name}, Target: {target_item}")
print(f"DEBUG: Task Queue Status - In Queue: {status.get('tasks_in_queue', 0)}, Completed: {status.get('tasks_completed', 0)}, Skipped: {status.get('tasks_skipped', 0)}, Rescheduled: {status.get('tasks_rescheduled', 0)}")
elif not scanner:
print("DEBUG: No active scanner session found.")
provider_info = scanner.get_provider_info() provider_info = scanner.get_provider_info()
return jsonify({ return jsonify({'success': True, 'providers': provider_info, 'user_session_id': user_session_id})
'success': True,
'providers': provider_info,
'user_session_id': user_session_id
})
except Exception as e: except Exception as e:
print(f"ERROR: Exception in get_providers endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/config/api-keys', methods=['POST']) @app.route('/api/config/api-keys', methods=['POST'])
def set_api_keys(): def set_api_keys():
""" """Set API keys for the current session."""
"""
try: try:
data = request.get_json() data = request.get_json()
if data is None: if data is None:
return jsonify({ return jsonify({'success': False, 'error': 'No API keys provided'}), 400
'success': False,
'error': 'No API keys provided'
}), 400
# Get user-specific scanner and config
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
session_config = scanner.config session_config = scanner.config
print(f"Setting API keys for session {user_session_id}: {list(data.keys())}")
updated_providers = [] updated_providers = []
# Iterate over the API keys provided in the request data
for provider_name, api_key in data.items(): for provider_name, api_key in data.items():
# This allows us to both set and clear keys. The config
# handles enabling/disabling based on if the key is empty.
api_key_value = str(api_key or '').strip() api_key_value = str(api_key or '').strip()
success = session_config.set_api_key(provider_name.lower(), api_key_value) success = session_config.set_api_key(provider_name.lower(), api_key_value)
if success: if success:
updated_providers.append(provider_name) updated_providers.append(provider_name)
print(f"API key {'set' if api_key_value else 'cleared'} for {provider_name}")
if updated_providers: if updated_providers:
# FIXED: Reinitialize scanner providers to apply the new keys
print("Reinitializing providers with new API keys...")
old_provider_count = len(scanner.providers)
scanner._initialize_providers() scanner._initialize_providers()
new_provider_count = len(scanner.providers)
print(f"Providers reinitialized: {old_provider_count} -> {new_provider_count}")
print(f"Available providers: {[p.get_name() for p in scanner.providers]}")
# Persist the updated scanner object back to the user's session
session_manager.update_session_scanner(user_session_id, scanner) session_manager.update_session_scanner(user_session_id, scanner)
return jsonify({ return jsonify({
'success': True, 'success': True,
'message': f'API keys updated for session {user_session_id}: {", ".join(updated_providers)}', 'message': f'API keys updated for: {", ".join(updated_providers)}',
'updated_providers': updated_providers, 'user_session_id': user_session_id
'user_session_id': user_session_id,
'available_providers': [p.get_name() for p in scanner.providers]
}) })
else: else:
return jsonify({ return jsonify({'success': False, 'error': 'No valid API keys were provided.'}), 400
'success': False,
'error': 'No valid API keys were provided or provider names were incorrect.'
}), 400
except Exception as e: except Exception as e:
print(f"ERROR: Exception in set_api_keys endpoint: {e}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.errorhandler(404) @app.errorhandler(404)
def not_found(error): def not_found(error):
"""Handle 404 errors.""" """Handle 404 errors."""
return jsonify({ return jsonify({'success': False, 'error': 'Endpoint not found'}), 404
'success': False,
'error': 'Endpoint not found'
}), 404
@app.errorhandler(500) @app.errorhandler(500)
def internal_error(error): def internal_error(error):
"""Handle 500 errors.""" """Handle 500 errors."""
print(f"ERROR: 500 Internal Server Error: {error}")
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({'success': False, 'error': 'Internal server error'}), 500
'success': False,
'error': 'Internal server error'
}), 500
if __name__ == '__main__': if __name__ == '__main__':
print("Starting DNSRecon Flask application with streamlined session management...")
# Load configuration from environment
config.load_from_env() config.load_from_env()
# Start Flask application
print(f"Starting server on {config.flask_host}:{config.flask_port}")
app.run( app.run(
host=config.flask_host, host=config.flask_host,
port=config.flask_port, port=config.flask_port,

View File

@ -149,7 +149,7 @@ class GraphManager:
if self.graph.has_node(node_id) and not self.graph.has_edge(node_id, correlation_node_id): if self.graph.has_node(node_id) and not self.graph.has_edge(node_id, correlation_node_id):
# Format relationship label as "corr_provider_attribute" # Format relationship label as "corr_provider_attribute"
relationship_label = f"corr_{provider}_{attribute}" relationship_label = f"{provider}_{attribute}"
self.add_edge( self.add_edge(
source_id=node_id, source_id=node_id,

View File

@ -5,6 +5,7 @@ import traceback
import os import os
import importlib import importlib
import redis import redis
import time
from typing import List, Set, Dict, Any, Tuple, Optional from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict from collections import defaultdict
@ -30,13 +31,11 @@ class ScanStatus:
class Scanner: class Scanner:
""" """
Main scanning orchestrator for DNSRecon passive reconnaissance. Main scanning orchestrator for DNSRecon passive reconnaissance.
FIXED: Now preserves session configuration including API keys when clearing graphs. UNIFIED: Combines comprehensive features with improved display formatting.
""" """
def __init__(self, session_config=None): def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration.""" """Initialize scanner with session-specific configuration."""
print("Initializing Scanner instance...")
try: try:
# Use provided session config or create default # Use provided session config or create default
if session_config is None: if session_config is None:
@ -57,15 +56,18 @@ class Scanner:
self.target_retries = defaultdict(int) self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False self.scan_failed_due_to_retries = False
# **NEW**: Track currently processing tasks to prevent processing after stop # Thread-safe processing tracking (from Document 1)
self.currently_processing = set() self.currently_processing = set()
self.processing_lock = threading.Lock() self.processing_lock = threading.Lock()
# Display-friendly processing list (from Document 2)
self.currently_processing_display = []
# Scanning progress tracking # Scanning progress tracking
self.total_indicators_found = 0 self.total_indicators_found = 0
self.indicators_processed = 0 self.indicators_processed = 0
self.indicators_completed = 0 self.indicators_completed = 0
self.tasks_re_enqueued = 0 self.tasks_re_enqueued = 0
self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped
self.total_tasks_ever_enqueued = 0 self.total_tasks_ever_enqueued = 0
self.current_indicator = "" self.current_indicator = ""
@ -73,19 +75,19 @@ class Scanner:
self.max_workers = self.config.max_concurrent_requests self.max_workers = self.config.max_concurrent_requests
self.executor = None self.executor = None
# Status logger thread with improved formatting
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
# Initialize providers with session config # Initialize providers with session config
print("Calling _initialize_providers with session config...")
self._initialize_providers() self._initialize_providers()
# Initialize logger # Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger() self.logger = get_forensic_logger()
# Initialize global rate limiter # Initialize global rate limiter
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
print("Scanner initialization complete")
except Exception as e: except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}") print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc() traceback.print_exc()
@ -96,17 +98,14 @@ class Scanner:
Check if stop is requested using both local and Redis-based signals. Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries. This ensures reliable termination across process boundaries.
""" """
# Check local threading event first (fastest)
if self.stop_event.is_set(): if self.stop_event.is_set():
return True return True
# Check Redis-based stop signal if session ID is available
if self.session_id: if self.session_id:
try: try:
from core.session_manager import session_manager from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id) return session_manager.is_stop_requested(self.session_id)
except Exception as e: except Exception as e:
print(f"Error checking Redis stop signal: {e}")
# Fall back to local event # Fall back to local event
return self.stop_event.is_set() return self.stop_event.is_set()
@ -116,22 +115,19 @@ class Scanner:
""" """
Set stop signal both locally and in Redis. Set stop signal both locally and in Redis.
""" """
# Set local event
self.stop_event.set() self.stop_event.set()
# Set Redis signal if session ID is available
if self.session_id: if self.session_id:
try: try:
from core.session_manager import session_manager from core.session_manager import session_manager
session_manager.set_stop_signal(self.session_id) session_manager.set_stop_signal(self.session_id)
except Exception as e: except Exception as e:
print(f"Error setting Redis stop signal: {e}") pass
def __getstate__(self): def __getstate__(self):
"""Prepare object for pickling by excluding unpicklable attributes.""" """Prepare object for pickling by excluding unpicklable attributes."""
state = self.__dict__.copy() state = self.__dict__.copy()
# Remove unpicklable threading objects
unpicklable_attrs = [ unpicklable_attrs = [
'stop_event', 'stop_event',
'scan_thread', 'scan_thread',
@ -139,14 +135,15 @@ class Scanner:
'processing_lock', 'processing_lock',
'task_queue', 'task_queue',
'rate_limiter', 'rate_limiter',
'logger' 'logger',
'status_logger_thread',
'status_logger_stop_event'
] ]
for attr in unpicklable_attrs: for attr in unpicklable_attrs:
if attr in state: if attr in state:
del state[attr] del state[attr]
# Handle providers separately to ensure they're picklable
if 'providers' in state: if 'providers' in state:
for provider in state['providers']: for provider in state['providers']:
if hasattr(provider, '_stop_event'): if hasattr(provider, '_stop_event'):
@ -158,7 +155,6 @@ class Scanner:
"""Restore object after unpickling by reconstructing threading objects.""" """Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state) self.__dict__.update(state)
# Reconstruct threading objects
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.scan_thread = None self.scan_thread = None
self.executor = None self.executor = None
@ -166,15 +162,18 @@ class Scanner:
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger() self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
if not hasattr(self, 'providers') or not self.providers: if not hasattr(self, 'providers') or not self.providers:
print("Providers not found after loading session, re-initializing...")
self._initialize_providers() self._initialize_providers()
if not hasattr(self, 'currently_processing'): if not hasattr(self, 'currently_processing'):
self.currently_processing = set() self.currently_processing = set()
# Re-set stop events for providers if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = []
if hasattr(self, 'providers'): if hasattr(self, 'providers'):
for provider in self.providers: for provider in self.providers:
if hasattr(provider, 'set_stop_event'): if hasattr(provider, 'set_stop_event'):
@ -183,8 +182,6 @@ class Scanner:
def _initialize_providers(self) -> None: def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration.""" """Initialize all available providers based on session configuration."""
self.providers = [] self.providers = []
print("Initializing providers with session config...")
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir): for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'): if filename.endswith('_provider.py') and not filename.startswith('base'):
@ -202,99 +199,98 @@ class Scanner:
if provider.is_available(): if provider.is_available():
provider.set_stop_event(self.stop_event) provider.set_stop_event(self.stop_event)
self.providers.append(provider) self.providers.append(provider)
print(f"{provider.get_display_name()} provider initialized successfully for session")
else:
print(f"{provider.get_display_name()} provider is not available")
except Exception as e: except Exception as e:
print(f"✗ Failed to initialize provider from {filename}: {e}")
traceback.print_exc() traceback.print_exc()
print(f"Initialized {len(self.providers)} providers for session") def _status_logger_thread(self):
"""Periodically prints a clean, formatted scan status to the terminal."""
# Color codes for improved display (from Document 2)
HEADER = "\033[95m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
ENDC = "\033[0m"
BOLD = "\033[1m"
last_status_str = ""
while not self.status_logger_stop_event.is_set():
try:
# Use thread-safe copy of currently processing
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
# Update display list for consistent formatting
self.currently_processing_display = in_flight_tasks.copy()
def update_session_config(self, new_config) -> None: status_str = (
"""Update session configuration and reinitialize providers.""" f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
print("Updating session configuration...") f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
self.config = new_config f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
self.max_workers = self.config.max_concurrent_requests f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
self._initialize_providers() f"Skipped: {self.tasks_skipped} | "
print("Session configuration updated") f"Rescheduled: {self.tasks_re_enqueued}"
)
if status_str != last_status_str:
print(f"\n{'-'*80}")
print(status_str)
if in_flight_tasks:
print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
# Display up to 3 currently processing tasks
display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]]
print("\n".join(display_tasks))
if len(in_flight_tasks) > 3:
print(f" ... and {len(in_flight_tasks) - 3} more")
print(f"{'-'*80}")
last_status_str = status_str
except Exception:
# Silently fail to avoid crashing the logger
pass
time.sleep(2) # Update interval
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
""" """
FIXED: Start a new reconnaissance scan preserving session configuration. Starts a new reconnaissance scan.
Only clears graph data when requested, never destroys session/API keys.
""" """
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Session ID: {self.session_id}")
print(f"Initial scanner status: {self.status}")
print(f"Clear graph requested: {clear_graph}")
print(f"Current providers: {[p.get_name() for p in self.providers]}")
self.total_tasks_ever_enqueued = 0
# FIXED: Improved cleanup of previous scan without destroying session config
if self.scan_thread and self.scan_thread.is_alive(): if self.scan_thread and self.scan_thread.is_alive():
print("A previous scan thread is still alive. Forcing termination...")
# Set stop signals immediately
self._set_stop_signal() self._set_stop_signal()
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
# Clear all processing state
with self.processing_lock: with self.processing_lock:
self.currently_processing.clear() self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
# Shutdown executor aggressively
if self.executor: if self.executor:
print("Shutting down executor forcefully...")
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None self.executor = None
self.scan_thread.join(5.0)
# Wait for thread termination with shorter timeout
print("Waiting for previous scan thread to terminate...")
self.scan_thread.join(5.0) # Reduced from 10 seconds
if self.scan_thread.is_alive():
print("WARNING: Previous scan thread is still alive after 5 seconds")
self.logger.logger.warning("Previous scan thread failed to terminate cleanly")
# FIXED: Reset scan state but preserve session configuration (API keys, etc.)
print("Resetting scanner state for new scan (preserving session config)...")
self.status = ScanStatus.IDLE self.status = ScanStatus.IDLE
self.stop_event.clear() self.stop_event.clear()
# Clear Redis stop signal explicitly
if self.session_id: if self.session_id:
from core.session_manager import session_manager from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id) session_manager.clear_stop_signal(self.session_id)
with self.processing_lock: with self.processing_lock:
self.currently_processing.clear() self.currently_processing.clear()
self.currently_processing_display = []
# Reset scan-specific state but keep providers and config intact
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
self.target_retries.clear() self.target_retries.clear()
self.scan_failed_due_to_retries = False self.scan_failed_due_to_retries = False
self.tasks_skipped = 0 # BUGFIX: Reset tasks_skipped for new scan
# Update session state immediately for GUI feedback
self._update_session_state() self._update_session_state()
print("Scanner state reset complete (providers preserved).")
try: try:
if not hasattr(self, 'providers') or not self.providers: if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False return False
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
# FIXED: Only clear graph if explicitly requested, don't destroy session
if clear_graph: if clear_graph:
print("Clearing graph data (preserving session configuration)")
self.graph.clear() self.graph.clear()
# Handle force rescan by clearing provider states for that specific node
if force_rescan_target and self.graph.graph.has_node(force_rescan_target): if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
print(f"Forcing rescan of {force_rescan_target}, clearing provider states.")
node_data = self.graph.graph.nodes[force_rescan_target] node_data = self.graph.graph.nodes[force_rescan_target]
if 'metadata' in node_data and 'provider_states' in node_data['metadata']: if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
node_data['metadata']['provider_states'] = {} node_data['metadata']['provider_states'] = {}
@ -307,17 +303,12 @@ class Scanner:
self.indicators_processed = 0 self.indicators_processed = 0
self.indicators_completed = 0 self.indicators_completed = 0
self.tasks_re_enqueued = 0 self.tasks_re_enqueued = 0
self.total_tasks_ever_enqueued = 0
self.current_indicator = self.current_target self.current_indicator = self.current_target
# Update GUI with scan preparation state
self._update_session_state() self._update_session_state()
# Start new forensic session (but don't reinitialize providers)
print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session() self.logger = new_session()
# Start scan in a separate thread
print(f"Starting scan thread for scanner {id(self)}...")
self.scan_thread = threading.Thread( self.scan_thread = threading.Thread(
target=self._execute_scan, target=self._execute_scan,
args=(self.current_target, max_depth), args=(self.current_target, max_depth),
@ -325,12 +316,14 @@ class Scanner:
) )
self.scan_thread.start() self.scan_thread.start()
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") # Start the status logger thread
print(f"Active providers for this scan: {[p.get_name() for p in self.providers]}") self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(target=self._status_logger_thread, daemon=True)
self.status_logger_thread.start()
return True return True
except Exception as e: except Exception as e:
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self._update_session_state() self._update_session_state()
@ -347,11 +340,9 @@ class Scanner:
def _execute_scan(self, target: str, max_depth: int) -> None: def _execute_scan(self, target: str, max_depth: int) -> None:
"""Execute the reconnaissance scan with proper termination handling.""" """Execute the reconnaissance scan with proper termination handling."""
print(f"_execute_scan started for {target} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() processed_tasks = set()
# Initial task population for the main target
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False) initial_providers = self._get_eligible_providers(target, is_ip, False)
for provider in initial_providers: for provider in initial_providers:
@ -366,23 +357,19 @@ class Scanner:
enabled_providers = [provider.get_name() for provider in self.providers] enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target, max_depth, enabled_providers) self.logger.log_scan_start(target, max_depth, enabled_providers)
# Determine initial node type
node_type = NodeType.IP if is_ip else NodeType.DOMAIN node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type) self.graph.add_node(target, node_type)
self._initialize_provider_states(target) self._initialize_provider_states(target)
# Better termination checking in main loop
while not self.task_queue.empty() and not self._is_stop_requested(): while not self.task_queue.empty() and not self._is_stop_requested():
try: try:
priority, (provider_name, target_item, depth) = self.task_queue.get() priority, (provider_name, target_item, depth) = self.task_queue.get()
except IndexError: except IndexError:
# Queue became empty during processing
break break
task_tuple = (provider_name, target_item) task_tuple = (provider_name, target_item)
if task_tuple in processed_tasks: if task_tuple in processed_tasks:
self.indicators_completed += 1 self.tasks_skipped += 1
continue continue
if depth > max_depth: if depth > max_depth:
@ -394,7 +381,6 @@ class Scanner:
with self.processing_lock: with self.processing_lock:
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested before processing {target_item}")
break break
self.currently_processing.add(task_tuple) self.currently_processing.add(task_tuple)
@ -404,7 +390,6 @@ class Scanner:
self._update_session_state() self._update_session_state()
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested during processing setup for {target_item}")
break break
provider = next((p for p in self.providers if p.get_name() == provider_name), None) provider = next((p for p in self.providers if p.get_name() == provider_name), None)
@ -413,18 +398,15 @@ class Scanner:
new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth) new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth)
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested after querying providers for {target_item}")
break break
if not success: if not success:
self.target_retries[task_tuple] += 1 self.target_retries[task_tuple] += 1
if self.target_retries[task_tuple] <= self.config.max_retries_per_target: if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
print(f"Re-queueing task {task_tuple} (attempt {self.target_retries[task_tuple]})")
self.task_queue.put((priority, (provider_name, target_item, depth))) self.task_queue.put((priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1 self.tasks_re_enqueued += 1
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
else: else:
print(f"ERROR: Max retries exceeded for task {task_tuple}")
self.scan_failed_due_to_retries = True self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), "Max retries exceeded") self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
else: else:
@ -446,21 +428,14 @@ class Scanner:
with self.processing_lock: with self.processing_lock:
self.currently_processing.discard(task_tuple) self.currently_processing.discard(task_tuple)
if self._is_stop_requested():
print("Scan terminated due to stop request")
self.logger.logger.info("Scan terminated by user request")
elif self.task_queue.empty():
print("Scan completed - no more targets to process")
self.logger.logger.info("Scan completed - all targets processed")
except Exception as e: except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}") self.logger.logger.error(f"Scan failed: {e}")
finally: finally:
with self.processing_lock: with self.processing_lock:
self.currently_processing.clear() self.currently_processing.clear()
self.currently_processing_display = []
if self._is_stop_requested(): if self._is_stop_requested():
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
@ -469,31 +444,27 @@ class Scanner:
else: else:
self.status = ScanStatus.COMPLETED self.status = ScanStatus.COMPLETED
# Stop the status logger
self.status_logger_stop_event.set()
if self.status_logger_thread:
self.status_logger_thread.join()
self._update_session_state() self._update_session_state()
self.logger.log_scan_complete() self.logger.log_scan_complete()
if self.executor: if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None self.executor = None
stats = self.graph.get_statistics()
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Tasks processed: {len(processed_tasks)}")
def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
""" """
Query a single provider and process the unified ProviderResult. Query a single provider and process the unified ProviderResult.
Now provider-agnostic - handles any provider that returns ProviderResult.
""" """
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested before querying {provider.get_name()} for {target}")
return set(), set(), False return set(), set(), False
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN target_type = NodeType.IP if is_ip else NodeType.DOMAIN
print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}")
# Ensure target node exists in graph
self.graph.add_node(target, target_type) self.graph.add_node(target, target_type)
self._initialize_provider_states(target) self._initialize_provider_states(target)
@ -502,13 +473,11 @@ class Scanner:
provider_successful = True provider_successful = True
try: try:
# Query provider - now returns unified ProviderResult
provider_result = self._query_single_provider_unified(provider, target, is_ip, depth) provider_result = self._query_single_provider_unified(provider, target, is_ip, depth)
if provider_result is None: if provider_result is None:
provider_successful = False provider_successful = False
elif not self._is_stop_requested(): elif not self._is_stop_requested():
# Process the unified result
discovered, is_large_entity = self._process_provider_result_unified( discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth target, provider, provider_result, depth
) )
@ -517,8 +486,6 @@ class Scanner:
else: else:
new_targets.update(discovered) new_targets.update(discovered)
self.graph.process_correlations_for_node(target) self.graph.process_correlations_for_node(target)
else:
print(f"Stop requested after processing results from {provider.get_name()}")
except Exception as e: except Exception as e:
provider_successful = False provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e)) self._log_provider_error(target, provider.get_name(), str(e))
@ -527,58 +494,45 @@ class Scanner:
def _query_single_provider_unified(self, provider: BaseProvider, target: str, is_ip: bool, current_depth: int) -> Optional[ProviderResult]: def _query_single_provider_unified(self, provider: BaseProvider, target: str, is_ip: bool, current_depth: int) -> Optional[ProviderResult]:
""" """
Query a single provider with stop signal checking, now returns ProviderResult. Query a single provider with stop signal checking.
""" """
provider_name = provider.get_name() provider_name = provider.get_name()
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested before querying {provider_name} for {target}")
return None return None
print(f"Querying {provider_name} for {target}")
self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}")
try: try:
# Query the provider - returns unified ProviderResult
if is_ip: if is_ip:
result = provider.query_ip(target) result = provider.query_ip(target)
else: else:
result = provider.query_domain(target) result = provider.query_domain(target)
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested after querying {provider_name} for {target}")
return None return None
# Update provider state with relationship count (more meaningful than raw result count)
relationship_count = result.get_relationship_count() if result else 0 relationship_count = result.get_relationship_count() if result else 0
self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time) self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
print(f"{provider_name} returned {relationship_count} relationships for {target}")
return result return result
except Exception as e: except Exception as e:
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
print(f"{provider_name} failed for {target}: {e}")
return None return None
def _process_provider_result_unified(self, target: str, provider: BaseProvider, def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
""" """
Process a unified ProviderResult object to update the graph. Process a unified ProviderResult object to update the graph.
Returns (discovered_targets, is_large_entity).
""" """
provider_name = provider.get_name() provider_name = provider.get_name()
discovered_targets = set() discovered_targets = set()
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Stop requested before processing results from {provider_name} for {target}")
return discovered_targets, False return discovered_targets, False
attributes_by_node = defaultdict(list) attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes: for attribute in provider_result.attributes:
# Convert the StandardAttribute object to a dictionary that the frontend can use
attr_dict = { attr_dict = {
"name": attribute.name, "name": attribute.name,
"value": attribute.value, "value": attribute.value,
@ -589,10 +543,8 @@ class Scanner:
} }
attributes_by_node[attribute.target_node].append(attr_dict) attributes_by_node[attribute.target_node].append(attr_dict)
# Add attributes to nodes
for node_id, node_attributes_list in attributes_by_node.items(): for node_id, node_attributes_list in attributes_by_node.items():
if self.graph.graph.has_node(node_id): if self.graph.graph.has_node(node_id):
# Determine node type
if _is_valid_ip(node_id): if _is_valid_ip(node_id):
node_type = NodeType.IP node_type = NodeType.IP
elif node_id.startswith('AS') and node_id[2:].isdigit(): elif node_id.startswith('AS') and node_id[2:].isdigit():
@ -600,26 +552,19 @@ class Scanner:
else: else:
node_type = NodeType.DOMAIN node_type = NodeType.DOMAIN
# Add node with the list of attributes
self.graph.add_node(node_id, node_type, attributes=node_attributes_list) self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
# Check for large entity based on relationship count
if provider_result.get_relationship_count() > self.config.large_entity_threshold: if provider_result.get_relationship_count() > self.config.large_entity_threshold:
print(f"Large entity detected: {provider_name} returned {provider_result.get_relationship_count()} relationships for {target}")
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth) members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
return members, True return members, True
# Process relationships
for i, relationship in enumerate(provider_result.relationships): for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested(): # Check periodically for stop if i % 5 == 0 and self._is_stop_requested():
print(f"Stop requested while processing relationships from {provider_name} for {target}")
break break
# Add nodes for relationship endpoints
source_node = relationship.source_node source_node = relationship.source_node
target_node = relationship.target_node target_node = relationship.target_node
# Determine node types
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
if target_node.startswith('AS') and target_node[2:].isdigit(): if target_node.startswith('AS') and target_node[2:].isdigit():
target_type = NodeType.ASN target_type = NodeType.ASN
@ -628,11 +573,9 @@ class Scanner:
else: else:
target_type = NodeType.DOMAIN target_type = NodeType.DOMAIN
# Add nodes to graph
self.graph.add_node(source_node, source_type) self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type) self.graph.add_node(target_node, target_type)
# Add edge to graph
if self.graph.add_edge( if self.graph.add_edge(
source_node, target_node, source_node, target_node,
relationship.relationship_type, relationship.relationship_type,
@ -640,9 +583,8 @@ class Scanner:
provider_name, provider_name,
relationship.raw_data relationship.raw_data
): ):
print(f"Added relationship: {source_node} -> {target_node} ({relationship.relationship_type})") pass
# Track discovered targets for further processing
if _is_valid_domain(target_node) or _is_valid_ip(target_node): if _is_valid_domain(target_node) or _is_valid_ip(target_node):
discovered_targets.add(target_node) discovered_targets.add(target_node)
@ -651,11 +593,10 @@ class Scanner:
def _create_large_entity_from_provider_result(self, source: str, provider_name: str, def _create_large_entity_from_provider_result(self, source: str, provider_name: str,
provider_result: ProviderResult, current_depth: int) -> Set[str]: provider_result: ProviderResult, current_depth: int) -> Set[str]:
""" """
Create a large entity node from a ProviderResult and return the members for DNS processing. Create a large entity node from a ProviderResult.
""" """
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
# Extract target nodes from relationships
targets = [rel.target_node for rel in provider_result.relationships] targets = [rel.target_node for rel in provider_result.relationships]
node_type = 'unknown' node_type = 'unknown'
@ -665,7 +606,6 @@ class Scanner:
elif _is_valid_ip(targets[0]): elif _is_valid_ip(targets[0]):
node_type = 'ip' node_type = 'ip'
# Create nodes in graph (they exist but are grouped)
for target in targets: for target in targets:
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
self.graph.add_node(target, target_node_type) self.graph.add_node(target, target_node_type)
@ -694,106 +634,76 @@ class Scanner:
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description) self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
# Create edge from source to large entity
if provider_result.relationships: if provider_result.relationships:
rel_type = provider_result.relationships[0].relationship_type rel_type = provider_result.relationships[0].relationship_type
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
{'large_entity_info': f'Contains {len(targets)} {node_type}s'}) {'large_entity_info': f'Contains {len(targets)} {node_type}s'})
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}")
return set(targets) return set(targets)
def stop_scan(self) -> bool: def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup.""" """Request immediate scan termination with proper cleanup."""
try: try:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
self.logger.logger.info("Scan termination requested by user") self.logger.logger.info("Scan termination requested by user")
# **IMPROVED**: More aggressive stop signal setting
self._set_stop_signal() self._set_stop_signal()
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
# **NEW**: Clear processing state immediately
with self.processing_lock: with self.processing_lock:
currently_processing_copy = self.currently_processing.copy()
self.currently_processing.clear() self.currently_processing.clear()
print(f"Cleared {len(currently_processing_copy)} currently processing targets: {currently_processing_copy}") self.currently_processing_display = []
# **IMPROVED**: Clear task queue and log what was discarded
discarded_tasks = [] discarded_tasks = []
while not self.task_queue.empty(): while not self.task_queue.empty():
discarded_tasks.append(self.task_queue.get()) discarded_tasks.append(self.task_queue.get())
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
print(f"Discarded {len(discarded_tasks)} pending tasks")
# **IMPROVED**: Aggressively shut down executor
if self.executor: if self.executor:
print("Shutting down executor with immediate cancellation...")
try: try:
# Cancel all pending futures
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
print("Executor shutdown completed")
except Exception as e: except Exception as e:
print(f"Error during executor shutdown: {e}") pass
# Immediately update GUI with stopped status
self._update_session_state() self._update_session_state()
print("Termination signals sent. The scan will stop as soon as possible.")
return True return True
except Exception as e: except Exception as e:
print(f"ERROR: Exception in stop_scan: {e}")
self.logger.logger.error(f"Error during scan termination: {e}") self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc() traceback.print_exc()
return False return False
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
""" """
Extracts a node from a large entity, re-creates its original edge, and Extracts a node from a large entity and re-queues it for scanning.
re-queues it for full scanning.
""" """
if not self.graph.graph.has_node(large_entity_id): if not self.graph.graph.has_node(large_entity_id):
print(f"ERROR: Large entity {large_entity_id} not found.")
return False return False
# 1. Get the original source node that discovered the large entity
predecessors = list(self.graph.graph.predecessors(large_entity_id)) predecessors = list(self.graph.graph.predecessors(large_entity_id))
if not predecessors: if not predecessors:
print(f"ERROR: No source node found for large entity {large_entity_id}.")
return False return False
source_node_id = predecessors[0] source_node_id = predecessors[0]
# Get the original edge data to replicate it for the extracted node
original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id)
if not original_edge_data: if not original_edge_data:
print(f"ERROR: Could not find original edge data from {source_node_id} to {large_entity_id}.")
return False return False
# 2. Modify the graph data structure first
success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract)
if not success: if not success:
print(f"ERROR: Node {node_id_to_extract} could not be removed from {large_entity_id}'s attributes.")
return False return False
# 3. Create the direct edge from the original source to the newly extracted node
print(f"Re-creating direct edge from {source_node_id} to extracted node {node_id_to_extract}")
self.graph.add_edge( self.graph.add_edge(
source_id=source_node_id, source_id=source_node_id,
target_id=node_id_to_extract, target_id=node_id_to_extract,
relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'), relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'),
confidence_score=original_edge_data.get('confidence_score', 0.85), # Slightly lower confidence confidence_score=original_edge_data.get('confidence_score', 0.85),
source_provider=original_edge_data.get('source_provider', 'unknown'), source_provider=original_edge_data.get('source_provider', 'unknown'),
raw_data={'context': f'Extracted from large entity {large_entity_id}'} raw_data={'context': f'Extracted from large entity {large_entity_id}'}
) )
# 4. Re-queue the extracted node for full processing by all eligible providers
print(f"Re-queueing extracted node {node_id_to_extract} for full reconnaissance...")
is_ip = _is_valid_ip(node_id_to_extract) is_ip = _is_valid_ip(node_id_to_extract)
# FIX: Correctly retrieve discovery_depth from the list of attributes
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', []) large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None) discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0 current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
@ -804,9 +714,7 @@ class Scanner:
self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth))) self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth)))
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
# 5. If the scanner is not running, we need to kickstart it to process this one item.
if self.status != ScanStatus.RUNNING: if self.status != ScanStatus.RUNNING:
print("Scanner is idle. Starting a mini-scan to process the extracted node.")
self.status = ScanStatus.RUNNING self.status = ScanStatus.RUNNING
self._update_session_state() self._update_session_state()
@ -818,25 +726,21 @@ class Scanner:
) )
self.scan_thread.start() self.scan_thread.start()
print(f"Successfully extracted and re-queued {node_id_to_extract} from {large_entity_id}.")
return True return True
def _update_session_state(self) -> None: def _update_session_state(self) -> None:
""" """
Update the scanner state in Redis for GUI updates. Update the scanner state in Redis for GUI updates.
This ensures the web interface sees real-time updates.
""" """
if self.session_id: if self.session_id:
try: try:
from core.session_manager import session_manager from core.session_manager import session_manager
success = session_manager.update_session_scanner(self.session_id, self) session_manager.update_session_scanner(self.session_id, self)
if not success:
print(f"WARNING: Failed to update session state for {self.session_id}")
except Exception as e: except Exception as e:
print(f"ERROR: Failed to update session state: {e}") pass
def get_scan_status(self) -> Dict[str, Any]: def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with processing information.""" """Get current scan status with comprehensive processing information."""
try: try:
with self.processing_lock: with self.processing_lock:
currently_processing_count = len(self.currently_processing) currently_processing_count = len(self.currently_processing)
@ -860,31 +764,18 @@ class Scanner:
'currently_processing': currently_processing_list[:5], 'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize(), 'tasks_in_queue': self.task_queue.qsize(),
'tasks_completed': self.indicators_completed, 'tasks_completed': self.indicators_completed,
'tasks_skipped': self.total_tasks_ever_enqueued - self.task_queue.qsize() - self.indicators_completed - self.tasks_re_enqueued, 'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued, 'tasks_rescheduled': self.tasks_re_enqueued,
} }
except Exception as e: except Exception as e:
print(f"ERROR: Exception in get_scan_status: {e}")
traceback.print_exc() traceback.print_exc()
return { return {
'status': 'error', 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0,
'target_domain': None, 'current_indicator': '', 'indicators_processed': 0, 'indicators_completed': 0,
'current_depth': 0, 'tasks_re_enqueued': 0, 'progress_percentage': 0.0, 'enabled_providers': [],
'max_depth': 0, 'graph_statistics': {}, 'task_queue_size': 0, 'currently_processing_count': 0,
'current_indicator': '', 'currently_processing': [], 'tasks_in_queue': 0, 'tasks_completed': 0,
'indicators_processed': 0, 'tasks_skipped': 0, 'tasks_rescheduled': 0,
'indicators_completed': 0,
'tasks_re_enqueued': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {},
'task_queue_size': 0,
'currently_processing_count': 0,
'currently_processing': [],
'tasks_in_queue': 0,
'tasks_completed': 0,
'tasks_skipped': 0,
'tasks_rescheduled': 0,
} }
def _initialize_provider_states(self, target: str) -> None: def _initialize_provider_states(self, target: str) -> None:
@ -910,8 +801,6 @@ class Scanner:
if provider.get_eligibility().get(target_key): if provider.get_eligibility().get(target_key):
if not self._already_queried_provider(target, provider.get_name()): if not self._already_queried_provider(target, provider.get_name()):
eligible.append(provider) eligible.append(provider)
else:
print(f"Skipping {provider.get_name()} for {target} - already queried")
return eligible return eligible
@ -923,7 +812,6 @@ class Scanner:
node_data = self.graph.graph.nodes[target] node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {}) provider_states = node_data.get('metadata', {}).get('provider_states', {})
# A provider has been successfully queried if a state exists and its status is 'success'
provider_state = provider_states.get(provider_name) provider_state = provider_states.get(provider_name)
return provider_state is not None and provider_state.get('status') == 'success' return provider_state is not None and provider_state.get('status') == 'success'
@ -947,8 +835,6 @@ class Scanner:
'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
} }
self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)")
def _log_target_processing_error(self, target: str, error: str) -> None: def _log_target_processing_error(self, target: str, error: str) -> None:
"""Log target processing errors for forensic trail.""" """Log target processing errors for forensic trail."""
self.logger.logger.error(f"Target processing failed for {target}: {error}") self.logger.logger.error(f"Target processing failed for {target}: {error}")
@ -957,11 +843,6 @@ class Scanner:
"""Log provider query errors for forensic trail.""" """Log provider query errors for forensic trail."""
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None:
"""Log when no providers are eligible for a target."""
target_type = 'IP' if is_ip else 'domain'
self.logger.logger.warning(f"No eligible providers for {target_type}: {target}")
def _calculate_progress(self) -> float: def _calculate_progress(self) -> float:
"""Calculate scan progress percentage based on task completion.""" """Calculate scan progress percentage based on task completion."""
if self.total_tasks_ever_enqueued == 0: if self.total_tasks_ever_enqueued == 0:
@ -996,13 +877,6 @@ class Scanner:
} }
return export_data return export_data
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""Get statistics for all providers with forensic information."""
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def get_provider_info(self) -> Dict[str, Dict[str, Any]]: def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get information about all available providers.""" """Get information about all available providers."""
info = {} info = {}
@ -1016,11 +890,9 @@ class Scanner:
attribute = getattr(module, attribute_name) attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute provider_class = attribute
# Instantiate to get metadata, even if not fully configured
temp_provider = provider_class(name=attribute_name, session_config=self.config) temp_provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = temp_provider.get_name() provider_name = temp_provider.get_name()
# Find the actual provider instance if it exists, to get live stats
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
info[provider_name] = { info[provider_name] = {
@ -1031,6 +903,5 @@ class Scanner:
'rate_limit': self.config.get_rate_limit(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name),
} }
except Exception as e: except Exception as e:
print(f"✗ Failed to get info for provider from {filename}: {e}")
traceback.print_exc() traceback.print_exc()
return info return info