Skip to content

Commit

Permalink
SRBM : VIS : Fixed visualization snapshoot id for concurrent computat…
Browse files Browse the repository at this point in the history
…ions.

Batch index in visualization file thread safe added.
  • Loading branch information
operixon committed Jun 3, 2018
1 parent c6df88c commit 36ef248
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/main/java/org/wit/snr/nn/srbm/SRBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ private void initCanvas() {
canvas.createBufferStrategy(3);
}

void draw(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) {
void draw(int batchIndex, Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) {
displayVisualizationOnScreen(W, X, negM, neghidprobs, vbias);
saveVisualizationToFile(W, X, negM, neghidprobs, vbias);
saveVisualizationToFile(batchIndex, W, X, negM, neghidprobs, vbias);
}

private void displayVisualizationOnScreen(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) {
Expand All @@ -112,7 +112,7 @@ private void renderVisualizationOnGraphicsComponent(Matrix W, Matrix X, Matrix n
vbiasDraw.render();
}

public void saveVisualizationToFile(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) {
public void saveVisualizationToFile(int batchIndex, Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) {
try {
BufferedImage image = new BufferedImage(canvas.getWidth(), canvas.getHeight(), BufferedImage.TYPE_INT_RGB);
Graphics2D graphics = image.createGraphics();
Expand All @@ -122,7 +122,7 @@ public void saveVisualizationToFile(Matrix W, Matrix X, Matrix negM, Matrix negh
+ File.separatorChar
+ sessionId
+ "-" + currentEpoch.get()
+ "-" + miniBatchIndex.get()
+ "-" + batchIndex
+ ".jpg";
ImageIO.write(image, "JPEG", new File(pathname));
} catch (Exception e) {
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void updateLayerData(MiniBatchTrainingResult reduce) {


private MiniBatchTrainingResult trainMiniBatch(Matrix X) {

final int batchIndex = miniBatchIndex.getAndIncrement();
timer.set(new Timer());
timer.get().start();
Matrix poshidprobs = getHidProbs(X);
Expand All @@ -62,8 +62,9 @@ private MiniBatchTrainingResult trainMiniBatch(Matrix X) {
Matrix vBiasDelta = updateVBias(X, negdata);
updateError(X, negdata);
Matrix hBiasDelta = updateHBias(X);
System.out.printf("E %s/%s | %s | %s %n", miniBatchIndex.getAndIncrement() * cfg.batchSize, currentEpoch, layer.error, timer.get().toString());
draw(layer.W, X, negdata, layer.hbias.reshape(50), layer.vbias);

System.out.printf("E %s/%s | %s | %s %n", batchIndex * cfg.batchSize, currentEpoch, layer.error, timer.get().toString());
draw(batchIndex, layer.W, X, negdata, layer.hbias.reshape(50), layer.vbias);
// timer.get().reset();
timer.remove();
return new MiniBatchTrainingResult(Wdelta, vBiasDelta, hBiasDelta);
Expand Down

0 comments on commit 36ef248

Please sign in to comment.