Skip to content

Commit

Permalink
suport fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
masajiro committed Feb 9, 2022
1 parent eeeb3d6 commit 9d2e058
Show file tree
Hide file tree
Showing 26 changed files with 5,327 additions and 110 deletions.
1 change: 0 additions & 1 deletion README-jp.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ Copyright (C) 2015 Yahoo Japan Corporation

http://www.apache.org/licenses/LICENSE-2.0

ヤフー株式会社は本ソフトウェアが利用している技術の特許権を取得しています。ただし、本ソフトウェアを介して権利化された技術を利用する場合に限り、Apacheライセンスバージョン2.0の下で特許権が行使されることはありません。

貢献者ライセンス同意(CLA)
-------------------------
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.13.8
1.14.0
1 change: 1 addition & 0 deletions bin/ngt/README-jp.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ ANNGやBKNNGを指定した場合には登録データ(ノード)からエ
データオブジェクトの型を指定します。
- __c__: 1バイト整数
- __f__: 4バイト浮動小数点(デフォルト)
- __h__: 2バイト浮動小数点

**-D** *distance\_function*
距離関数を指定します。
Expand Down
1 change: 1 addition & 0 deletions bin/ngt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Specify the number of edges at search time accompanying or following index gener
Specify the data object type.
- __c__: 1 byte unsigned integer
- __f__: 4 byte floating point number (default)
- __h__: 2 byte floating point number

**-D** *distance\_function*
Specify the distance function as follows.
Expand Down
16 changes: 16 additions & 0 deletions lib/NGT/Capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ bool ngt_is_property_object_type_float(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Float);
}

bool ngt_is_property_object_type_float16(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Float16);
}

bool ngt_is_property_object_type_integer(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
}
Expand All @@ -205,6 +209,18 @@ bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
return true;
}

bool ngt_set_property_object_type_float16(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}

(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float16;
return true;
}

bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
Expand Down
4 changes: 4 additions & 0 deletions lib/NGT/Capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ int32_t ngt_get_property_object_type(NGTProperty, NGTError);

bool ngt_is_property_object_type_float(int32_t);

bool ngt_is_property_object_type_float16(int32_t);

bool ngt_is_property_object_type_integer(int32_t);

bool ngt_set_property_object_type_float(NGTProperty, NGTError);

bool ngt_set_property_object_type_float16(NGTProperty, NGTError);

bool ngt_set_property_object_type_integer(NGTProperty, NGTError);

bool ngt_set_property_distance_type_l1(NGTProperty, NGTError);
Expand Down
13 changes: 12 additions & 1 deletion lib/NGT/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ using namespace std;
case 'c':
property.objectType = NGT::Index::Property::ObjectType::Uint8;
break;
#ifdef NGT_HALF_FLOAT
case 'h':
property.objectType = NGT::Index::Property::ObjectType::Float16;
break;
#endif
default:
std::stringstream msg;
msg << "Command::CreateParameter: Error: Invalid object type. " << objectType;
Expand Down Expand Up @@ -175,7 +180,13 @@ using namespace std;
const string usage = "Usage: ngt create "
"-d dimension [-p #-of-thread] [-i index-type(t|g)] [-g graph-type(a|k|b|o|i)] "
"[-t truncation-edge-limit] [-E edge-size] [-S edge-size-for-search] [-L edge-size-limit] "
"[-e epsilon] [-o object-type(f|c)] [-D distance-function(1|2|a|A|h|j|c|C|E|p|l)] [-n #-of-inserted-objects] " // added by Nyapicom
"[-e epsilon] "
#ifdef NGT_HALF_FLOAT
"[-o object-type(f|h|c)] "
#else
"[-o object-type(f|c)] "
#endif
"[-D distance-function(1|2|a|A|h|j|c|C|E|p|l)] [-n #-of-inserted-objects] " // added by Nyapicom
"[-P path-adjustment-interval] [-B dynamic-edge-size-base] [-A object-alignment(t|f)] "
"[-T build-time-limit] [-O outgoing x incoming] "
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Expand Down
11 changes: 11 additions & 0 deletions lib/NGT/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@
#include "NGT/defines.h"
#include "NGT/SharedMemoryAllocator.h"

#ifdef NGT_HALF_FLOAT
#include "NGT/half.hpp"
#endif

#define ADVANCED_USE_REMOVED_LIST
#define SHARED_REMOVED_LIST

namespace NGT {
typedef unsigned int ObjectID;
typedef float Distance;
#ifdef NGT_HALF_FLOAT
typedef half_float::half float16;
#endif

#define NGTThrowException(MESSAGE) throw NGT::Exception(__FILE__, (size_t)__LINE__, MESSAGE)
#define NGTThrowSpecificException(MESSAGE, TYPE) throw NGT::TYPE(__FILE__, (size_t)__LINE__, MESSAGE)
Expand Down Expand Up @@ -2100,6 +2107,10 @@ namespace NGT {
delete static_cast<std::vector<double>*>(query);
} else if (*queryType == typeid(uint8_t)) {
delete static_cast<std::vector<uint8_t>*>(query);
#ifdef NGT_HALF_FLOAT
} else if (*queryType == typeid(float16)) {
delete static_cast<std::vector<float16>*>(query);
#endif
}
query = 0;
queryType = 0;
Expand Down
164 changes: 142 additions & 22 deletions lib/NGT/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,70 @@ NeighborhoodGraph::Search::jaccardUint8(NeighborhoodGraph &graph, NGT::SearchCon
graph.searchReadOnlyGraph<PrimitiveComparator::JaccardUint8, DistanceCheckedSet>(sc, seeds);
}

#ifdef NGT_HALF_FLOAT
void
NeighborhoodGraph::Search::normalizedCosineSimilarityFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedCosineSimilarityFloat16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::cosineSimilarityFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::CosineSimilarityFloat16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::normalizedAngleFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedAngleFloat16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::angleFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::AngleFloat16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::l1Float16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::L1Float16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::l2Float16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::L2Float16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::normalizedL2Float16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedL2Float16, DistanceCheckedSet>(sc, seeds);
}

void
NeighborhoodGraph::Search::sparseJaccardFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::SparseJaccardFloat16, DistanceCheckedSet>(sc, seeds);
}

// added by Nyapicom
void
NeighborhoodGraph::Search::poincareFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::PoincareFloat16, DistanceCheckedSet>(sc, seeds);
}

// added by Nyapicom
void
NeighborhoodGraph::Search::lorentzFloat16(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::LorentzFloat16, DistanceCheckedSet>(sc, seeds);
}
#endif

////

void
Expand Down Expand Up @@ -236,7 +300,67 @@ NeighborhoodGraph::Search::jaccardUint8ForLargeDataset(NeighborhoodGraph &graph,
graph.searchReadOnlyGraph<PrimitiveComparator::JaccardUint8, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

#ifdef NGT_HALF_FLOAT
void
NeighborhoodGraph::Search::normalizedCosineSimilarityFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedCosineSimilarityFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::cosineSimilarityFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::CosineSimilarityFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::normalizedAngleFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedAngleFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::angleFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::AngleFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::l1Float16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::L1Float16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::l2Float16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::L2Float16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::normalizedL2Float16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::NormalizedL2Float16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::sparseJaccardFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::SparseJaccardFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::poincareFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::PoincareFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}

void
NeighborhoodGraph::Search::lorentzFloat16ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds)
{
graph.searchReadOnlyGraph<PrimitiveComparator::LorentzFloat16, DistanceCheckedSetForLargeDataset>(sc, seeds);
}
#endif

#endif

Expand Down Expand Up @@ -333,6 +457,7 @@ NeighborhoodGraph::setupDistances(NGT::SearchContainer &sc, ObjectDistances &see
}

#ifdef NGT_DISTANCE_COMPUTATION_COUNT
sc.visitCount += seeds.size();
sc.distanceComputationCount += seeds.size();
#endif
}
Expand Down Expand Up @@ -416,7 +541,6 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
setupDistances(sc, seeds, COMPARATOR::compare);
setupSeeds(sc, seeds, results, unchecked, distanceChecked);


Distance explorationRadius = sc.explorationCoefficient * sc.radius;
const size_t dimension = objectSpace->getPaddedDimension();
ReadOnlyGraphNode *nodes = &searchRepository.front();
Expand All @@ -440,43 +564,39 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,

pair<uint64_t, PersistentObject*>* nsPtrs[neighborSize];
size_t nsPtrsSize = 0;

for (; neighborptr < neighborendptr; ++neighborptr) {
if (!distanceChecked[(*(neighborptr)).first]) {
nsPtrs[nsPtrsSize] = neighborptr;
if (nsPtrsSize < prefetchOffset) {
unsigned char *ptr = reinterpret_cast<unsigned char*>((*(neighborptr)).second);
MemoryCache::prefetch(ptr, prefetchSize);
}
nsPtrsSize++;
}
#ifdef NGT_VISIT_COUNT
sc.visitCount++;
#endif
if (!distanceChecked[(*(neighborptr)).first]) {
distanceChecked.insert((*(neighborptr)).first);
nsPtrs[nsPtrsSize] = neighborptr;
if (nsPtrsSize < prefetchOffset) {
unsigned char *ptr = reinterpret_cast<unsigned char*>((*(neighborptr)).second);
MemoryCache::prefetch(ptr, prefetchSize);
}
nsPtrsSize++;
}
}
for (size_t idx = 0; idx < nsPtrsSize; idx++) {
neighborptr = nsPtrs[idx];
if (idx + prefetchOffset < nsPtrsSize) {
unsigned char *ptr = reinterpret_cast<unsigned char*>((*(nsPtrs[idx + prefetchOffset])).second);
MemoryCache::prefetch(ptr, prefetchSize);
}
#ifdef NGT_VISIT_COUNT
sc.visitCount++;
#endif
auto &neighbor = *neighborptr;
distanceChecked.insert(neighbor.first);

#ifdef NGT_DISTANCE_COMPUTATION_COUNT
sc.distanceComputationCount++;
#endif
Distance distance = COMPARATOR::compare((void*)&sc.object[0],
(void*)&(*static_cast<PersistentObject*>(neighbor.second))[0], dimension);
(void*)&(*static_cast<PersistentObject*>(neighborptr->second))[0], dimension);
if (distance <= explorationRadius) {
result.set(neighbor.first, distance);
result.set(neighborptr->first, distance);
unchecked.push(result);
if (distance <= sc.radius) {
results.push(result);
if (results.size() >= sc.size) {
if (results.size() > sc.size) {
results.pop();
}
if (results.size() > sc.size) {
results.pop();
sc.radius = results.top().distance;
explorationRadius = sc.explorationCoefficient * sc.radius;
}
Expand Down Expand Up @@ -805,7 +925,7 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds,
}
if (insertionA != insertionB) {
stringstream msg;
msg << "Graph::removeEdgeReliably:Warning. Lost conectivity! Isn't this ANNG? ID=" << id << ".";
msg << "Graph::removeEdgeReliably:Warning. Lost connectivity! Isn't this ANNG? ID=" << id << ".";
#ifdef NGT_FORCED_REMOVE
msg << " Anyway continue...";
cerr << msg.str() << endl;
Expand Down
Loading

0 comments on commit 9d2e058

Please sign in to comment.