Skip to content

Commit

Permalink
Merge pull request #4 from nikit91/db-normal-model
Browse files Browse the repository at this point in the history
New search model
  • Loading branch information
nikit-srivastava authored Oct 19, 2018
2 parents c428d5d + 2df2539 commit a9f0a91
Show file tree
Hide file tree
Showing 19 changed files with 1,414 additions and 112 deletions.
6 changes: 5 additions & 1 deletion config/cfg.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ org.aksw.word2vecrestful.web.Word2VecController.maxN: 10000
org.aksw.word2vecrestful.word2vec.Word2VecModelLoader.bin: true
org.aksw.word2vecrestful.word2vec.Word2VecFactory.model: data/GoogleNews-vectors-negative300.bin
org.aksw.word2vecrestful.Application.inmemory: true
org.aksw.word2vecrestful.Application.subsetfiledir: data/subset-files-1/
org.aksw.word2vecrestful.Application.subsetfiledir: data/subset-files-1/
org.aksw.word2vecrestful.word2vec.normalizedbinmodel.bin: true
org.aksw.word2vecrestful.word2vec.normalizedbinmodel.model: data/normalbinmodel/GoogleNews-vectors-negative300-normalized.bin
org.aksw.word2vecrestful.word2vec.stats.sdfile: data/normal/stat/normal-model-sd.csv
org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelKMeans.filepath: data/kmeans/comparison-vecs.csv
2 changes: 1 addition & 1 deletion config/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss} %5p [%t] (%F:%M:%L) - %m%n
log4j.appender.file=org.apache.log4j.RollingFileAppender
log4j.appender.file.File=log/root.log
log4j.appender.file.MaxFileSize=1MB
log4j.appender.file.MaxFileSize=150MB
log4j.appender.file.MaxBackupIndex=100
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{HH:mm:ss} %5p [%t] (%F:%M:%L) - %m%n
39 changes: 38 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<java.version>1.8</java.version>
<slf4j.version>1.7.10</slf4j.version>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down Expand Up @@ -71,7 +72,31 @@
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>


<dependency>
<groupId>org.dice-research</groupId>
<artifactId>topicmodeling.commons</artifactId>
<version>0.0.3-SNAPSHOT</version>
</dependency>
<!-- ~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~ -->
<!-- slf4j: Logging API -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<!-- slf4j: Logging Binding -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>${slf4j.version}</version>
</dependency>
<!-- ~~~~~~~~~~~~~~~~~~~ End Logging ~~~~~~~~~~~~~~~~~~~~~~ -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
</dependencies>
<build>
<testSourceDirectory>src/test/java</testSourceDirectory>
Expand All @@ -89,6 +114,7 @@
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
<argLine>-Xmx15024m</argLine>
</configuration>
</plugin>
<plugin>
Expand All @@ -98,6 +124,7 @@
<configuration>
<parallel>methods</parallel>
<threadCount>10</threadCount>
<argLine>-Xmx15024m</argLine>
</configuration>
</plugin>
</plugins>
Expand All @@ -107,6 +134,16 @@
<id>spring-releases</id>
<url>https://repo.spring.io/libs-release</url>
</repository>
<repository>
<id>maven.aksw.internal</id>
<name>University Leipzig, AKSW Maven2 Repository</name>
<url>http://maven.aksw.org/repository/internal</url>
</repository>
<repository>
<id>maven.aksw.snapshots</id>
<name>University Leipzig, AKSW Maven2 Repository</name>
<url>http://maven.aksw.org/repository/snapshots</url>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ public static void generateSubsetFiles(File subsetConfig, String outputFileDir,
* @throws IOException
*/
public static void main(String[] args) throws JsonProcessingException, FileNotFoundException, IOException {
File subsetConfig = new File(".\\word2vec-dump\\subsetconfig2.json");
Word2VecModel model = Word2VecFactory.get();
File subsetConfig = new File("word2vec-dump\\subsetconfig2.json");
Word2VecModel model = Word2VecFactory.getNormalBinModel();
generateSubsetFiles(subsetConfig, Cfg.get("org.aksw.word2vecrestful.Application.subsetfiledir"), model.word2vec,
model.vectorSize);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.aksw.word2vecrestful.utils.Cfg;
import org.apache.commons.io.FileUtils;
Expand All @@ -22,10 +23,10 @@
public class DataSubsetProvider {

private String fileDir = Cfg.get("org.aksw.word2vecrestful.Application.subsetfiledir");
private final Map<String, List<String>> SUBSET_MODELS = new HashMap<>();
private final Map<String, Set<String>> SUBSET_MODELS = new HashMap<>();

/**
* Method to fetch the list of words in a subset
* Method to fetch the set of words in a subset
*
* @param subsetKey
* - key to identify the subset
Expand All @@ -34,15 +35,16 @@ public class DataSubsetProvider {
* @throws FileNotFoundException
* @throws IOException
*/
public List<String> fetchSubsetWords(String subsetKey) throws IOException {
public Set<String> fetchSubsetWords(String subsetKey) throws IOException {
// fetch from cache
List<String> resList = SUBSET_MODELS.get(subsetKey);
Set<String> resList = SUBSET_MODELS.get(subsetKey);
// if not in cache then read from file and add to cache
if (resList == null) {
// logic to fetch the words from the stored subsets
File file1 = new File(fileDir + "/" + appendFileExtension(subsetKey));
if (file1.exists()) {
resList = FileUtils.readLines(file1, StandardCharsets.UTF_8);
resList = new HashSet<>();
resList.addAll(FileUtils.readLines(file1, StandardCharsets.UTF_8));
SUBSET_MODELS.put(subsetKey, resList);
}
}
Expand Down
161 changes: 161 additions & 0 deletions src/main/java/org/aksw/word2vecrestful/tool/ModelNormalizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package org.aksw.word2vecrestful.tool;

import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;

import org.aksw.word2vecrestful.utils.Cfg;
import org.aksw.word2vecrestful.utils.Word2VecMath;
import org.aksw.word2vecrestful.word2vec.Word2VecFactory;
import org.aksw.word2vecrestful.word2vec.Word2VecModelLoader;
import org.apache.commons.io.output.FileWriterWithEncoding;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;

public class ModelNormalizer {
public static Logger LOG = LogManager.getLogger(ModelNormalizer.class);
public static final byte[] END_LINE_BA = "\n".getBytes(StandardCharsets.UTF_8);
public static final byte[] WHITESPACE_BA = " ".getBytes(StandardCharsets.UTF_8);

/**
* Method to normalize a bin word2vec model line
*
* @param line
* - line from a bin model to be normalized
* @param vectorSize
* - size of the vector
* @return - normalized line
*/
public static String getNormalizedVecLine(String word, float[] vector) {
StringBuffer resStr = new StringBuffer();
resStr.append(word);
vector = Word2VecMath.normalize(vector);
for (int i = 0; i < vector.length; i++) {
resStr.append(" ").append(String.valueOf(vector[i]));
}
return resStr.toString();
}

public static byte[] getNormalizedVecBA(float[] vector) {
vector = Word2VecMath.normalize(vector);
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
buffer.order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < vector.length; i++) {
buffer.putFloat(vector[i]);
}
return buffer.array();
}

/**
* Method to generate a normalized model for a word2vec bin model
*
* @param inputFile
* - word2vec file of the model to be normalized
* @param outputFile
* - output file for normalized model
* @throws IOException
*/
public void generateNormalizedModel(File inputFile, File outputFile) throws IOException {
// ensure directory creation
outputFile.getParentFile().mkdirs();
// open an output stream
BufferedWriter bWriter = null;
FileInputStream fin = null;
try {
// reads file header
fin = new FileInputStream(inputFile);
String word = Word2VecModelLoader.readWord(fin);
int words = Integer.parseInt(word);
word = Word2VecModelLoader.readWord(fin);
int vectorSize = Integer.parseInt(word);
bWriter = new BufferedWriter(new FileWriterWithEncoding(outputFile, StandardCharsets.UTF_8));
bWriter.write(words + " " + vectorSize);
LOG.info("Expecting " + words + " words with " + vectorSize + " values per vector.");
for (int w = 0; w < words; ++w) {
word = Word2VecModelLoader.readWord(fin);
// LOG.info(word);
float[] vector = Word2VecModelLoader.readVector(fin, vectorSize);
bWriter.newLine();
bWriter.write(getNormalizedVecLine(word, vector));
if (w % 10000 == 0) {
bWriter.flush();
}
}
} catch (final IOException e) {
LOG.error(e.getLocalizedMessage(), e);
} finally {
fin.close();
bWriter.close();
}
}

/**
* Method to generate a normalized model for a word2vec bin model
*
* @param inputFile
* - word2vec file of the model to be normalized
* @param outputFile
* - output file for normalized model
* @throws IOException
*/
public void generateNormalizedBinModel(File inputFile, File outputFile) throws IOException {
// ensure directory creation
outputFile.getParentFile().mkdirs();
// open an output stream
BufferedOutputStream bOutStrm = null;
FileInputStream fin = null;
try {
bOutStrm = new BufferedOutputStream(new FileOutputStream(outputFile));
// reads file header
fin = new FileInputStream(inputFile);
String word = Word2VecModelLoader.readWord(fin);
bOutStrm.write(word.getBytes(StandardCharsets.UTF_8));
bOutStrm.write(WHITESPACE_BA);
Integer words = Integer.parseInt(word);
word = Word2VecModelLoader.readWord(fin);
bOutStrm.write(word.getBytes(StandardCharsets.UTF_8));
Integer vectorSize = Integer.parseInt(word);
bOutStrm.write(END_LINE_BA);
LOG.info("Expecting " + words + " words with " + vectorSize + " values per vector.");
for (int w = 0; w < words; ++w) {
word = Word2VecModelLoader.readWord(fin);
// LOG.info(word);
float[] vector = Word2VecModelLoader.readVector(fin, vectorSize);

bOutStrm.write(word.getBytes(StandardCharsets.UTF_8));
bOutStrm.write(WHITESPACE_BA);
bOutStrm.write(getNormalizedVecBA(vector));

if ((w + 1) % 10000 == 0) {
bOutStrm.flush();
LOG.info((w + 1) + " Records inserted.");
}
}
} catch (final IOException e) {
LOG.error(e.getLocalizedMessage(), e);
} finally {
fin.close();
bOutStrm.close();
}
}

public static void main(String[] args) throws IOException, SQLException {
String cfgKeyModel = Word2VecFactory.class.getName().concat(".model");
String model = (Cfg.get(cfgKeyModel));
ModelNormalizer modelNormalizer = new ModelNormalizer();
File inputFile = new File(model);
// "org.aksw.word2vecrestful.word2vec.normalizedbinmodel.model"
String outputModel = (Cfg.get("org.aksw.word2vecrestful.word2vec.normalizedbinmodel.model"));
File outputFile = new File(outputModel);
// modelNormalizer.generateNormalizedModel(inputFile, outputFile);
modelNormalizer.generateNormalizedBinModel(inputFile, outputFile);
}

}
Loading

0 comments on commit a9f0a91

Please sign in to comment.