Skip to content

Commit

Permalink
Simplify MS Defender integration & enable by default (#847)
Browse files Browse the repository at this point in the history
  • Loading branch information
slreznit authored May 15, 2024
1 parent c12390a commit f3f19bf
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 32 deletions.
16 changes: 4 additions & 12 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from openai import AsyncAzureOpenAI
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from backend.auth.auth_utils import get_authenticated_user_details, get_tenantid
from backend.auth.auth_utils import get_authenticated_user_details
from backend.security.ms_defender_utils import get_msdefender_user_json
from backend.history.cosmosdbservice import CosmosConversationClient

from backend.utils import (
Expand Down Expand Up @@ -268,7 +269,7 @@ async def assets(path):
"sanitize_answer": SANITIZE_ANSWER,
}
# Enable Microsoft Defender for Cloud Integration
MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "false").lower() == "true"
MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "true").lower() == "true"

def should_use_data():
global DATASOURCE_TYPE
Expand Down Expand Up @@ -737,16 +738,7 @@ def prepare_model_args(request_body, request_headers):
user_json = None
if (MS_DEFENDER_ENABLED):
authenticated_user_details = get_authenticated_user_details(request_headers)
tenantId = get_tenantid(authenticated_user_details.get("client_principal_b64"))
conversation_id = request_body.get("conversation_id", None)
user_args = {
"EndUserId": authenticated_user_details.get('user_principal_id'),
"EndUserIdType": 'Entra',
"EndUserTenantId": tenantId,
"ConversationId": conversation_id,
"SourceIp": request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', '')),
}
user_json = json.dumps(user_args)
user_json = get_msdefender_user_json(authenticated_user_details, request_headers)

model_args = {
"messages": messages,
Expand Down
21 changes: 1 addition & 20 deletions backend/auth/auth_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import base64
import json
import logging

def get_authenticated_user_details(request_headers):
user_object = {}

Expand All @@ -21,19 +17,4 @@ def get_authenticated_user_details(request_headers):
user_object['client_principal_b64'] = raw_user_object.get('X-Ms-Client-Principal')
user_object['aad_id_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')

return user_object

def get_tenantid(client_principal_b64):
tenant_id = ''
if client_principal_b64:
try:
# Decode the base64 header to get the JSON string
decoded_bytes = base64.b64decode(client_principal_b64)
decoded_string = decoded_bytes.decode('utf-8')
# Convert the JSON string1into a Python dictionary
user_info = json.loads(decoded_string)
# Extract the tenant ID
tenant_id = user_info.get('tid') # 'tid' typically holds the tenant ID
except Exception as ex:
logging.exception(ex)
return tenant_id
return user_object
Empty file added backend/security/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions backend/security/ms_defender_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import json

def get_msdefender_user_json(authenticated_user_details, request_headers):
auth_provider = authenticated_user_details.get('auth_provider')
source_ip = request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', ''))
user_args = {
"EndUserId": authenticated_user_details.get('user_principal_id'),
"EndUserIdType": "EntraId" if auth_provider == "aad" else auth_provider,
"SourceIp": source_ip.split(':')[0], #remove port
}
return json.dumps(user_args)

0 comments on commit f3f19bf

Please sign in to comment.