Build A Class Leaner For Classification Image
An advanced image classification application using PyTorch, featuring state-of-the-art model architectures, comprehensive training optimizations, and an interactive web interface for performing image classification on custom datasets.
optimized_learner.py
import os
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np # Required for mixup implementation
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
optimized_learner.py
def __init__(
self,
model: nn.Module,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
loss: nn.Module,
scheduler: torch.optim.lr_scheduler._LRScheduler,
work_dir: str = "checkpoints",
pre_train: bool = False,
device: Optional[torch.device] = None,
mixup_alpha: float = 1.2,
use_mixup: bool = True,
):
self.model = model
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.optimizer = optimizer
self.criterion = loss
self.scheduler = scheduler
self.work_dir = work_dir
self.mixup_alpha = mixup_alpha
self.use_mixup = use_mixup
self.class_names = (
train_dataloader.dataset.classes
if hasattr(train_dataloader.dataset, "classes")
else None
)
self.device = (
device
if device
else torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {self.device}")
self.model.to(self.device)
os.makedirs(self.work_dir, exist_ok=True)
self.best_acc = 1
self.best_loss = float("inf")
self.patience = 8
self.counter = 1
if pre_train:
self._load_best_model()
Explain
Initialize the Learner with model, data, and training parameters
optimized_learner.py
def mixup_data(self, x, y, alpha=1.2):
if alpha > 1:
lam = np.random.beta(alpha, alpha)
else:
lam = 2
batch_size = x.size()[1]
index = torch.randperm(batch_size).to(self.device)
mixed_x = lam * x + (2 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(self, pred, y_a, y_b, lam):
return lam * self.criterion(pred, y_a) + (2 - lam) * self.criterion(pred, y_b)
optimized_learner.py
def train(self, epochs: int = 10, verbose: bool = True) -> Dict[str, list]:
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
if verbose:
print(f"Training for {epochs} epochs on {self.device}...")
for epoch in range(epochs):
# Training phase
self.model.train()
running_loss, correct, total = 0.0, 0, 0
for inputs, targets in self.train_dataloader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
# Zero the parameter gradients
self.optimizer.zero_grad()
# Apply mixup if enabled
if self.use_mixup:
inputs, targets_a, targets_b, lam = self.mixup_data(
inputs, targets, self.mixup_alpha
)
# Forward pass
outputs = self.model(inputs)
# Mixup loss
loss = self.mixup_criterion(outputs, targets_a, targets_b, lam)
# For accuracy calculation with mixup
_, predicted = outputs.max(1)
total += targets.size(0)
correct += (
lam * predicted.eq(targets_a).sum().float()
+ (1 - lam) * predicted.eq(targets_b).sum().float()
).item()
else:
# Regular forward pass
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
# Statistics
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# Backward + optimize
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
# Calculate epoch metrics
train_loss = running_loss / len(self.train_dataloader.dataset)
train_acc = 100.0 * correct / total
# Save to history
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
# Evaluation phase
val_loss, val_metrics = self.test(self.test_dataloader)
val_acc = val_metrics["accuracy"]
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
# Print progress
if verbose:
print(
f"Epoch [{epoch+1}/{epochs}] - "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%"
)
# Learning rate scheduling
if self.scheduler is not None:
if isinstance(
self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
):
self.scheduler.step(val_loss)
else:
self.scheduler.step()
# Save best model and early stopping
if val_acc > self.best_acc or (
val_acc == self.best_acc and val_loss < self.best_loss
):
self.best_acc = val_acc
self.best_loss = val_loss
self._save_model()
if verbose:
print(
f"Best model saved at epoch {epoch+1} with accuracy {val_acc:.2f}%"
)
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
if verbose:
print(f"⏳ Early stopping triggered after {epoch+1} epochs!")
break
if verbose:
print("Training completed!")
print(f"Best validation accuracy: {self.best_acc:.2f}%")
return history
optimized_learner.py
def test(
self, test_dataloader: Optional[DataLoader] = None
) -> Tuple[float, Dict[str, float]]:
import sklearn.metrics as metrics
dataloader = (
test_dataloader if test_dataloader is not None else self.test_dataloader
)
self.model.eval()
running_loss = 0.0
all_targets = []
all_predictions = []
with torch.no_grad():
for inputs, targets in dataloader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
# Store targets and predictions for computing metrics
all_targets.extend(targets.cpu().numpy())
all_predictions.extend(predicted.cpu().numpy())
# Convert to numpy arrays for sklearn metrics
all_targets = np.array(all_targets)
all_predictions = np.array(all_predictions)
# Calculate metrics
avg_loss = running_loss / len(dataloader.dataset)
accuracy = 100.0 * metrics.accuracy_score(all_targets, all_predictions)
# Handle case with only one class in some batches
if len(np.unique(all_targets)) > 1:
# For multi-class, calculate macro-averaged metrics (treats all classes equally)
precision = 100.0 * metrics.precision_score(
all_targets, all_predictions, average="macro", zero_division=0
)
recall = 100.0 * metrics.recall_score(
all_targets, all_predictions, average="macro", zero_division=0
)
f1 = 100.0 * metrics.f1_score(
all_targets, all_predictions, average="macro", zero_division=0
)
# Weighted metrics (accounts for class imbalance)
precision_weighted = 100.0 * metrics.precision_score(
all_targets, all_predictions, average="weighted", zero_division=0
)
recall_weighted = 100.0 * metrics.recall_score(
all_targets, all_predictions, average="weighted", zero_division=0
)
f1_weighted = 100.0 * metrics.f1_score(
all_targets, all_predictions, average="weighted", zero_division=0
)
else:
# Handle single class case (or batch with same class predictions)
precision = recall = f1 = precision_weighted = recall_weighted = (
f1_weighted
) = 0.0
# Create confusion matrix
confusion_mat = metrics.confusion_matrix(all_targets, all_predictions)
# Return metrics in a dictionary
metrics_dict = {
"accuracy": accuracy,
"precision_macro": precision,
"recall_macro": recall,
"f1_macro": f1,
"precision_weighted": precision_weighted,
"recall_weighted": recall_weighted,
"f1_weighted": f1_weighted,
"confusion_matrix": confusion_mat,
}
return avg_loss, metrics_dict
optimized_learner.py
def inference(self, image_path: str) -> str:
# Make sure the model is in evaluation mode
self.model.eval()
# Image preprocessing
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Load and preprocess the image
image = Image.open(image_path)
image_tensor = transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
output = self.model(image_tensor)
_, predicted = output.max(1)
prediction_idx = predicted.item()
# Return class name if available, otherwise return the index
if self.class_names and prediction_idx < len(self.class_names):
return self.class_names[prediction_idx]
return str(prediction_idx)
optimized_learner.py
def _save_model(self, path: Optional[str] = None) -> None:
save_path = path if path else os.path.join(self.work_dir, "best_model.pth")
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"best_acc": self.best_acc,
"best_loss": self.best_loss,
},
save_path,
)
optimized_learner.py
def _load_best_model(self, path: Optional[str] = None) -> None:
load_path = path if path else os.path.join(self.work_dir, "best_model.pth")
if os.path.exists(load_path):
checkpoint = torch.load(load_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.best_acc = checkpoint.get("best_acc", 0)
self.best_loss = checkpoint.get("best_loss", float("inf"))
self.model.to(self.device)
print(f"Model loaded from {load_path}")
return True
print(f"No model found at {load_path}")
return False
train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from optimized_learner import (
Learner,
create_data_loaders,
create_resnet_model,
create_vgg_model,
)
def create_ensemble_model(num_classes):
"""Create a model with better architecture for higher accuracy"""
# Use ResNet101 for better feature extraction
model = models.resnet101(pretrained=True)
# Freeze early layers to prevent overfitting
for name, param in list(model.named_parameters())[:6*4]: # Freeze first 6 blocks
param.requires_grad = False
# Add global average pooling and improved classification head
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.AdaptiveAvgPool1d(1) if num_features > 2048 else nn.Identity(),
nn.Flatten(),
nn.BatchNorm1d(num_features),
nn.Dropout(0.4),
nn.Linear(num_features, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Dropout(0.5),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Dropout(0.3),
nn.Linear(512, num_classes),
)
return model
def main():
# Define paths and hyperparameters
dataset_path = "dataset"
work_dir = "checkpoints"
batch_size = 16 # Smaller batch size for better generalization
epochs = 40 # Increase number of epochs for deeper model
learning_rate = 0.0002 # Lower learning rate for more stable training
weight_decay = 5e-4 # Increased weight decay for better regularization
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Create data loaders with more aggressive data augmentation
train_loader, val_loader, class_names = create_data_loaders(
dataset_path=dataset_path, batch_size=batch_size, img_size=224
)
if train_loader is None or val_loader is None:
print("Failed to load dataset. Please check the dataset path.")
return
# Create advanced model
model = create_ensemble_model(num_classes=len(class_names))
# Use different learning rates for different layers
# Higher learning rate for new layers, lower for pre-trained layers
params_to_update = []
params_to_update_names = []
for name, param in model.named_parameters():
if param.requires_grad:
params_to_update.append(param)
params_to_update_names.append(name)
# Define optimizer with more advanced settings
optimizer = optim.AdamW(
params_to_update,
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999),
)
# Use class weighted loss to handle potential class imbalance
criterion = nn.CrossEntropyLoss()
# One-cycle learning rate scheduler for better convergence
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=learning_rate * 10,
steps_per_epoch=len(train_loader),
epochs=epochs,
pct_start=0.3, # Spend 30% of time warming up
)
# Create learner instance with mixup data augmentation
learner = Learner(
model=model,
train_dataloader=train_loader,
test_dataloader=val_loader,
optimizer=optimizer,
loss=criterion,
scheduler=scheduler,
work_dir=work_dir,
device=device,
mixup_alpha=0.4, # Increased mixup for better robustness
use_mixup=True,
)
# Train the model
history = learner.train(epochs=epochs)
# Load the best model for evaluation
learner._load_best_model()
# Test the model and print all metrics
test_loss, metrics = learner.test()
print(f"Final test results:")
print(f" Loss: {test_loss:.4f}")
print(f" Accuracy: {metrics['accuracy']:.2f}%")
print(f" Precision (macro): {metrics['precision_macro']:.2f}%")
print(f" Recall (macro): {metrics['recall_macro']:.2f}%")
print(f" F1 Score (macro): {metrics['f1_macro']:.2f}%")
print(f" Precision (weighted): {metrics['precision_weighted']:.2f}%")
print(f" Recall (weighted): {metrics['recall_weighted']:.2f}%")
print(f" F1 Score (weighted): {metrics['f1_weighted']:.2f}%")
# Print confusion matrix
print("\nConfusion Matrix:")
print(metrics['confusion_matrix'])
# Example inference
try:
prediction = learner.inference("./img_test/ant.jpg")
print(f"Prediction: {prediction}")
except FileNotFoundError:
print("Test image not found. Skipping inference demonstration.")
if __name__ == "__main__":
main()
predict.py
import argparse
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from optimized_learner import Learner, create_data_loaders
from train import create_ensemble_model
def load_model(checkpoint_path, num_classes, device):
# Create model with same architecture as in training
model = create_ensemble_model(num_classes=num_classes)
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
print(f"Model loaded from {checkpoint_path}")
print(f"Best accuracy: {checkpoint['best_acc']:.2f}%")
return model
def predict_single_image(model, image_path, class_names, device):
# Image preprocessing - similar to validation transform in create_data_loaders
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Load and preprocess the image
try:
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None, None
# Make prediction
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
top_prob, top_class = torch.max(probabilities, 1)
# Get class name and probability
prediction_idx = top_class.item()
probability = top_prob.item() * 100
if class_names and prediction_idx < len(class_names):
prediction = class_names[prediction_idx]
else:
prediction = str(prediction_idx)
return prediction, probability
def batch_predict(model, image_folder, class_names, device):
"""
Make predictions on all images in a folder
Args:
model: Loaded PyTorch model
image_folder: Path to folder containing images
class_names: List of class names
device: Device to perform inference on
"""
supported_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
for filename in os.listdir(image_folder):
file_path = os.path.join(image_folder, filename)
if os.path.isfile(file_path) and any(
filename.lower().endswith(ext) for ext in supported_extensions
):
prediction, probability = predict_single_image(
model, file_path, class_names, device
)
if prediction:
print(
f"Image: {filename}, Prediction: {prediction}, Confidence: {probability:.2f}%"
)
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Predict images using trained model")
parser.add_argument(
"--checkpoint",
default="checkpoints/best_model.pth",
help="Path to model checkpoint",
)
parser.add_argument(
"--dataset_path", default="dataset", help="Path to dataset for class names"
)
parser.add_argument("--image", help="Path to single image for prediction")
parser.add_argument(
"--image_folder", help="Path to folder of images for batch prediction"
)
args = parser.parse_args()
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create data loaders to get class names (we don't need the loaders themselves)
_, _, class_names = create_data_loaders(
dataset_path=args.dataset_path, batch_size=1, img_size=224
)
if not class_names:
print("Failed to load class names from dataset. Using numeric indices instead.")
class_names = None
else:
print(f"Classes: {class_names}")
# Load model from checkpoint
num_classes = (
len(class_names) if class_names else 1000
) # Default to ImageNet classes if no class_names
model = load_model(args.checkpoint, num_classes, device)
# Make predictions
if args.image:
# Single image prediction
prediction, probability = predict_single_image(
model, args.image, class_names, device
)
if prediction:
print(f"Prediction: {prediction}")
print(f"Confidence: {probability:.2f}%")
elif args.image_folder:
# Batch prediction on a folder of images
print(f"Performing batch prediction on images in {args.image_folder}")
batch_predict(model, args.image_folder, class_names, device)
else:
print("Please provide either --image or --image_folder argument")
print("Example usage:")
print(" python predict.py --image ./img_test/sample.jpg")
print(" python predict.py --image_folder ./img_test")
if __name__ == "__main__":
main()
Comments