iteration on ws implementation
This commit is contained in:
@@ -15,6 +15,7 @@ class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for all DNSRecon data providers.
|
||||
Now supports session-specific configuration and returns standardized ProviderResult objects.
|
||||
FIXED: Enhanced pickle support to prevent weakref serialization errors.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
||||
@@ -53,22 +54,57 @@ class BaseProvider(ABC):
|
||||
def __getstate__(self):
|
||||
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
|
||||
unpicklable_attrs = ['_local', '_stop_event']
|
||||
|
||||
# Exclude unpickleable attributes that may contain weakrefs
|
||||
unpicklable_attrs = [
|
||||
'_local', # Thread-local storage (contains requests.Session)
|
||||
'_stop_event', # Threading event
|
||||
'logger', # Logger may contain weakrefs in handlers
|
||||
]
|
||||
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
|
||||
# Also handle any potential weakrefs in the config object
|
||||
if 'config' in state and hasattr(state['config'], '__getstate__'):
|
||||
# If config has its own pickle support, let it handle itself
|
||||
pass
|
||||
elif 'config' in state:
|
||||
# Otherwise, ensure config doesn't contain unpicklable objects
|
||||
try:
|
||||
# Test if config can be pickled
|
||||
import pickle
|
||||
pickle.dumps(state['config'])
|
||||
except (TypeError, AttributeError):
|
||||
# If config can't be pickled, we'll recreate it during unpickling
|
||||
state['_config_class'] = type(state['config']).__name__
|
||||
del state['config']
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
|
||||
self.__dict__.update(state)
|
||||
# Re-initialize the '_local' attribute and stop event
|
||||
|
||||
# Re-initialize unpickleable attributes
|
||||
self._local = threading.local()
|
||||
self._stop_event = None
|
||||
self.logger = get_forensic_logger()
|
||||
|
||||
# Recreate config if it was removed during pickling
|
||||
if not hasattr(self, 'config') and hasattr(self, '_config_class'):
|
||||
if self._config_class == 'Config':
|
||||
from config import config as global_config
|
||||
self.config = global_config
|
||||
elif self._config_class == 'SessionConfig':
|
||||
from core.session_config import create_session_config
|
||||
self.config = create_session_config()
|
||||
del self._config_class
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Get or create thread-local requests session."""
|
||||
if not hasattr(self._local, 'session'):
|
||||
self._local.session = requests.Session()
|
||||
self._local.session.headers.update({
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager
|
||||
class CorrelationProvider(BaseProvider):
|
||||
"""
|
||||
A provider that finds correlations between nodes in the graph.
|
||||
FIXED: Enhanced pickle support to prevent weakref issues with graph references.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "correlation", session_config=None):
|
||||
@@ -38,6 +39,38 @@ class CorrelationProvider(BaseProvider):
|
||||
'query_timestamp',
|
||||
]
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
FIXED: Prepare CorrelationProvider for pickling by excluding graph reference.
|
||||
"""
|
||||
state = super().__getstate__()
|
||||
|
||||
# Remove graph reference to prevent circular dependencies and weakrefs
|
||||
if 'graph' in state:
|
||||
del state['graph']
|
||||
|
||||
# Also handle correlation_index which might contain complex objects
|
||||
if 'correlation_index' in state:
|
||||
# Clear correlation index as it will be rebuilt when needed
|
||||
state['correlation_index'] = {}
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""
|
||||
FIXED: Restore CorrelationProvider after unpickling.
|
||||
"""
|
||||
super().__setstate__(state)
|
||||
|
||||
# Re-initialize graph reference (will be set by scanner)
|
||||
self.graph = None
|
||||
|
||||
# Re-initialize correlation index
|
||||
self.correlation_index = {}
|
||||
|
||||
# Re-compile regex pattern
|
||||
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "correlation"
|
||||
@@ -79,13 +112,20 @@ class CorrelationProvider(BaseProvider):
|
||||
def _find_correlations(self, node_id: str) -> ProviderResult:
|
||||
"""
|
||||
Find correlations for a given node.
|
||||
FIXED: Added safety checks to prevent issues when graph is None.
|
||||
"""
|
||||
result = ProviderResult()
|
||||
# FIXED: Ensure self.graph is not None before proceeding.
|
||||
|
||||
# FIXED: Ensure self.graph is not None before proceeding
|
||||
if not self.graph or not self.graph.graph.has_node(node_id):
|
||||
return result
|
||||
|
||||
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
try:
|
||||
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
except Exception as e:
|
||||
# If there's any issue accessing the graph, return empty result
|
||||
print(f"Warning: Could not access graph for correlation analysis: {e}")
|
||||
return result
|
||||
|
||||
for attr in node_attributes:
|
||||
attr_name = attr.get('name')
|
||||
@@ -134,6 +174,7 @@ class CorrelationProvider(BaseProvider):
|
||||
|
||||
if len(self.correlation_index[attr_value]['nodes']) > 1:
|
||||
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
|
||||
|
||||
return result
|
||||
|
||||
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
|
||||
|
||||
@@ -11,6 +11,7 @@ class DNSProvider(BaseProvider):
|
||||
"""
|
||||
Provider for standard DNS resolution and reverse DNS lookups.
|
||||
Now returns standardized ProviderResult objects with IPv4 and IPv6 support.
|
||||
FIXED: Enhanced pickle support to prevent resolver serialization issues.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, session_config=None):
|
||||
@@ -28,19 +29,20 @@ class DNSProvider(BaseProvider):
|
||||
self.resolver.lifetime = 10
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare the object for pickling."""
|
||||
state = self.__dict__.copy()
|
||||
"""Prepare the object for pickling by excluding resolver."""
|
||||
state = super().__getstate__()
|
||||
# Remove the unpickleable 'resolver' attribute
|
||||
if 'resolver' in state:
|
||||
del state['resolver']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore the object after unpickling."""
|
||||
self.__dict__.update(state)
|
||||
"""Restore the object after unpickling by reconstructing resolver."""
|
||||
super().__setstate__(state)
|
||||
# Re-initialize the 'resolver' attribute
|
||||
self.resolver = resolver.Resolver()
|
||||
self.resolver.timeout = 5
|
||||
self.resolver.lifetime = 10
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
@@ -121,10 +123,10 @@ class DNSProvider(BaseProvider):
|
||||
if _is_valid_domain(hostname):
|
||||
# Determine appropriate forward relationship type based on IP version
|
||||
if ip_version == 6:
|
||||
relationship_type = 'dns_aaaa_record'
|
||||
relationship_type = 'shodan_aaaa_record'
|
||||
record_prefix = 'AAAA'
|
||||
else:
|
||||
relationship_type = 'dns_a_record'
|
||||
relationship_type = 'shodan_a_record'
|
||||
record_prefix = 'A'
|
||||
|
||||
# Add the relationship
|
||||
|
||||
Reference in New Issue
Block a user