From f3f19bf5f4cd9754ff0f759ade72057ca1e01fbc Mon Sep 17 00:00:00 2001 From: Slava Reznitsky <30771358+slreznit@users.noreply.github.com> Date: Wed, 15 May 2024 21:38:48 +0300 Subject: [PATCH] Simplify MS Defender integration & enable by default (#847) --- app.py | 16 ++++------------ backend/auth/auth_utils.py | 21 +-------------------- backend/security/__init__.py | 0 backend/security/ms_defender_utils.py | 11 +++++++++++ 4 files changed, 16 insertions(+), 32 deletions(-) create mode 100644 backend/security/__init__.py create mode 100644 backend/security/ms_defender_utils.py diff --git a/app.py b/app.py index 4407746203..da1d8454e8 100644 --- a/app.py +++ b/app.py @@ -17,7 +17,8 @@ from openai import AsyncAzureOpenAI from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider -from backend.auth.auth_utils import get_authenticated_user_details, get_tenantid +from backend.auth.auth_utils import get_authenticated_user_details +from backend.security.ms_defender_utils import get_msdefender_user_json from backend.history.cosmosdbservice import CosmosConversationClient from backend.utils import ( @@ -268,7 +269,7 @@ async def assets(path): "sanitize_answer": SANITIZE_ANSWER, } # Enable Microsoft Defender for Cloud Integration -MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "false").lower() == "true" +MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "true").lower() == "true" def should_use_data(): global DATASOURCE_TYPE @@ -737,16 +738,7 @@ def prepare_model_args(request_body, request_headers): user_json = None if (MS_DEFENDER_ENABLED): authenticated_user_details = get_authenticated_user_details(request_headers) - tenantId = get_tenantid(authenticated_user_details.get("client_principal_b64")) - conversation_id = request_body.get("conversation_id", None) - user_args = { - "EndUserId": authenticated_user_details.get('user_principal_id'), - "EndUserIdType": 'Entra', - "EndUserTenantId": tenantId, - "ConversationId": conversation_id, - "SourceIp": request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', '')), - } - user_json = json.dumps(user_args) + user_json = get_msdefender_user_json(authenticated_user_details, request_headers) model_args = { "messages": messages, diff --git a/backend/auth/auth_utils.py b/backend/auth/auth_utils.py index 3a97e610a3..59dd02ea3b 100644 --- a/backend/auth/auth_utils.py +++ b/backend/auth/auth_utils.py @@ -1,7 +1,3 @@ -import base64 -import json -import logging - def get_authenticated_user_details(request_headers): user_object = {} @@ -21,19 +17,4 @@ def get_authenticated_user_details(request_headers): user_object['client_principal_b64'] = raw_user_object.get('X-Ms-Client-Principal') user_object['aad_id_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token') - return user_object - -def get_tenantid(client_principal_b64): - tenant_id = '' - if client_principal_b64: - try: - # Decode the base64 header to get the JSON string - decoded_bytes = base64.b64decode(client_principal_b64) - decoded_string = decoded_bytes.decode('utf-8') - # Convert the JSON string1into a Python dictionary - user_info = json.loads(decoded_string) - # Extract the tenant ID - tenant_id = user_info.get('tid') # 'tid' typically holds the tenant ID - except Exception as ex: - logging.exception(ex) - return tenant_id \ No newline at end of file + return user_object \ No newline at end of file diff --git a/backend/security/__init__.py b/backend/security/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/security/ms_defender_utils.py b/backend/security/ms_defender_utils.py new file mode 100644 index 0000000000..1c62e782b2 --- /dev/null +++ b/backend/security/ms_defender_utils.py @@ -0,0 +1,11 @@ +import json + +def get_msdefender_user_json(authenticated_user_details, request_headers): + auth_provider = authenticated_user_details.get('auth_provider') + source_ip = request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', '')) + user_args = { + "EndUserId": authenticated_user_details.get('user_principal_id'), + "EndUserIdType": "EntraId" if auth_provider == "aad" else auth_provider, + "SourceIp": source_ip.split(':')[0], #remove port + } + return json.dumps(user_args) \ No newline at end of file