Skip to content

Commit

Permalink
move field factory class and minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Dec 13, 2023
1 parent 66f20ed commit e5f67d1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 48 deletions.
42 changes: 41 additions & 1 deletion redisvl/schema/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from typing_extensions import Literal

from pydantic import BaseModel, Field, validator
Expand Down Expand Up @@ -113,3 +113,43 @@ def as_field(self):
}
)
return RedisVectorField(self.name, self.algorithm, field_data, as_name=self.as_name)


class FieldFactory:
"""
Factory class to create fields from client data and kwargs.
"""
FIELD_TYPE_MAP = {
"tag": TagField,
"text": TextField,
"numeric": NumericField,
"geo": GeoField,
}

VECTOR_FIELD_TYPE_MAP = {
'flat': FlatVectorField,
'hnsw': HNSWVectorField,
}

@classmethod
def _get_vector_type(cls, **field_data: Dict[str, Any]) -> Union[FlatVectorField, HNSWVectorField]:
"""Get the vector field type from the field data."""
algorithm = field_data.get('algorithm', '').lower()
if algorithm not in cls.VECTOR_FIELD_TYPE_MAP:
raise ValueError(f"Unknown vector field algorithm: {algorithm}")

# default to FLAT
return cls.VECTOR_FIELD_TYPE_MAP.get(algorithm, FlatVectorField)(**field_data)

@classmethod
def create_field(cls, field_type: str, name: str, **kwargs) -> BaseField:
"""Create a field of a given type with provided attributes."""

if field_type == 'vector':
return cls._get_vector_type(name=name, **kwargs)

if field_type not in cls.FIELD_TYPE_MAP:
raise ValueError(f"Unknown field type: {field_type}")

field_class = cls.FIELD_TYPE_MAP[field_type]
return field_class(name=name, **kwargs)
56 changes: 9 additions & 47 deletions redisvl/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,18 @@
import yaml
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Union, Tuple, Optional, Type
from typing import Any, Dict, List

from pydantic import BaseModel, ValidationError

from redisvl.schema.fields import (
BaseField,
TagField,
TextField,
NumericField,
FlatVectorField,
HNSWVectorField,
GeoField
)
from redisvl.schema.fields import BaseField, FieldFactory


class StorageType(Enum):
HASH = "hash"
JSON = "json"


def get_vector_type(**field_data: Dict[str, Any]) -> Union[FlatVectorField, HNSWVectorField]:
"""Get the vector field type from the field data."""

vector_field_classes = {
'flat': FlatVectorField,
'hnsw': HNSWVectorField
}
algorithm = field_data.get('algorithm', '').lower()
if algorithm not in vector_field_classes.keys():
raise ValueError(f"Unknown vector field algorithm: {algorithm}")

# default to FLAT
return vector_field_classes.get(algorithm, FlatVectorField)(**field_data)

class FieldFactory:
FIELD_TYPE_MAP = {
"tag": TagField,
"text": TextField,
"numeric": NumericField,
"geo": GeoField,
"vector": get_vector_type
}

@staticmethod
def create_field(field_type: str, name: str, **kwargs) -> BaseField:
field_class = FieldFactory.FIELD_TYPE_MAP.get(field_type)
if not field_class:
raise ValueError(f"Unknown field type: {field_type}")
return field_class(name=name, **kwargs)




class IndexSchema(BaseModel):
"""
RedisVL index schema for storing and indexing vectors and metadata
Expand Down Expand Up @@ -109,6 +68,7 @@ def add_field(self, field_type: str, **kwargs):
name = kwargs.get('name', None)
if name is None:
raise ValueError("Field name is required.")

new_field = FieldFactory.create_field(field_type, **kwargs)
if any(field.name == name for field in self.fields.get(field_type, [])):
raise ValueError(
Expand Down Expand Up @@ -154,14 +114,16 @@ def generate_fields(
which makes it tedious to manually define each field. This method
can be used to automatically generate fields from a sample of data.
Note: Vector fields are not generated by this method
Note: This method is a hueristic and may not always generate the
Note: Vector fields are not generated by this method.
Note: This method is a heuristic and may not always generate the
correct field type.
Args:
data (Dict[str, Any]): The sample data to generate fields from.
strict (bool): Whether to raise an error if a field type cannot be inferred.
ignore_fields (List[str]): A list of field names to ignore.
strict (bool, optional): Whether to raise an error if a field type
cannot be inferred. Defaults to False.
ignore_fields (List[str], optional): A list of field names to
ignore. Defaults to [].
Returns:
Dict[str, List[Dict[str, Any]]]: A dictionary of fields.
Expand Down

0 comments on commit e5f67d1

Please sign in to comment.