-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Milvus-doc-bot
authored and
Milvus-doc-bot
committed
Jan 13, 2025
1 parent
0cd9ec7
commit ac8813c
Showing
2 changed files
with
6 additions
and
6 deletions.
There are no files selected for viewing
2 changes: 1 addition & 1 deletion
2
localization/v2.5.x/site/en/integrations/integrate_with_pytorch.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"codeList":["pip install pymilvus torch gdown torchvision tqdm\n","import gdown\nimport zipfile\n\nurl = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'\noutput = './paintings.zip'\ngdown.download(url, output)\n\nwith zipfile.ZipFile(\"./paintings.zip\",\"r\") as zip_ref:\n zip_ref.extractall(\"./paintings\")\n","# Milvus Setup Arguments\nCOLLECTION_NAME = 'image_search' # Collection name\nDIMENSION = 2048 # Embedding vector size in this example\nMILVUS_HOST = \"localhost\"\nMILVUS_PORT = \"19530\"\n\n# Inference Arguments\nBATCH_SIZE = 128\nTOP_K = 3\n","from pymilvus import connections\n\n# Connect to the instance\nconnections.connect(host=MILVUS_HOST, port=MILVUS_PORT)\n","from pymilvus import utility\n\n# Remove any previous collections with the same name\nif utility.has_collection(COLLECTION_NAME):\n utility.drop_collection(COLLECTION_NAME)\n","from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n\n# Create collection which includes the id, filepath of the image, and image embedding\nfields = [\n FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters\n FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n]\nschema = CollectionSchema(fields=fields)\ncollection = Collection(name=COLLECTION_NAME, schema=schema)\n","# Create an AutoIndex index for collection\nindex_params = {\n'metric_type':'L2',\n'index_type':\"IVF_FLAT\",\n'params':{'nlist': 16384}\n}\ncollection.create_index(field_name=\"image_embedding\", index_params=index_params)\ncollection.load()\n","import glob\n\n# Get the filepaths of the images\npaths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)\nlen(paths)\n","import torch\n\n# Load the embedding model with the last layer removed\nmodel = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)\nmodel = torch.nn.Sequential(*(list(model.children())[:-1]))\nmodel.eval()\n","from torchvision import transforms\n\n# Preprocessing for images\npreprocess = transforms.Compose([\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n])\n","from PIL import Image\nfrom tqdm import tqdm\n\n# Embed function that embeds the batch and inserts it\ndef embed(data):\n with torch.no_grad():\n output = model(torch.stack(data[0])).squeeze()\n collection.insert([data[1], output.tolist()])\n\ndata_batch = [[],[]]\n\n# Read the images into batches for embedding and insertion\nfor path in tqdm(paths):\n im = Image.open(path).convert('RGB')\n data_batch[0].append(preprocess(im))\n data_batch[1].append(path)\n if len(data_batch[0]) % BATCH_SIZE == 0:\n embed(data_batch)\n data_batch = [[],[]]\n\n# Embed and insert the remainder\nif len(data_batch[0]) != 0:\n embed(data_batch)\n\n# Call a flush to index any unsealed segments.\ncollection.flush()\n","import glob\n\n# Get the filepaths of the search images\nsearch_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)\nlen(search_paths)\n","import time\nfrom matplotlib import pyplot as plt\n\n# Embed the search images\ndef embed(data):\n with torch.no_grad():\n ret = model(torch.stack(data))\n # If more than one image, use squeeze\n if len(ret) > 1:\n return ret.squeeze().tolist()\n # Squeeze would remove batch for single image, so using flatten\n else:\n return torch.flatten(ret, start_dim=1).tolist()\n\ndata_batch = [[],[]]\n\nfor path in search_paths:\n im = Image.open(path).convert('RGB')\n data_batch[0].append(preprocess(im))\n data_batch[1].append(path)\n\nembeds = embed(data_batch[0])\nstart = time.time()\nres = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])\nfinish = time.time()\n","# Show the image results\nf, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)\n\nfor hits_i, hits in enumerate(res):\n axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))\n axarr[hits_i][0].set_axis_off()\n axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))\n for hit_i, hit in enumerate(hits):\n axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))\n axarr[hits_i][hit_i + 1].set_axis_off()\n axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))\n\n# Save the search result in a separate image file alongside your script.\nplt.savefig('search_result.png')\n"],"headingContent":"Image Search with Milvus","anchorList":[{"label":"Image Search with Milvus","href":"Image-Search-with-Milvus","type":1,"isActive":false},{"label":"Installing the requirements","href":"Installing-the-requirements","type":2,"isActive":false},{"label":"Grabbing the data","href":"Grabbing-the-data","type":2,"isActive":false},{"label":"Global Arguments","href":"Global-Arguments","type":2,"isActive":false},{"label":"Setting up Milvus","href":"Setting-up-Milvus","type":2,"isActive":false},{"label":"Inserting the data","href":"Inserting-the-data","type":2,"isActive":false},{"label":"Performing the search","href":"Performing-the-search","type":2,"isActive":false}]} | ||
{"codeList":["pip install pymilvus torch gdown torchvision tqdm\n","import gdown\nimport zipfile\n\nurl = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'\noutput = './paintings.zip'\ngdown.download(url, output)\n\nwith zipfile.ZipFile(\"./paintings.zip\",\"r\") as zip_ref:\n zip_ref.extractall(\"./paintings\")\n","# Milvus Setup Arguments\nCOLLECTION_NAME = 'image_search' # Collection name\nDIMENSION = 2048 # Embedding vector size in this example\nMILVUS_HOST = \"localhost\"\nMILVUS_PORT = \"19530\"\n\n# Inference Arguments\nBATCH_SIZE = 128\nTOP_K = 3\n","from pymilvus import connections\n\n# Connect to the instance\nconnections.connect(host=MILVUS_HOST, port=MILVUS_PORT)\n","from pymilvus import utility\n\n# Remove any previous collections with the same name\nif utility.has_collection(COLLECTION_NAME):\n utility.drop_collection(COLLECTION_NAME)\n","from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n\n# Create collection which includes the id, filepath of the image, and image embedding\nfields = [\n FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters\n FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n]\nschema = CollectionSchema(fields=fields)\ncollection = Collection(name=COLLECTION_NAME, schema=schema)\n","# Create an AutoIndex index for collection\nindex_params = {\n'metric_type':'L2',\n'index_type':\"IVF_FLAT\",\n'params':{'nlist': 16384}\n}\ncollection.create_index(field_name=\"image_embedding\", index_params=index_params)\ncollection.load()\n","import glob\n\n# Get the filepaths of the images\npaths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)\nlen(paths)\n","import torch\n\n# Load the embedding model with the last layer removed\nmodel = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)\nmodel = torch.nn.Sequential(*(list(model.children())[:-1]))\nmodel.eval()\n","from torchvision import transforms\n\n# Preprocessing for images\npreprocess = transforms.Compose([\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n])\n","from PIL import Image\nfrom tqdm import tqdm\n\n# Embed function that embeds the batch and inserts it\ndef embed(data):\n with torch.no_grad():\n output = model(torch.stack(data[0])).squeeze()\n collection.insert([data[1], output.tolist()])\n\ndata_batch = [[],[]]\n\n# Read the images into batches for embedding and insertion\nfor path in tqdm(paths):\n im = Image.open(path).convert('RGB')\n data_batch[0].append(preprocess(im))\n data_batch[1].append(path)\n if len(data_batch[0]) % BATCH_SIZE == 0:\n embed(data_batch)\n data_batch = [[],[]]\n\n# Embed and insert the remainder\nif len(data_batch[0]) != 0:\n embed(data_batch)\n\n# Call a flush to index any unsealed segments.\ncollection.flush()\n","import glob\n\n# Get the filepaths of the search images\nsearch_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)\nlen(search_paths)\n","import time\nfrom matplotlib import pyplot as plt\n\n# Embed the search images\ndef embed(data):\n with torch.no_grad():\n ret = model(torch.stack(data))\n # If more than one image, use squeeze\n if len(ret) > 1:\n return ret.squeeze().tolist()\n # Squeeze would remove batch for single image, so using flatten\n else:\n return torch.flatten(ret, start_dim=1).tolist()\n\ndata_batch = [[],[]]\n\nfor path in search_paths:\n im = Image.open(path).convert('RGB')\n data_batch[0].append(preprocess(im))\n data_batch[1].append(path)\n\nembeds = embed(data_batch[0])\nstart = time.time()\nres = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])\nfinish = time.time()\n","# Show the image results\nf, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)\n\nfor hits_i, hits in enumerate(res):\n axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))\n axarr[hits_i][0].set_axis_off()\n axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))\n for hit_i, hit in enumerate(hits):\n axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))\n axarr[hits_i][hit_i + 1].set_axis_off()\n axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))\n\n# Save the search result in a separate image file alongside your script.\nplt.savefig('search_result.png')\n"],"headingContent":"Image Search with PyTorch and Milvus","anchorList":[{"label":"Image Search with PyTorch and Milvus","href":"Image-Search-with-PyTorch-and-Milvus","type":1,"isActive":false},{"label":"Installing the requirements","href":"Installing-the-requirements","type":2,"isActive":false},{"label":"Grabbing the data","href":"Grabbing-the-data","type":2,"isActive":false},{"label":"Global Arguments","href":"Global-Arguments","type":2,"isActive":false},{"label":"Setting up Milvus","href":"Setting-up-Milvus","type":2,"isActive":false},{"label":"Inserting the data","href":"Inserting-the-data","type":2,"isActive":false},{"label":"Performing the search","href":"Performing-the-search","type":2,"isActive":false}]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters