diff --git a/src/main/java/org/aksw/word2vecrestful/word2vec/GenWord2VecModel.java b/src/main/java/org/aksw/word2vecrestful/word2vec/GenWord2VecModel.java index 4301c10..1d9251b 100644 --- a/src/main/java/org/aksw/word2vecrestful/word2vec/GenWord2VecModel.java +++ b/src/main/java/org/aksw/word2vecrestful/word2vec/GenWord2VecModel.java @@ -1,9 +1,10 @@ package org.aksw.word2vecrestful.word2vec; -import java.util.Map; +import java.io.IOException; public interface GenWord2VecModel { public int getVectorSize(); public String getClosestEntry(float[] vector); public String getClosestSubEntry(float[] vector, String subKey); + public void process() throws IOException; } diff --git a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBruteForce.java b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBruteForce.java index 556980a..fa73efd 100644 --- a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBruteForce.java +++ b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBruteForce.java @@ -1,16 +1,12 @@ package org.aksw.word2vecrestful.word2vec; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import java.util.Set; import org.aksw.word2vecrestful.subset.DataSubsetProvider; -import org.aksw.word2vecrestful.utils.TimeLogger; import org.aksw.word2vecrestful.utils.Word2VecMath; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; -import org.dice_research.topicmodeling.commons.sort.AssociativeSort; /** * Class to encapsulate word2vec in-memory model and expose methods to perform @@ -24,13 +20,18 @@ public class W2VNrmlMemModelBruteForce implements GenWord2VecModel { private Map word2vec; private int vectorSize; + // for future use + @SuppressWarnings("unused") private DataSubsetProvider dataSubsetProvider; - // TODO : Remove this - private TimeLogger tl = new TimeLogger(); public W2VNrmlMemModelBruteForce(final Map word2vec, final int vectorSize) { this.word2vec = word2vec; this.vectorSize = vectorSize; + + } + + @Override + public void process() throws IOException { this.dataSubsetProvider = new DataSubsetProvider(); } @@ -74,51 +75,9 @@ public String getClosestSubEntry(float[] vector, String subKey) { * @return closest word to the given vector alongwith it's vector */ private String getClosestEntry(float[] vector, String subKey) { - Set wordSet = null; - String closestVec = null; - try { - if (subKey == null) { - wordSet = word2vec.keySet(); - } else { - tl.logTime(1); - wordSet = dataSubsetProvider.fetchSubsetWords(subKey); - tl.printTime(1, "fetchSubsetWords"); - } - // LOG.info("Normalizing input vector"); - // Normalize incoming vector - vector = Word2VecMath.normalize(vector); - - return Word2VecMath.findClosestNormalizedVec(word2vec, vector); - // LOG.info("fetching nearby vectors"); - // calculate cosine similarity of all distances -// String[] wordArr = new String[wordSet.size()]; -// int[] idArr = new int[wordSet.size()]; -// double[] cosineArr = new double[wordSet.size()]; -// int i = 0; -// for (String word : wordSet) { -// wordArr[i] = word; -// idArr[i] = i; -// float[] wordVec = word2vec.get(word); -// cosineArr[i] = Word2VecMath.cosineSimilarityNormalizedVecs(wordVec, vector); -// i++; -// } -// cosineArr = AssociativeSort.quickSort(cosineArr, idArr); -// double maxVal = cosineArr[cosineArr.length - 1]; -// for (int j = cosineArr.length - 1; j >= 0; j--) { -// if (cosineArr[j] == maxVal) { -// int closestWordId = idArr[j]; -// String closestWord = wordArr[closestWordId]; -// closestVec = closestWord; -// }else { -// break; -// } -// } - - } catch (IOException e) { - LOG.error(e.getStackTrace()); - } - // LOG.info("Closest word found is " + closestVec.keySet()); - return closestVec; + // Normalize incoming vector + vector = Word2VecMath.normalize(vector); + return Word2VecMath.findClosestNormalizedVec(word2vec, vector); } /** diff --git a/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java b/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java index e1245d3..f7609cf 100644 --- a/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java +++ b/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java @@ -5,8 +5,9 @@ import java.util.List; import org.aksw.word2vecrestful.utils.Cfg; -import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelBruteForce; +import org.aksw.word2vecrestful.word2vec.GenWord2VecModel; import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelBinSrch; +import org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelBruteForce; import org.aksw.word2vecrestful.word2vec.Word2VecFactory; import org.aksw.word2vecrestful.word2vec.Word2VecModel; import org.apache.log4j.LogManager; @@ -39,7 +40,7 @@ public void testNbmTime() throws IOException { List correctWords = getCorrectWords(centroids, nbm); LOG.info("Correct Words are :" + correctWords); LOG.info("Initializing W2VNrmlMemModelBinSrch Model"); - final W2VNrmlMemModelBinSrch memModel = new W2VNrmlMemModelBinSrch(nbm.word2vec, nbm.vectorSize); + final GenWord2VecModel memModel = new W2VNrmlMemModelBinSrch(nbm.word2vec, nbm.vectorSize); memModel.process(); List lrModelWords = new ArrayList<>();