From e5f67d1c9f058f921d762c91d698c7f21ee3d6d4 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 13 Dec 2023 10:28:31 -0500 Subject: [PATCH] move field factory class and minor updates --- redisvl/schema/fields.py | 42 +++++++++++++++++++++++++++++- redisvl/schema/schema.py | 56 +++++++--------------------------------- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 227f13f8..568b20be 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from typing_extensions import Literal from pydantic import BaseModel, Field, validator @@ -113,3 +113,43 @@ def as_field(self): } ) return RedisVectorField(self.name, self.algorithm, field_data, as_name=self.as_name) + + +class FieldFactory: + """ + Factory class to create fields from client data and kwargs. + """ + FIELD_TYPE_MAP = { + "tag": TagField, + "text": TextField, + "numeric": NumericField, + "geo": GeoField, + } + + VECTOR_FIELD_TYPE_MAP = { + 'flat': FlatVectorField, + 'hnsw': HNSWVectorField, + } + + @classmethod + def _get_vector_type(cls, **field_data: Dict[str, Any]) -> Union[FlatVectorField, HNSWVectorField]: + """Get the vector field type from the field data.""" + algorithm = field_data.get('algorithm', '').lower() + if algorithm not in cls.VECTOR_FIELD_TYPE_MAP: + raise ValueError(f"Unknown vector field algorithm: {algorithm}") + + # default to FLAT + return cls.VECTOR_FIELD_TYPE_MAP.get(algorithm, FlatVectorField)(**field_data) + + @classmethod + def create_field(cls, field_type: str, name: str, **kwargs) -> BaseField: + """Create a field of a given type with provided attributes.""" + + if field_type == 'vector': + return cls._get_vector_type(name=name, **kwargs) + + if field_type not in cls.FIELD_TYPE_MAP: + raise ValueError(f"Unknown field type: {field_type}") + + field_class = cls.FIELD_TYPE_MAP[field_type] + return field_class(name=name, **kwargs) \ No newline at end of file diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 78b99fed..e7805879 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -2,19 +2,11 @@ import yaml from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Union, Tuple, Optional, Type +from typing import Any, Dict, List from pydantic import BaseModel, ValidationError -from redisvl.schema.fields import ( - BaseField, - TagField, - TextField, - NumericField, - FlatVectorField, - HNSWVectorField, - GeoField -) +from redisvl.schema.fields import BaseField, FieldFactory class StorageType(Enum): @@ -22,39 +14,6 @@ class StorageType(Enum): JSON = "json" -def get_vector_type(**field_data: Dict[str, Any]) -> Union[FlatVectorField, HNSWVectorField]: - """Get the vector field type from the field data.""" - - vector_field_classes = { - 'flat': FlatVectorField, - 'hnsw': HNSWVectorField - } - algorithm = field_data.get('algorithm', '').lower() - if algorithm not in vector_field_classes.keys(): - raise ValueError(f"Unknown vector field algorithm: {algorithm}") - - # default to FLAT - return vector_field_classes.get(algorithm, FlatVectorField)(**field_data) - -class FieldFactory: - FIELD_TYPE_MAP = { - "tag": TagField, - "text": TextField, - "numeric": NumericField, - "geo": GeoField, - "vector": get_vector_type - } - - @staticmethod - def create_field(field_type: str, name: str, **kwargs) -> BaseField: - field_class = FieldFactory.FIELD_TYPE_MAP.get(field_type) - if not field_class: - raise ValueError(f"Unknown field type: {field_type}") - return field_class(name=name, **kwargs) - - - - class IndexSchema(BaseModel): """ RedisVL index schema for storing and indexing vectors and metadata @@ -109,6 +68,7 @@ def add_field(self, field_type: str, **kwargs): name = kwargs.get('name', None) if name is None: raise ValueError("Field name is required.") + new_field = FieldFactory.create_field(field_type, **kwargs) if any(field.name == name for field in self.fields.get(field_type, [])): raise ValueError( @@ -154,14 +114,16 @@ def generate_fields( which makes it tedious to manually define each field. This method can be used to automatically generate fields from a sample of data. - Note: Vector fields are not generated by this method - Note: This method is a hueristic and may not always generate the + Note: Vector fields are not generated by this method. + Note: This method is a heuristic and may not always generate the correct field type. Args: data (Dict[str, Any]): The sample data to generate fields from. - strict (bool): Whether to raise an error if a field type cannot be inferred. - ignore_fields (List[str]): A list of field names to ignore. + strict (bool, optional): Whether to raise an error if a field type + cannot be inferred. Defaults to False. + ignore_fields (List[str], optional): A list of field names to + ignore. Defaults to []. Returns: Dict[str, List[Dict[str, Any]]]: A dictionary of fields.