diff --git a/api/main.py b/api/main.py index 62d06d1..7ca9c5e 100644 --- a/api/main.py +++ b/api/main.py @@ -1131,7 +1131,6 @@ async def import_csv_with_leak_id(leak_id: int, """ - i = 0 inserted_ids = [] for r in df.reset_index().to_dict(orient = 'records'): sql = """ @@ -1151,13 +1150,10 @@ async def import_csv_with_leak_id(leak_id: int, r['domain'], r['browser'], r['malware_name'], r['infected_machine'], r['dg'])) leak_data_id = int(cur.fetchone()['id']) inserted_ids.append(leak_data_id) - i += 1 except Exception as ex: return Answer(success = False, errormsg = str(ex), data = []) t1 = time.time() d = round(t1 - t0, 3) - num_deduped = len(inserted_ids) - # logger.info("inserted %d rows, %d duplicates, %d new rows" % (i, i - num_deduped, num_deduped)) # now get the data of all the IDs / dedup try: @@ -1181,6 +1177,12 @@ async def import_csv_with_leak_id(leak_id: int, async def enrich_dg_by_email(email: EmailStr, response: Response, api_key: APIKey = Depends(validate_api_key_header)) -> Answer: + """ + Enricher function: insert an email, returns the DG. + + :param email: + :return: The DG or "Unknown" + """ t0 = time.time() le = LDAPEnricher() retval = le.email_to_dg(email) diff --git a/tests/test_main.py b/tests/test_main.py index 6bd1b6a..11ee0f4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -469,6 +469,23 @@ def test_import_csv_spycloud_invalid_ticket_id(self): def test_import_csv_spycloud(self): fixtures_file = "./tests/fixtures/data_anonymized_spycloud.csv" f = open(fixtures_file, "rb") - response = client.post('/import/csv/spycloud/%s?summary=test2' % ("ticket99",), files = {"_file": f}, headers = VALID_AUTH) + response = client.post('/import/csv/spycloud/%s?summary=test2' % ("ticket99",), files = {"_file": f}, + headers = VALID_AUTH) assert 200 <= response.status_code < 300 assert response.json()['meta']['count'] >= 0 + + +class TestEnricherEmailToDG(unittest.TestCase): + response = None + + def test_enrich_dg_by_email(self): + email = "aaron@example.com" + if not os.getenv('CED_SERVER'): + with self.assertRaises(Exception): + client.get('/enrich/email_to_dg/%s' % (email,), headers = VALID_AUTH) + else: + response = client.get('/enrich/email_to_dg/%s' % (email,), headers = VALID_AUTH) + assert response.status_code == 200 + data = response.json() + assert data['meta']['count'] >= 1 + assert data['data'][0]['dg']