Skip to content

Commit

Permalink
removing duplicate methods
Browse files Browse the repository at this point in the history
  • Loading branch information
nikit-srivastava committed Oct 18, 2018
1 parent 1cd238d commit 4e2b02e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 200 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.aksw.word2vecrestful.word2vec;

import java.io.IOException;
import java.util.BitSet;
import java.util.Map;

Expand All @@ -10,7 +11,7 @@

/**
* Class to encapsulate word2vec in-memory model and expose methods to perform
* search on the model.
* search on the model. (Only works with Normalized Model)
*
* This class selects {@link W2VNrmlMemModelBinSrch#compareVecCount} vectors (1
* mean vector and others on basis Map iterator) and then calculates the cosine
Expand All @@ -26,29 +27,33 @@
public class W2VNrmlMemModelBinSrch implements GenWord2VecModel {
public static Logger LOG = LogManager.getLogger(GenWord2VecModel.class);

private Map<String, float[]> word2vec;
private int vectorSize;
private float[][] comparisonVecs = null;
private String[] wordArr;
private float[][] vecArr;
private int[] indxArr;
private double[] simValArr;
private int compareVecCount = 150;
private int bucketCount = 10;
private BitSet[][] csBucketContainer;
protected Map<String, float[]> word2vec;
protected int vectorSize;
protected float[][] comparisonVecs = null;
protected String[] wordArr;
protected float[][] vecArr;
protected int[] indxArr;
protected double[] simValArr;
protected int compareVecCount = 150;
protected int bucketCount = 10;
protected BitSet[][] csBucketContainer;

public W2VNrmlMemModelBinSrch(final Map<String, float[]> word2vec, final int vectorSize) {
public W2VNrmlMemModelBinSrch(final Map<String, float[]> word2vec, final int vectorSize) throws IOException {
this.word2vec = word2vec;
this.vectorSize = vectorSize;
comparisonVecs = new float[compareVecCount][vectorSize];
csBucketContainer = new BitSet[compareVecCount][bucketCount];
process();
}

protected void process() throws IOException {
LOG.info("Process from BinSrch called");
// Setting mean as comparison vec
setMeanComparisonVec(word2vec, vectorSize);
// Initialize Arrays
processCosineSim();
// Set other comparison vecs
setAllComparisonVecs();

}

private void setBucketVals(int compVecIndex, float[] comparisonVec) {
Expand All @@ -73,7 +78,7 @@ private void setAllComparisonVecs() {
}
}

private int getBucketIndex(double cosineSimVal) {
protected int getBucketIndex(double cosineSimVal) {
Double dIndx = ((bucketCount - 1d) / 2d) * (cosineSimVal + 1d);
return Math.round(dIndx.floatValue());
}
Expand All @@ -98,7 +103,7 @@ private void processCosineSim() {
AssociativeSort.quickSort(simValArr, indxArr);
}

private void setValToBucket(int wordIndex, double cosSimVal, BitSet[] meanComparisonVecBuckets) {
protected void setValToBucket(int wordIndex, double cosSimVal, BitSet[] meanComparisonVecBuckets) {
int bucketIndex = getBucketIndex(cosSimVal);
BitSet bitset = meanComparisonVecBuckets[bucketIndex];
if (bitset == null) {
Expand Down Expand Up @@ -166,7 +171,7 @@ public String getClosestSubEntry(float[] vector, String subKey) {
* - key to subset if any
* @return closest word to the given vector alongwith it's vector
*/
private String getClosestEntry(float[] vector, String subKey) {
protected String getClosestEntry(float[] vector, String subKey) {
String closestWord = null;
try {
// Normalize incoming vector
Expand Down Expand Up @@ -197,7 +202,7 @@ private String getClosestEntry(float[] vector, String subKey) {
}
}
int nearbyWordsCount = finBitSet.cardinality();
//LOG.info("Number of nearby words: " + nearbyWordsCount);
// LOG.info("Number of nearby words: " + nearbyWordsCount);
int[] nearbyIndexes = new int[nearbyWordsCount];
int j = 0;
for (int i = finBitSet.nextSetBit(0); i >= 0; i = finBitSet.nextSetBit(i + 1), j++) {
Expand All @@ -212,11 +217,11 @@ private String getClosestEntry(float[] vector, String subKey) {
LOG.error("Exception has occured while finding closest word.");
e.printStackTrace();
}
//LOG.info("Closest word found is: " + closestWord);
// LOG.info("Closest word found is: " + closestWord);
return closestWord;
}

private String findClosestWord(int[] nearbyIndexes, float[] vector) {
protected String findClosestWord(int[] nearbyIndexes, float[] vector) {
double minDist = -2;
String minWord = null;
double tempDist;
Expand All @@ -243,7 +248,7 @@ private String findClosestWord(int[] nearbyIndexes, float[] vector) {
* - minimum distance constraint
* @return squared euclidean distance between two vector or -1
*/
private double getSqEucDist(float[] arr1, float[] arr2, double minDist) {
protected double getSqEucDist(float[] arr1, float[] arr2, double minDist) {
double dist = 0;
for (int i = 0; i < vectorSize; i++) {
dist += Math.pow(arr1[i] - arr2[i], 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

/**
* Class to encapsulate word2vec in-memory model and expose methods to perform
* search on the model.
* search on the model. (Only works with Normalized Model)
*
* This class selects {@link W2VNrmlMemModelKMeans#compareVecCount} vectors
* (centroids of the KMeans result on the model vectors) and then calculates the
Expand All @@ -35,25 +35,19 @@
* @author Nikit
*
*/
public class W2VNrmlMemModelKMeans implements GenWord2VecModel {
public class W2VNrmlMemModelKMeans extends W2VNrmlMemModelBinSrch {
public static Logger LOG = LogManager.getLogger(GenWord2VecModel.class);

private Map<String, float[]> word2vec;
private int vectorSize;
private float[][] comparisonVecs = null;
private String[] wordArr;
private float[][] vecArr;
private int compareVecCount = 100;
private int bucketCount = 10;
private int kMeansMaxItr = 5;
private BitSet[][] csBucketContainer;
private String vecFilePath = "data/kmeans/comparison-vecs.csv";

public W2VNrmlMemModelKMeans(final Map<String, float[]> word2vec, final int vectorSize) throws IOException {
this.word2vec = word2vec;
this.vectorSize = vectorSize;
comparisonVecs = new float[compareVecCount][vectorSize];
csBucketContainer = new BitSet[compareVecCount][bucketCount];
super(word2vec, vectorSize);
}

@Override
protected void process() throws IOException {
LOG.info("Process from KMeans called");
fetchComparisonVectors();
// Initialize Arrays
processCosineSim();
Expand Down Expand Up @@ -117,167 +111,6 @@ private void processCosineSim() {
}
}

private int getBucketIndex(double cosineSimVal) {
Double dIndx = ((bucketCount - 1d) / 2d) * (cosineSimVal + 1d);
return Math.round(dIndx.floatValue());
}

private void setValToBucket(int wordIndex, double cosSimVal, BitSet[] meanComparisonVecBuckets) {
int bucketIndex = getBucketIndex(cosSimVal);
BitSet bitset = meanComparisonVecBuckets[bucketIndex];
if (bitset == null) {
bitset = new BitSet(word2vec.size());
meanComparisonVecBuckets[bucketIndex] = bitset;
}
bitset.set(wordIndex);
}

/**
* Method to fetch the closest word entry for a given vector using cosine
* similarity
*
* @param vector
* - vector to find closest word to
*
* @return closest word to the given vector alongwith it's vector
*/
@Override
public String getClosestEntry(float[] vector) {
return getClosestEntry(vector, null);
}

/**
* Method to fetch the closest word entry for a given vector using cosine
* similarity
*
* @param vector
* - vector to find closest word to
* @param subKey
* - key to subset if any
* @return closest word to the given vector alongwith it's vector
*/
@Override
public String getClosestSubEntry(float[] vector, String subKey) {
return getClosestEntry(vector, subKey);
}

/**
* Method to fetch the closest word entry for a given vector using cosine
* similarity
*
* @param vector
* - vector to find closest word to
* @param subKey
* - key to subset if any
* @return closest word to the given vector alongwith it's vector
*/
private String getClosestEntry(float[] vector, String subKey) {
String closestWord = null;
try {
// Normalize incoming vector
vector = Word2VecMath.normalize(vector);
// calculate cosine similarity of all distances
float[] curCompVec;
BitSet finBitSet = null;
for (int i = 0; i < compareVecCount; i++) {
curCompVec = comparisonVecs[i];
double cosSimVal = Word2VecMath.cosineSimilarityNormalizedVecs(curCompVec, vector);
int indx = getBucketIndex(cosSimVal);
BitSet curBs = new BitSet(word2vec.size());
BitSet tempBs = csBucketContainer[i][indx];
if (tempBs != null) {
curBs.or(tempBs);
}
int temIndx = indx + 1;
if (temIndx < csBucketContainer[i].length && csBucketContainer[i][temIndx] != null) {
curBs.or(csBucketContainer[i][temIndx]);
}
temIndx = indx - 1;
if (temIndx > -1 && csBucketContainer[i][temIndx] != null) {
curBs.or(csBucketContainer[i][temIndx]);
}
if (i == 0) {
finBitSet = curBs;
} else {
finBitSet.and(curBs);
}
}
int nearbyWordsCount = finBitSet.cardinality();
//LOG.info("Number of nearby words: " + nearbyWordsCount);
int[] nearbyIndexes = new int[nearbyWordsCount];
int j = 0;
for (int i = finBitSet.nextSetBit(0); i >= 0; i = finBitSet.nextSetBit(i + 1), j++) {
// operate on index i here
nearbyIndexes[j] = i;
if (i == Integer.MAX_VALUE) {
break; // or (i+1) would overflow
}
}
closestWord = findClosestWord(nearbyIndexes, vector);
} catch (Exception e) {
LOG.error("Exception has occured while finding closest word.");
e.printStackTrace();
}
//LOG.info("Closest word found is: " + closestWord);
return closestWord;
}

private String findClosestWord(int[] nearbyIndexes, float[] vector) {
double minDist = -2;
String minWord = null;
double tempDist;
for (int indx : nearbyIndexes) {
float[] wordvec = vecArr[indx];
tempDist = getSqEucDist(vector, wordvec, minDist);
if (tempDist != -1) {
minWord = wordArr[indx];
minDist = tempDist;
}
}
return minWord;
}

/**
* Method to find the squared value of euclidean distance between two vectors if
* it is less than the provided minimum distance value, otherwise return -1
*
* @param arr1
* - first vector
* @param arr2
* - second vector
* @param minDist
* - minimum distance constraint
* @return squared euclidean distance between two vector or -1
*/
private double getSqEucDist(float[] arr1, float[] arr2, double minDist) {
double dist = 0;
for (int i = 0; i < vectorSize; i++) {
dist += Math.pow(arr1[i] - arr2[i], 2);
if (minDist != -2 && dist > minDist)
return -1;
}
return dist;
}

/**
* Method to fetch vectorSize
*
* @return - vectorSize
*/
@Override
public int getVectorSize() {
return this.vectorSize;
}

/**
* Method to fetch word2vec map
*
* @return - word2vec map
*/
public Map<String, float[]> getWord2VecMap() {
return this.word2vec;
}

public static float[][] readVecsFromFile(File inputFile) throws IOException {
float[][] vecArr = null;
FileReader fileReader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import java.util.List;

import org.aksw.word2vecrestful.utils.Cfg;
import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelBinSrch;
import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelBruteForce;
import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelKMeans;
import org.aksw.word2vecrestful.word2vec.Word2VecFactory;
import org.aksw.word2vecrestful.word2vec.Word2VecModel;
import org.apache.log4j.LogManager;
Expand Down Expand Up @@ -38,11 +38,11 @@ public void testNbmTime() throws IOException {
LOG.info("Starting BruteForce-Model Test");
List<String> correctWords = getCorrectWords(centroids, nbm);
LOG.info("Correct Words are :" + correctWords);
LOG.info("Initializing W2VNrmlMemModelBinSrch Model");
final W2VNrmlMemModelBinSrch memModel = new W2VNrmlMemModelBinSrch(nbm.word2vec, nbm.vectorSize);
LOG.info("Initializing W2VNrmlMemModelKMeans Model");
final W2VNrmlMemModelKMeans memModel = new W2VNrmlMemModelKMeans(nbm.word2vec, nbm.vectorSize);
List<String> lrModelWords = new ArrayList<>();

LOG.info("Starting W2VNrmlMemModelBinSrch Test");
LOG.info("Starting W2VNrmlMemModelKMeans Test");

for (int i = 0; i < centroids.length; i++) {
LOG.info("Sending query for Centroid " + (i + 1));
Expand All @@ -52,7 +52,7 @@ public void testNbmTime() throws IOException {
totTime += diff;
LOG.info("Query time recorded for Centroid " + (i + 1) + " is " + diff + " milliseconds.");
}
LOG.info("Average query time for W2VNrmlMemModelBinSrch is : " + (totTime / centroids.length) + " milliseconds");
LOG.info("Average query time for W2VNrmlMemModelKMeans is : " + (totTime / centroids.length) + " milliseconds");
LOG.info("Predicted Words are :" + lrModelWords);
float percVal = calcPercScore(correctWords, lrModelWords);
LOG.info("Score for Test is : " + percVal + "%");
Expand Down

0 comments on commit 4e2b02e

Please sign in to comment.