Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions api_app/analyzers_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,20 @@ def _create_data_model_dictionary(self) -> Dict:
return result

def create_data_model(self) -> Optional[BaseDataModel]:
# TODO we don't need to actually crate a new object every time.
# if the report is the same of the previous one, we can just link it
if not self._validation_before_data_model():
return None
dictionary = self._create_data_model_dictionary()
temp_instance = self.data_model_class()
fingerprint = temp_instance.generate_fingerprint(data=dictionary)
existing_data_model = self.data_model_class.objects.filter(fingerprint=fingerprint).first()

if existing_data_model:
logger.info(f"Deduplicated: Linking existing Data Model {existing_data_model.pk}")
self.data_model = existing_data_model
else:
self.data_model: BaseDataModel = self.data_model_class.objects.create(fingerprint=fingerprint)
self.data_model.merge(dictionary)

self.data_model: BaseDataModel = self.data_model_class.objects.create()
self.data_model.merge(dictionary)
self.save()
return self.data_model

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from django.db import migrations, models

class Migration(migrations.Migration):

dependencies = [
('data_model_manager', '0011_data_model_date_index'),
]

operations = [
migrations.AddField(
model_name='domaindatamodel',
name='fingerprint',
field=models.CharField(blank=True, db_index=True, max_length=64, default=''),
),
migrations.AddField(
model_name='filedatamodel',
name='fingerprint',
field=models.CharField(blank=True, db_index=True, max_length=64, default=''),
),
migrations.AddField(
model_name='ipdatamodel',
name='fingerprint',
field=models.CharField(blank=True, db_index=True, max_length=64, default=''),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import hashlib
import json
import logging

from django.db import migrations

logger = logging.getLogger(__name__)

def normalize_dict(obj):
if isinstance(obj, dict):
return {k: normalize_dict(v) for k, v in sorted(obj.items())}
if isinstance(obj, list):
return [normalize_dict(i) for i in obj]
return obj

def generate_fingerprint_from_instance(instance):
data = {}
for field in instance._meta.fields:
name = field.name
if name in ["id", "date", "fingerprint"]:
continue
value = getattr(instance, name)
if hasattr(value, "isoformat"):
value = value.isoformat()
data[name] = value
normalized_data = normalize_dict(data)
encoded_data = json.dumps(normalized_data, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded_data).hexdigest()

def populate_fingerprints(apps, schema_editor):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem about adding data migrations is that no failing tests show up by running classic CI because the CI would work on a fresh environment.
This is a very risky change if not tested with already existing environments. The benefit of this could be easily destroyed by introducing an unwanted breaking change. Additional more comprehensive tests should be required. We can't merge this in the next release, we would need to wait a major like we will do for other critical PRs

batch_size = 500
for model_name in ["IPDataModel", "DomainDataModel", "FileDataModel"]:
Model = apps.get_model("data_model_manager", model_name)
queryset = Model.objects.filter(fingerprint="").iterator(chunk_size=batch_size)
batch = []
for instance in queryset:
try:
instance.fingerprint = generate_fingerprint_from_instance(instance)
batch.append(instance)
except Exception as e:
logger.error(f"Failed to generate fingerprint for {model_name} {instance.pk}: {e}")
if len(batch) >= batch_size:
Model.objects.bulk_update(batch, ["fingerprint"])
batch = []
if batch:
Model.objects.bulk_update(batch, ["fingerprint"])
from django.contrib.contenttypes.models import ContentType
ct, _ = ContentType.objects.get_or_create(app_label="data_model_manager", model=model_name.lower())
from django.db.models import Count
duplicates = Model.objects.values("fingerprint").annotate(c=Count("id")).filter(c__gt=1)
for entry in duplicates:
fp = entry["fingerprint"]
if not fp:
continue
instances = list(Model.objects.filter(fingerprint=fp).order_by("date"))
canonical = instances[0]
redundant_ids = [r.id for r in instances[1:]]
AnalyzerReport = apps.get_model("analyzers_manager", "AnalyzerReport")
AnalyzerReport.objects.filter(
data_model_content_type_id=ct.id,
data_model_object_id__in=redundant_ids
).update(data_model_object_id=canonical.id)
Job = apps.get_model("api_app", "Job")
Job.objects.filter(
data_model_content_type_id=ct.id,
data_model_object_id__in=redundant_ids
).update(data_model_object_id=canonical.id)
try:
UserAnalyzableEvent = apps.get_model("user_events_manager", "UserAnalyzableEvent")
UserAnalyzableEvent.objects.filter(
data_model_content_type_id=ct.id,
data_model_object_id__in=redundant_ids
).update(data_model_object_id=canonical.id)
except LookupError:
pass
Model.objects.filter(id__in=redundant_ids).delete()

def reverse_populate_fingerprints(apps, schema_editor):
pass

class Migration(migrations.Migration):
dependencies = [
('data_model_manager', '0012_domaindatamodel_fingerprint_and_more'),
]
operations = [
migrations.RunPython(populate_fingerprints, reverse_populate_fingerprints),
]
36 changes: 35 additions & 1 deletion api_app/data_model_manager/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import json
import logging
from typing import Dict, Type, Union
from typing import Any, Dict, Type, Union

from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.contenttypes.models import ContentType
Expand All @@ -26,6 +27,14 @@
logger = logging.getLogger(__name__)


def normalize_dict(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: normalize_dict(v) for k, v in sorted(obj.items())}
if isinstance(obj, list):
return [normalize_dict(i) for i in obj]
return obj


class IETFReport(models.Model):
rrname = LowercaseCharField(max_length=100)
rrtype = LowercaseCharField(max_length=100)
Expand Down Expand Up @@ -91,6 +100,7 @@ class BaseDataModel(models.Model):
default=dict
) # field for additional information related to a specific analyzer
date = models.DateTimeField(default=now)
fingerprint = models.CharField(max_length=64, db_index=True, blank=True, default="")
analyzers_report = GenericRelation(
to="analyzers_manager.AnalyzerReport",
object_id_field="data_model_object_id",
Expand Down Expand Up @@ -125,6 +135,30 @@ def owner(self) -> User:
elif self.jobs.exists():
return self.jobs.first().user

def get_content_map(self, data: Dict = None) -> Dict:
if data is None:
data = {}
for field in self._meta.fields:
name = field.name
if name in ["id", "date", "fingerprint"]:
continue
value = getattr(self, name)
if hasattr(value, "isoformat"):
value = value.isoformat()
data[name] = value
data.pop("id", None)
data.pop("date", None)
data.pop("fingerprint", None)
data.pop("analyzers_report", None)
data.pop("jobs", None)
data.pop("user_events", None)
return normalize_dict(data)

def generate_fingerprint(self, data: Dict = None) -> str:
content_map = self.get_content_map(data)
encoded_data = json.dumps(content_map, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded_data).hexdigest()

def merge(self, other: Union["BaseDataModel", Dict], append: bool = True) -> "BaseDataModel":
if not self.pk:
raise ValueError("Unable to merge a model that was not saved.")
Expand Down
3 changes: 2 additions & 1 deletion api_app/user_events_manager/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
@receiver(models.signals.post_delete, sender=UserIPWildCardEvent)
@receiver(models.signals.post_delete, sender=UserAnalyzableEvent)
def post_delete_event_delete_data_model(sender, instance: UserDomainWildCardEvent, **kwargs):
instance.data_model.delete()
if instance.data_model:
instance.data_model.delete()
71 changes: 71 additions & 0 deletions tests/api_app/data_model_manager/test_cas_deduplication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from kombu import uuid

from api_app.analyzables_manager.models import Analyzable
from api_app.analyzers_manager.models import AnalyzerConfig, AnalyzerReport
from api_app.choices import Classification
from api_app.data_model_manager.models import IPDataModel
from api_app.models import Job
from tests import CustomTestCase


class CASDeduplicationTestCase(CustomTestCase):
def setUp(self):
super().setUp()
self.analyzable = Analyzable.objects.get_or_create(
name="1.1.1.1", defaults={"classification": Classification.IP.value}
)[0]
self.job1 = Job.objects.create(
analyzable=self.analyzable,
status=Job.STATUSES.ANALYZERS_RUNNING.value,
)
self.job2 = Job.objects.create(
analyzable=self.analyzable,
status=Job.STATUSES.ANALYZERS_RUNNING.value,
)
self.config = AnalyzerConfig.objects.first()

def test_smart_deduplication_via_fingerprint(self):
report1 = AnalyzerReport.objects.create(
job=self.job1,
config=self.config,
status=AnalyzerReport.STATUSES.SUCCESS.value,
task_id=str(uuid()),
parameters={},
)
report1._create_data_model_dictionary = lambda: {"isp": "Google", "asn": 15169}
dm1 = report1.create_data_model()
initial_count = IPDataModel.objects.count()

report2 = AnalyzerReport.objects.create(
job=self.job2,
config=self.config,
status=AnalyzerReport.STATUSES.SUCCESS.value,
task_id=str(uuid()),
parameters={},
)
report2._create_data_model_dictionary = lambda: {"isp": "Google", "asn": 15169}
dm2 = report2.create_data_model()
self.assertEqual(dm1.pk, dm2.pk)
self.assertEqual(IPDataModel.objects.count(), initial_count)

def test_normalization_stability(self):
report1 = AnalyzerReport.objects.create(
job=self.job1,
config=self.config,
status=AnalyzerReport.STATUSES.SUCCESS.value,
task_id=str(uuid()),
parameters={},
)
report1._create_data_model_dictionary = lambda: {"asn": 15169, "isp": "Google"}
dm1 = report1.create_data_model()
report2 = AnalyzerReport.objects.create(
job=self.job2,
config=self.config,
status=AnalyzerReport.STATUSES.SUCCESS.value,
task_id=str(uuid()),
parameters={},
)
report2._create_data_model_dictionary = lambda: {"isp": "Google", "asn": 15169}
dm2 = report2.create_data_model()
self.assertEqual(dm1.fingerprint, dm2.fingerprint)
self.assertEqual(dm1.pk, dm2.pk)
71 changes: 71 additions & 0 deletions tests/api_app/data_model_manager/test_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from django.db import connection
from django.db.migrations.executor import MigrationExecutor

from api_app.helpers import calculate_md5, calculate_sha1, calculate_sha256
from tests import CustomTestCase


class MigrationIntegrityTestCase(CustomTestCase):
@property
def app_name(self):
return "data_model_manager"

@property
def migration_from(self):
return "0012_domaindatamodel_fingerprint_and_more"

@property
def migration_to(self):
return "0013_populate_fingerprints"

def setUp(self):
super().setUp()
self.executor = MigrationExecutor(connection)
self.old_state = self.executor.migrate([(self.app_name, self.migration_from)])

def test_migration_0013_deduplication_integrity(self):
old_apps = self.old_state.apps
IPDataModel = old_apps.get_model(self.app_name, "IPDataModel")
UserAnalyzableEvent = old_apps.get_model("user_events_manager", "UserAnalyzableEvent")
Analyzable = old_apps.get_model("analyzables_manager", "Analyzable")
ContentType = old_apps.get_model("contenttypes", "ContentType")
User = old_apps.get_model("certego_saas_user", "User")
user = User.objects.create(username="test_migrator", email="test@intelowl.org")
name1, name2 = "1.1.1.1", "8.8.8.8"
az1 = Analyzable.objects.create(
name=name1,
classification="ip",
md5=calculate_md5(name1.encode()),
sha1=calculate_sha1(name1.encode()),
sha256=calculate_sha256(name1.encode()),
)
az2 = Analyzable.objects.create(
name=name2,
classification="ip",
md5=calculate_md5(name2.encode()),
sha1=calculate_sha1(name2.encode()),
sha256=calculate_sha256(name2.encode()),
)
dm1 = IPDataModel.objects.create(evaluation="benign", reliability=5)
dm2 = IPDataModel.objects.create(evaluation="benign", reliability=5)
ct = ContentType.objects.get_for_model(IPDataModel)
UserAnalyzableEvent.objects.create(
user=user, analyzable=az1, data_model_content_type=ct, data_model_object_id=dm1.id
)
UserAnalyzableEvent.objects.create(
user=user, analyzable=az2, data_model_content_type=ct, data_model_object_id=dm2.id
)
self.executor.loader.build_graph()
new_state = self.executor.migrate([(self.app_name, self.migration_to)])
new_apps = new_state.apps
IPDataModelNew = new_apps.get_model(self.app_name, "IPDataModel")
UserAnalyzableEventNew = new_apps.get_model("user_events_manager", "UserAnalyzableEvent")
self.assertEqual(IPDataModelNew.objects.count(), 1)
canonical = IPDataModelNew.objects.first()
events = UserAnalyzableEventNew.objects.filter(data_model_object_id=canonical.id)
self.assertEqual(events.count(), 2)
self.assertFalse(IPDataModelNew.objects.filter(id=dm2.id).exists())

def tearDown(self):
self.executor.migrate(self.executor.loader.graph.leaf_nodes())
super().tearDown()
Loading