How a two-layer neural network learns, forgets, and corrects itself — from bulk MNIST pre-training through online single-sample fine-tuning, with live weight statistics at every step.
Training a model on a large dataset gives it broad knowledge. But what happens
when the model misclassifies a specific digit you drew? You can correct it —
without starting over — by running a handful of gradient steps on just that one
sample. This is online fine-tuning, and it's exactly what the
TensorTrain system implements in its trainOnSample() method.
This post walks through the full pipeline: loading data in IDX binary format, building and training a multi-layer perceptron with TensorFlow Java, tracking weight statistics epoch-by-epoch, and finally fine-tuning on a single user-drawn sample. Code snippets are taken directly from the source.
Two datasets are supported. Both are stored in the IDX binary format: a compact binary encoding where a 32-bit magic number identifies the data type, followed by dimension counts and raw pixel bytes.
MnistLoader downloads four gzip files from a Google Cloud mirror,
decompresses them on the fly, and parses the IDX headers. Each pixel byte is
normalized to the [0, 1] range by dividing by 255:
/** Returns float[numImages][784], values normalized to [0,1]. */
public static float[][] loadImages(String filename) throws IOException {
try (DataInputStream dis = open(filename)) {
int magic = dis.readInt(); // 2051
int count = dis.readInt();
int rows = dis.readInt(); // 28
int cols = dis.readInt(); // 28
float[][] images = new float[count][rows * cols];
for (int i = 0; i < count; i++) {
for (int j = 0; j < rows * cols; j++) {
images[i][j] = dis.readUnsignedByte() / 255.0f;
}
}
return images;
}
}
EMNIST images are stored column-major (transposed relative to MNIST).
Reading them naively produces mirror-flipped, 90°-rotated digits that the model
cannot learn from. EmnistLoader corrects this with an explicit
(r, c) → (c, r) swap during loading:
// EMNIST stores images transposed — swap (r, c) → (c, r)
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
images[i][c * rows + r] = (raw[r * cols + c] & 0xFF) / 255.0f;
}
}
The classifier is a two-layer multi-layer perceptron (MLP) built directly on the TensorFlow Java low-level graph API — no Keras, no high-level wrappers.
| Layer | Shape | Parameters | Activation |
|---|---|---|---|
| W1 | 784 × 256 | 200,704 | — |
| b1 | 256 | 256 | — |
| ReLU | ?× 256 | 0 | max(0, x) |
| W2 | 256 × 10 | 2,560 | — |
| b2 | 10 | 10 | — |
| Softmax output | ? × 10 | 0 | Softmax |
| Total | 203,530 | ||
Weights are initialized with He (Kaiming) Normal scaling:
scale = √(2 / fan_in). This keeps the variance of activations
stable through ReLU layers, avoiding vanishing or exploding gradients at the
start of training.
private void initRandomWeights() {
Random rng = new Random(42);
float s1 = (float) Math.sqrt(2.0 / INPUT_SIZE); // He scale for W1
float s2 = (float) Math.sqrt(2.0 / HIDDEN_SIZE); // He scale for W2
assignWeights(
randMatrix(INPUT_SIZE, HIDDEN_SIZE, rng, s1),
new float[HIDDEN_SIZE], // biases start at zero
randMatrix(HIDDEN_SIZE, OUTPUT_SIZE, rng, s2),
new float[OUTPUT_SIZE]);
}
private float[][] randMatrix(int r, int c, Random rng, float scale) {
float[][] m = new float[r][c];
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
m[i][j] = (float) rng.nextGaussian() * scale;
return m;
}
The model uses categorical cross-entropy, computed through
log-softmax + one-hot encoding rather than
SparseSoftmaxCrossEntropyWithLogits — which lacks a registered
C++ gradient in the TF Java 2.x build:
// log p(y | x) for each sample, then average over batch
var logSoftmax = tf.nn.logSoftmax(logits);
var oneHot = tf.oneHot(yLabels, tf.constant(OUTPUT_SIZE),
tf.constant(1.0f), tf.constant(0.0f));
var meanLogProb = tf.math.mean(tf.math.mul(oneHot, logSoftmax),
tf.constant(new int[]{1}));
// Negate and scale to get standard cross-entropy loss
var loss = tf.math.mean(
tf.math.neg(tf.math.mul(meanLogProb, tf.constant((float) OUTPUT_SIZE))),
tf.constant(new int[]{0}));
Training runs asynchronously via a servlet thread pool. The core loop in
DigitClassifier.train() shuffles the entire training set each
epoch, then processes mini-batches of 128 samples. After each epoch it
evaluates on a 2,000-sample subset and records a weight snapshot.
public void train(float[][] images, int[] labels, int epochs,
TrainingCallback cb) {
int n = images.length;
int batches = n / BATCH_SIZE; // BATCH_SIZE = 128
for (int ep = 1; ep <= epochs; ep++) {
int[] idx = shuffleIdx(n); // fresh random permutation every epoch
for (int b = 0; b < batches; b++) {
float[][] xb = new float[BATCH_SIZE][INPUT_SIZE];
long[] yb = new long[BATCH_SIZE];
for (int i = 0; i < BATCH_SIZE; i++) {
int s = idx[b * BATCH_SIZE + i];
xb[i] = images[s];
yb[i] = labels[s];
}
try (TFloat32 xt = mat2Tensor(xb);
TInt64 yt = longVec2Tensor(yb)) {
session.runner()
.feed(xInput.asOutput(), xt)
.feed(yLabels.asOutput(), yt)
.addTarget(trainOp) // Adam minimise(loss)
.run();
}
}
lastAccuracy = evaluate(images, labels, 2000);
snapshots.add(recordSnapshot(ep, lastAccuracy));
if (cb != null) cb.onEpoch(ep, epochs, lastAccuracy);
}
saveWeights(); // persist to ~/.tensortrain/model_weights.ser
}
After every epoch, recordSnapshot() computes four statistics for
each weight tensor — mean, standard deviation, min, and max — and stores them
in an EpochSnapshot. These are served by
GET /api/graph/snapshots and rendered as live charts in the UI.
private EpochSnapshot recordSnapshot(int ep, float accuracy) {
return new EpochSnapshot(ep, accuracy,
computeStats("W1", new int[]{INPUT_SIZE, HIDDEN_SIZE}, w1),
computeStats("b1", new int[]{HIDDEN_SIZE}, b1),
computeStats("W2", new int[]{HIDDEN_SIZE, OUTPUT_SIZE}, w2),
computeStats("b2", new int[]{OUTPUT_SIZE}, b2));
}
The computeStats method also builds a 20-bin histogram over the
full weight distribution, enabling the UI to draw weight-distribution charts.
Here is what the statistics typically look like across 10 training epochs:
| Epoch | W1 mean | W1 std | W1 min | W1 max | Accuracy |
|---|---|---|---|---|---|
| Init | 0.0000 | 0.0500 | −0.162 | +0.159 | — |
| 1 | −0.0003 | 0.0521 | −0.201 | +0.198 | 82.4% |
| 2 | −0.0005 | 0.0548 | −0.238 | +0.241 | 91.8% |
| 3 | −0.0006 | 0.0571 | −0.261 | +0.264 | 93.9% |
| 5 | −0.0007 | 0.0603 | −0.288 | +0.291 | 95.9% |
| 10 | −0.0008 | 0.0641 | −0.315 | +0.320 | 97.6% |
The key observation: as training proceeds, the weight distribution spreads out. Neurons are specializing — some learn to detect horizontal strokes, others vertical, others curves. A narrow distribution (all weights near zero) means the model has not yet differentiated its neurons. A wide, roughly bell-shaped distribution is a healthy sign of learned features.
After bulk training the model is deployed as a web app where users draw digits
on a canvas. If the model guesses wrong, the user can correct it.
This triggers POST /api/train/sample, which calls
trainOnSample() — running 20 Adam gradient steps on just that
one 28×28 drawing.
trainOnSample method/**
* Runs a small number of gradient steps on one user-supplied sample.
* Used for online correction after an incorrect prediction.
*
* @param image 784-element pixel array (values in [0,1])
* @param label correct digit label (0-9)
* @param steps number of gradient steps (10-30 is a reasonable default)
*/
public synchronized void trainOnSample(float[] image, int label, int steps) {
float[][] x = new float[][] { image };
long[] y = new long[] { (long) label };
try (TFloat32 xt = mat2Tensor(x);
TInt64 yt = longVec2Tensor(y)) {
for (int i = 0; i < steps; i++) {
session.runner()
.feed(xInput.asOutput(), xt)
.feed(yLabels.asOutput(), yt)
.addTarget(trainOp)
.run();
}
}
trained = true;
}
The servlet wires this up as a synchronous HTTP call so the UI immediately reflects the correction:
// TrainServlet.java — POST /api/train/sample
float[] pixels = new float[DigitClassifier.INPUT_SIZE];
for (int i = 0; i < pixels.length; i++)
pixels[i] = ((Number) rawPixels.get(i)).floatValue();
clf.trainOnSample(pixels, label, 20); // 20 gradient steps
Fine-tuning on a single sample produces a targeted, localized weight shift. Because Adam maintains per-parameter moment estimates from bulk training, it applies the single-sample gradients with appropriate magnitude — not too aggressive, not too timid.
The chart above shows the average weight delta in W2 (the output layer) after fine-tuning on a "7" that was misclassified as "2". The model pushes up the column of weights that votes for class 7, and nudges down the competing classes — with the largest penalty applied to class 2, the incorrect prediction. The W1 (hidden layer) changes are smaller but spread across all 200,704 parameters — those neurons whose activations were strong for this image get a slightly larger push.
A single-sample fine-tune is surgical: it nudges the loss surface for one input without rewriting the knowledge encoded across 60,000+ training examples. That said, 20 gradient steps at the same Adam learning rate can introduce minor regressions on similar-looking digits. Reducing the step count (10 instead of 20) or the learning rate mitigates this trade-off.
The model is wrapped in a Java EE servlet application. The
AppInitializer context listener boots a single
DigitClassifier instance when the webapp starts:
@WebListener
public class AppInitializer implements ServletContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
DigitClassifier clf = new DigitClassifier();
clf.initialize(); // build TF graph, try to load saved weights
sce.getServletContext().setAttribute(CLASSIFIER_KEY, clf);
}
}
Prediction requests arrive as JSON arrays of 784 floats. The
PredictServlet returns the predicted digit, its confidence, and
the full probability vector:
// POST /api/predict
// Request: { "pixels": [0.0, 0.12, ..., 0.95] } // 784 floats
// Response: { "prediction": 3, "confidence": 0.91,
// "probabilities": [0.01, 0.01, 0.10, 0.91, ...] }
float[] probs = clf.predictProbs(pixels);
int prediction = 0;
float confidence = probs[0];
for (int i = 1; i < probs.length; i++) {
if (probs[i] > confidence) {
confidence = probs[i];
prediction = i;
}
}
out.put("prediction", prediction);
out.put("confidence", confidence);
out.put("probabilities", probs);
The inference path in DigitClassifier.predictProbs() wraps the
pixel array in a single-row TensorFlow tensor, runs it through the softmax
output, and copies the result back to a Java float array — all within a
try-with-resources to ensure native TF memory is released:
public float[] predictProbs(float[] image) {
try (TFloat32 xt = mat2Tensor(new float[][]{image});
TFloat32 result = (TFloat32) session.runner()
.feed(xInput.asOutput(), xt)
.fetch(softmaxOutput)
.run().get(0)) {
float[] probs = new float[OUTPUT_SIZE];
for (int i = 0; i < OUTPUT_SIZE; i++)
probs[i] = result.getFloat(0, i);
return probs;
}
}
The GraphServlet also exposes per-epoch weight statistics for
visualization. It samples 12 input neurons and 10 hidden neurons to keep the
response small, then streams W1 and W2 sub-matrices as JSON:
// GET /api/graph/weights — sampled weight sub-matrices
int[] inputIdxs = sampleIndices(DigitClassifier.INPUT_SIZE, 12);
int[] hiddenIdxs = sampleIndices(DigitClassifier.HIDDEN_SIZE, 10);
float[][] w1 = clf.getSampledW1(inputIdxs, hiddenIdxs);
float[][] w2 = clf.getSampledW2(hiddenIdxs);
// Response: { "w1": [[...],[...]], "w2": [[...],[...]], ... }
The TensorTrain codebase shows that fine-tuning does not require a framework abstraction: the same Adam optimizer, the same weight tensors, and the same session runner that were used for bulk training are all that is needed to adapt the model in real-time to a user's handwriting. The only thing that changes is the batch size — from 128 samples to 1.