From 36ef248137015761dc3cf2c600116fcece911bf6 Mon Sep 17 00:00:00 2001 From: Artur Koperkiewicz Date: Mon, 4 Jun 2018 01:48:52 +0200 Subject: [PATCH] SRBM : VIS : Fixed visualization snapshoot id for concurrent computations. Batch index in visualization file thread safe added. --- src/main/java/org/wit/snr/nn/srbm/SRBM.java | 8 ++++---- src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/wit/snr/nn/srbm/SRBM.java b/src/main/java/org/wit/snr/nn/srbm/SRBM.java index 1324afb..1519886 100644 --- a/src/main/java/org/wit/snr/nn/srbm/SRBM.java +++ b/src/main/java/org/wit/snr/nn/srbm/SRBM.java @@ -83,9 +83,9 @@ private void initCanvas() { canvas.createBufferStrategy(3); } - void draw(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) { + void draw(int batchIndex, Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) { displayVisualizationOnScreen(W, X, negM, neghidprobs, vbias); - saveVisualizationToFile(W, X, negM, neghidprobs, vbias); + saveVisualizationToFile(batchIndex, W, X, negM, neghidprobs, vbias); } private void displayVisualizationOnScreen(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) { @@ -112,7 +112,7 @@ private void renderVisualizationOnGraphicsComponent(Matrix W, Matrix X, Matrix n vbiasDraw.render(); } - public void saveVisualizationToFile(Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) { + public void saveVisualizationToFile(int batchIndex, Matrix W, Matrix X, Matrix negM, Matrix neghidprobs, Matrix vbias) { try { BufferedImage image = new BufferedImage(canvas.getWidth(), canvas.getHeight(), BufferedImage.TYPE_INT_RGB); Graphics2D graphics = image.createGraphics(); @@ -122,7 +122,7 @@ public void saveVisualizationToFile(Matrix W, Matrix X, Matrix negM, Matrix negh + File.separatorChar + sessionId + "-" + currentEpoch.get() - + "-" + miniBatchIndex.get() + + "-" + batchIndex + ".jpg"; ImageIO.write(image, "JPEG", new File(pathname)); } catch (Exception e) { diff --git a/src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java b/src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java index b6f6503..7faa6d9 100644 --- a/src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java +++ b/src/main/java/org/wit/snr/nn/srbm/SRBMMapReduceJSA.java @@ -51,7 +51,7 @@ private void updateLayerData(MiniBatchTrainingResult reduce) { private MiniBatchTrainingResult trainMiniBatch(Matrix X) { - + final int batchIndex = miniBatchIndex.getAndIncrement(); timer.set(new Timer()); timer.get().start(); Matrix poshidprobs = getHidProbs(X); @@ -62,8 +62,9 @@ private MiniBatchTrainingResult trainMiniBatch(Matrix X) { Matrix vBiasDelta = updateVBias(X, negdata); updateError(X, negdata); Matrix hBiasDelta = updateHBias(X); - System.out.printf("E %s/%s | %s | %s %n", miniBatchIndex.getAndIncrement() * cfg.batchSize, currentEpoch, layer.error, timer.get().toString()); - draw(layer.W, X, negdata, layer.hbias.reshape(50), layer.vbias); + + System.out.printf("E %s/%s | %s | %s %n", batchIndex * cfg.batchSize, currentEpoch, layer.error, timer.get().toString()); + draw(batchIndex, layer.W, X, negdata, layer.hbias.reshape(50), layer.vbias); // timer.get().reset(); timer.remove(); return new MiniBatchTrainingResult(Wdelta, vBiasDelta, hBiasDelta);