iteration on ws implementation

This commit is contained in:
overcuriousity
2025-09-20 16:52:05 +02:00
parent 75a595c9cb
commit c4e6a8998a
9 changed files with 1224 additions and 290 deletions

View File

@@ -15,6 +15,7 @@ class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration and returns standardized ProviderResult objects.
FIXED: Enhanced pickle support to prevent weakref serialization errors.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
@@ -53,22 +54,57 @@ class BaseProvider(ABC):
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
unpicklable_attrs = ['_local', '_stop_event']
# Exclude unpickleable attributes that may contain weakrefs
unpicklable_attrs = [
'_local', # Thread-local storage (contains requests.Session)
'_stop_event', # Threading event
'logger', # Logger may contain weakrefs in handlers
]
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
# Also handle any potential weakrefs in the config object
if 'config' in state and hasattr(state['config'], '__getstate__'):
# If config has its own pickle support, let it handle itself
pass
elif 'config' in state:
# Otherwise, ensure config doesn't contain unpicklable objects
try:
# Test if config can be pickled
import pickle
pickle.dumps(state['config'])
except (TypeError, AttributeError):
# If config can't be pickled, we'll recreate it during unpickling
state['_config_class'] = type(state['config']).__name__
del state['config']
return state
def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
# Re-initialize unpickleable attributes
self._local = threading.local()
self._stop_event = None
self.logger = get_forensic_logger()
# Recreate config if it was removed during pickling
if not hasattr(self, 'config') and hasattr(self, '_config_class'):
if self._config_class == 'Config':
from config import config as global_config
self.config = global_config
elif self._config_class == 'SessionConfig':
from core.session_config import create_session_config
self.config = create_session_config()
del self._config_class
@property
def session(self):
"""Get or create thread-local requests session."""
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({

View File

@@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager
class CorrelationProvider(BaseProvider):
"""
A provider that finds correlations between nodes in the graph.
FIXED: Enhanced pickle support to prevent weakref issues with graph references.
"""
def __init__(self, name: str = "correlation", session_config=None):
@@ -38,6 +39,38 @@ class CorrelationProvider(BaseProvider):
'query_timestamp',
]
def __getstate__(self):
"""
FIXED: Prepare CorrelationProvider for pickling by excluding graph reference.
"""
state = super().__getstate__()
# Remove graph reference to prevent circular dependencies and weakrefs
if 'graph' in state:
del state['graph']
# Also handle correlation_index which might contain complex objects
if 'correlation_index' in state:
# Clear correlation index as it will be rebuilt when needed
state['correlation_index'] = {}
return state
def __setstate__(self, state):
"""
FIXED: Restore CorrelationProvider after unpickling.
"""
super().__setstate__(state)
# Re-initialize graph reference (will be set by scanner)
self.graph = None
# Re-initialize correlation index
self.correlation_index = {}
# Re-compile regex pattern
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
def get_name(self) -> str:
"""Return the provider name."""
return "correlation"
@@ -79,13 +112,20 @@ class CorrelationProvider(BaseProvider):
def _find_correlations(self, node_id: str) -> ProviderResult:
"""
Find correlations for a given node.
FIXED: Added safety checks to prevent issues when graph is None.
"""
result = ProviderResult()
# FIXED: Ensure self.graph is not None before proceeding.
# FIXED: Ensure self.graph is not None before proceeding
if not self.graph or not self.graph.graph.has_node(node_id):
return result
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
try:
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
except Exception as e:
# If there's any issue accessing the graph, return empty result
print(f"Warning: Could not access graph for correlation analysis: {e}")
return result
for attr in node_attributes:
attr_name = attr.get('name')
@@ -134,6 +174,7 @@ class CorrelationProvider(BaseProvider):
if len(self.correlation_index[attr_value]['nodes']) > 1:
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
return result
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):

View File

@@ -11,6 +11,7 @@ class DNSProvider(BaseProvider):
"""
Provider for standard DNS resolution and reverse DNS lookups.
Now returns standardized ProviderResult objects with IPv4 and IPv6 support.
FIXED: Enhanced pickle support to prevent resolver serialization issues.
"""
def __init__(self, name=None, session_config=None):
@@ -28,19 +29,20 @@ class DNSProvider(BaseProvider):
self.resolver.lifetime = 10
def __getstate__(self):
"""Prepare the object for pickling."""
state = self.__dict__.copy()
"""Prepare the object for pickling by excluding resolver."""
state = super().__getstate__()
# Remove the unpickleable 'resolver' attribute
if 'resolver' in state:
del state['resolver']
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
self.__dict__.update(state)
"""Restore the object after unpickling by reconstructing resolver."""
super().__setstate__(state)
# Re-initialize the 'resolver' attribute
self.resolver = resolver.Resolver()
self.resolver.timeout = 5
self.resolver.lifetime = 10
def get_name(self) -> str:
"""Return the provider name."""
@@ -121,10 +123,10 @@ class DNSProvider(BaseProvider):
if _is_valid_domain(hostname):
# Determine appropriate forward relationship type based on IP version
if ip_version == 6:
relationship_type = 'dns_aaaa_record'
relationship_type = 'shodan_aaaa_record'
record_prefix = 'AAAA'
else:
relationship_type = 'dns_a_record'
relationship_type = 'shodan_a_record'
record_prefix = 'A'
# Add the relationship