diff --git a/app.py b/app.py index 894304e..89af52a 100644 --- a/app.py +++ b/app.py @@ -374,17 +374,7 @@ def get_providers(): # Get user-specific scanner user_session_id, scanner = get_user_scanner() - provider_stats = scanner.get_provider_statistics() - - # Add configuration information - provider_info = {} - for provider_name, stats in provider_stats.items(): - provider_info[provider_name] = { - 'statistics': stats, - 'enabled': config.is_provider_enabled(provider_name), - 'rate_limit': config.get_rate_limit(provider_name), - 'requires_api_key': provider_name in ['shodan'] - } + provider_info = scanner.get_provider_info() return jsonify({ 'success': True, @@ -409,7 +399,7 @@ def set_api_keys(): try: data = request.get_json() - if not data: + if data is None: return jsonify({ 'success': False, 'error': 'No API keys provided' @@ -421,16 +411,23 @@ def set_api_keys(): updated_providers = [] - for provider, api_key in data.items(): - if provider in ['shodan'] and api_key.strip(): - success = session_config.set_api_key(provider, api_key.strip()) - if success: - updated_providers.append(provider) + # Iterate over the API keys provided in the request data + for provider_name, api_key in data.items(): + # This allows us to both set and clear keys. The config + # handles enabling/disabling based on if the key is empty. + api_key_value = str(api_key or '').strip() + success = session_config.set_api_key(provider_name.lower(), api_key_value) + + if success: + updated_providers.append(provider_name) if updated_providers: - # Reinitialize scanner providers for this session only + # Reinitialize scanner providers to apply the new keys scanner._initialize_providers() + # Persist the updated scanner object back to the user's session + session_manager.update_session_scanner(user_session_id, scanner) + return jsonify({ 'success': True, 'message': f'API keys updated for session {user_session_id}: {", ".join(updated_providers)}', @@ -440,7 +437,7 @@ def set_api_keys(): else: return jsonify({ 'success': False, - 'error': 'No valid API keys were provided' + 'error': 'No valid API keys were provided or provider names were incorrect.' }), 400 except Exception as e: @@ -450,14 +447,6 @@ def set_api_keys(): 'success': False, 'error': f'Internal server error: {str(e)}' }), 500 - - except Exception as e: - print(f"ERROR: Exception in set_api_keys endpoint: {e}") - traceback.print_exc() - return jsonify({ - 'success': False, - 'error': f'Internal server error: {str(e)}' - }), 500 @app.route('/api/session/info', methods=['GET']) diff --git a/core/scanner.py b/core/scanner.py index 54304c1..e53f0c1 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -3,6 +3,8 @@ import threading import traceback import time +import os +import importlib from typing import List, Set, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future from collections import defaultdict, deque @@ -11,9 +13,7 @@ from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType, RelationshipType from core.logger import get_forensic_logger, new_session from utils.helpers import _is_valid_ip, _is_valid_domain -from providers.crtsh_provider import CrtShProvider -from providers.dns_provider import DNSProvider -from providers.shodan_provider import ShodanProvider +from providers.base_provider import BaseProvider class ScanStatus: @@ -61,13 +61,6 @@ class Scanner: self.max_workers = self.config.max_concurrent_requests self.executor = None - # Provider eligibility mapping - self.provider_eligibility = { - 'dns': {'domains': True, 'ips': True}, - 'crtsh': {'domains': True, 'ips': False}, - 'shodan': {'domains': True, 'ips': True} - } - # Initialize providers with session config print("Calling _initialize_providers with session config...") self._initialize_providers() @@ -163,25 +156,27 @@ class Scanner: self.providers = [] print("Initializing providers with session config...") - # Provider classes mapping - provider_classes = { - 'dns': DNSProvider, - 'crtsh': CrtShProvider, - 'shodan': ShodanProvider - } - - for provider_name, provider_class in provider_classes.items(): - if self.config.is_provider_enabled(provider_name): + provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') + for filename in os.listdir(provider_dir): + if filename.endswith('_provider.py') and not filename.startswith('base'): + module_name = f"providers.{filename[:-3]}" try: - provider = provider_class(session_config=self.config) - if provider.is_available(): - provider.set_stop_event(self.stop_event) - self.providers.append(provider) - print(f"✓ {provider_name.title()} provider initialized successfully for session") - else: - print(f"✗ {provider_name.title()} provider is not available") + module = importlib.import_module(module_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: + provider_class = attribute + provider_name = provider_class(session_config=self.config).get_name() + if self.config.is_provider_enabled(provider_name): + provider = provider_class(session_config=self.config) + if provider.is_available(): + provider.set_stop_event(self.stop_event) + self.providers.append(provider) + print(f"✓ {provider.get_display_name()} provider initialized successfully for session") + else: + print(f"✗ {provider.get_display_name()} provider is not available") except Exception as e: - print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") + print(f"✗ Failed to initialize provider from {filename}: {e}") traceback.print_exc() print(f"Initialized {len(self.providers)} providers for session") @@ -417,13 +412,11 @@ class Scanner: target_key = 'ips' if is_ip else 'domains' for provider in self.providers: - provider_name = provider.get_name() - if provider_name in self.provider_eligibility: - if self.provider_eligibility[provider_name][target_key]: - if not self._already_queried_provider(target, provider_name): - eligible.append(provider) - else: - print(f"Skipping {provider_name} for {target} - already queried") + if provider.get_eligibility().get(target_key): + if not self._already_queried_provider(target, provider.get_name()): + eligible.append(provider) + else: + print(f"Skipping {provider.get_name()} for {target} - already queried") return eligible @@ -740,4 +733,36 @@ class Scanner: stats = {} for provider in self.providers: stats[provider.get_name()] = provider.get_statistics() - return stats \ No newline at end of file + return stats + + def get_provider_info(self) -> Dict[str, Dict[str, Any]]: + """Get information about all available providers.""" + info = {} + provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') + for filename in os.listdir(provider_dir): + if filename.endswith('_provider.py') and not filename.startswith('base'): + module_name = f"providers.{filename[:-3]}" + try: + module = importlib.import_module(module_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: + provider_class = attribute + # Instantiate to get metadata, even if not fully configured + temp_provider = provider_class(session_config=self.config) + provider_name = temp_provider.get_name() + + # Find the actual provider instance if it exists, to get live stats + live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) + + info[provider_name] = { + 'display_name': temp_provider.get_display_name(), + 'requires_api_key': temp_provider.requires_api_key(), + 'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(), + 'enabled': self.config.is_provider_enabled(provider_name), + 'rate_limit': self.config.get_rate_limit(provider_name), + } + except Exception as e: + print(f"✗ Failed to get info for provider from {filename}: {e}") + traceback.print_exc() + return info \ No newline at end of file diff --git a/providers/__init__.py b/providers/__init__.py index b56306c..a31a586 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -16,4 +16,4 @@ __all__ = [ 'ShodanProvider' ] -__version__ = "1.0.0-phase2" \ No newline at end of file +__version__ = "0.0.0-rc" \ No newline at end of file diff --git a/providers/base_provider.py b/providers/base_provider.py index 5bb4ccd..03f497e 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -126,6 +126,21 @@ class BaseProvider(ABC): """Return the provider name.""" pass + @abstractmethod + def get_display_name(self) -> str: + """Return the provider display name for the UI.""" + pass + + @abstractmethod + def requires_api_key(self) -> bool: + """Return True if the provider requires an API key.""" + pass + + @abstractmethod + def get_eligibility(self) -> Dict[str, bool]: + """Return a dictionary indicating if the provider can query domains and/or IPs.""" + pass + @abstractmethod def is_available(self) -> bool: """Check if the provider is available and properly configured.""" diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index 4b01864..b20bbff 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -36,6 +36,18 @@ class CrtShProvider(BaseProvider): """Return the provider name.""" return "crtsh" + def get_display_name(self) -> str: + """Return the provider display name for the UI.""" + return "crt.sh" + + def requires_api_key(self) -> bool: + """Return True if the provider requires an API key.""" + return False + + def get_eligibility(self) -> Dict[str, bool]: + """Return a dictionary indicating if the provider can query domains and/or IPs.""" + return {'domains': True, 'ips': False} + def is_available(self) -> bool: """ Check if the provider is configured to be used. diff --git a/providers/dns_provider.py b/providers/dns_provider.py index 11cb578..d3236cc 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -33,6 +33,18 @@ class DNSProvider(BaseProvider): """Return the provider name.""" return "dns" + def get_display_name(self) -> str: + """Return the provider display name for the UI.""" + return "DNS" + + def requires_api_key(self) -> bool: + """Return True if the provider requires an API key.""" + return False + + def get_eligibility(self) -> Dict[str, bool]: + """Return a dictionary indicating if the provider can query domains and/or IPs.""" + return {'domains': True, 'ips': True} + def is_available(self) -> bool: """DNS is always available - no API key required.""" return True diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 4bc9d4a..306fc0d 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -15,7 +15,7 @@ class ShodanProvider(BaseProvider): Provider for querying Shodan API for IP address and hostname information. Now uses session-specific API keys. """ - + def __init__(self, session_config=None): """Initialize Shodan provider with session-specific configuration.""" super().__init__( @@ -26,32 +26,43 @@ class ShodanProvider(BaseProvider): ) self.base_url = "https://api.shodan.io" self.api_key = self.config.get_api_key('shodan') - + def is_available(self) -> bool: """Check if Shodan provider is available (has valid API key in this session).""" return self.api_key is not None and len(self.api_key.strip()) > 0 - + def get_name(self) -> str: """Return the provider name.""" return "shodan" - + def get_display_name(self) -> str: + """Return the provider display name for the UI.""" + return "shodan" + + def requires_api_key(self) -> bool: + """Return True if the provider requires an API key.""" + return True + + def get_eligibility(self) -> Dict[str, bool]: + """Return a dictionary indicating if the provider can query domains and/or IPs.""" + return {'domains': True, 'ips': True} + def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query Shodan for information about a domain. Uses Shodan's hostname search to find associated IPs. - + Args: domain: Domain to investigate - + Returns: List of relationships discovered from Shodan data """ if not _is_valid_domain(domain) or not self.is_available(): return [] - + relationships = [] - + try: # Search for hostname in Shodan search_query = f"hostname:{domain}" @@ -61,22 +72,22 @@ class ShodanProvider(BaseProvider): 'query': search_query, 'minify': True # Get minimal data to reduce bandwidth } - + response = self.make_request(url, method="GET", params=params, target_indicator=domain) - + if not response or response.status_code != 200: return [] - + data = response.json() - + if 'matches' not in data: return [] - + # Process search results for match in data['matches']: ip_address = match.get('ip_str') hostnames = match.get('hostnames', []) - + if ip_address and domain in hostnames: raw_data = { 'ip_address': ip_address, @@ -88,7 +99,7 @@ class ShodanProvider(BaseProvider): 'ports': match.get('ports', []), 'last_update': match.get('last_update', '') } - + relationships.append(( domain, ip_address, @@ -96,7 +107,7 @@ class ShodanProvider(BaseProvider): RelationshipType.A_RECORD.default_confidence, raw_data )) - + self.log_relationship_discovery( source_node=domain, target_node=ip_address, @@ -105,7 +116,7 @@ class ShodanProvider(BaseProvider): raw_data=raw_data, discovery_method="shodan_hostname_search" ) - + # Also create relationships to other hostnames on the same IP for hostname in hostnames: if hostname != domain and _is_valid_domain(hostname): @@ -114,7 +125,7 @@ class ShodanProvider(BaseProvider): 'all_hostnames': hostnames, 'discovery_context': 'shared_hosting' } - + relationships.append(( domain, hostname, @@ -122,7 +133,7 @@ class ShodanProvider(BaseProvider): 0.6, # Lower confidence for shared hosting hostname_raw_data )) - + self.log_relationship_discovery( source_node=domain, target_node=hostname, @@ -131,39 +142,39 @@ class ShodanProvider(BaseProvider): raw_data=hostname_raw_data, discovery_method="shodan_shared_hosting" ) - + except json.JSONDecodeError as e: self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") - + return relationships - + def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query Shodan for information about an IP address. - + Args: ip: IP address to investigate - + Returns: List of relationships discovered from Shodan IP data """ if not _is_valid_ip(ip) or not self.is_available(): return [] - + relationships = [] - + try: # Query Shodan host information url = f"{self.base_url}/shodan/host/{ip}" params = {'key': self.api_key} - + response = self.make_request(url, method="GET", params=params, target_indicator=ip) - + if not response or response.status_code != 200: return [] - + data = response.json() - + # Extract hostname relationships hostnames = data.get('hostnames', []) for hostname in hostnames: @@ -180,7 +191,7 @@ class ShodanProvider(BaseProvider): 'last_update': data.get('last_update', ''), 'os': data.get('os', '') } - + relationships.append(( ip, hostname, @@ -188,7 +199,7 @@ class ShodanProvider(BaseProvider): RelationshipType.A_RECORD.default_confidence, raw_data )) - + self.log_relationship_discovery( source_node=ip, target_node=hostname, @@ -197,19 +208,25 @@ class ShodanProvider(BaseProvider): raw_data=raw_data, discovery_method="shodan_host_lookup" ) - + # Extract ASN relationship if available asn = data.get('asn') if asn: - asn_name = f"AS{asn}" - + # Ensure the ASN starts with "AS" + if isinstance(asn, str) and asn.startswith('AS'): + asn_name = asn + asn_number = asn[2:] + else: + asn_name = f"AS{asn}" + asn_number = str(asn) + asn_raw_data = { 'ip_address': ip, - 'asn': asn, + 'asn': asn_number, 'isp': data.get('isp', ''), 'org': data.get('org', '') } - + relationships.append(( ip, asn_name, @@ -217,7 +234,7 @@ class ShodanProvider(BaseProvider): RelationshipType.ASN_MEMBERSHIP.default_confidence, asn_raw_data )) - + self.log_relationship_discovery( source_node=ip, target_node=asn_name, @@ -226,25 +243,25 @@ class ShodanProvider(BaseProvider): raw_data=asn_raw_data, discovery_method="shodan_asn_lookup" ) - + except json.JSONDecodeError as e: self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") return relationships - + def search_by_organization(self, org_name: str) -> List[Dict[str, Any]]: """ Search Shodan for hosts belonging to a specific organization. - + Args: org_name: Organization name to search for - + Returns: List of host information dictionaries """ if not self.is_available(): return [] - + try: search_query = f"org:\"{org_name}\"" url = f"{self.base_url}/shodan/host/search" @@ -253,42 +270,42 @@ class ShodanProvider(BaseProvider): 'query': search_query, 'minify': True } - + response = self.make_request(url, method="GET", params=params, target_indicator=org_name) - + if response and response.status_code == 200: data = response.json() return data.get('matches', []) - + except Exception as e: self.logger.logger.error(f"Error searching Shodan by organization {org_name}: {e}") - + return [] - + def get_host_services(self, ip: str) -> List[Dict[str, Any]]: """ Get service information for a specific IP address. - + Args: ip: IP address to query - + Returns: List of service information dictionaries """ if not _is_valid_ip(ip) or not self.is_available(): return [] - + try: url = f"{self.base_url}/shodan/host/{ip}" params = {'key': self.api_key} - + response = self.make_request(url, method="GET", params=params, target_indicator=ip) - + if response and response.status_code == 200: data = response.json() return data.get('data', []) # Service banners - + except Exception as e: self.logger.logger.error(f"Error getting Shodan services for IP {ip}: {e}") - + return [] \ No newline at end of file diff --git a/static/js/graph.js b/static/js/graph.js index 2afa598..6ec5e2d 100644 --- a/static/js/graph.js +++ b/static/js/graph.js @@ -13,7 +13,6 @@ class GraphManager { this.currentLayout = 'physics'; this.nodeInfoPopup = null; - // Enhanced graph options for Phase 2 this.options = { nodes: { shape: 'dot', @@ -214,20 +213,7 @@ class GraphManager { } }); - this.network.on('blurNode', (params) => { - this.hideNodeInfoPopup(); - this.clearHoverHighlights(); - }); - - // Double-click to focus on node - this.network.on('doubleClick', (params) => { - if (params.nodes.length > 0) { - const nodeId = params.nodes[0]; - this.focusOnNode(nodeId); - } - }); - - // Context menu (right-click) + // TODO Context menu (right-click) this.network.on('oncontext', (params) => { params.event.preventDefault(); if (params.nodes.length > 0) { diff --git a/static/js/main.js b/static/js/main.js index 96b109b..b7391c2 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -12,10 +12,8 @@ class DNSReconApp { this.pollInterval = null; this.currentSessionId = null; - // UI Elements this.elements = {}; - // Application state this.isScanning = false; this.lastGraphUpdate = null; @@ -80,7 +78,7 @@ class DNSReconApp { // API Key Modal elements apiKeyModal: document.getElementById('api-key-modal'), apiKeyModalClose: document.getElementById('api-key-modal-close'), - shodanApiKey: document.getElementById('shodan-api-key'), + apiKeyInputs: document.getElementById('api-key-inputs'), saveApiKeys: document.getElementById('save-api-keys'), resetApiKeys: document.getElementById('reset-api-keys'), @@ -732,6 +730,7 @@ class DNSReconApp { if (response.success) { this.updateProviderDisplay(response.providers); + this.buildApiKeyModal(response.providers); console.log('Providers loaded successfully'); } @@ -766,7 +765,7 @@ class DNSReconApp { providerItem.innerHTML = `
Provides infrastructure context and service information.
+ `; + } else { + // If the API key is not set + inputGroup.innerHTML = ` + + +Provides infrastructure context and service information.
+ `; + } + this.elements.apiKeyInputs.appendChild(inputGroup); + } + } + + // Add event listeners for the new clear buttons + this.elements.apiKeyInputs.querySelectorAll('.clear-api-key-btn').forEach(button => { + button.addEventListener('click', (e) => { + const provider = e.target.dataset.provider; + this.clearApiKey(provider); + }); + }); + + if (!hasApiKeyProviders) { + this.elements.apiKeyInputs.innerHTML = 'No providers require API keys.
'; + } + } + + /** + * Clear an API key for a specific provider + * @param {string} provider The name of the provider to clear the API key for + */ + async clearApiKey(provider) { + try { + const response = await this.apiCall('/api/config/api-keys', 'POST', { [provider]: '' }); + if (response.success) { + this.showSuccess(`API key for ${provider} has been cleared.`); + this.loadProviders(); // This will rebuild the modal with the updated state + } else { + throw new Error(response.error || 'Failed to clear API key'); + } + } catch (error) { + this.showError(`Error clearing API key: ${error.message}`); + } + } } // Add CSS animations for message toasts diff --git a/templates/index.html b/templates/index.html index 1e9bdda..7a189a3 100644 --- a/templates/index.html +++ b/templates/index.html @@ -186,7 +186,7 @@