RobustMNIST (v1.0)

RobustMNIST is a lightweight, 11-class convolutional neural network designed for handwritten digit recognition. Unlike standard models, this architecture is built to handle out-of-distribution (OOD) inputs and extreme image corruption through a dedicated "Unknown" class.

Model Details

  • Developed by: MultivexAI
  • Version: 1.0
  • Task: Open-Set Handwritten Digit Recognition
  • Architecture: 6-Layer Gated CNN (approx. 430k parameters)
  • Classes: 11 (0–9 for standard digits, 10 for "Unknown")
  • Input: 1x28x28 grayscale image.

The "Unknown" Class (Class 10)

Traditional MNIST models often guess a digit confidently even when the input is just random noise or a shape that isn't a number.

RobustMNIST introduces Class 10, representing the "Unknown" domain.

  • In-Distribution: For clean digits, the model predicts classes 0–9 with high accuracy, while maintaining a 15-20% uncertainty margin for Class 10.
  • Out-of-Distribution: When an image is severely corrupted (noise, stains, blurs) or represents a non-digit shape, the model's confidence shifts entirely to Class 10.

Performance Metrics

Evaluation on standard MNIST and extreme corruption sets:

Set Accuracy
Clean MNIST Test Set 99.51%
Extreme OOD / Corrupted Set 92.33%

Limitations & Expectations

While titled RobustMNIST, it is important to clarify that "robust" does not mean "invincible." This is a small-scale model designed to demonstrate OOD detection, not a perfect safety system.

  • No 100% Guarantee: Like all neural networks, this model can and will make mistakes.
  • The "Robust" Definition: In this context, robustness refers to the model's improved resistance to noise and its ability to express uncertainty via the "Unknown" class compared to standard classifiers. It is not an absolute shield against all possible adversarial or geometric attacks.
  • Semantic Edge Cases: Certain transformations, such as rotating a "6" until it looks like a "9" or mirroring asymmetric digits create mathematical ambiguities. We acknowledge these limits; at this parameter count, the model prioritizes identifying structured digits over handling every possible topological distortion.
  • Research Scope: This is a 1.0 release focused on balancing clean accuracy with OOD calibration. We agree that edge cases exist where the model may still fail or default to "Unknown" unexpectedly.

Usage

To use this model, ensure you have model.py and model.pt in your directory.

Simple Test Script (test.py)

This script picks a random digit from the MNIST test set and runs a prediction.

Tested on Python 3.12.

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from model import HierarchicalNetwork

# Execution configurations
PARAMETER_PATH = "model.pt"
HARDWARE_TARGET = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def execute():
    # initialize architecture
    processor = HierarchicalNetwork(out_dims=11).to(HARDWARE_TARGET)
    
    # load state parameters
    state_data = torch.load(PARAMETER_PATH, map_location=HARDWARE_TARGET)
    weights = state_data.get('state_dict', state_data) 
    processor.load_state_dict(weights)
    processor.eval()

    # pull random sample
    dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
    sample_index = torch.randint(0, len(dataset), (1,)).item()
    input_tensor, ground_truth = dataset[sample_index]
    
    # compute projections
    formatted_input= input_tensor.unsqueeze(0).to(HARDWARE_TARGET)
    with torch.inference_mode():
        raw_outputs = processor(formatted_input)
        probabilities = F.softmax(raw_outputs, dim=1).cpu().numpy()[0]

    # compile outputs
    predicted_class = probabilities.argmax()
    category_names = [str(i) for i in range(10)]+["Unknown"]

    print("\n" + "="*30)
    print(f"Sample Index : {sample_index}")
    print(f"True Label   : {ground_truth}")
    print(f"Prediction   : {category_names[predicted_class]}")
    print(f"Confidence   : {probabilities[predicted_class] * 100:.2f}%")
    print("=" * 30)

if __name__ == "__main__":
    execute()

Released by MultivexAI | Licensed under Apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train MultivexAI/RobustMNIST-v1.0

Space using MultivexAI/RobustMNIST-v1.0 1

Collection including MultivexAI/RobustMNIST-v1.0

Evaluation results