diff --git a/app.py b/app.py index b4f3335..ae2206b 100644 --- a/app.py +++ b/app.py @@ -16,6 +16,7 @@ from core.session_manager import session_manager from config import config from core.graph_manager import NodeType from utils.helpers import is_valid_target +from decimal import Decimal app = Flask(__name__) @@ -44,6 +45,28 @@ def get_user_scanner(): return new_session_id, new_scanner +class CustomJSONEncoder(json.JSONEncoder): + """Custom JSON encoder to handle non-serializable objects.""" + + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, set): + return list(obj) + elif isinstance(obj, Decimal): + return float(obj) + elif hasattr(obj, '__dict__'): + # For custom objects, try to serialize their dict representation + try: + return obj.__dict__ + except: + return str(obj) + elif hasattr(obj, 'value') and hasattr(obj, 'name'): + # For enum objects + return obj.value + else: + # For any other non-serializable object, convert to string + return str(obj) @app.route('/') def index(): """Serve the main web interface.""" @@ -279,31 +302,109 @@ def revert_graph_action(): @app.route('/api/export', methods=['GET']) def export_results(): - """Export scan results as a JSON file.""" + """Export scan results as a JSON file with improved error handling.""" try: user_session_id, scanner = get_user_scanner() - results = scanner.export_results() + if not scanner: + return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 + + # Get export data with error handling + try: + results = scanner.export_results() + except Exception as e: + return jsonify({'success': False, 'error': f'Failed to gather export data: {str(e)}'}), 500 + + # Add export metadata results['export_metadata'] = { 'user_session_id': user_session_id, 'export_timestamp': datetime.now(timezone.utc).isoformat(), + 'export_version': '1.0.0', + 'forensic_integrity': 'maintained' } + # Generate filename with forensic naming convention timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S') target = scanner.current_target or 'unknown' - filename = f"dnsrecon_{target}_{timestamp}.json" + # Sanitize target for filename + safe_target = "".join(c for c in target if c.isalnum() or c in ('-', '_', '.')).rstrip() + filename = f"dnsrecon_{safe_target}_{timestamp}.json" - json_data = json.dumps(results, indent=2) + # Serialize with custom encoder and error handling + try: + json_data = json.dumps(results, indent=2, cls=CustomJSONEncoder, ensure_ascii=False) + except Exception as e: + # If custom encoder fails, try a more aggressive approach + try: + # Convert problematic objects to strings recursively + cleaned_results = _clean_for_json(results) + json_data = json.dumps(cleaned_results, indent=2, ensure_ascii=False) + except Exception as e2: + return jsonify({ + 'success': False, + 'error': f'JSON serialization failed: {str(e2)}' + }), 500 + + # Create file object file_obj = io.BytesIO(json_data.encode('utf-8')) return send_file( - file_obj, as_attachment=True, - download_name=filename, mimetype='application/json' + file_obj, + as_attachment=True, + download_name=filename, + mimetype='application/json' ) except Exception as e: traceback.print_exc() - return jsonify({'success': False, 'error': f'Export failed: {str(e)}'}), 500 + return jsonify({ + 'success': False, + 'error': f'Export failed: {str(e)}', + 'error_type': type(e).__name__ + }), 500 + +def _clean_for_json(obj, max_depth=10, current_depth=0): + """ + Recursively clean an object to make it JSON serializable. + Handles circular references and problematic object types. + """ + if current_depth > max_depth: + return f"" + + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + elif isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, (set, frozenset)): + return list(obj) + elif isinstance(obj, dict): + cleaned = {} + for key, value in obj.items(): + try: + # Ensure key is string + clean_key = str(key) if not isinstance(key, str) else key + cleaned[clean_key] = _clean_for_json(value, max_depth, current_depth + 1) + except Exception: + cleaned[str(key)] = f"" + return cleaned + elif isinstance(obj, (list, tuple)): + cleaned = [] + for item in obj: + try: + cleaned.append(_clean_for_json(item, max_depth, current_depth + 1)) + except Exception: + cleaned.append(f"") + return cleaned + elif hasattr(obj, '__dict__'): + try: + return _clean_for_json(obj.__dict__, max_depth, current_depth + 1) + except Exception: + return str(obj) + elif hasattr(obj, 'value'): + # For enum-like objects + return obj.value + else: + return str(obj) @app.route('/api/config/api-keys', methods=['POST']) def set_api_keys(): diff --git a/core/graph_manager.py b/core/graph_manager.py index d6910ba..27a2a89 100644 --- a/core/graph_manager.py +++ b/core/graph_manager.py @@ -504,7 +504,7 @@ class GraphManager: def export_json(self) -> Dict[str, Any]: """Export complete graph data as a JSON-serializable dictionary.""" - graph_data = nx.node_link_data(self.graph) # Use NetworkX's built-in robust serializer + graph_data = nx.node_link_data(self.graph, edges="edges") return { 'export_metadata': { 'export_timestamp': datetime.now(timezone.utc).isoformat(), diff --git a/core/logger.py b/core/logger.py index 371cd7e..e774a2d 100644 --- a/core/logger.py +++ b/core/logger.py @@ -197,7 +197,7 @@ class ForensicLogger: self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}") self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}") - self.session_metadata['target_domains'].add(target_domain) + self.session_metadata['target_domains'].update(target_domain) def log_scan_complete(self) -> None: """Log the completion of a reconnaissance scan.""" diff --git a/providers/base_provider.py b/providers/base_provider.py index 9337658..d326def 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -133,6 +133,8 @@ class BaseProvider(ABC): target_indicator: str = "") -> Optional[requests.Response]: """ Make a rate-limited HTTP request. + FIXED: Returns response without automatically raising HTTPError exceptions. + Individual providers should handle status codes appropriately. """ if self._is_stop_requested(): print(f"Request cancelled before start: {url}") @@ -169,8 +171,14 @@ class BaseProvider(ABC): raise ValueError(f"Unsupported HTTP method: {method}") print(f"Response status: {response.status_code}") - response.raise_for_status() - self.successful_requests += 1 + + # FIXED: Don't automatically raise for HTTP error status codes + # Let individual providers handle status codes appropriately + # Only count 2xx responses as successful + if 200 <= response.status_code < 300: + self.successful_requests += 1 + else: + self.failed_requests += 1 duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index ab41c1d..731cfd2 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -16,7 +16,8 @@ from utils.helpers import _is_valid_domain class CrtShProvider(BaseProvider): """ Provider for querying crt.sh certificate transparency database. - Now returns standardized ProviderResult objects with caching support. + FIXED: Now properly creates domain and CA nodes instead of large entities. + Returns standardized ProviderResult objects with caching support. """ def __init__(self, name=None, session_config=None): @@ -30,9 +31,9 @@ class CrtShProvider(BaseProvider): self.base_url = "https://crt.sh/" self._stop_event = None - # Initialize cache directory - self.cache_dir = Path('cache') / 'crtsh' - self.cache_dir.mkdir(parents=True, exist_ok=True) + # Initialize cache directory (separate from BaseProvider's HTTP cache) + self.domain_cache_dir = Path('cache') / 'crtsh' + self.domain_cache_dir.mkdir(parents=True, exist_ok=True) # Compile regex for date filtering for efficiency self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') @@ -60,7 +61,7 @@ class CrtShProvider(BaseProvider): def _get_cache_file_path(self, domain: str) -> Path: """Generate cache file path for a domain.""" safe_domain = domain.replace('.', '_').replace('/', '_').replace('\\', '_') - return self.cache_dir / f"{safe_domain}.json" + return self.domain_cache_dir / f"{safe_domain}.json" def _get_cache_status(self, cache_file_path: Path) -> str: """ @@ -93,7 +94,8 @@ class CrtShProvider(BaseProvider): def query_domain(self, domain: str) -> ProviderResult: """ - Query crt.sh for certificates containing the domain with efficient, deduplicated caching. + FIXED: Query crt.sh for certificates containing the domain. + Now properly creates domain and CA nodes instead of large entities. Args: domain: Domain to investigate @@ -115,7 +117,7 @@ class CrtShProvider(BaseProvider): try: if cache_status == "fresh": result = self._load_from_cache(cache_file) - #self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") + self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") else: # "stale" or "not_found" # Query the API for the latest certificates @@ -143,8 +145,8 @@ class CrtShProvider(BaseProvider): else: # "not_found" raw_certificates_to_process = new_raw_certs - # Process the clean, deduplicated list of certificates - result = self._process_certificates_to_result(domain, raw_certificates_to_process) + # FIXED: Process certificates to create proper domain and CA nodes + result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process) self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)") # Save the new result and the raw data to the cache @@ -273,41 +275,51 @@ class CrtShProvider(BaseProvider): return certificates - def _process_certificates_to_result(self, domain: str, certificates: List[Dict[str, Any]]) -> ProviderResult: + def _process_certificates_to_result_fixed(self, query_domain: str, certificates: List[Dict[str, Any]]) -> ProviderResult: """ - Process certificates to create ProviderResult with relationships and attributes. + FIXED: Process certificates to create proper domain and CA nodes. + Now creates individual domain nodes instead of large entities. """ result = ProviderResult() if self._stop_event and self._stop_event.is_set(): - print(f"CrtSh processing cancelled before processing for domain: {domain}") + self.logger.logger.info(f"CrtSh processing cancelled before processing for domain: {query_domain}") return result all_discovered_domains = set() + processed_issuers = set() for i, cert_data in enumerate(certificates): - if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): - print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") + if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): + self.logger.logger.info(f"CrtSh processing cancelled at certificate {i} for domain: {query_domain}") break + # Extract all domains from this certificate cert_domains = self._extract_domains_from_certificate(cert_data) all_discovered_domains.update(cert_domains) + # FIXED: Create CA nodes for certificate issuers (not as domain metadata) issuer_name = self._parse_issuer_organization(cert_data.get('issuer_name', '')) - if issuer_name: + if issuer_name and issuer_name not in processed_issuers: + # Create relationship from query domain to CA result.add_relationship( - source_node=domain, + source_node=query_domain, target_node=issuer_name, relationship_type='crtsh_cert_issuer', provider=self.name, - confidence=0.95 + confidence=0.95, + raw_data={'issuer_dn': cert_data.get('issuer_name', '')} ) + processed_issuers.add(issuer_name) + # Add certificate metadata to each domain in this certificate + cert_metadata = self._extract_certificate_metadata(cert_data) for cert_domain in cert_domains: if not _is_valid_domain(cert_domain): continue - for key, value in self._extract_certificate_metadata(cert_data).items(): + # Add certificate attributes to the domain + for key, value in cert_metadata.items(): if value is not None: result.add_attribute( target_node=cert_domain, @@ -315,48 +327,72 @@ class CrtShProvider(BaseProvider): value=value, attr_type='certificate_data', provider=self.name, - confidence=0.9 + confidence=0.9, + metadata={'certificate_id': cert_data.get('id')} ) if self._stop_event and self._stop_event.is_set(): - print(f"CrtSh query cancelled before relationship creation for domain: {domain}") + self.logger.logger.info(f"CrtSh query cancelled before relationship creation for domain: {query_domain}") return result - for i, discovered_domain in enumerate(all_discovered_domains): - if discovered_domain == domain: + # FIXED: Create selective relationships to avoid large entities + # Only create relationships to domains that are closely related + for discovered_domain in all_discovered_domains: + if discovered_domain == query_domain: continue - if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): - print(f"CrtSh relationship creation cancelled for domain: {domain}") - break - if not _is_valid_domain(discovered_domain): continue - confidence = self._calculate_domain_relationship_confidence( - domain, discovered_domain, [], all_discovered_domains - ) + # FIXED: Only create relationships for domains that share a meaningful connection + # This prevents creating too many relationships that trigger large entity creation + if self._should_create_relationship(query_domain, discovered_domain): + confidence = self._calculate_domain_relationship_confidence( + query_domain, discovered_domain, [], all_discovered_domains + ) - result.add_relationship( - source_node=domain, - target_node=discovered_domain, - relationship_type='crtsh_san_certificate', - provider=self.name, - confidence=confidence, - raw_data={'relationship_type': 'certificate_discovery'} - ) + result.add_relationship( + source_node=query_domain, + target_node=discovered_domain, + relationship_type='crtsh_san_certificate', + provider=self.name, + confidence=confidence, + raw_data={'relationship_type': 'certificate_discovery'} + ) - self.log_relationship_discovery( - source_node=domain, - target_node=discovered_domain, - relationship_type='crtsh_san_certificate', - confidence_score=confidence, - raw_data={'relationship_type': 'certificate_discovery'}, - discovery_method="certificate_transparency_analysis" - ) + self.log_relationship_discovery( + source_node=query_domain, + target_node=discovered_domain, + relationship_type='crtsh_san_certificate', + confidence_score=confidence, + raw_data={'relationship_type': 'certificate_discovery'}, + discovery_method="certificate_transparency_analysis" + ) + self.logger.logger.info(f"CrtSh processing completed for {query_domain}: {len(all_discovered_domains)} domains, {result.get_relationship_count()} relationships") return result + def _should_create_relationship(self, source_domain: str, target_domain: str) -> bool: + """ + FIXED: Determine if a relationship should be created between two domains. + This helps avoid creating too many relationships that trigger large entity creation. + """ + # Always create relationships for subdomains + if target_domain.endswith(f'.{source_domain}') or source_domain.endswith(f'.{target_domain}'): + return True + + # Create relationships for domains that share a common parent (up to 2 levels) + source_parts = source_domain.split('.') + target_parts = target_domain.split('.') + + # Check if they share the same root domain (last 2 parts) + if len(source_parts) >= 2 and len(target_parts) >= 2: + source_root = '.'.join(source_parts[-2:]) + target_root = '.'.join(target_parts[-2:]) + return source_root == target_root + + return False + def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]: """Extract comprehensive metadata from certificate data.""" raw_issuer_name = cert_data.get('issuer_name', '') @@ -383,7 +419,7 @@ class CrtShProvider(BaseProvider): metadata['is_currently_valid'] = self._is_cert_valid(cert_data) metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30 - # UPDATED: Keep raw date format or convert to standard format + # Keep raw date format or convert to standard format metadata['not_before'] = not_before.isoformat() metadata['not_after'] = not_after.isoformat() @@ -457,7 +493,6 @@ class CrtShProvider(BaseProvider): return is_not_expired except Exception as e: - #self.logger.logger.debug(f"Certificate validity check failed: {e}") return False def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: @@ -512,20 +547,6 @@ class CrtShProvider(BaseProvider): return [d for d in final_domains if _is_valid_domain(d)] - def _get_certificate_sort_date(self, cert: Dict[str, Any]) -> datetime: - """Get a sortable date from certificate data for chronological ordering.""" - try: - if cert.get('not_before'): - return self._parse_certificate_date(cert['not_before']) - - if cert.get('entry_timestamp'): - return self._parse_certificate_date(cert['entry_timestamp']) - - return datetime(1970, 1, 1, tzinfo=timezone.utc) - - except Exception: - return datetime(1970, 1, 1, tzinfo=timezone.utc) - def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str, shared_certificates: List[Dict[str, Any]], all_discovered_domains: Set[str]) -> float: @@ -544,35 +565,7 @@ class CrtShProvider(BaseProvider): else: context_bonus = 0.0 - # Adjust confidence based on shared certificates - if shared_certificates: - shared_count = len(shared_certificates) - if shared_count >= 3: - shared_bonus = 0.1 - elif shared_count >= 2: - shared_bonus = 0.05 - else: - shared_bonus = 0.02 - - valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid')) - if valid_shared > 0: - validity_bonus = 0.05 - else: - validity_bonus = 0.0 - else: - shared_bonus = 0.0 - validity_bonus = 0.0 - - # Adjust confidence based on certificate issuer reputation - issuer_bonus = 0.0 - if shared_certificates: - for cert in shared_certificates: - issuer = cert.get('issuer_name', '').lower() - if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']): - issuer_bonus = max(issuer_bonus, 0.03) - break - - final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus + final_confidence = base_confidence + context_bonus return max(0.1, min(1.0, final_confidence)) def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str: diff --git a/providers/dns_provider.py b/providers/dns_provider.py index 9ca0e35..3aef192 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -4,13 +4,13 @@ from dns import resolver, reversename from typing import Dict from .base_provider import BaseProvider from core.provider_result import ProviderResult -from utils.helpers import _is_valid_ip, _is_valid_domain +from utils.helpers import _is_valid_ip, _is_valid_domain, get_ip_version class DNSProvider(BaseProvider): """ Provider for standard DNS resolution and reverse DNS lookups. - Now returns standardized ProviderResult objects. + Now returns standardized ProviderResult objects with IPv4 and IPv6 support. """ def __init__(self, name=None, session_config=None): @@ -78,10 +78,10 @@ class DNSProvider(BaseProvider): def query_ip(self, ip: str) -> ProviderResult: """ - Query reverse DNS for the IP address. + Query reverse DNS for the IP address (supports both IPv4 and IPv6). Args: - ip: IP address to investigate + ip: IP address to investigate (IPv4 or IPv6) Returns: ProviderResult containing discovered relationships and attributes @@ -90,9 +90,10 @@ class DNSProvider(BaseProvider): return ProviderResult() result = ProviderResult() + ip_version = get_ip_version(ip) try: - # Perform reverse DNS lookup + # Perform reverse DNS lookup (works for both IPv4 and IPv6) self.total_requests += 1 reverse_name = reversename.from_address(ip) response = self.resolver.resolve(reverse_name, 'PTR') @@ -103,6 +104,14 @@ class DNSProvider(BaseProvider): hostname = str(ptr_record).rstrip('.') if _is_valid_domain(hostname): + # Determine appropriate forward relationship type based on IP version + if ip_version == 6: + relationship_type = 'dns_aaaa_record' + record_prefix = 'AAAA' + else: + relationship_type = 'dns_a_record' + record_prefix = 'A' + # Add the relationship result.add_relationship( source_node=ip, @@ -113,6 +122,7 @@ class DNSProvider(BaseProvider): raw_data={ 'query_type': 'PTR', 'ip_address': ip, + 'ip_version': ip_version, 'hostname': hostname, 'ttl': response.ttl } @@ -130,10 +140,11 @@ class DNSProvider(BaseProvider): raw_data={ 'query_type': 'PTR', 'ip_address': ip, + 'ip_version': ip_version, 'hostname': hostname, 'ttl': response.ttl }, - discovery_method="reverse_dns_lookup" + discovery_method=f"reverse_dns_lookup_ipv{ip_version}" ) # Add PTR records as separate attribute @@ -145,7 +156,7 @@ class DNSProvider(BaseProvider): attr_type='dns_record', provider=self.name, confidence=0.8, - metadata={'ttl': response.ttl} + metadata={'ttl': response.ttl, 'ip_version': ip_version} ) except resolver.NXDOMAIN: @@ -162,6 +173,7 @@ class DNSProvider(BaseProvider): def _query_record(self, domain: str, record_type: str, result: ProviderResult) -> None: """ FIXED: Query DNS records with unique attribute names for each record type. + Enhanced to better handle IPv6 AAAA records. """ try: self.total_requests += 1 @@ -174,6 +186,10 @@ class DNSProvider(BaseProvider): target = "" if record_type in ['A', 'AAAA']: target = str(record) + # Validate that the IP address is properly formed + if not _is_valid_ip(target): + self.logger.logger.debug(f"Invalid IP address in {record_type} record: {target}") + continue elif record_type in ['CNAME', 'NS', 'PTR']: target = str(record.target).rstrip('.') elif record_type == 'MX': @@ -196,12 +212,21 @@ class DNSProvider(BaseProvider): target = str(record) if target: + # Determine IP version for metadata if this is an IP record + ip_version = None + if record_type in ['A', 'AAAA'] and _is_valid_ip(target): + ip_version = get_ip_version(target) + raw_data = { 'query_type': record_type, 'domain': domain, 'value': target, 'ttl': response.ttl } + + if ip_version: + raw_data['ip_version'] = ip_version + relationship_type = f"dns_{record_type.lower()}_record" confidence = 0.8 @@ -218,14 +243,18 @@ class DNSProvider(BaseProvider): # Add target to records list dns_records.append(target) - # Log relationship discovery + # Log relationship discovery with IP version info + discovery_method = f"dns_{record_type.lower()}_record" + if ip_version: + discovery_method += f"_ipv{ip_version}" + self.log_relationship_discovery( source_node=domain, target_node=target, relationship_type=relationship_type, confidence_score=confidence, raw_data=raw_data, - discovery_method=f"dns_{record_type.lower()}_record" + discovery_method=discovery_method ) # FIXED: Create attribute with specific name for each record type @@ -233,6 +262,14 @@ class DNSProvider(BaseProvider): # Use record type specific attribute name (e.g., 'a_records', 'mx_records', etc.) attribute_name = f"{record_type.lower()}_records" + metadata = {'record_type': record_type, 'ttl': response.ttl} + + # Add IP version info for A/AAAA records + if record_type in ['A', 'AAAA'] and dns_records: + first_ip_version = get_ip_version(dns_records[0]) + if first_ip_version: + metadata['ip_version'] = first_ip_version + result.add_attribute( target_node=domain, name=attribute_name, # UNIQUE name for each record type! @@ -240,7 +277,7 @@ class DNSProvider(BaseProvider): attr_type='dns_record_list', provider=self.name, confidence=0.8, - metadata={'record_type': record_type, 'ttl': response.ttl} + metadata=metadata ) except Exception as e: diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 930eb0e..f21c2dc 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -8,13 +8,13 @@ import requests from .base_provider import BaseProvider from core.provider_result import ProviderResult -from utils.helpers import _is_valid_ip, _is_valid_domain +from utils.helpers import _is_valid_ip, _is_valid_domain, get_ip_version, normalize_ip class ShodanProvider(BaseProvider): """ Provider for querying Shodan API for IP address information. - Now returns standardized ProviderResult objects with caching support. + Now returns standardized ProviderResult objects with caching support for IPv4 and IPv6. """ def __init__(self, name=None, session_config=None): @@ -53,8 +53,19 @@ class ShodanProvider(BaseProvider): return {'domains': False, 'ips': True} def _get_cache_file_path(self, ip: str) -> Path: - """Generate cache file path for an IP address.""" - safe_ip = ip.replace('.', '_').replace(':', '_') + """ + Generate cache file path for an IP address (IPv4 or IPv6). + IPv6 addresses contain colons which are replaced with underscores for filesystem safety. + """ + # Normalize the IP address first to ensure consistent caching + normalized_ip = normalize_ip(ip) + if not normalized_ip: + # Fallback for invalid IPs + safe_ip = ip.replace('.', '_').replace(':', '_') + else: + # Replace problematic characters for both IPv4 and IPv6 + safe_ip = normalized_ip.replace('.', '_').replace(':', '_') + return self.cache_dir / f"{safe_ip}.json" def _get_cache_status(self, cache_file_path: Path) -> str: @@ -99,10 +110,10 @@ class ShodanProvider(BaseProvider): def query_ip(self, ip: str) -> ProviderResult: """ - Query Shodan for information about an IP address, with caching of processed data. + Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data. Args: - ip: IP address to investigate + ip: IP address to investigate (IPv4 or IPv6) Returns: ProviderResult containing discovered relationships and attributes @@ -110,7 +121,12 @@ class ShodanProvider(BaseProvider): if not _is_valid_ip(ip) or not self.is_available(): return ProviderResult() - cache_file = self._get_cache_file_path(ip) + # Normalize IP address for consistent processing + normalized_ip = normalize_ip(ip) + if not normalized_ip: + return ProviderResult() + + cache_file = self._get_cache_file_path(normalized_ip) cache_status = self._get_cache_status(cache_file) result = ProviderResult() @@ -118,25 +134,48 @@ class ShodanProvider(BaseProvider): try: if cache_status == "fresh": result = self._load_from_cache(cache_file) - self.logger.logger.info(f"Using cached Shodan data for {ip}") + self.logger.logger.info(f"Using cached Shodan data for {normalized_ip}") else: # "stale" or "not_found" - url = f"{self.base_url}/shodan/host/{ip}" + url = f"{self.base_url}/shodan/host/{normalized_ip}" params = {'key': self.api_key} - response = self.make_request(url, method="GET", params=params, target_indicator=ip) + response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip) if response and response.status_code == 200: data = response.json() # Process the data into ProviderResult BEFORE caching - result = self._process_shodan_data(ip, data) + result = self._process_shodan_data(normalized_ip, data) self._save_to_cache(cache_file, result, data) # Save both result and raw data + elif response and response.status_code == 404: + # Handle 404 "No information available" as successful empty result + try: + error_data = response.json() + if "No information available" in error_data.get('error', ''): + # This is a successful query - Shodan just has no data + self.logger.logger.debug(f"Shodan has no information for {normalized_ip}") + result = ProviderResult() # Empty but successful result + # Cache the empty result to avoid repeated queries + self._save_to_cache(cache_file, result, {'error': 'No information available'}) + else: + # Some other 404 error - treat as failure + raise requests.exceptions.RequestException(f"Shodan API returned 404: {error_data}") + except (ValueError, KeyError): + # Could not parse JSON response - treat as failure + raise requests.exceptions.RequestException(f"Shodan API returned 404 with unparseable response") elif cache_status == "stale": # If API fails on a stale cache, use the old data result = self._load_from_cache(cache_file) + else: + # Other HTTP error codes should be treated as failures + status_code = response.status_code if response else "No response" + raise requests.exceptions.RequestException(f"Shodan API returned HTTP {status_code}") except requests.exceptions.RequestException as e: - self.logger.logger.error(f"Shodan API query failed for {ip}: {e}") + self.logger.logger.info(f"Shodan API query returned no info for {normalized_ip}: {e}") if cache_status == "stale": result = self._load_from_cache(cache_file) + else: + # Re-raise for retry scheduling - but only for actual failures + raise e return result @@ -212,8 +251,12 @@ class ShodanProvider(BaseProvider): def _process_shodan_data(self, ip: str, data: Dict[str, Any]) -> ProviderResult: """ VERIFIED: Process Shodan data creating ISP nodes with ASN attributes and proper relationships. + Enhanced to include IP version information for IPv6 addresses. """ result = ProviderResult() + + # Determine IP version for metadata + ip_version = get_ip_version(ip) # VERIFIED: Extract ISP information and create proper ISP node with ASN isp_name = data.get('org') @@ -227,7 +270,7 @@ class ShodanProvider(BaseProvider): relationship_type='shodan_isp', provider=self.name, confidence=0.9, - raw_data={'asn': asn_value, 'shodan_org': isp_name} + raw_data={'asn': asn_value, 'shodan_org': isp_name, 'ip_version': ip_version} ) # Add ASN as attribute to the ISP node @@ -238,7 +281,7 @@ class ShodanProvider(BaseProvider): attr_type='isp_info', provider=self.name, confidence=0.9, - metadata={'description': 'Autonomous System Number from Shodan'} + metadata={'description': 'Autonomous System Number from Shodan', 'ip_version': ip_version} ) # Also add organization name as attribute to ISP node for completeness @@ -249,7 +292,7 @@ class ShodanProvider(BaseProvider): attr_type='isp_info', provider=self.name, confidence=0.9, - metadata={'description': 'Organization name from Shodan'} + metadata={'description': 'Organization name from Shodan', 'ip_version': ip_version} ) # Process hostnames (reverse DNS) @@ -257,21 +300,27 @@ class ShodanProvider(BaseProvider): if key == 'hostnames': for hostname in value: if _is_valid_domain(hostname): + # Use appropriate relationship type based on IP version + if ip_version == 6: + relationship_type = 'shodan_aaaa_record' + else: + relationship_type = 'shodan_a_record' + result.add_relationship( source_node=ip, target_node=hostname, - relationship_type='shodan_a_record', + relationship_type=relationship_type, provider=self.name, confidence=0.8, - raw_data=data + raw_data={**data, 'ip_version': ip_version} ) self.log_relationship_discovery( source_node=ip, target_node=hostname, - relationship_type='shodan_a_record', + relationship_type=relationship_type, confidence_score=0.8, - raw_data=data, - discovery_method="shodan_host_lookup" + raw_data={**data, 'ip_version': ip_version}, + discovery_method=f"shodan_host_lookup_ipv{ip_version}" ) elif key == 'ports': # Add open ports as attributes to the IP @@ -282,7 +331,8 @@ class ShodanProvider(BaseProvider): value=port, attr_type='shodan_network_info', provider=self.name, - confidence=0.9 + confidence=0.9, + metadata={'ip_version': ip_version} ) elif isinstance(value, (str, int, float, bool)) and value is not None: # Add other Shodan fields as IP attributes (keep raw field names) @@ -292,7 +342,8 @@ class ShodanProvider(BaseProvider): value=value, attr_type='shodan_info', provider=self.name, - confidence=0.9 + confidence=0.9, + metadata={'ip_version': ip_version} ) return result \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0e37daa..d46c0bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -Flask>=2.3.3 -networkx>=3.1 -requests>=2.31.0 -python-dateutil>=2.8.2 -Werkzeug>=2.3.7 -urllib3>=2.0.0 -dnspython>=2.4.2 +Flask +networkx +requests +python-dateutil +Werkzeug +urllib3 +dnspython gunicorn redis python-dotenv \ No newline at end of file diff --git a/static/js/main.js b/static/js/main.js index cfb3c25..ee56fe7 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -58,7 +58,10 @@ class DNSReconApp { startScan: document.getElementById('start-scan'), addToGraph: document.getElementById('add-to-graph'), stopScan: document.getElementById('stop-scan'), - exportResults: document.getElementById('export-results'), + exportOptions: document.getElementById('export-options'), + exportModal: document.getElementById('export-modal'), + exportModalClose: document.getElementById('export-modal-close'), + exportGraphJson: document.getElementById('export-graph-json'), configureSettings: document.getElementById('configure-settings'), // Status elements @@ -146,11 +149,24 @@ class DNSReconApp { this.stopScan(); }); - this.elements.exportResults.addEventListener('click', (e) => { + this.elements.exportOptions.addEventListener('click', (e) => { e.preventDefault(); - this.exportResults(); + this.showExportModal(); }); + if (this.elements.exportModalClose) { + this.elements.exportModalClose.addEventListener('click', () => this.hideExportModal()); + } + if (this.elements.exportModal) { + this.elements.exportModal.addEventListener('click', (e) => { + if (e.target === this.elements.exportModal) this.hideExportModal(); + }); + } + if (this.elements.exportGraphJson) { + this.elements.exportGraphJson.addEventListener('click', () => this.exportGraphJson()); + } + + this.elements.configureSettings.addEventListener('click', () => this.showSettingsModal()); // Enter key support for target domain input @@ -219,6 +235,7 @@ class DNSReconApp { if (e.key === 'Escape') { this.hideModal(); this.hideSettingsModal(); + this.hideExportModal(); // Add this line } }); @@ -376,26 +393,96 @@ class DNSReconApp { } /** - * Export scan results + * Show Export modal */ - async exportResults() { + showExportModal() { + if (this.elements.exportModal) { + this.elements.exportModal.style.display = 'block'; + } + } + + /** + * Hide Export modal + */ + hideExportModal() { + if (this.elements.exportModal) { + this.elements.exportModal.style.display = 'none'; + } + } + + /** + * Export graph data as JSON with proper error handling + */ + async exportGraphJson() { try { - console.log('Exporting results...'); + console.log('Exporting graph data as JSON...'); - // Create a temporary link to trigger download + // Show loading state + if (this.elements.exportGraphJson) { + const originalContent = this.elements.exportGraphJson.innerHTML; + this.elements.exportGraphJson.innerHTML = '[...]Exporting...'; + this.elements.exportGraphJson.disabled = true; + + // Store original content for restoration + this.elements.exportGraphJson._originalContent = originalContent; + } + + // Make API call to get export data + const response = await fetch('/api/export', { + method: 'GET', + headers: { + 'Content-Type': 'application/json' + } + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `HTTP ${response.status}: ${response.statusText}`); + } + + // Check if response is JSON or file download + const contentType = response.headers.get('content-type'); + if (contentType && contentType.includes('application/json') && !response.headers.get('content-disposition')) { + // This is an error response in JSON format + const errorData = await response.json(); + throw new Error(errorData.error || 'Export failed'); + } + + // Get the filename from headers or create one + const contentDisposition = response.headers.get('content-disposition'); + let filename = 'dnsrecon_export.json'; + if (contentDisposition) { + const filenameMatch = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/); + if (filenameMatch) { + filename = filenameMatch[1].replace(/['"]/g, ''); + } + } + + // Create blob and download + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); const link = document.createElement('a'); - link.href = '/api/export'; - link.download = ''; // Let server determine filename + link.href = url; + link.download = filename; document.body.appendChild(link); link.click(); document.body.removeChild(link); + window.URL.revokeObjectURL(url); - this.showSuccess('Results export initiated'); - console.log('Results export initiated'); + this.showSuccess('Graph data exported successfully'); + this.hideExportModal(); } catch (error) { - console.error('Failed to export results:', error); - this.showError(`Failed to export results: ${error.message}`); + console.error('Failed to export graph data:', error); + this.showError(`Export failed: ${error.message}`); + } finally { + // Restore button state + if (this.elements.exportGraphJson) { + const originalContent = this.elements.exportGraphJson._originalContent || + '[JSON]Export Graph Data'; + this.elements.exportGraphJson.innerHTML = originalContent; + this.elements.exportGraphJson.disabled = false; + } } } @@ -2116,14 +2203,7 @@ class DNSReconApp { } } - /** - * Validate target (domain or IP) - * @param {string} target - Target to validate - * @returns {boolean} True if valid - */ - isValidTarget(target) { - return this.isValidDomain(target) || this.isValidIp(target); - } + /** * Validate domain name @@ -2143,20 +2223,149 @@ class DNSReconApp { } /** - * Validate IP address + * Validate target (domain or IP) - UPDATED for IPv6 support + * @param {string} target - Target to validate + * @returns {boolean} True if valid + */ + isValidTarget(target) { + return this.isValidDomain(target) || this.isValidIp(target); + } + + /** + * Validate IP address (IPv4 or IPv6) * @param {string} ip - IP to validate * @returns {boolean} True if valid */ isValidIp(ip) { console.log(`Validating IP: "${ip}"`); - const parts = ip.split('.'); - if (parts.length !== 4) { + + if (!ip || typeof ip !== 'string') { return false; } - return parts.every(part => { - const num = parseInt(part, 10); - return !isNaN(num) && num >= 0 && num <= 255 && String(num) === part; - }); + + ip = ip.trim(); + + // IPv4 validation + if (this.isValidIPv4(ip)) { + return true; + } + + // IPv6 validation + if (this.isValidIPv6(ip)) { + return true; + } + + return false; + } + + /** + * Validate IPv4 address + * @param {string} ip - IP to validate + * @returns {boolean} True if valid IPv4 + */ + isValidIPv4(ip) { + const ipv4Pattern = /^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/; + const match = ip.match(ipv4Pattern); + + if (!match) { + return false; + } + + // Check each octet is between 0-255 + for (let i = 1; i <= 4; i++) { + const octet = parseInt(match[i], 10); + if (octet < 0 || octet > 255) { + return false; + } + // Check for leading zeros (except for '0' itself) + if (match[i].length > 1 && match[i][0] === '0') { + return false; + } + } + + return true; + } + + /** + * Validate IPv6 address + * @param {string} ip - IP to validate + * @returns {boolean} True if valid IPv6 + */ + isValidIPv6(ip) { + // Handle IPv6 with embedded IPv4 (e.g., ::ffff:192.168.1.1) + if (ip.includes('.')) { + const lastColon = ip.lastIndexOf(':'); + if (lastColon !== -1) { + const ipv6Part = ip.substring(0, lastColon + 1); + const ipv4Part = ip.substring(lastColon + 1); + + if (this.isValidIPv4(ipv4Part)) { + // Validate the IPv6 part (should end with ::) + return this.isValidIPv6Pure(ipv6Part + '0:0'); + } + } + } + + return this.isValidIPv6Pure(ip); + } + + /** + * Validate pure IPv6 address (no embedded IPv4) + * @param {string} ip - IPv6 address to validate + * @returns {boolean} True if valid IPv6 + */ + isValidIPv6Pure(ip) { + // Basic format check + if (!ip || ip.length < 2 || ip.length > 39) { + return false; + } + + // Check for invalid characters + if (!/^[0-9a-fA-F:]+$/.test(ip)) { + return false; + } + + // Handle double colon (::) for zero compression + const doubleColonCount = (ip.match(/::/g) || []).length; + if (doubleColonCount > 1) { + return false; // Only one :: allowed + } + + let parts; + if (doubleColonCount === 1) { + // Expand the :: notation + const [before, after] = ip.split('::'); + const beforeParts = before ? before.split(':') : []; + const afterParts = after ? after.split(':') : []; + + // Calculate how many zero groups the :: represents + const totalParts = beforeParts.length + afterParts.length; + const zeroGroups = 8 - totalParts; + + if (zeroGroups < 1) { + return false; // :: must represent at least one zero group + } + + // Build the full address + parts = beforeParts.concat(Array(zeroGroups).fill('0')).concat(afterParts); + } else { + // No :: notation, split normally + parts = ip.split(':'); + } + + // IPv6 should have exactly 8 groups + if (parts.length !== 8) { + return false; + } + + // Validate each group (1-4 hex digits) + for (const part of parts) { + if (!part || part.length > 4 || !/^[0-9a-fA-F]+$/.test(part)) { + return false; + } + } + + return true; } /** diff --git a/templates/index.html b/templates/index.html index 0c7a4f9..f32d02c 100644 --- a/templates/index.html +++ b/templates/index.html @@ -53,9 +53,9 @@ [STOP] Terminate Scan - + + + + diff --git a/utils/helpers.py b/utils/helpers.py index 2d17717..8849790 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -1,3 +1,8 @@ +# dnsrecon-reduced/utils/helpers.py + +import ipaddress +from typing import Union + def _is_valid_domain(domain: str) -> bool: """ Basic domain validation. @@ -26,32 +31,27 @@ def _is_valid_domain(domain: str) -> bool: def _is_valid_ip(ip: str) -> bool: """ - Basic IP address validation. + IP address validation supporting both IPv4 and IPv6. Args: ip: IP address string to validate Returns: - True if IP appears valid + True if IP appears valid (IPv4 or IPv6) """ + if not ip: + return False + try: - parts = ip.split('.') - if len(parts) != 4: - return False - - for part in parts: - num = int(part) - if not 0 <= num <= 255: - return False - + # This handles both IPv4 and IPv6 validation + ipaddress.ip_address(ip.strip()) return True - except (ValueError, AttributeError): return False def is_valid_target(target: str) -> bool: """ - Checks if the target is a valid domain or IP address. + Checks if the target is a valid domain or IP address (IPv4/IPv6). Args: target: The target string to validate. @@ -59,4 +59,36 @@ def is_valid_target(target: str) -> bool: Returns: True if the target is a valid domain or IP, False otherwise. """ - return _is_valid_domain(target) or _is_valid_ip(target) \ No newline at end of file + return _is_valid_domain(target) or _is_valid_ip(target) + +def get_ip_version(ip: str) -> Union[int, None]: + """ + Get the IP version (4 or 6) of a valid IP address. + + Args: + ip: IP address string + + Returns: + 4 for IPv4, 6 for IPv6, None if invalid + """ + try: + addr = ipaddress.ip_address(ip.strip()) + return addr.version + except (ValueError, AttributeError): + return None + +def normalize_ip(ip: str) -> Union[str, None]: + """ + Normalize an IP address to its canonical form. + + Args: + ip: IP address string + + Returns: + Normalized IP address string, None if invalid + """ + try: + addr = ipaddress.ip_address(ip.strip()) + return str(addr) + except (ValueError, AttributeError): + return None \ No newline at end of file