Skip to content

Commit

Permalink
refactoring the code
Browse files Browse the repository at this point in the history
  • Loading branch information
nikit-srivastava committed Oct 18, 2018
1 parent 7521898 commit 1cd238d
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 65 deletions.
31 changes: 0 additions & 31 deletions src/main/java/nikit/test/Word2VecTester.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nikit.test;
package org.aksw.word2vecrestful.utils;

import java.util.HashMap;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
import org.apache.log4j.Logger;
import org.dice_research.topicmodeling.commons.sort.AssociativeSort;

import nikit.test.TimeLogger;

/**
* Class to encapsulate word2vec in-memory model and expose methods to perform
* search on the model
* search on the model.
*
* This class selects {@link W2VNrmlMemModelBinSrch#compareVecCount} vectors (1
* mean vector and others on basis Map iterator) and then calculates the cosine
* similarity of all words in model to those vectors.
*
* It uses the knowledge about pre-processed similarities with
* {@link W2VNrmlMemModelBinSrch#comparisonVecs} to narrow down the search of
* closest word for the user specified vector.
*
* @author Nikit
*
Expand All @@ -30,8 +36,6 @@ public class W2VNrmlMemModelBinSrch implements GenWord2VecModel {
private int compareVecCount = 150;
private int bucketCount = 10;
private BitSet[][] csBucketContainer;
// TODO : Remove this
private TimeLogger tl = new TimeLogger();

public W2VNrmlMemModelBinSrch(final Map<String, float[]> word2vec, final int vectorSize) {
this.word2vec = word2vec;
Expand Down Expand Up @@ -170,13 +174,12 @@ private String getClosestEntry(float[] vector, String subKey) {
// calculate cosine similarity of all distances
float[] curCompVec;
BitSet finBitSet = null;
tl.logTime(1);
for (int i = 0; i < compareVecCount; i++) {
curCompVec = comparisonVecs[i];
double cosSimVal = Word2VecMath.cosineSimilarityNormalizedVecs(curCompVec, vector);
int indx = getBucketIndex(cosSimVal);
BitSet curBs = new BitSet(word2vec.size());
if(csBucketContainer[i][indx]!=null) {
if (csBucketContainer[i][indx] != null) {
curBs.or(csBucketContainer[i][indx]);
}
int temIndx = indx + 1;
Expand All @@ -193,10 +196,8 @@ private String getClosestEntry(float[] vector, String subKey) {
finBitSet.and(curBs);
}
}
tl.printTime(1, "Setting Bits");
tl.logTime(1);
int nearbyWordsCount = finBitSet.cardinality();
LOG.info("Number of nearby words: "+nearbyWordsCount);
//LOG.info("Number of nearby words: " + nearbyWordsCount);
int[] nearbyIndexes = new int[nearbyWordsCount];
int j = 0;
for (int i = finBitSet.nextSetBit(0); i >= 0; i = finBitSet.nextSetBit(i + 1), j++) {
Expand All @@ -206,15 +207,12 @@ private String getClosestEntry(float[] vector, String subKey) {
break; // or (i+1) would overflow
}
}
tl.printTime(1, "Extracting words");
tl.logTime(1);
closestWord = findClosestWord(nearbyIndexes, vector);
tl.printTime(1, "finding closest word");
} catch (Exception e) {
LOG.error("Exception has occured while finding closest word.");
e.printStackTrace();
}
LOG.info("Closest word found is: "+closestWord);
//LOG.info("Closest word found is: " + closestWord);
return closestWord;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import java.util.Set;

import org.aksw.word2vecrestful.subset.DataSubsetProvider;
import org.aksw.word2vecrestful.utils.TimeLogger;
import org.aksw.word2vecrestful.utils.Word2VecMath;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.dice_research.topicmodeling.commons.sort.AssociativeSort;

import nikit.test.TimeLogger;

/**
* Class to encapsulate word2vec in-memory model and expose methods to perform
* search on the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
import com.opencsv.CSVReader;
import com.opencsv.CSVWriter;

import nikit.test.TimeLogger;

/**
* Class to encapsulate word2vec in-memory model and expose methods to perform
* search on the model
* search on the model.
*
* This class selects {@link W2VNrmlMemModelKMeans#compareVecCount} vectors
* (centroids of the KMeans result on the model vectors) and then calculates the
* cosine similarity of all words in model to those vectors.
*
* It uses the knowledge about pre-processed similarities with
* {@link W2VNrmlMemModelKMeans#comparisonVecs} to narrow down the search of
* closest word for the user specified vector.
*
* @author Nikit
*
Expand All @@ -42,8 +48,6 @@ public class W2VNrmlMemModelKMeans implements GenWord2VecModel {
private int kMeansMaxItr = 5;
private BitSet[][] csBucketContainer;
private String vecFilePath = "data/kmeans/comparison-vecs.csv";
// TODO : Remove this
private TimeLogger tl = new TimeLogger();

public W2VNrmlMemModelKMeans(final Map<String, float[]> word2vec, final int vectorSize) throws IOException {
this.word2vec = word2vec;
Expand Down Expand Up @@ -175,14 +179,13 @@ private String getClosestEntry(float[] vector, String subKey) {
// calculate cosine similarity of all distances
float[] curCompVec;
BitSet finBitSet = null;
tl.logTime(1);
for (int i = 0; i < compareVecCount; i++) {
curCompVec = comparisonVecs[i];
double cosSimVal = Word2VecMath.cosineSimilarityNormalizedVecs(curCompVec, vector);
int indx = getBucketIndex(cosSimVal);
BitSet curBs = new BitSet(word2vec.size());
BitSet tempBs = csBucketContainer[i][indx];
if(tempBs!=null) {
if (tempBs != null) {
curBs.or(tempBs);
}
int temIndx = indx + 1;
Expand All @@ -199,10 +202,8 @@ private String getClosestEntry(float[] vector, String subKey) {
finBitSet.and(curBs);
}
}
tl.printTime(1, "Setting Bits");
tl.logTime(1);
int nearbyWordsCount = finBitSet.cardinality();
LOG.info("Number of nearby words: "+nearbyWordsCount);
//LOG.info("Number of nearby words: " + nearbyWordsCount);
int[] nearbyIndexes = new int[nearbyWordsCount];
int j = 0;
for (int i = finBitSet.nextSetBit(0); i >= 0; i = finBitSet.nextSetBit(i + 1), j++) {
Expand All @@ -212,15 +213,12 @@ private String getClosestEntry(float[] vector, String subKey) {
break; // or (i+1) would overflow
}
}
tl.printTime(1, "Extracting words");
tl.logTime(1);
closestWord = findClosestWord(nearbyIndexes, vector);
tl.printTime(1, "finding closest word");
} catch (Exception e) {
LOG.error("Exception has occured while finding closest word.");
e.printStackTrace();
}
LOG.info("Closest word found is: "+closestWord);
//LOG.info("Closest word found is: " + closestWord);
return closestWord;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import org.apache.log4j.PropertyConfigurator;
import org.junit.Test;

import nikit.test.TestConst;

public class NrmlzdThetaMdlPrfmncTester {
static {
PropertyConfigurator.configure(Cfg.LOG_FILE);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nikit.test;
package org.aksw.word2vecrestful;

import java.util.HashMap;
import java.util.Map;
Expand Down

0 comments on commit 1cd238d

Please sign in to comment.