Deep Learning · Java · TensorFlow

Fine-Tuning a Handwritten
Digit Classifier

How a two-layer neural network learns, forgets, and corrects itself — from bulk MNIST pre-training through online single-sample fine-tuning, with live weight statistics at every step.

🧠 TensorTrain · MLP · Adam 📦 MNIST · EMNIST ☕ Java + TensorFlow Java API

Contents

  1. The Datasets — MNIST & EMNIST
  2. Model Architecture
  3. Bulk Pre-Training
  4. How Weights Change Across Epochs
  5. Online Fine-Tuning on a Single Sample
  6. Serving Predictions via HTTP
  7. Key Takeaways

Training a model on a large dataset gives it broad knowledge. But what happens when the model misclassifies a specific digit you drew? You can correct it — without starting over — by running a handful of gradient steps on just that one sample. This is online fine-tuning, and it's exactly what the TensorTrain system implements in its trainOnSample() method.

This post walks through the full pipeline: loading data in IDX binary format, building and training a multi-layer perceptron with TensorFlow Java, tracking weight statistics epoch-by-epoch, and finally fine-tuning on a single user-drawn sample. Code snippets are taken directly from the source.

1 · The Datasets

Two datasets are supported. Both are stored in the IDX binary format: a compact binary encoding where a 32-bit magic number identifies the data type, followed by dimension counts and raw pixel bytes.

60 K
MNIST training images
10 K
MNIST test images
240 K
EMNIST training images
28×28
Image resolution

Loading MNIST

MnistLoader downloads four gzip files from a Google Cloud mirror, decompresses them on the fly, and parses the IDX headers. Each pixel byte is normalized to the [0, 1] range by dividing by 255:

/** Returns float[numImages][784], values normalized to [0,1]. */
public static float[][] loadImages(String filename) throws IOException {
    try (DataInputStream dis = open(filename)) {
        int magic = dis.readInt();   // 2051
        int count = dis.readInt();
        int rows  = dis.readInt();   // 28
        int cols  = dis.readInt();   // 28
        float[][] images = new float[count][rows * cols];
        for (int i = 0; i < count; i++) {
            for (int j = 0; j < rows * cols; j++) {
                images[i][j] = dis.readUnsignedByte() / 255.0f;
            }
        }
        return images;
    }
}

EMNIST — the Transposition Quirk

EMNIST images are stored column-major (transposed relative to MNIST). Reading them naively produces mirror-flipped, 90°-rotated digits that the model cannot learn from. EmnistLoader corrects this with an explicit (r, c) → (c, r) swap during loading:

// EMNIST stores images transposed — swap (r, c) → (c, r)
for (int r = 0; r < rows; r++) {
    for (int c = 0; c < cols; c++) {
        images[i][c * rows + r] = (raw[r * cols + c] & 0xFF) / 255.0f;
    }
}
⚠ Dataset Pitfall Skipping this transposition is a common, hard-to-spot bug. The model will still converge — but on a subtly wrong orientation, causing inference errors on normally-oriented user drawings.

2 · Model Architecture

The classifier is a two-layer multi-layer perceptron (MLP) built directly on the TensorFlow Java low-level graph API — no Keras, no high-level wrappers.

Input
784 pixels
MatMul
W1 [784×256]
Add b1
[256]
ReLU
h [?,256]
MatMul
W2 [256×10]
Add b2
[10]
Softmax
probs [?,10]
LayerShapeParametersActivation
W1784 × 256200,704
b1256256
ReLU?× 2560max(0, x)
W2256 × 102,560
b21010
Softmax output? × 100Softmax
Total203,530

Weight Initialization — He Normal

Weights are initialized with He (Kaiming) Normal scaling: scale = √(2 / fan_in). This keeps the variance of activations stable through ReLU layers, avoiding vanishing or exploding gradients at the start of training.

private void initRandomWeights() {
    Random rng = new Random(42);
    float s1 = (float) Math.sqrt(2.0 / INPUT_SIZE);   // He scale for W1
    float s2 = (float) Math.sqrt(2.0 / HIDDEN_SIZE);  // He scale for W2
    assignWeights(
        randMatrix(INPUT_SIZE,  HIDDEN_SIZE, rng, s1),
        new float[HIDDEN_SIZE],   // biases start at zero
        randMatrix(HIDDEN_SIZE, OUTPUT_SIZE, rng, s2),
        new float[OUTPUT_SIZE]);
}

private float[][] randMatrix(int r, int c, Random rng, float scale) {
    float[][] m = new float[r][c];
    for (int i = 0; i < r; i++)
        for (int j = 0; j < c; j++)
            m[i][j] = (float) rng.nextGaussian() * scale;
    return m;
}

Loss Function

The model uses categorical cross-entropy, computed through log-softmax + one-hot encoding rather than SparseSoftmaxCrossEntropyWithLogits — which lacks a registered C++ gradient in the TF Java 2.x build:

// log p(y | x) for each sample, then average over batch
var logSoftmax = tf.nn.logSoftmax(logits);
var oneHot     = tf.oneHot(yLabels, tf.constant(OUTPUT_SIZE),
                            tf.constant(1.0f), tf.constant(0.0f));
var meanLogProb = tf.math.mean(tf.math.mul(oneHot, logSoftmax),
                               tf.constant(new int[]{1}));
// Negate and scale to get standard cross-entropy loss
var loss = tf.math.mean(
    tf.math.neg(tf.math.mul(meanLogProb, tf.constant((float) OUTPUT_SIZE))),
    tf.constant(new int[]{0}));

3 · Bulk Pre-Training

Training runs asynchronously via a servlet thread pool. The core loop in DigitClassifier.train() shuffles the entire training set each epoch, then processes mini-batches of 128 samples. After each epoch it evaluates on a 2,000-sample subset and records a weight snapshot.

public void train(float[][] images, int[] labels, int epochs,
                   TrainingCallback cb) {
    int n       = images.length;
    int batches = n / BATCH_SIZE;   // BATCH_SIZE = 128

    for (int ep = 1; ep <= epochs; ep++) {
        int[] idx = shuffleIdx(n);   // fresh random permutation every epoch

        for (int b = 0; b < batches; b++) {
            float[][] xb = new float[BATCH_SIZE][INPUT_SIZE];
            long[]    yb = new long[BATCH_SIZE];
            for (int i = 0; i < BATCH_SIZE; i++) {
                int s = idx[b * BATCH_SIZE + i];
                xb[i] = images[s];
                yb[i] = labels[s];
            }
            try (TFloat32 xt = mat2Tensor(xb);
                 TInt64   yt = longVec2Tensor(yb)) {
                session.runner()
                    .feed(xInput.asOutput(), xt)
                    .feed(yLabels.asOutput(), yt)
                    .addTarget(trainOp)   // Adam minimise(loss)
                    .run();
            }
        }

        lastAccuracy = evaluate(images, labels, 2000);
        snapshots.add(recordSnapshot(ep, lastAccuracy));
        if (cb != null) cb.onEpoch(ep, epochs, lastAccuracy);
    }
    saveWeights();   // persist to ~/.tensortrain/model_weights.ser
}
💡 Optimizer: Adam with lr = 0.001 Adam adapts the learning rate per-parameter using first and second moment estimates. On MNIST it typically reaches ~97% accuracy in 5–10 epochs, compared to ~95% for vanilla SGD at the same learning rate.

Typical Accuracy Curve — MNIST (10 Epochs)

80% 87% 93% 96% 99% 1 2 3 4 5 6 7 8 9 10 Epoch MNIST (60 K) EMNIST (240 K)

4 · How Weights Change Across Epochs

After every epoch, recordSnapshot() computes four statistics for each weight tensor — mean, standard deviation, min, and max — and stores them in an EpochSnapshot. These are served by GET /api/graph/snapshots and rendered as live charts in the UI.

private EpochSnapshot recordSnapshot(int ep, float accuracy) {
    return new EpochSnapshot(ep, accuracy,
        computeStats("W1", new int[]{INPUT_SIZE,  HIDDEN_SIZE}, w1),
        computeStats("b1", new int[]{HIDDEN_SIZE},              b1),
        computeStats("W2", new int[]{HIDDEN_SIZE, OUTPUT_SIZE}, w2),
        computeStats("b2", new int[]{OUTPUT_SIZE},              b2));
}

The computeStats method also builds a 20-bin histogram over the full weight distribution, enabling the UI to draw weight-distribution charts. Here is what the statistics typically look like across 10 training epochs:

Epoch W1 mean W1 std W1 min W1 max Accuracy
Init 0.0000 0.0500 −0.162 +0.159
1 −0.00030.0521 −0.201 +0.198 82.4%
2 −0.00050.0548 −0.238 +0.241 91.8%
3 −0.00060.0571 −0.261 +0.264 93.9%
5 −0.00070.0603 −0.288 +0.291 95.9%
10 −0.00080.0641 −0.315 +0.320 97.6%
📊 What to watch for As training progresses, the standard deviation of W1 grows steadily (weights spread out as neurons specialize), while the mean stays near zero (the network is not shifting its "centre of gravity"). A sharply growing std combined with a non-trivial mean is a warning sign of weight collapse or overfitting.

W1 Weight Distribution — Evolution Over 10 Epochs

0.049 0.053 0.057 0.061 0.065 0.069 1 3 5 7 9 Epoch fine-tune W1 std (784×256) W2 std (256×10)

W1 Weight Histogram — Before vs After Training (20 bins)

−0.32 −0.16 0.00 +0.16 +0.32 Weight value 0 5 K 10 K 15 K Init (std ≈ 0.050) Epoch 10 (std ≈ 0.064)

The key observation: as training proceeds, the weight distribution spreads out. Neurons are specializing — some learn to detect horizontal strokes, others vertical, others curves. A narrow distribution (all weights near zero) means the model has not yet differentiated its neurons. A wide, roughly bell-shaped distribution is a healthy sign of learned features.

5 · Online Fine-Tuning on a Single Sample

After bulk training the model is deployed as a web app where users draw digits on a canvas. If the model guesses wrong, the user can correct it. This triggers POST /api/train/sample, which calls trainOnSample() — running 20 Adam gradient steps on just that one 28×28 drawing.

The trainOnSample method

/**
 * Runs a small number of gradient steps on one user-supplied sample.
 * Used for online correction after an incorrect prediction.
 *
 * @param image  784-element pixel array (values in [0,1])
 * @param label  correct digit label (0-9)
 * @param steps  number of gradient steps (10-30 is a reasonable default)
 */
public synchronized void trainOnSample(float[] image, int label, int steps) {
    float[][] x = new float[][] { image };
    long[]    y = new long[]    { (long) label };

    try (TFloat32 xt = mat2Tensor(x);
         TInt64   yt = longVec2Tensor(y)) {

        for (int i = 0; i < steps; i++) {
            session.runner()
                .feed(xInput.asOutput(), xt)
                .feed(yLabels.asOutput(), yt)
                .addTarget(trainOp)
                .run();
        }
    }
    trained = true;
}

The servlet wires this up as a synchronous HTTP call so the UI immediately reflects the correction:

// TrainServlet.java — POST /api/train/sample
float[] pixels = new float[DigitClassifier.INPUT_SIZE];
for (int i = 0; i < pixels.length; i++)
    pixels[i] = ((Number) rawPixels.get(i)).floatValue();

clf.trainOnSample(pixels, label, 20);   // 20 gradient steps

What happens to the weights?

Fine-tuning on a single sample produces a targeted, localized weight shift. Because Adam maintains per-parameter moment estimates from bulk training, it applies the single-sample gradients with appropriate magnitude — not too aggressive, not too timid.

Weight Delta After 20-Step Fine-Tune — W2 (256 × 10), Target Digit: 7

0 +0.012 −0.012 0 1 2 3 4 5 6 7 8 9 Output digit class +0.009 −0.007

LIVE: TensorGraph Δ Diff View — Weight Changes After Fine-Tune

TensorGraph neural network weight diff visualization showing green (increased) and red (decreased) connections after a single-sample fine-tune
Δ Diff mode — Green connections indicate weights that increased after fine-tuning on a corrected sample; red connections indicate weights that decreased. Opacity encodes the magnitude of each change. The sampled view shows 12 of 784 input neurons, 10 of 256 hidden neurons, and all 10 output neurons. Notice how hidden neuron 28 fans out strong positive deltas reinforcing class 7 (the correct label), while neurons 170 and 198 drive the largest negative deltas away from class 2 (the wrong prediction) through the hidden→output layer.

The chart above shows the average weight delta in W2 (the output layer) after fine-tuning on a "7" that was misclassified as "2". The model pushes up the column of weights that votes for class 7, and nudges down the competing classes — with the largest penalty applied to class 2, the incorrect prediction. The W1 (hidden layer) changes are smaller but spread across all 200,704 parameters — those neurons whose activations were strong for this image get a slightly larger push.

Fine-tuning vs. Catastrophic Forgetting

A single-sample fine-tune is surgical: it nudges the loss surface for one input without rewriting the knowledge encoded across 60,000+ training examples. That said, 20 gradient steps at the same Adam learning rate can introduce minor regressions on similar-looking digits. Reducing the step count (10 instead of 20) or the learning rate mitigates this trade-off.

⚠ Catastrophic Forgetting Risk Repeated fine-tunes with many steps or high learning rate can gradually erode bulk-training accuracy. If the correction rate is high, consider periodically re-training on the original dataset plus the accumulated correction samples.

Prediction Confidence — Before vs After Fine-Tuning (sample: "7" drawn)

0% 25% 50% 75% 100% 0 1 2 3 4 5 6 7 8 9 Digit class 63% 28% 10% 91% Before fine-tune (predicted: "2", true: "7") After fine-tune (correct: "7")

6 · Serving Predictions via HTTP

The model is wrapped in a Java EE servlet application. The AppInitializer context listener boots a single DigitClassifier instance when the webapp starts:

@WebListener
public class AppInitializer implements ServletContextListener {

    @Override
    public void contextInitialized(ServletContextEvent sce) {
        DigitClassifier clf = new DigitClassifier();
        clf.initialize();   // build TF graph, try to load saved weights
        sce.getServletContext().setAttribute(CLASSIFIER_KEY, clf);
    }
}

Prediction requests arrive as JSON arrays of 784 floats. The PredictServlet returns the predicted digit, its confidence, and the full probability vector:

// POST /api/predict
// Request:  { "pixels": [0.0, 0.12, ..., 0.95] }   // 784 floats
// Response: { "prediction": 3, "confidence": 0.91,
//             "probabilities": [0.01, 0.01, 0.10, 0.91, ...] }

float[] probs      = clf.predictProbs(pixels);
int     prediction = 0;
float   confidence = probs[0];
for (int i = 1; i < probs.length; i++) {
    if (probs[i] > confidence) {
        confidence = probs[i];
        prediction = i;
    }
}
out.put("prediction",    prediction);
out.put("confidence",    confidence);
out.put("probabilities", probs);

The inference path in DigitClassifier.predictProbs() wraps the pixel array in a single-row TensorFlow tensor, runs it through the softmax output, and copies the result back to a Java float array — all within a try-with-resources to ensure native TF memory is released:

public float[] predictProbs(float[] image) {
    try (TFloat32 xt     = mat2Tensor(new float[][]{image});
         TFloat32 result = (TFloat32) session.runner()
                 .feed(xInput.asOutput(), xt)
                 .fetch(softmaxOutput)
                 .run().get(0)) {

        float[] probs = new float[OUTPUT_SIZE];
        for (int i = 0; i < OUTPUT_SIZE; i++)
            probs[i] = result.getFloat(0, i);
        return probs;
    }
}

Weight Snapshots API

The GraphServlet also exposes per-epoch weight statistics for visualization. It samples 12 input neurons and 10 hidden neurons to keep the response small, then streams W1 and W2 sub-matrices as JSON:

// GET /api/graph/weights — sampled weight sub-matrices
int[] inputIdxs  = sampleIndices(DigitClassifier.INPUT_SIZE,  12);
int[] hiddenIdxs = sampleIndices(DigitClassifier.HIDDEN_SIZE, 10);

float[][] w1 = clf.getSampledW1(inputIdxs, hiddenIdxs);
float[][] w2 = clf.getSampledW2(hiddenIdxs);

// Response: { "w1": [[...],[...]], "w2": [[...],[...]], ... }

7 · Key Takeaways

✅ Dataset integrity matters first The EMNIST transpose fix is a textbook example of a subtle data-pipeline bug that produces a model which trains, converges, and still fails on real user input. Always visualize a sample of your training data before training.
✅ Watch weight statistics, not just accuracy Growing std + stable mean = healthy specialisation. Exploding max or collapsing std signals a training problem even when validation accuracy looks acceptable.
📐 He initialization sets the stage Scaling initial weights by √(2/fan_in) is not optional with ReLU. It ensures that the variance of pre-activations stays near 1 at initialization, making the first few epochs of Adam stable and fast.
⚠ Online fine-tuning is powerful but fragile 20 gradient steps corrects a specific error quickly. But without regularization or experience replay, repeated corrections can silently erode accuracy on the broader distribution. Monitor global accuracy after each fine-tune session.

The TensorTrain codebase shows that fine-tuning does not require a framework abstraction: the same Adam optimizer, the same weight tensors, and the same session runner that were used for bulk training are all that is needed to adapt the model in real-time to a user's handwriting. The only thing that changes is the batch size — from 128 samples to 1.