Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Gliner Analyzer #477

Draft
wants to merge 3 commits into
base: pebblo-0.1.18
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pebblo/app/service/doc_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,14 @@ def _get_classifier_response(self, doc):
)
try:
if doc_info.data:
topics, topic_count = topic_classifier_obj.predict(doc_info.data)
topics, topic_count, topic_details = topic_classifier_obj.predict(
doc_info.data
)
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
doc_info.data,
anonymize_snippets=ClassifierConstants.anonymize_snippets.value,
Expand Down
1 change: 1 addition & 0 deletions pebblo/app/service/prompt_gov.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_classifier_response(self):
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
self.input.get("prompt"),
anonymize_snippets=False,
Expand Down
5 changes: 4 additions & 1 deletion pebblo/app/service/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _fetch_classified_data(self, input_data, input_type=""):
entities,
entity_count,
_,
_,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
input_data
)
Expand All @@ -53,7 +54,9 @@ def _fetch_classified_data(self, input_data, input_type=""):

# Topic classification is performed only for the response.
if input_type == "response":
topics, topic_count = self.topic_classifier_obj.predict(input_data)
topics, topic_count, topic_details = self.topic_classifier_obj.predict(
input_data
)
data["topicCount"] = topic_count
data["topics"] = topics

Expand Down
2 changes: 1 addition & 1 deletion pebblo/entity_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from pebblo.entity_classifier.entity_classifier import EntityClassifier

text = <Input Data>
entity_classifier_obj = EntityClassifier()
entities, total_count, anonymized_text = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets)
entities, total_count, anonymized_text, entity_details = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets)
print(f"Entity Group: {entity_groups}")
print(f"Entity Count: {total_entity_count}")
print(f"Anonymized Text: {anonymized_text}")
Expand Down
Empty file.
47 changes: 47 additions & 0 deletions pebblo/entity_classifier/custom_analyzer/gliner_recognizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

from gliner import GLiNER
from presidio_analyzer import EntityRecognizer, RecognizerResult

logger = logging.getLogger()


class GlinerRecognizer(EntityRecognizer):
"""
Custom recognizer that uses the GLiNER model for entity detection.
"""

def __init__(self):
ENTITIES = ["SECRET_KEY_TOKEN"]
super().__init__(supported_entities=ENTITIES, supported_language="en")
self.model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
# Initialize the GLiNER model

def load(self):
# GLiNER model might require loading resources. Initialize any here.
pass

def analyze(self, text, entities, nlp_artifacts=None):
# Use GLiNER model to detect entities
results = []

# Get predictions from GLiNER
predictions = self.model.predict_entities(text, self.supported_entities)
for entity in predictions:
entity_type = entity["label"] # Get entity type from GLiNER
start = entity["start"] # Start position of the entity in the text
end = entity["end"] # End position of the entity in the text
score = entity["score"] # Confidence score from GLiNER

# Append the detected entity result
results.append(
RecognizerResult(
entity_type=entity_type, start=start, end=end, score=score
)
)

return results

def validate_result(self, pattern_name, text, start, end):
# Implement any additional validation if required
return True
108 changes: 88 additions & 20 deletions pebblo/entity_classifier/entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from presidio_analyzer.context_aware_enhancers import LemmaContextAwareEnhancer
from presidio_anonymizer import AnonymizerEngine

from pebblo.entity_classifier.custom_analyzer.gliner_recognizer import GlinerRecognizer
from pebblo.entity_classifier.utils.config import (
ConfidenceScore,
Entities,
SecretEntities,
entity_conf_mapping,
)
from pebblo.entity_classifier.utils.utils import (
add_custom_regex_analyzer_registry,
Expand All @@ -19,6 +21,8 @@
class EntityClassifier:
def __init__(self):
self.analyzer = AnalyzerEngine()
# Create an instance of the GLiNER recognizer

self.anonymizer = AnonymizerEngine()
self.entities = list(Entities.__members__.keys())
self.entities.extend(list(SecretEntities.__members__.keys()))
Expand All @@ -39,22 +43,57 @@ def custom_analyze(self):
),
),
)
gliner_recognizer = GlinerRecognizer()
# Add the GLiNER recognizer to the Presidio Analyzer
self.analyzer.registry.add_recognizer(gliner_recognizer)

def analyze_response(self, input_text, anonymize_all_entities=True):
# Returns analyzed output
"""
Analyze the given input text to detect and classify entities based on predefined criteria.

Args:
input_text (str): The text to be analyzed for detecting entities.
anonymize_all_entities (bool): Flag to determine if all detected entities should be anonymized.
(Currently not used in the function logic.)

Returns:
list: A list of detected entities that meet the criteria for classification.
"""
# Analyze the text to detect entities using the Presidio analyzer
analyzer_results = self.analyzer.analyze(text=input_text, language="en")
analyzer_results = [
result
for result in analyzer_results
if result.score >= float(ConfidenceScore.Entity.value)
]
if not anonymize_all_entities: # Condition for anonymized document
analyzer_results = [
result
for result in analyzer_results
if result.entity_type in self.entities
]
return analyzer_results

# Initialize the list to hold the final classified entities
final_results = []

# Iterate through the detected entities
for entity in analyzer_results:
try:
mapped_entity = None

# Map entity type to predefined entities if it exists in the Entities enumeration
if entity.entity_type in Entities.__members__:
mapped_entity = Entities[entity.entity_type].value

# Check if the entity type exists in SecretEntities enumeration
elif entity.entity_type in SecretEntities.__members__:
mapped_entity = SecretEntities[entity.entity_type].value

# Append entity to final results if it meets the confidence threshold and is in the desired entities list
if (
mapped_entity
and entity.score >= float(entity_conf_mapping[mapped_entity])
and entity.entity_type in self.entities
):
final_results.append(entity)

# Handle any exceptions that occur during entity classification
except Exception as ex:
logger.warning(
f"Error in analyze_response in entity classification. {str(ex)}"
)

# Return the list of classified entities that met the criteria
return final_results

def anonymize_response(self, analyzer_results, input_text):
# Returns anonymized output
Expand All @@ -64,17 +103,37 @@ def anonymize_response(self, analyzer_results, input_text):

return anonymized_text.items, anonymized_text.text

@staticmethod
def get_analyzed_entities_response(data, anonymized_response=None):
# Returns entities with its location i.e. start to end and confidence score
response = []

for index, value in enumerate(data):
location = f"{value.start}_{value.end}"
if anonymized_response:
anonymized_data = anonymized_response[len(data) - index - 1]
location = f"{anonymized_data.start}_{anonymized_data.end}"
response.append(
{
"entity_type": value.entity_type,
"location": location,
"confidence_score": value.score,
}
)
return response

def presidio_entity_classifier_and_anonymizer(
self, input_text, anonymize_snippets=False
):
"""
Perform classification on the input data and return a dictionary with the count of each entity group.
And also returns plain input text as anonymized text output
:param anonymize_snippets: Flag whether to anonymize snippets in report.
:param input_text: Input string / document snippet
:param anonymize_snippets: Flag whether to anonymize snippets in report.
:return: entities: containing the entity group Name as key and its count as value.
total_count: Total count of entity groupsInput text in anonymized form.
anonymized_text: Input text in anonymized form.
entity_details: Entities with its details such as location and confidence score.
Example:

input_text = " My SSN is 222-85-4836.
Expand All @@ -89,21 +148,30 @@ def presidio_entity_classifier_and_anonymizer(
"""
entities = {}
total_count = 0
anonymized_text = ""
try:
logger.debug("Presidio Entity Classifier and Anonymizer Started.")

analyzer_results = self.analyze_response(input_text)
anonymized_response, anonymized_text = self.anonymize_response(
analyzer_results, input_text
)

if anonymize_snippets: # If Document snippet needs to be anonymized
anonymized_response, anonymized_text = self.anonymize_response(
analyzer_results, input_text
)
input_text = anonymized_text.replace("<", "&lt;").replace(">", "&gt;")
entities, total_count = get_entities(self.entities, anonymized_response)
entities_response = self.get_analyzed_entities_response(
analyzer_results, anonymized_response
)
else:
entities_response = self.get_analyzed_entities_response(
analyzer_results
)
entities, entity_details, total_count = get_entities(
self.entities, entities_response
)
logger.debug("Presidio Entity Classifier and Anonymizer Finished")
logger.debug(f"Entities: {entities}")
logger.debug(f"Entity Total count: {total_count}")
return entities, total_count, input_text
return entities, total_count, input_text, entity_details
except Exception as e:
logger.error(
f"Presidio Entity Classifier and Anonymizer Failed, Exception: {e}"
Expand Down
23 changes: 23 additions & 0 deletions pebblo/entity_classifier/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Entities(Enum):
US_BANK_NUMBER = "us-bank-account-number"
IBAN_CODE = "iban-code"
US_ITIN = "us-itin"
SECRET_KEY_TOKEN = "gliner-entity"


class SecretEntities(Enum):
Expand All @@ -38,6 +39,28 @@ class SecretEntities(Enum):
GOOGLE_API_KEY = "google-api-key"


entity_conf_mapping = {
# Identification
Entities.US_SSN.value: 0.8,
Entities.US_PASSPORT.value: 0.4,
Entities.US_DRIVER_LICENSE.value: 0.4,
# Financial
Entities.US_ITIN.value: 0.8,
Entities.CREDIT_CARD.value: 0.8,
Entities.US_BANK_NUMBER.value: 0.4,
Entities.IBAN_CODE.value: 0.8,
Entities.SECRET_KEY_TOKEN.value: 0.5,
# Secret
SecretEntities.GITHUB_TOKEN.value: 0.8,
SecretEntities.SLACK_TOKEN.value: 0.8,
SecretEntities.AWS_ACCESS_KEY.value: 0.45,
SecretEntities.AWS_SECRET_KEY.value: 0.8,
SecretEntities.AZURE_KEY_ID.value: 0.8,
SecretEntities.AZURE_CLIENT_SECRET.value: 0.8,
SecretEntities.GOOGLE_API_KEY.value: 0.8,
}


class ConfidenceScore(Enum):
Entity = "0.8" # based on this score entity output is finalized
EntityMinScore = "0.45" # It denotes the pattern's strength
Expand Down
2 changes: 1 addition & 1 deletion pebblo/entity_classifier/utils/regex_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
"aws-access-key": r"""\b((?:AKIA|ABIA|ACCA|ASIA)[0-9A-Z]{16})\b""",
"aws-secret-key": r"""\b([A-Za-z0-9+/]{40})[ \r\n'"\x60]""",
"azure-key-id": r"""(?i)(%s).{0,20}([a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12})""",
"azure-client-secret": r"""(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})""",
"azure-client-secret": r"""\b(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})\b""",
"google-api-key": r"""(?i)(?:youtube)(?:.|[\n\r]){0,40}\bAIza[0-9A-Za-z\-_]{35}\b""",
}
28 changes: 21 additions & 7 deletions pebblo/entity_classifier/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,35 @@
secret_entities_context_mapping,
)
from pebblo.entity_classifier.utils.regex_pattern import regex_secrets_patterns
from pebblo.utils import get_confidence_score_label


def get_entities(entities_list, response):
entity_groups = dict()
entity_groups: dict = dict()
entity_details: dict = dict()

mapped_entity = None
total_count = 0
for entity in response:
if entity.entity_type in entities_list:
if entity.entity_type in Entities.__members__:
mapped_entity = Entities[entity.entity_type].value
elif entity.entity_type in SecretEntities.__members__:
mapped_entity = SecretEntities[entity.entity_type].value
if entity["entity_type"] in entities_list:
if entity["entity_type"] in Entities.__members__:
mapped_entity = Entities[entity["entity_type"]].value
elif entity["entity_type"] in SecretEntities.__members__:
mapped_entity = SecretEntities[entity["entity_type"]].value
entity_groups[mapped_entity] = entity_groups.get(mapped_entity, 0) + 1
entity_data = {
"location": entity["location"],
"confidence_score": get_confidence_score_label(
entity["confidence_score"]
),
}
if mapped_entity in entity_details.keys():
entity_details[mapped_entity].append(entity_data)
else:
entity_details[mapped_entity] = [entity_data]
total_count += 1

return entity_groups, total_count
return entity_groups, entity_details, total_count


def add_custom_regex_analyzer_registry():
Expand Down
2 changes: 1 addition & 1 deletion pebblo/topic_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ from pebblo.topic_classifier.topic_classifier import TopicClassifier

text = "Your sample text here."
topic_classifier_obj = TopicClassifier()
topics, total_topic_count = topic_classifier_obj.predict(text)
topics, total_topic_count, topic_details = topic_classifier_obj.predict(text)
print(f"Topic Response: {topics}")
print(f"Topic Count: {total_topic_count}")
```
Loading
Loading