Skip to content

Commit

Permalink
Merge pull request #3 from nikit91/master
Browse files Browse the repository at this point in the history
New Models Implementation
  • Loading branch information
MichaelRoeder authored Oct 19, 2018
2 parents 07976ec + a9f0a91 commit e025956
Show file tree
Hide file tree
Showing 21 changed files with 1,727 additions and 105 deletions.
7 changes: 6 additions & 1 deletion config/cfg.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ org.aksw.word2vecrestful.web.Word2VecController.apikey: 1234
org.aksw.word2vecrestful.web.Word2VecController.maxN: 10000
org.aksw.word2vecrestful.word2vec.Word2VecModelLoader.bin: true
org.aksw.word2vecrestful.word2vec.Word2VecFactory.model: data/GoogleNews-vectors-negative300.bin
org.aksw.word2vecrestful.Application.inmemory: true
org.aksw.word2vecrestful.Application.inmemory: true
org.aksw.word2vecrestful.Application.subsetfiledir: data/subset-files-1/
org.aksw.word2vecrestful.word2vec.normalizedbinmodel.bin: true
org.aksw.word2vecrestful.word2vec.normalizedbinmodel.model: data/normalbinmodel/GoogleNews-vectors-negative300-normalized.bin
org.aksw.word2vecrestful.word2vec.stats.sdfile: data/normal/stat/normal-model-sd.csv
org.aksw.word2vecrestful.word2vec.W2VNrmlMemModelKMeans.filepath: data/kmeans/comparison-vecs.csv
2 changes: 1 addition & 1 deletion config/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss} %5p [%t] (%F:%M:%L) - %m%n
log4j.appender.file=org.apache.log4j.RollingFileAppender
log4j.appender.file.File=log/root.log
log4j.appender.file.MaxFileSize=1MB
log4j.appender.file.MaxFileSize=150MB
log4j.appender.file.MaxBackupIndex=100
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{HH:mm:ss} %5p [%t] (%F:%M:%L) - %m%n
48 changes: 46 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<java.version>1.8</java.version>
<slf4j.version>1.7.10</slf4j.version>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down Expand Up @@ -51,7 +52,33 @@
<version>4.11</version>
<scope>test</scope>
</dependency>
<!-- ~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~ -->

<!-- https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>2.9.6</version>
</dependency>

<dependency>
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>4.1</version>
</dependency>

<!-- https://mvnrepository.com/artifact/commons-io/commons-io -->
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>

<dependency>
<groupId>org.dice-research</groupId>
<artifactId>topicmodeling.commons</artifactId>
<version>0.0.3-SNAPSHOT</version>
</dependency>
<!-- ~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~ -->
<!-- slf4j: Logging API -->
<dependency>
<groupId>org.slf4j</groupId>
Expand All @@ -64,7 +91,12 @@
<artifactId>slf4j-log4j12</artifactId>
<version>${slf4j.version}</version>
</dependency>
<!-- ~~~~~~~~~~~~~~~~~~~ End Logging ~~~~~~~~~~~~~~~~~~~~~~ -->
<!-- ~~~~~~~~~~~~~~~~~~~ End Logging ~~~~~~~~~~~~~~~~~~~~~~ -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
</dependencies>
<build>
<testSourceDirectory>src/test/java</testSourceDirectory>
Expand All @@ -82,6 +114,7 @@
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
<argLine>-Xmx15024m</argLine>
</configuration>
</plugin>
<plugin>
Expand All @@ -91,6 +124,7 @@
<configuration>
<parallel>methods</parallel>
<threadCount>10</threadCount>
<argLine>-Xmx15024m</argLine>
</configuration>
</plugin>
</plugins>
Expand All @@ -100,6 +134,16 @@
<id>spring-releases</id>
<url>https://repo.spring.io/libs-release</url>
</repository>
<repository>
<id>maven.aksw.internal</id>
<name>University Leipzig, AKSW Maven2 Repository</name>
<url>http://maven.aksw.org/repository/internal</url>
</repository>
<repository>
<id>maven.aksw.snapshots</id>
<name>University Leipzig, AKSW Maven2 Repository</name>
<url>http://maven.aksw.org/repository/snapshots</url>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
Expand Down
146 changes: 146 additions & 0 deletions src/main/java/org/aksw/word2vecrestful/subset/DataSubsetGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package org.aksw.word2vecrestful.subset;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;

import org.aksw.word2vecrestful.utils.Cfg;
import org.aksw.word2vecrestful.word2vec.Word2VecFactory;
import org.aksw.word2vecrestful.word2vec.Word2VecModel;
import org.apache.commons.io.output.FileWriterWithEncoding;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;

/**
* Class to help generate and persist the subsets of a word2vec model Expected
* json config format:
*
* <pre>
*{
* "data" : [ {
* "key" : "xyz",
* "centroid" : [...],
* "sd" : [...]
* }...]
*}
* </pre>
* <dl>
* <dt>key</dt>
* <dd>identification literal for the subset</dd>
* <dt>centroid</dt>
* <dd>centroid array for the subset</dd>
* <dt>sd</dt>
* <dd>standard deviation array for the subset</dd>
* </dl>
*
* @author Nikit
*
*/
public class DataSubsetGenerator {

public static final String DATA_LABEL = "data";
public static final String KEY_LABEL = "key";
public static final String CENTROID_LABEL = "centroid";
public static final String SD_LABEL = "sd";
public static final ObjectMapper OBJ_MAPPER = new ObjectMapper();
public static final ObjectReader OBJ_READER = OBJ_MAPPER.reader();

/**
* Method to generate subset json files for a given configuration and word2vec
* model
*
* @param subsetConfig
* - configuration json file
* @param outputFileDir
* - output directory for the subset files
* @param word2vec
* - word2vec model map
* @param vectorSize
* - size of the vectors in model
* @throws JsonProcessingException
* @throws FileNotFoundException
* @throws IOException
*/
public static void generateSubsetFiles(File subsetConfig, String outputFileDir, Map<String, float[]> word2vec,
int vectorSize) throws JsonProcessingException, FileNotFoundException, IOException {
// Read file into a json
ObjectNode inpObj = (ObjectNode) OBJ_READER.readTree(new FileInputStream(subsetConfig));
ArrayNode inpDt = (ArrayNode) inpObj.get(DATA_LABEL);
// Traverse the json for keys
Iterator<JsonNode> inpIt = inpDt.iterator();
float[] maxlim = new float[vectorSize];
float[] minlim = new float[vectorSize];
while (inpIt.hasNext()) {

JsonNode curNode = inpIt.next();
// fetch value of key
String key = curNode.get(KEY_LABEL).asText();
// fetch value of centroid
ArrayNode centroid = (ArrayNode) curNode.get(CENTROID_LABEL);
// fetch value of standard deviation
ArrayNode stndrdDev = (ArrayNode) curNode.get(SD_LABEL);
// create an output file
File outputFile = new File(outputFileDir + "/" + key + ".txt");
outputFile.getParentFile().mkdirs();
// open an output stream
BufferedWriter bWriter = new BufferedWriter(new FileWriterWithEncoding(outputFile, StandardCharsets.UTF_8));
boolean limitNotSet = true;
// loop through the model
for (Entry<String, float[]> wordEntry : word2vec.entrySet()) {
String word = wordEntry.getKey();
float[] wordvec = wordEntry.getValue();
boolean isValid = true;
for (int i = 0; i < centroid.size(); i++) {
if (limitNotSet) {
float centVal = centroid.get(i).floatValue();
float sdVal = stndrdDev.get(i).floatValue();
// maxlim = add sd to centroid
maxlim[i] = centVal + 3 * sdVal;
// minlim = subtract sb from centroid
minlim[i] = centVal - 3 * sdVal;
}
// check if values of all the dimensions are under maxlim and minlim
float curVal = wordvec[i];
if (curVal > maxlim[i] || curVal < minlim[i]) {
isValid = false;
break;
}
}
limitNotSet = false;
if (isValid) {
// write the word in the file
bWriter.write(word);
bWriter.newLine();
}
}
// Close the stream
bWriter.close();
}
}

/**
* Method to demonstrate example usage
*
* @param args
* @throws JsonProcessingException
* @throws FileNotFoundException
* @throws IOException
*/
public static void main(String[] args) throws JsonProcessingException, FileNotFoundException, IOException {
File subsetConfig = new File("word2vec-dump\\subsetconfig2.json");
Word2VecModel model = Word2VecFactory.getNormalBinModel();
generateSubsetFiles(subsetConfig, Cfg.get("org.aksw.word2vecrestful.Application.subsetfiledir"), model.word2vec,
model.vectorSize);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.aksw.word2vecrestful.subset;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.aksw.word2vecrestful.utils.Cfg;
import org.apache.commons.io.FileUtils;

import com.fasterxml.jackson.core.JsonProcessingException;

/**
* Class to help retrieve the list of words from a subset stored on the disk
*
* @author Nikit
*
*/
public class DataSubsetProvider {

private String fileDir = Cfg.get("org.aksw.word2vecrestful.Application.subsetfiledir");
private final Map<String, Set<String>> SUBSET_MODELS = new HashMap<>();

/**
* Method to fetch the set of words in a subset
*
* @param subsetKey
* - key to identify the subset
* @return - a list of words in the related subset
* @throws JsonProcessingException
* @throws FileNotFoundException
* @throws IOException
*/
public Set<String> fetchSubsetWords(String subsetKey) throws IOException {
// fetch from cache
Set<String> resList = SUBSET_MODELS.get(subsetKey);
// if not in cache then read from file and add to cache
if (resList == null) {
// logic to fetch the words from the stored subsets
File file1 = new File(fileDir + "/" + appendFileExtension(subsetKey));
if (file1.exists()) {
resList = new HashSet<>();
resList.addAll(FileUtils.readLines(file1, StandardCharsets.UTF_8));
SUBSET_MODELS.put(subsetKey, resList);
}
}
return resList;
}

/**
* Method to append txt extension at the end of a key
*
* @param name
* - key
* @return key appended with txt extension
*/
public static String appendFileExtension(String name) {
return name + ".txt";
}

}
Loading

0 comments on commit e025956

Please sign in to comment.