368 lines
16 KiB
Python
368 lines
16 KiB
Python
# File: src/dns_resolver.py
|
|
"""DNS resolution functionality with enhanced TLD testing and forensic operation tracking."""
|
|
|
|
import dns.resolver
|
|
import dns.reversename
|
|
import dns.query
|
|
import dns.zone
|
|
from typing import List, Dict, Optional, Set
|
|
import socket
|
|
import time
|
|
import logging
|
|
import uuid
|
|
from .data_structures import DNSRecord, ReconData
|
|
from .config import Config
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DNSResolver:
|
|
"""DNS resolution and record lookup with optimized TLD testing and forensic tracking."""
|
|
|
|
# All DNS record types to query
|
|
RECORD_TYPES = [
|
|
'A', 'AAAA', 'MX', 'NS', 'TXT', 'CNAME', 'SOA', 'PTR',
|
|
'SRV', 'CAA', 'DNSKEY', 'DS', 'RRSIG', 'NSEC', 'NSEC3'
|
|
]
|
|
|
|
def __init__(self, config: Config):
|
|
self.config = config
|
|
self.last_request = 0
|
|
self.query_count = 0
|
|
|
|
logger.info(f"🌐 DNS resolver initialized with {len(config.DNS_SERVERS)} servers: {config.DNS_SERVERS}")
|
|
logger.info(f"⚡ DNS rate limit: {config.DNS_RATE_LIMIT}/s, timeout: {config.DNS_TIMEOUT}s")
|
|
|
|
def _rate_limit(self):
|
|
"""Apply rate limiting - more graceful for DNS servers."""
|
|
now = time.time()
|
|
time_since_last = now - self.last_request
|
|
min_interval = 1.0 / self.config.DNS_RATE_LIMIT
|
|
|
|
if time_since_last < min_interval:
|
|
sleep_time = min_interval - time_since_last
|
|
# Only log if sleep is significant to reduce spam
|
|
if sleep_time > 0.1:
|
|
logger.debug(f"⏸️ DNS rate limiting: sleeping for {sleep_time:.2f}s")
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_request = time.time()
|
|
self.query_count += 1
|
|
|
|
def resolve_hostname_fast(self, hostname: str) -> List[str]:
|
|
"""Fast hostname resolution optimized for TLD testing."""
|
|
ips = []
|
|
|
|
logger.debug(f"🚀 Fast resolving hostname: {hostname}")
|
|
|
|
# Use only the first DNS server and shorter timeout for TLD testing
|
|
resolver = dns.resolver.Resolver()
|
|
resolver.nameservers = [self.config.DNS_SERVERS[0]] # Use primary DNS only
|
|
resolver.timeout = 2 # Shorter timeout for TLD testing
|
|
resolver.lifetime = 2 # Total query time limit
|
|
|
|
try:
|
|
# Try A records only for speed (most common)
|
|
answers = resolver.resolve(hostname, 'A')
|
|
for answer in answers:
|
|
ips.append(str(answer))
|
|
logger.debug(f"⚡ Fast A record for {hostname}: {answer}")
|
|
except dns.resolver.NXDOMAIN:
|
|
logger.debug(f"❌ NXDOMAIN for {hostname}")
|
|
except dns.resolver.NoAnswer:
|
|
logger.debug(f"⚠️ No A record for {hostname}")
|
|
except dns.resolver.Timeout:
|
|
logger.debug(f"⏱️ Timeout for {hostname}")
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error fast resolving {hostname}: {e}")
|
|
|
|
if ips:
|
|
logger.debug(f"⚡ Fast resolved {hostname} to {len(ips)} IPs: {ips}")
|
|
|
|
return ips
|
|
|
|
def resolve_hostname(self, hostname: str, operation_id: Optional[str] = None) -> List[str]:
|
|
"""Resolve hostname to IP addresses (full resolution with retries)."""
|
|
ips = []
|
|
|
|
logger.debug(f"🔍 Resolving hostname: {hostname}")
|
|
|
|
for dns_server in self.config.DNS_SERVERS:
|
|
self._rate_limit()
|
|
resolver = dns.resolver.Resolver()
|
|
resolver.nameservers = [dns_server]
|
|
resolver.timeout = self.config.DNS_TIMEOUT
|
|
|
|
try:
|
|
# Try A records
|
|
answers = resolver.resolve(hostname, 'A')
|
|
for answer in answers:
|
|
ips.append(str(answer))
|
|
logger.debug(f"✅ A record for {hostname}: {answer}")
|
|
except dns.resolver.NXDOMAIN:
|
|
logger.debug(f"❌ NXDOMAIN for {hostname} A record on {dns_server}")
|
|
except dns.resolver.NoAnswer:
|
|
logger.debug(f"⚠️ No A record for {hostname} on {dns_server}")
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error resolving A record for {hostname} on {dns_server}: {e}")
|
|
|
|
try:
|
|
# Try AAAA records (IPv6)
|
|
answers = resolver.resolve(hostname, 'AAAA')
|
|
for answer in answers:
|
|
ips.append(str(answer))
|
|
logger.debug(f"✅ AAAA record for {hostname}: {answer}")
|
|
except dns.resolver.NXDOMAIN:
|
|
logger.debug(f"❌ NXDOMAIN for {hostname} AAAA record on {dns_server}")
|
|
except dns.resolver.NoAnswer:
|
|
logger.debug(f"⚠️ No AAAA record for {hostname} on {dns_server}")
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error resolving AAAA record for {hostname} on {dns_server}: {e}")
|
|
|
|
unique_ips = list(set(ips))
|
|
if unique_ips:
|
|
logger.info(f"✅ Resolved {hostname} to {len(unique_ips)} unique IPs: {unique_ips}")
|
|
else:
|
|
logger.debug(f"❌ No IPs found for {hostname}")
|
|
|
|
return unique_ips
|
|
|
|
def get_all_dns_records(self, hostname: str, operation_id: Optional[str] = None) -> List[DNSRecord]:
|
|
"""Get all DNS records for a hostname with forensic tracking."""
|
|
records = []
|
|
successful_queries = 0
|
|
|
|
# Generate operation ID if not provided
|
|
if operation_id is None:
|
|
operation_id = str(uuid.uuid4())
|
|
|
|
logger.debug(f"📋 Getting all DNS records for: {hostname} (operation: {operation_id})")
|
|
|
|
for record_type in self.RECORD_TYPES:
|
|
type_found = False
|
|
|
|
for dns_server in self.config.DNS_SERVERS:
|
|
self._rate_limit()
|
|
resolver = dns.resolver.Resolver()
|
|
resolver.nameservers = [dns_server]
|
|
resolver.timeout = self.config.DNS_TIMEOUT
|
|
|
|
try:
|
|
answers = resolver.resolve(hostname, record_type)
|
|
for answer in answers:
|
|
# Create DNSRecord with forensic metadata
|
|
record = DNSRecord(
|
|
record_type=record_type,
|
|
value=str(answer),
|
|
ttl=answers.ttl,
|
|
operation_id=operation_id # Forensic tracking
|
|
)
|
|
records.append(record)
|
|
|
|
if not type_found:
|
|
logger.debug(f"✅ Found {record_type} record for {hostname}: {answer}")
|
|
type_found = True
|
|
|
|
if not type_found:
|
|
successful_queries += 1
|
|
break # Found records, no need to query other DNS servers for this type
|
|
|
|
except dns.resolver.NXDOMAIN:
|
|
logger.debug(f"❌ NXDOMAIN for {hostname} {record_type} on {dns_server}")
|
|
break # Domain doesn't exist, no point checking other servers
|
|
except dns.resolver.NoAnswer:
|
|
logger.debug(f"⚠️ No {record_type} record for {hostname} on {dns_server}")
|
|
continue # Try next DNS server
|
|
except dns.resolver.Timeout:
|
|
logger.debug(f"⏱️ Timeout for {hostname} {record_type} on {dns_server}")
|
|
continue # Try next DNS server
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error querying {record_type} for {hostname} on {dns_server}: {e}")
|
|
continue # Try next DNS server
|
|
|
|
logger.info(f"📋 Found {len(records)} DNS records for {hostname} across {len(set(r.record_type for r in records))} record types")
|
|
|
|
# Log query statistics every 100 queries
|
|
if self.query_count % 100 == 0:
|
|
logger.info(f"📊 DNS query statistics: {self.query_count} total queries performed")
|
|
|
|
return records
|
|
|
|
def query_specific_record_type(self, hostname: str, record_type: str, operation_id: Optional[str] = None) -> List[DNSRecord]:
|
|
"""Query a specific DNS record type with forensic tracking."""
|
|
records = []
|
|
|
|
# Generate operation ID if not provided
|
|
if operation_id is None:
|
|
operation_id = str(uuid.uuid4())
|
|
|
|
logger.debug(f"🎯 Querying {record_type} records for {hostname} (operation: {operation_id})")
|
|
|
|
for dns_server in self.config.DNS_SERVERS:
|
|
self._rate_limit()
|
|
resolver = dns.resolver.Resolver()
|
|
resolver.nameservers = [dns_server]
|
|
resolver.timeout = self.config.DNS_TIMEOUT
|
|
|
|
try:
|
|
answers = resolver.resolve(hostname, record_type)
|
|
for answer in answers:
|
|
# Create DNSRecord with forensic metadata
|
|
record = DNSRecord(
|
|
record_type=record_type,
|
|
value=str(answer),
|
|
ttl=answers.ttl,
|
|
operation_id=operation_id # Forensic tracking
|
|
)
|
|
records.append(record)
|
|
logger.debug(f"✅ {record_type} record for {hostname}: {answer}")
|
|
|
|
break # Found records, no need to query other DNS servers
|
|
|
|
except dns.resolver.NXDOMAIN:
|
|
logger.debug(f"❌ NXDOMAIN for {hostname} {record_type} on {dns_server}")
|
|
break # Domain doesn't exist, no point checking other servers
|
|
except dns.resolver.NoAnswer:
|
|
logger.debug(f"⚠️ No {record_type} record for {hostname} on {dns_server}")
|
|
continue # Try next DNS server
|
|
except dns.resolver.Timeout:
|
|
logger.debug(f"⏱️ Timeout for {hostname} {record_type} on {dns_server}")
|
|
continue # Try next DNS server
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error querying {record_type} for {hostname} on {dns_server}: {e}")
|
|
continue # Try next DNS server
|
|
|
|
logger.debug(f"🎯 Found {len(records)} {record_type} records for {hostname}")
|
|
return records
|
|
|
|
def reverse_dns_lookup(self, ip: str, operation_id: Optional[str] = None) -> Optional[str]:
|
|
"""Perform reverse DNS lookup with forensic tracking."""
|
|
logger.debug(f"🔍 Reverse DNS lookup for: {ip} (operation: {operation_id or 'auto'})")
|
|
|
|
try:
|
|
self._rate_limit()
|
|
hostname = socket.gethostbyaddr(ip)[0]
|
|
logger.info(f"✅ Reverse DNS for {ip}: {hostname}")
|
|
return hostname
|
|
except socket.herror:
|
|
logger.debug(f"❌ No reverse DNS for {ip}")
|
|
return None
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error in reverse DNS for {ip}: {e}")
|
|
return None
|
|
|
|
def extract_subdomains_from_dns(self, records: List[DNSRecord]) -> Set[str]:
|
|
"""Extract potential subdomains from DNS records."""
|
|
subdomains = set()
|
|
|
|
logger.debug(f"🌿 Extracting subdomains from {len(records)} DNS records")
|
|
|
|
for record in records:
|
|
value = record.value.lower()
|
|
|
|
# Extract from different record types
|
|
try:
|
|
if record.record_type == 'MX':
|
|
# MX record format: "priority hostname"
|
|
parts = value.split()
|
|
if len(parts) >= 2:
|
|
hostname = parts[-1].rstrip('.') # Take the last part (hostname)
|
|
if self._is_valid_hostname(hostname):
|
|
subdomains.add(hostname)
|
|
logger.debug(f"🌿 Found subdomain from MX: {hostname}")
|
|
|
|
elif record.record_type in ['CNAME', 'NS']:
|
|
# Direct hostname records
|
|
hostname = value.rstrip('.')
|
|
if self._is_valid_hostname(hostname):
|
|
subdomains.add(hostname)
|
|
logger.debug(f"🌿 Found subdomain from {record.record_type}: {hostname}")
|
|
|
|
elif record.record_type == 'TXT':
|
|
# Search for domain-like strings in TXT records
|
|
# Common patterns: include:example.com, v=spf1 include:_spf.google.com
|
|
words = value.replace(',', ' ').replace(';', ' ').split()
|
|
for word in words:
|
|
# Look for include: patterns
|
|
if word.startswith('include:'):
|
|
hostname = word[8:].rstrip('.')
|
|
if self._is_valid_hostname(hostname):
|
|
subdomains.add(hostname)
|
|
logger.debug(f"🌿 Found subdomain from TXT include: {hostname}")
|
|
|
|
# Look for other domain patterns
|
|
elif '.' in word and not word.startswith('http'):
|
|
clean_word = word.strip('",\'()[]{}').rstrip('.')
|
|
if self._is_valid_hostname(clean_word):
|
|
subdomains.add(clean_word)
|
|
logger.debug(f"🌿 Found subdomain from TXT: {clean_word}")
|
|
|
|
elif record.record_type == 'SRV':
|
|
# SRV record format: "priority weight port target"
|
|
parts = value.split()
|
|
if len(parts) >= 4:
|
|
hostname = parts[-1].rstrip('.') # Target hostname
|
|
if self._is_valid_hostname(hostname):
|
|
subdomains.add(hostname)
|
|
logger.debug(f"🌿 Found subdomain from SRV: {hostname}")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"⚠️ Error extracting subdomain from {record.record_type} record '{value}': {e}")
|
|
continue
|
|
|
|
if subdomains:
|
|
logger.info(f"🌿 Extracted {len(subdomains)} potential subdomains")
|
|
else:
|
|
logger.debug("❌ No subdomains extracted from DNS records")
|
|
|
|
return subdomains
|
|
|
|
def _is_valid_hostname(self, hostname: str) -> bool:
|
|
"""Basic hostname validation."""
|
|
if not hostname or len(hostname) > 255:
|
|
return False
|
|
|
|
# Must contain at least one dot
|
|
if '.' not in hostname:
|
|
return False
|
|
|
|
# Must not be an IP address
|
|
if self._looks_like_ip(hostname):
|
|
return False
|
|
|
|
# Basic character check - allow international domains
|
|
# Remove overly restrictive character filtering
|
|
if not hostname.replace('-', '').replace('.', '').replace('_', '').isalnum():
|
|
# Allow some special cases for internationalized domains
|
|
try:
|
|
hostname.encode('ascii')
|
|
except UnicodeEncodeError:
|
|
return False # Skip non-ASCII for now
|
|
|
|
# Must have reasonable length parts
|
|
parts = hostname.split('.')
|
|
if len(parts) < 2:
|
|
return False
|
|
|
|
# Each part should be reasonable length
|
|
for part in parts:
|
|
if len(part) < 1 or len(part) > 63:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _looks_like_ip(self, text: str) -> bool:
|
|
"""Check if text looks like an IP address."""
|
|
try:
|
|
socket.inet_aton(text)
|
|
return True
|
|
except socket.error:
|
|
pass
|
|
|
|
try:
|
|
socket.inet_pton(socket.AF_INET6, text)
|
|
return True
|
|
except socket.error:
|
|
pass
|
|
|
|
return False |