From 2d4ba0b408ed50d1e92609f831178e0c0035f033 Mon Sep 17 00:00:00 2001 From: "nikit91@gmail.com" Date: Fri, 19 Oct 2018 07:51:29 +0200 Subject: [PATCH] Changing method calls --- .../word2vec/W2VNrmlMemModelBinSrch.java | 24 +++++++++++++++-- .../word2vec/W2VNrmlMemModelKMeans.java | 27 +++++++++++++++---- .../NrmlzdThetaMdlPrfmncTester.java | 7 ++--- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBinSrch.java b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBinSrch.java index f50c8a8..6f90405 100644 --- a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBinSrch.java +++ b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelBinSrch.java @@ -44,12 +44,16 @@ protected W2VNrmlMemModelBinSrch() { public W2VNrmlMemModelBinSrch(final Map word2vec, final int vectorSize) throws IOException { this.word2vec = word2vec; this.vectorSize = vectorSize; + initVars(); + // process(); + } + + public void initVars() { comparisonVecs = new float[compareVecCount][vectorSize]; csBucketContainer = new BitSet[compareVecCount][bucketCount]; - process(); } - protected void process() throws IOException { + public void process() throws IOException { LOG.info("Process from BinSrch called"); // Setting mean as comparison vec setMeanComparisonVec(word2vec, vectorSize); @@ -280,4 +284,20 @@ public Map getWord2VecMap() { return this.word2vec; } + public int getCompareVecCount() { + return compareVecCount; + } + + public void setCompareVecCount(int compareVecCount) { + this.compareVecCount = compareVecCount; + } + + public int getBucketCount() { + return bucketCount; + } + + public void setBucketCount(int bucketCount) { + this.bucketCount = bucketCount; + } + } diff --git a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelKMeans.java b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelKMeans.java index d8ec809..1fc2316 100644 --- a/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelKMeans.java +++ b/src/main/java/org/aksw/word2vecrestful/word2vec/W2VNrmlMemModelKMeans.java @@ -39,15 +39,15 @@ public class W2VNrmlMemModelKMeans extends W2VNrmlMemModelBinSrch { public static Logger LOG = LogManager.getLogger(GenWord2VecModel.class); - private static final int KMEANS_MAX_ITR = 5; - private static final String VEC_FILEPATH = Cfg.get(W2VNrmlMemModelKMeans.class.getName().concat(".filepath")); + private int kMeansMaxItr = 5; + private String vecFilePath = Cfg.get(W2VNrmlMemModelKMeans.class.getName().concat(".filepath")); public W2VNrmlMemModelKMeans(final Map word2vec, final int vectorSize) throws IOException { super(word2vec, vectorSize); } @Override - protected void process() throws IOException { + public void process() throws IOException { LOG.info("Process from KMeans called"); fetchComparisonVectors(); // Initialize Arrays @@ -55,7 +55,7 @@ protected void process() throws IOException { } private void fetchComparisonVectors() throws IOException { - File vecFile = new File(VEC_FILEPATH); + File vecFile = new File(vecFilePath); if (vecFile.exists()) { LOG.info("Reading Comparsion vectors from the file."); // read the persisted vectors @@ -73,7 +73,7 @@ private void fetchComparisonVectors() throws IOException { private void generateComparisonVectors() { KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer<>(compareVecCount, - KMEANS_MAX_ITR); + kMeansMaxItr); List vecList = new ArrayList<>(); for (float[] vec : word2vec.values()) { vecList.add(getClusterablePoint(vec)); @@ -177,4 +177,21 @@ public static float[] convertToFloatArr(String[] vec) { return resArr; } + // Getter and Setters + public int getkMeansMaxItr() { + return kMeansMaxItr; + } + + public void setkMeansMaxItr(int kMeansMaxItr) { + this.kMeansMaxItr = kMeansMaxItr; + } + + public String getVecFilePath() { + return vecFilePath; + } + + public void setVecFilePath(String vecFilePath) { + this.vecFilePath = vecFilePath; + } + } diff --git a/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java b/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java index ba797eb..afb52b3 100644 --- a/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java +++ b/src/test/java/org/aksw/word2vecrestful/NrmlzdThetaMdlPrfmncTester.java @@ -34,12 +34,13 @@ public void testNbmTime() throws IOException { LOG.info("Starting InMemory Theta Model test!"); Word2VecModel nbm = Word2VecFactory.getNormalBinModel(); float[][] centroids = TEST_CENTROIDS; - //float[][] centroids = fetchWordsVec(TEST_WORDS, nbm); + // float[][] centroids = fetchWordsVec(TEST_WORDS, nbm); LOG.info("Starting BruteForce-Model Test"); List correctWords = getCorrectWords(centroids, nbm); LOG.info("Correct Words are :" + correctWords); LOG.info("Initializing W2VNrmlMemModelKMeans Model"); final W2VNrmlMemModelKMeans memModel = new W2VNrmlMemModelKMeans(nbm.word2vec, nbm.vectorSize); + memModel.process(); List lrModelWords = new ArrayList<>(); LOG.info("Starting W2VNrmlMemModelKMeans Test"); @@ -66,7 +67,7 @@ private static float[][] fetchWordsVec(String[] words, Word2VecModel nbm) { } return resVec; } - + public static float calcPercScore(List correctWordSet, List lrModelWords) { float percScore = 0; int len = correctWordSet.size(); @@ -79,7 +80,7 @@ public static float calcPercScore(List correctWordSet, List lrMo return percScore; } - + public static List getCorrectWords(float[][] centroids, Word2VecModel nbm) { List wordSet = new ArrayList<>(); W2VNrmlMemModelBruteForce bruteForce = new W2VNrmlMemModelBruteForce(nbm.word2vec, nbm.vectorSize);