diff --git a/app.py b/app.py index 1a3e9fa..4e2f175 100644 --- a/app.py +++ b/app.py @@ -14,6 +14,7 @@ import io from core.session_manager import session_manager from config import config from core.graph_manager import NodeType +from utils.helpers import is_valid_target app = Flask(__name__) @@ -65,18 +66,20 @@ def start_scan(): try: data = request.get_json() - if not data or 'target_domain' not in data: - return jsonify({'success': False, 'error': 'Missing target_domain in request'}), 400 + if not data or 'target' not in data: + return jsonify({'success': False, 'error': 'Missing target in request'}), 400 - target_domain = data['target_domain'].strip() + target = data['target'].strip() max_depth = data.get('max_depth', config.default_recursion_depth) clear_graph = data.get('clear_graph', True) - print(f"Parsed - target_domain: '{target_domain}', max_depth: {max_depth}, clear_graph: {clear_graph}") + print(f"Parsed - target: '{target}', max_depth: {max_depth}, clear_graph: {clear_graph}") # Validation - if not target_domain: - return jsonify({'success': False, 'error': 'Target domain cannot be empty'}), 400 + if not target: + return jsonify({'success': False, 'error': 'Target cannot be empty'}), 400 + if not is_valid_target(target): + return jsonify({'success': False, 'error': 'Invalid target format. Please enter a valid domain or IP address.'}), 400 if not isinstance(max_depth, int) or not 1 <= max_depth <= 5: return jsonify({'success': False, 'error': 'Max depth must be an integer between 1 and 5'}), 400 @@ -101,7 +104,7 @@ def start_scan(): print(f"Using scanner {id(scanner)} in session {user_session_id}") - success = scanner.start_scan(target_domain, max_depth, clear_graph=clear_graph) + success = scanner.start_scan(target, max_depth, clear_graph=clear_graph) if success: return jsonify({ @@ -120,7 +123,7 @@ def start_scan(): print(f"ERROR: Exception in start_scan endpoint: {e}") traceback.print_exc() return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 - + @app.route('/api/scan/stop', methods=['POST']) def stop_scan(): """Stop the current scan with immediate GUI feedback.""" diff --git a/core/scanner.py b/core/scanner.py index 9afab06..0f836c3 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -204,7 +204,7 @@ class Scanner: self._initialize_providers() print("Session configuration updated") - def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: + def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True) -> bool: """Start a new reconnaissance scan with proper cleanup of previous scans.""" print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"Session ID: {self.session_id}") @@ -268,7 +268,7 @@ class Scanner: if clear_graph: self.graph.clear() - self.current_target = target_domain.lower().strip() + self.current_target = target.lower().strip() self.max_depth = max_depth self.current_depth = 0 @@ -304,76 +304,80 @@ class Scanner: self._update_session_state() return False - def _execute_scan(self, target_domain: str, max_depth: int) -> None: + def _execute_scan(self, target: str, max_depth: int) -> None: """Execute the reconnaissance scan with proper termination handling.""" - print(f"_execute_scan started for {target_domain} with depth {max_depth}") + print(f"_execute_scan started for {target} with depth {max_depth}") self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_targets = set() - self.task_queue.append((target_domain, 0, False)) + self.task_queue.append((target, 0, False)) try: self.status = ScanStatus.RUNNING self._update_session_state() enabled_providers = [provider.get_name() for provider in self.providers] - self.logger.log_scan_start(target_domain, max_depth, enabled_providers) - self.graph.add_node(target_domain, NodeType.DOMAIN) - self._initialize_provider_states(target_domain) + self.logger.log_scan_start(target, max_depth, enabled_providers) + + # Determine initial node type + node_type = NodeType.IP if _is_valid_ip(target) else NodeType.DOMAIN + self.graph.add_node(target, node_type) + + self._initialize_provider_states(target) - # **IMPROVED**: Better termination checking in main loop + # Better termination checking in main loop while self.task_queue and not self._is_stop_requested(): try: - target, depth, is_large_entity_member = self.task_queue.popleft() + target_item, depth, is_large_entity_member = self.task_queue.popleft() except IndexError: # Queue became empty during processing break - if target in processed_targets: + if target_item in processed_targets: continue if depth > max_depth: continue - # **NEW**: Track this target as currently processing + # Track this target as currently processing with self.processing_lock: if self._is_stop_requested(): - print(f"Stop requested before processing {target}") + print(f"Stop requested before processing {target_item}") break - self.currently_processing.add(target) + self.currently_processing.add(target_item) try: self.current_depth = depth - self.current_indicator = target + self.current_indicator = target_item self._update_session_state() - # **IMPROVED**: More frequent stop checking during processing + # More frequent stop checking during processing if self._is_stop_requested(): - print(f"Stop requested during processing setup for {target}") + print(f"Stop requested during processing setup for {target_item}") break - new_targets, large_entity_members, success = self._query_providers_for_target(target, depth, is_large_entity_member) + new_targets, large_entity_members, success = self._query_providers_for_target(target_item, depth, is_large_entity_member) - # **NEW**: Check stop signal after provider queries + # Check stop signal after provider queries if self._is_stop_requested(): - print(f"Stop requested after querying providers for {target}") + print(f"Stop requested after querying providers for {target_item}") break if not success: - self.target_retries[target] += 1 - if self.target_retries[target] <= self.config.max_retries_per_target: - print(f"Re-queueing target {target} (attempt {self.target_retries[target]})") - self.task_queue.append((target, depth, is_large_entity_member)) + self.target_retries[target_item] += 1 + if self.target_retries[target_item] <= self.config.max_retries_per_target: + print(f"Re-queueing target {target_item} (attempt {self.target_retries[target_item]})") + self.task_queue.append((target_item, depth, is_large_entity_member)) self.tasks_re_enqueued += 1 else: - print(f"ERROR: Max retries exceeded for target {target}") + print(f"ERROR: Max retries exceeded for target {target_item}") self.scan_failed_due_to_retries = True - self._log_target_processing_error(target, "Max retries exceeded") + self._log_target_processing_error(target_item, "Max retries exceeded") else: - processed_targets.add(target) + processed_targets.add(target_item) self.indicators_completed += 1 - # **NEW**: Only add new targets if not stopped + # Only add new targets if not stopped if not self._is_stop_requested(): for new_target in new_targets: if new_target not in processed_targets: @@ -384,11 +388,11 @@ class Scanner: self.task_queue.append((member, depth, True)) finally: - # **NEW**: Always remove from processing set + # Always remove from processing set with self.processing_lock: - self.currently_processing.discard(target) + self.currently_processing.discard(target_item) - # **NEW**: Log termination reason + # Log termination reason if self._is_stop_requested(): print("Scan terminated due to stop request") self.logger.logger.info("Scan terminated by user request") @@ -402,7 +406,7 @@ class Scanner: self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: - # **NEW**: Clear processing state on exit + # Clear processing state on exit with self.processing_lock: self.currently_processing.clear() diff --git a/static/js/main.js b/static/js/main.js index 977c3ea..611713a 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -49,7 +49,7 @@ class DNSReconApp { console.log('Initializing DOM elements...'); this.elements = { // Form elements - targetDomain: document.getElementById('target-domain'), + targetInput: document.getElementById('target-input'), maxDepth: document.getElementById('max-depth'), startScan: document.getElementById('start-scan'), addToGraph: document.getElementById('add-to-graph'), @@ -87,7 +87,7 @@ class DNSReconApp { }; // Verify critical elements exist - const requiredElements = ['targetDomain', 'startScan', 'scanStatus']; + const requiredElements = ['targetInput', 'startScan', 'scanStatus']; for (const elementName of requiredElements) { if (!this.elements[elementName]) { throw new Error(`Required element '${elementName}' not found in DOM`); @@ -156,7 +156,7 @@ class DNSReconApp { this.elements.configureSettings.addEventListener('click', () => this.showSettingsModal()); // Enter key support for target domain input - this.elements.targetDomain.addEventListener('keypress', (e) => { + this.elements.targetInput.addEventListener('keypress', (e) => { if (e.key === 'Enter' && !this.isScanning) { console.log('Enter key pressed in domain input'); this.startScan(); @@ -238,23 +238,23 @@ class DNSReconApp { console.log('=== STARTING SCAN ==='); try { - const targetDomain = this.elements.targetDomain.value.trim(); + const target = this.elements.targetInput.value.trim(); const maxDepth = parseInt(this.elements.maxDepth.value); - console.log(`Target domain: "${targetDomain}", Max depth: ${maxDepth}`); + console.log(`Target: "${target}", Max depth: ${maxDepth}`); // Validation - if (!targetDomain) { - console.log('Validation failed: empty domain'); - this.showError('Please enter a target domain'); - this.elements.targetDomain.focus(); + if (!target) { + console.log('Validation failed: empty target'); + this.showError('Please enter a target domain or IP'); + this.elements.targetInput.focus(); return; } - if (!this.isValidDomain(targetDomain)) { - console.log(`Validation failed: invalid domain format for "${targetDomain}"`); - this.showError('Please enter a valid domain name (e.g., example.com)'); - this.elements.targetDomain.focus(); + if (!this.isValidTarget(target)) { + console.log(`Validation failed: invalid target format for "${target}"`); + this.showError('Please enter a valid domain name (e.g., example.com) or IP address (e.g., 8.8.8.8)'); + this.elements.targetInput.focus(); return; } @@ -265,7 +265,7 @@ class DNSReconApp { console.log('Making API call to start scan...'); const requestData = { - target_domain: targetDomain, + target: target, max_depth: maxDepth, clear_graph: clearGraph }; @@ -284,7 +284,7 @@ class DNSReconApp { this.graphManager.clear(); } - console.log(`Scan started for ${targetDomain} with depth ${maxDepth}`); + console.log(`Scan started for ${target} with depth ${maxDepth}`); // Start polling immediately with faster interval for responsiveness this.startPolling(1000); @@ -685,7 +685,7 @@ class DNSReconApp { this.elements.stopScan.classList.remove('loading'); this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; } - if (this.elements.targetDomain) this.elements.targetDomain.disabled = true; + if (this.elements.targetInput) this.elements.targetInput.disabled = true; if (this.elements.maxDepth) this.elements.maxDepth.disabled = true; if (this.elements.configureSettings) this.elements.configureSettings.disabled = true; break; @@ -708,7 +708,7 @@ class DNSReconApp { this.elements.stopScan.disabled = true; this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; } - if (this.elements.targetDomain) this.elements.targetDomain.disabled = false; + if (this.elements.targetInput) this.elements.targetInput.disabled = false; if (this.elements.maxDepth) this.elements.maxDepth.disabled = false; if (this.elements.configureSettings) this.elements.configureSettings.disabled = false; break; @@ -2065,65 +2065,48 @@ class DNSReconApp { } /** - * Validate domain name - improved validation + * Validate target (domain or IP) + * @param {string} target - Target to validate + * @returns {boolean} True if valid + */ + isValidTarget(target) { + return this.isValidDomain(target) || this.isValidIp(target); + } + + /** + * Validate domain name * @param {string} domain - Domain to validate * @returns {boolean} True if valid */ isValidDomain(domain) { console.log(`Validating domain: "${domain}"`); - - // Basic checks - if (!domain || typeof domain !== 'string') { - console.log('Validation failed: empty or non-string domain'); + if (!domain || typeof domain !== 'string' || domain.length > 253 || /^\d{1,3}(\.\d{1,3}){3}$/.test(domain)) { return false; } - if (domain.length > 253) { - console.log('Validation failed: domain too long'); - return false; - } - if (domain.startsWith('.') || domain.endsWith('.')) { - console.log('Validation failed: domain starts or ends with dot'); - return false; - } - if (domain.includes('..')) { - console.log('Validation failed: domain contains double dots'); - return false; - } - - // Split into parts and validate each const parts = domain.split('.'); - if (parts.length < 2) { - console.log('Validation failed: domain has less than 2 parts'); + if (parts.length < 2 || parts.some(part => !/^[a-zA-Z0-9-]{1,63}$/.test(part) || part.startsWith('-') || part.endsWith('-'))) { return false; } - - // Check each part - for (const part of parts) { - if (!part || part.length > 63) { - console.log(`Validation failed: invalid part "${part}"`); - return false; - } - if (part.startsWith('-') || part.endsWith('-')) { - console.log(`Validation failed: part "${part}" starts or ends with hyphen`); - return false; - } - if (!/^[a-zA-Z0-9-]+$/.test(part)) { - console.log(`Validation failed: part "${part}" contains invalid characters`); - return false; - } - } - - // Check TLD (last part) is alphabetic - const tld = parts[parts.length - 1]; - if (!/^[a-zA-Z]{2,}$/.test(tld)) { - console.log(`Validation failed: invalid TLD "${tld}"`); - return false; - } - - console.log('Domain validation passed'); return true; } - + + /** + * Validate IP address + * @param {string} ip - IP to validate + * @returns {boolean} True if valid + */ + isValidIp(ip) { + console.log(`Validating IP: "${ip}"`); + const parts = ip.split('.'); + if (parts.length !== 4) { + return false; + } + return parts.every(part => { + const num = parseInt(part, 10); + return !isNaN(num) && num >= 0 && num <= 255 && String(num) === part; + }); + } + /** * Format status text for display * @param {string} status - Raw status diff --git a/templates/index.html b/templates/index.html index e063ab1..fdc927f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -32,8 +32,8 @@
- - + +
diff --git a/utils/helpers.py b/utils/helpers.py index b105eb5..2d17717 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -48,3 +48,15 @@ def _is_valid_ip(ip: str) -> bool: except (ValueError, AttributeError): return False + +def is_valid_target(target: str) -> bool: + """ + Checks if the target is a valid domain or IP address. + + Args: + target: The target string to validate. + + Returns: + True if the target is a valid domain or IP, False otherwise. + """ + return _is_valid_domain(target) or _is_valid_ip(target) \ No newline at end of file