Skip to content

Commit

Permalink
Changing method calls
Browse files Browse the repository at this point in the history
  • Loading branch information
nikit-srivastava committed Oct 19, 2018
1 parent d8d478a commit 2d4ba0b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@ protected W2VNrmlMemModelBinSrch() {
public W2VNrmlMemModelBinSrch(final Map<String, float[]> word2vec, final int vectorSize) throws IOException {
this.word2vec = word2vec;
this.vectorSize = vectorSize;
initVars();
// process();
}

public void initVars() {
comparisonVecs = new float[compareVecCount][vectorSize];
csBucketContainer = new BitSet[compareVecCount][bucketCount];
process();
}

protected void process() throws IOException {
public void process() throws IOException {
LOG.info("Process from BinSrch called");
// Setting mean as comparison vec
setMeanComparisonVec(word2vec, vectorSize);
Expand Down Expand Up @@ -280,4 +284,20 @@ public Map<String, float[]> getWord2VecMap() {
return this.word2vec;
}

public int getCompareVecCount() {
return compareVecCount;
}

public void setCompareVecCount(int compareVecCount) {
this.compareVecCount = compareVecCount;
}

public int getBucketCount() {
return bucketCount;
}

public void setBucketCount(int bucketCount) {
this.bucketCount = bucketCount;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@
public class W2VNrmlMemModelKMeans extends W2VNrmlMemModelBinSrch {
public static Logger LOG = LogManager.getLogger(GenWord2VecModel.class);

private static final int KMEANS_MAX_ITR = 5;
private static final String VEC_FILEPATH = Cfg.get(W2VNrmlMemModelKMeans.class.getName().concat(".filepath"));
private int kMeansMaxItr = 5;
private String vecFilePath = Cfg.get(W2VNrmlMemModelKMeans.class.getName().concat(".filepath"));

public W2VNrmlMemModelKMeans(final Map<String, float[]> word2vec, final int vectorSize) throws IOException {
super(word2vec, vectorSize);
}

@Override
protected void process() throws IOException {
public void process() throws IOException {
LOG.info("Process from KMeans called");
fetchComparisonVectors();
// Initialize Arrays
processCosineSim();
}

private void fetchComparisonVectors() throws IOException {
File vecFile = new File(VEC_FILEPATH);
File vecFile = new File(vecFilePath);
if (vecFile.exists()) {
LOG.info("Reading Comparsion vectors from the file.");
// read the persisted vectors
Expand All @@ -73,7 +73,7 @@ private void fetchComparisonVectors() throws IOException {

private void generateComparisonVectors() {
KMeansPlusPlusClusterer<ClusterableVec> clusterer = new KMeansPlusPlusClusterer<>(compareVecCount,
KMEANS_MAX_ITR);
kMeansMaxItr);
List<ClusterableVec> vecList = new ArrayList<>();
for (float[] vec : word2vec.values()) {
vecList.add(getClusterablePoint(vec));
Expand Down Expand Up @@ -177,4 +177,21 @@ public static float[] convertToFloatArr(String[] vec) {
return resArr;
}

// Getter and Setters
public int getkMeansMaxItr() {
return kMeansMaxItr;
}

public void setkMeansMaxItr(int kMeansMaxItr) {
this.kMeansMaxItr = kMeansMaxItr;
}

public String getVecFilePath() {
return vecFilePath;
}

public void setVecFilePath(String vecFilePath) {
this.vecFilePath = vecFilePath;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ public void testNbmTime() throws IOException {
LOG.info("Starting InMemory Theta Model test!");
Word2VecModel nbm = Word2VecFactory.getNormalBinModel();
float[][] centroids = TEST_CENTROIDS;
//float[][] centroids = fetchWordsVec(TEST_WORDS, nbm);
// float[][] centroids = fetchWordsVec(TEST_WORDS, nbm);
LOG.info("Starting BruteForce-Model Test");
List<String> correctWords = getCorrectWords(centroids, nbm);
LOG.info("Correct Words are :" + correctWords);
LOG.info("Initializing W2VNrmlMemModelKMeans Model");
final W2VNrmlMemModelKMeans memModel = new W2VNrmlMemModelKMeans(nbm.word2vec, nbm.vectorSize);
memModel.process();
List<String> lrModelWords = new ArrayList<>();

LOG.info("Starting W2VNrmlMemModelKMeans Test");
Expand All @@ -66,7 +67,7 @@ private static float[][] fetchWordsVec(String[] words, Word2VecModel nbm) {
}
return resVec;
}

public static float calcPercScore(List<String> correctWordSet, List<String> lrModelWords) {
float percScore = 0;
int len = correctWordSet.size();
Expand All @@ -79,7 +80,7 @@ public static float calcPercScore(List<String> correctWordSet, List<String> lrMo
return percScore;

}

public static List<String> getCorrectWords(float[][] centroids, Word2VecModel nbm) {
List<String> wordSet = new ArrayList<>();
W2VNrmlMemModelBruteForce bruteForce = new W2VNrmlMemModelBruteForce(nbm.word2vec, nbm.vectorSize);
Expand Down

0 comments on commit 2d4ba0b

Please sign in to comment.