Skip to content
geeksforgeeks
  • Courses
    • DSA to Development
    • Get IBM Certification
    • Newly Launched!
      • Master Django Framework
      • Become AWS Certified
    • For Working Professionals
      • Interview 101: DSA & System Design
      • Data Science Training Program
      • JAVA Backend Development (Live)
      • DevOps Engineering (LIVE)
      • Data Structures & Algorithms in Python
    • For Students
      • Placement Preparation Course
      • Data Science (Live)
      • Data Structure & Algorithm-Self Paced (C++/JAVA)
      • Master Competitive Programming (Live)
      • Full Stack Development with React & Node JS (Live)
    • Full Stack Development
    • Data Science Program
    • All Courses
  • Tutorials
    • Data Structures & Algorithms
    • ML & Data Science
    • Interview Corner
    • Programming Languages
    • Web Development
    • CS Subjects
    • DevOps And Linux
    • School Learning
  • Practice
    • GfG 160: Daily DSA
    • Problem of the Day
    • Practice Coding Problems
    • GfG SDE Sheet
  • Deep Learning Tutorial
  • Data Analysis Tutorial
  • Python – Data visualization tutorial
  • NumPy
  • Pandas
  • OpenCV
  • R
  • Machine Learning Tutorial
  • Machine Learning Projects
  • Machine Learning Interview Questions
  • Machine Learning Mathematics
  • Deep Learning Project
  • Deep Learning Interview Questions
  • Computer Vision Tutorial
  • Computer Vision Projects
  • NLP
  • NLP Project
  • NLP Interview Questions
  • Statistics with Python
  • 100 Days of Machine Learning
Open In App
Next Article:
Convolutional Neural Network (CNN) Architectures
Next article icon

Implementation of a CNN based Image Classifier using PyTorch

Last Updated : 25 Feb, 2022
Comments
Improve
Suggest changes
Like Article
Like
Report

Introduction:

Introduced in the 1980s by Yann LeCun, Convolution Neural Networks(also called CNNs or ConvNets) have come a long way. From being employed for simple digit classification tasks, CNN-based architectures are being used very profoundly over much Deep Learning and Computer Vision-related tasks like object detection, image segmentation, gaze tracking, among others. Using the PyTorch framework, this article will implement a CNN-based image classifier on the popular CIFAR-10 dataset. 

Before going ahead with the code and installation, the reader is expected to understand how CNNs work theoretically and with various related operations like convolution, pooling, etc. The article also assumes a basic familiarity with the PyTorch workflow and its various utilities, like Dataloaders, Datasets, Tensor transforms, and CUDA operations. For a quick refresher of these concepts, the reader is encouraged to go through the following articles:

  • Introduction to Convolutional Neural Network
  • Training Neural Networks with Validation using PyTorch
  • How to set up and Run CUDA Operations in Pytorch?

Installation

For the implementation of the CNN and downloading the CIFAR-10 dataset, we'll be requiring the torch and torchvision modules. Apart from that, we'll be using numpy and matplotlib for data analysis and plotting. The required libraries can be installed using the pip package manager through the following command:

    pip install torch torchvision torchaudio numpy matplotlib

Stepwise implementation

Step 1: Downloading data and printing some sample images from the training set.

  • Before starting our journey to implementing CNN, we first need to download the dataset onto our local machine, which we'll be training our model over. We'll be using the torchvision utility for this purpose and downloading the CIFAR-10 dataset into training and testing sets in directories "./CIFAR10/train" and "./CIFAR10/test," respectively. We also apply a normalized transform where the procedure is done over the three channels for all the images.
  • Now, we have a training dataset and a test dataset with 50000 and 10000 images, respectively, of a dimension 32x32x3. After that, we convert these datasets into data loaders of a batch size of 128 for better generalization and a faster training process.
  • Finally, we plot out some sample images from the 1st training batch to get an idea of the images we're dealing with using the make_grid utility from torchvision.

Code:

Python3
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

# The below two lines are optional and are just there to avoid any SSL
# related errors while downloading the CIFAR-10 dataset
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

#Defining plotting settings
plt.rcParams['figure.figsize'] = 14, 6

#Initializing normalizing transform for the dataset
normalize_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean = (0.5, 0.5, 0.5), 
                                     std = (0.5, 0.5, 0.5))])

#Downloading the CIFAR10 dataset into train and test sets
train_dataset = torchvision.datasets.CIFAR10(
    root="./CIFAR10/train", train=True,
    transform=normalize_transform,
    download=True)
  
test_dataset = torchvision.datasets.CIFAR10(
    root="./CIFAR10/test", train=False,
    transform=normalize_transform,
    download=True)
  
#Generating data loaders from the corresponding datasets
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

#Plotting 25 images from the 1st batch 
dataiter = iter(train_loader)
images, labels = dataiter.next()
plt.imshow(np.transpose(torchvision.utils.make_grid(
  images[:25], normalize=True, padding=1, nrow=5).numpy(), (1, 2, 0)))
plt.axis('off')

Output:

Figure 1: Some sample images from the training dataset

Step-2: Plotting class distribution of the dataset

It's generally a good idea to plot out the class distribution of the training set. This helps in checking whether the provided dataset is balanced or not. To do this, we iterate over the entire training set in batches and collect the respective classes of each instance. Finally, we calculate the counts of the unique classes and plot them.

Code:

Python3
#Iterating over the training dataset and storing the target class for each sample
classes = []
for batch_idx, data in enumerate(train_loader, 0):
    x, y = data 
    classes.extend(y.tolist())
    
#Calculating the unique classes and the respective counts and plotting them
unique, counts = np.unique(classes, return_counts=True)
names = list(test_dataset.class_to_idx.keys())
plt.bar(names, counts)
plt.xlabel("Target Classes")
plt.ylabel("Number of training instances")

Output:

Figure 2: Class distribution of the training set

As shown in Figure 2, each of the ten classes has almost the same number of training samples. Thus we don't need to take additional steps to rebalance the dataset.

Step-3: Implementing the CNN architecture

On the architecture side, we'll be using a simple model that employs three convolution layers with depths 32, 64, and 64, respectively, followed by two fully connected layers for performing classification. 

  • Each convolutional layer involves a convolutional operation involving a 3x3 convolution filter and is followed by a ReLU activation operation for introducing nonlinearity into the system and a max-pooling operation with a 2x2 filter to reduce the dimensionality of the feature map.
  • After the end of the convolutional blocks, we flatten the multidimensional layer into a low dimensional structure for starting our classification blocks. After the first linear layer, the last output layer(also a linear layer) has ten neurons for each of the ten unique classes in our dataset.

The architecture is as follows:

Figure 3: Architecture of the CNN

For building our model, we'll make a CNN class inherited from the torch.nn.Module class for taking advantage of the Pytorch utilities. Apart from that, we'll be using the torch.nn.Sequential container to combine our layers one after the other. 

  • The Conv2D(), ReLU(), and MaxPool2D() layers perform the convolution, activation, and pooling operations. We used padding of 1 to give sufficient learning space to the kernel as padding gives the image more coverage area, especially the pixels in the outer frame.
  • After the convolutional blocks, the Linear() fully connected layers perform classification.

Code:

Python3
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
            #Input = 3 x 32 x 32, Output = 32 x 32 x 32
            torch.nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), 
            torch.nn.ReLU(),
            #Input = 32 x 32 x 32, Output = 32 x 16 x 16
            torch.nn.MaxPool2d(kernel_size=2),

            #Input = 32 x 16 x 16, Output = 64 x 16 x 16
            torch.nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            #Input = 64 x 16 x 16, Output = 64 x 8 x 8
            torch.nn.MaxPool2d(kernel_size=2),
            
            #Input = 64 x 8 x 8, Output = 64 x 8 x 8
            torch.nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1),
            torch.nn.ReLU(),
            #Input = 64 x 8 x 8, Output = 64 x 4 x 4
            torch.nn.MaxPool2d(kernel_size=2),

            torch.nn.Flatten(),
            torch.nn.Linear(64*4*4, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 10)
        )

    def forward(self, x):
        return self.model(x)

Step-4: Defining the training parameters and beginning the training process

We begin the training process by selecting the device to train our model onto, i.e., CPU or a GPU. Then, we define our model hyperparameters which are as follows:

  • We train our models over 50 epochs, and since we have a multiclass problem, we used the Cross-Entropy Loss as our objective function.
  • We used the popular Adam optimizer with a learning rate of 0.001 and weight_decay of 0.01 to prevent overfitting through regularization to optimize the objective function.

Finally, we begin our training loop, which involves calculating outputs for each batch and the loss by comparing the predicted labels with the true labels. In the end, we've plotted the training loss for each respective epoch to ensure the training process went as per the plan.

Code:

Python3
#Selecting the appropriate training device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNN().to(device)

#Defining the model hyper parameters
num_epochs = 50
learning_rate = 0.001
weight_decay = 0.01
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

#Training process begins
train_loss_list = []
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}:', end = ' ')
    train_loss = 0
    
    #Iterating over the training dataset in batches
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        
        #Extracting images and target labels for the batch being iterated
        images = images.to(device)
        labels = labels.to(device)

        #Calculating the model output and the cross entropy loss
        outputs = model(images)
        loss = criterion(outputs, labels)

        #Updating weights according to calculated loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    #Printing loss for each epoch
    train_loss_list.append(train_loss/len(train_loader))
    print(f"Training loss = {train_loss_list[-1]}")   
    
#Plotting loss for all epochs
plt.plot(range(1,num_epochs+1), train_loss_list)
plt.xlabel("Number of epochs")
plt.ylabel("Training loss")

Output:

Figure 4: Plot of training loss vs. number of epochs

From FIgure 4, we can see that the loss decreases as the epochs increase, indicating a successful training procedure.

Step-5: Calculating the model's accuracy on the test set

Now that our model's trained, we need to check its performance on the test set. To do that, we iterate over the entire test set in batches and calculate the accuracy score by comparing the true and predicted labels for each batch. 

Code:

Python3
test_acc=0
model.eval()

with torch.no_grad():
    #Iterating over the training dataset in batches
    for i, (images, labels) in enumerate(test_loader):
        
        images = images.to(device)
        y_true = labels.to(device)
        
        #Calculating outputs for the batch being iterated
        outputs = model(images)
        
        #Calculated prediction labels from models
        _, y_pred = torch.max(outputs.data, 1)
        
        #Comparing predicted and true labels
        test_acc += (y_pred == y_true).sum().item()
    
    print(f"Test set accuracy = {100 * test_acc / len(test_dataset)} %")

Output:

Figure 5: Accuracy on the test set

Step 6: Generating predictions for sample images in the test set

As shown in Figure 5, our model has achieved an accuracy of nearly 72%. To validate its performance, we can generate some predictions for some sample images. To do that, we take the first five images of the last batch of the test set and plot them using the make_grid utility from torchvision. We then collect their true labels and predictions from the model and show them in the plot's title.

Code:

Python3
#Generating predictions for 'num_images' amount of images from the last batch of test set
num_images = 5
y_true_name = [names[y_true[idx]] for idx in range(num_images)] 
y_pred_name = [names[y_pred[idx]] for idx in range(num_images)] 

#Generating the title for the plot
title = f"Actual labels: {y_true_name}, Predicted labels: {y_pred_name}"

#Finally plotting the images with their actual and predicted labels in the title
plt.imshow(np.transpose(torchvision.utils.make_grid(images[:num_images].cpu(), normalize=True, padding=1).numpy(), (1, 2, 0)))
plt.title(title)
plt.axis("off")

Output:

Figure 6: Actual vs. Predicted labels for 5 sample images from the test set. Note that the labels are in the same order as the respective images, from left to right.

As can be seen from Figure 6, the model is producing correct predictions for all the images except the 2nd one as it misclassifies the dog as a cat!

Conclusion:

This article covered the PyTorch implementation of a simple CNN on the popular CIFAR-10 dataset. The reader is encouraged to play around with the network architecture and model hyperparameters to increase the model accuracy even more!

References

  1. https://cs231n.github.io/convolutional-networks/
  2. https://pytorch.org/docs/stable/index.html
  3. https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

Next Article
Convolutional Neural Network (CNN) Architectures
author
adityasaini70
Improve
Article Tags :
  • Python Programs
  • Machine Learning
  • AI-ML-DS
  • Python-PyTorch
  • Artificial Intelligence
  • python
  • Deep-Learning
Practice Tags :
  • Machine Learning
  • python

Similar Reads

    Deep Learning Tutorial
    Deep Learning tutorial covers the basics and more advanced topics, making it perfect for beginners and those with experience. Whether you're just starting or looking to expand your knowledge, this guide makes it easy to learn about the different technologies of Deep Learning.Deep Learning is a branc
    5 min read

    Introduction to Deep Learning

    Introduction to Deep Learning
    Deep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works?
    7 min read
    Difference Between Artificial Intelligence vs Machine Learning vs Deep Learning
    Artificial Intelligence is basically the mechanism to incorporate human intelligence into machines through a set of rules(algorithm). AI is a combination of two words: "Artificial" meaning something made by humans or non-natural things and "Intelligence" meaning the ability to understand or think ac
    14 min read

    Basic Neural Network

    Difference between ANN and BNN
    Do you ever think of what it's like to build anything like a brain, how these things work, or what they do? Let us look at how nodes communicate with neurons and what are some differences between artificial and biological neural networks. 1. Artificial Neural Network: Artificial Neural Network (ANN)
    3 min read
    Single Layer Perceptron in TensorFlow
    Single Layer Perceptron is inspired by biological neurons and their ability to process information. To understand the SLP we first need to break down the workings of a single artificial neuron which is the fundamental building block of neural networks. An artificial neuron is a simplified computatio
    4 min read
    Multi-Layer Perceptron Learning in Tensorflow
    Multi-Layer Perceptron (MLP) consists of fully connected dense layers that transform input data from one dimension to another. It is called multi-layer because it contains an input layer, one or more hidden layers and an output layer. The purpose of an MLP is to model complex relationships between i
    6 min read
    Deep Neural net with forward and back propagation from scratch - Python
    This article aims to implement a deep neural network from scratch. We will implement a deep neural network containing two input layers, a hidden layer with four units and one output layer. The implementation will go from scratch and the following steps will be implemented. Algorithm:1. Loading and v
    6 min read
    Understanding Multi-Layer Feed Forward Networks
    Let's understand how errors are calculated and weights are updated in backpropagation networks(BPNs). Consider the following network in the below figure. Backpropagation Network (BPN) The network in the above figure is a simple multi-layer feed-forward network or backpropagation network. It contains
    7 min read
    List of Deep Learning Layers
    Deep learning (DL) is characterized by the use of neural networks with multiple layers to model and solve complex problems. Each layer in the neural network plays a unique role in the process of converting input data into meaningful and insightful outputs. The article explores the layers that are us
    7 min read

    Activation Functions

    Activation Functions
    To put it in simple terms, an artificial neuron calculates the 'weighted sum' of its inputs and adds a bias, as shown in the figure below by the net input. Mathematically, \text{Net Input} =\sum \text{(Weight} \times \text{Input)+Bias} Now the value of net input can be any anything from -inf to +inf
    3 min read
    Types Of Activation Function in ANN
    The biological neural network has been modeled in the form of Artificial Neural Networks with artificial neurons simulating the function of a biological neuron. The artificial neuron is depicted in the below picture:Structure of an Artificial NeuronEach neuron consists of three major components: A s
    3 min read
    Activation Functions in Pytorch
    In this article, we will Understand PyTorch Activation Functions. What is an activation function and why to use them?Activation functions are the building blocks of Pytorch. Before coming to types of activation function, let us first understand the working of neurons in the human brain. In the Artif
    5 min read
    Understanding Activation Functions in Depth
    In artificial neural networks, the activation function of a neuron determines its output for a given input. This output serves as the input for subsequent neurons in the network, continuing the process until the network solves the original problem. Consider a binary classification problem, where the
    6 min read

    Artificial Neural Network

    Artificial Neural Networks and its Applications
    As you read this article, which organ in your body is thinking about it? It's the brain, of course! But do you know how the brain works? Well, it has neurons or nerve cells that are the primary units of both the brain and the nervous system. These neurons receive sensory input from the outside world
    9 min read
    Gradient Descent Optimization in Tensorflow
    Gradient descent is an optimization algorithm used to find the values of parameters (coefficients) of a function (f) that minimizes a cost function. In other words, gradient descent is an iterative algorithm that helps to find the optimal solution to a given problem.In this blog, we will discuss gra
    15+ min read
    Choose Optimal Number of Epochs to Train a Neural Network in Keras
    One of the critical issues while training a neural network on the sample data is Overfitting. When the number of epochs used to train a neural network model is more than necessary, the training model learns patterns that are specific to sample data to a great extent. This makes the model incapable t
    6 min read

    Classification

    Python | Classify Handwritten Digits with Tensorflow
    Classifying handwritten digits is the basic problem of the machine learning and can be solved in many ways here we will implement them by using TensorFlowUsing a Linear Classifier Algorithm with tf.contrib.learn linear classifier achieves the classification of handwritten digits by making a choice b
    4 min read
    Train a Deep Learning Model With Pytorch
    Neural Network is a type of machine learning model inspired by the structure and function of human brain. It consists of layers of interconnected nodes called neurons which process and transmit information. Neural networks are particularly well-suited for tasks such as image and speech recognition,
    6 min read

    Regression

    Linear Regression using PyTorch
    Linear Regression is a very commonly used statistical method that allows us to determine and study the relationship between two continuous variables. The various properties of linear regression and its Python implementation have been covered in this article previously. Now, we shall find out how to
    4 min read
    Linear Regression Using Tensorflow
    We will briefly summarize Linear Regression before implementing it using TensorFlow. Since we will not get into the details of either Linear Regression or Tensorflow, please read the following articles for more details: Linear Regression (Python Implementation)Introduction to TensorFlowIntroduction
    6 min read

    Hyperparameter tuning

    Hyperparameter Tuning
    Hyperparameter tuning is the process of selecting the optimal values for a machine learning model's hyperparameters. These are typically set before the actual training process begins and control aspects of the learning process itself. They influence the model's performance its complexity and how fas
    7 min read

    Introduction to Convolution Neural Network

    Introduction to Convolution Neural Network
    Convolutional Neural Network (CNN) is an advanced version of artificial neural networks (ANNs), primarily designed to extract features from grid-like matrix datasets. This is particularly useful for visual datasets such as images or videos, where data patterns play a crucial role. CNNs are widely us
    8 min read
    Digital Image Processing Basics
    Digital Image Processing means processing digital image by means of a digital computer. We can also say that it is a use of computer algorithms, in order to get enhanced image either to extract some useful information. Digital image processing is the use of algorithms and mathematical models to proc
    7 min read
    Difference between Image Processing and Computer Vision
    Image processing and Computer Vision both are very exciting field of Computer Science. Computer Vision: In Computer Vision, computers or machines are made to gain high-level understanding from the input digital images or videos with the purpose of automating tasks that the human visual system can do
    2 min read
    CNN | Introduction to Pooling Layer
    Pooling layer is used in CNNs to reduce the spatial dimensions (width and height) of the input feature maps while retaining the most important information. It involves sliding a two-dimensional filter over each channel of a feature map and summarizing the features within the region covered by the fi
    5 min read
    CIFAR-10 Image Classification in TensorFlow
    Prerequisites:Image ClassificationConvolution Neural Networks including basic pooling, convolution layers with normalization in neural networks, and dropout.Data Augmentation.Neural Networks.Numpy arrays.In this article, we are going to discuss how to classify images using TensorFlow. Image Classifi
    8 min read
    Implementation of a CNN based Image Classifier using PyTorch
    Introduction: Introduced in the 1980s by Yann LeCun, Convolution Neural Networks(also called CNNs or ConvNets) have come a long way. From being employed for simple digit classification tasks, CNN-based architectures are being used very profoundly over much Deep Learning and Computer Vision-related t
    9 min read
    Convolutional Neural Network (CNN) Architectures
    Convolutional Neural Network(CNN) is a neural network architecture in Deep Learning, used to recognize the pattern from structured arrays. However, over many years, CNN architectures have evolved. Many variants of the fundamental CNN Architecture This been developed, leading to amazing advances in t
    11 min read
    Object Detection vs Object Recognition vs Image Segmentation
    Object Recognition: Object recognition is the technique of identifying the object present in images and videos. It is one of the most important applications of machine learning and deep learning. The goal of this field is to teach machines to understand (recognize) the content of an image just like
    5 min read
    YOLO v2 - Object Detection
    In terms of speed, YOLO is one of the best models in object recognition, able to recognize objects and process frames at the rate up to 150 FPS for small networks. However, In terms of accuracy mAP, YOLO was not the state of the art model but has fairly good Mean average Precision (mAP) of 63% when
    7 min read

    Recurrent Neural Network

    Natural Language Processing (NLP) Tutorial
    Natural Language Processing (NLP) is the branch of Artificial Intelligence (AI) that gives the ability to machine understand and process human languages. Human languages can be in the form of text or audio format.Applications of NLPThe applications of Natural Language Processing are as follows:Voice
    5 min read
    Introduction to NLTK: Tokenization, Stemming, Lemmatization, POS Tagging
    Natural Language Toolkit (NLTK) is one of the largest Python libraries for performing various Natural Language Processing tasks. From rudimentary tasks such as text pre-processing to tasks like vectorized representation of text - NLTK's API has covered everything. In this article, we will accustom o
    5 min read
    Word Embeddings in NLP
    Word Embeddings are numeric representations of words in a lower-dimensional space, that capture semantic and syntactic information. They play a important role in Natural Language Processing (NLP) tasks. Here, we'll discuss some traditional and neural approaches used to implement Word Embeddings, suc
    14 min read
    Introduction to Recurrent Neural Networks
    Recurrent Neural Networks (RNNs) differ from regular neural networks in how they process information. While standard neural networks pass information in one direction i.e from input to output, RNNs feed information back into the network at each step.Imagine reading a sentence and you try to predict
    10 min read
    Recurrent Neural Networks Explanation
    Today, different Machine Learning techniques are used to handle different types of data. One of the most difficult types of data to handle and the forecast is sequential data. Sequential data is different from other types of data in the sense that while all the features of a typical dataset can be a
    8 min read
    Sentiment Analysis with an Recurrent Neural Networks (RNN)
    Recurrent Neural Networks (RNNs) are used in sequence tasks such as sentiment analysis due to their ability to capture context from sequential data. In this article we will be apply RNNs to analyze the sentiment of customer reviews from Swiggy food delivery platform. The goal is to classify reviews
    5 min read
    Short term Memory
    In the wider community of neurologists and those who are researching the brain, It is agreed that two temporarily distinct processes contribute to the acquisition and expression of brain functions. These variations can result in long-lasting alterations in neuron operations, for instance through act
    5 min read
    What is LSTM - Long Short Term Memory?
    Long Short-Term Memory (LSTM) is an enhanced version of the Recurrent Neural Network (RNN) designed by Hochreiter and Schmidhuber. LSTMs can capture long-term dependencies in sequential data making them ideal for tasks like language translation, speech recognition and time series forecasting. Unlike
    5 min read
    Long Short Term Memory Networks Explanation
    Prerequisites: Recurrent Neural Networks To solve the problem of Vanishing and Exploding Gradients in a Deep Recurrent Neural Network, many variations were developed. One of the most famous of them is the Long Short Term Memory Network(LSTM). In concept, an LSTM recurrent unit tries to "remember" al
    7 min read
    LSTM - Derivation of Back propagation through time
    Long Short-Term Memory (LSTM) are a type of neural network designed to handle long-term dependencies by handling the vanishing gradient problem. One of the fundamental techniques used to train LSTMs is Backpropagation Through Time (BPTT) where we have sequential data. In this article we see how BPTT
    4 min read
    Text Generation using Recurrent Long Short Term Memory Network
    LSTMs are a type of neural network that are well-suited for tasks involving sequential data such as text generation. They are particularly useful because they can remember long-term dependencies in the data which is crucial when dealing with text that often has context that spans over multiple words
    4 min read
geeksforgeeks-footer-logo
Corporate & Communications Address:
A-143, 7th Floor, Sovereign Corporate Tower, Sector- 136, Noida, Uttar Pradesh (201305)
Registered Address:
K 061, Tower K, Gulshan Vivante Apartment, Sector 137, Noida, Gautam Buddh Nagar, Uttar Pradesh, 201305
GFG App on Play Store GFG App on App Store
Advertise with us
  • Company
  • About Us
  • Legal
  • Privacy Policy
  • In Media
  • Contact Us
  • Advertise with us
  • GFG Corporate Solution
  • Placement Training Program
  • Languages
  • Python
  • Java
  • C++
  • PHP
  • GoLang
  • SQL
  • R Language
  • Android Tutorial
  • Tutorials Archive
  • DSA
  • Data Structures
  • Algorithms
  • DSA for Beginners
  • Basic DSA Problems
  • DSA Roadmap
  • Top 100 DSA Interview Problems
  • DSA Roadmap by Sandeep Jain
  • All Cheat Sheets
  • Data Science & ML
  • Data Science With Python
  • Data Science For Beginner
  • Machine Learning
  • ML Maths
  • Data Visualisation
  • Pandas
  • NumPy
  • NLP
  • Deep Learning
  • Web Technologies
  • HTML
  • CSS
  • JavaScript
  • TypeScript
  • ReactJS
  • NextJS
  • Bootstrap
  • Web Design
  • Python Tutorial
  • Python Programming Examples
  • Python Projects
  • Python Tkinter
  • Python Web Scraping
  • OpenCV Tutorial
  • Python Interview Question
  • Django
  • Computer Science
  • Operating Systems
  • Computer Network
  • Database Management System
  • Software Engineering
  • Digital Logic Design
  • Engineering Maths
  • Software Development
  • Software Testing
  • DevOps
  • Git
  • Linux
  • AWS
  • Docker
  • Kubernetes
  • Azure
  • GCP
  • DevOps Roadmap
  • System Design
  • High Level Design
  • Low Level Design
  • UML Diagrams
  • Interview Guide
  • Design Patterns
  • OOAD
  • System Design Bootcamp
  • Interview Questions
  • Inteview Preparation
  • Competitive Programming
  • Top DS or Algo for CP
  • Company-Wise Recruitment Process
  • Company-Wise Preparation
  • Aptitude Preparation
  • Puzzles
  • School Subjects
  • Mathematics
  • Physics
  • Chemistry
  • Biology
  • Social Science
  • English Grammar
  • Commerce
  • World GK
  • GeeksforGeeks Videos
  • DSA
  • Python
  • Java
  • C++
  • Web Development
  • Data Science
  • CS Subjects
@GeeksforGeeks, Sanchhaya Education Private Limited, All rights reserved
We use cookies to ensure you have the best browsing experience on our website. By using our site, you acknowledge that you have read and understood our Cookie Policy & Privacy Policy
Lightbox
Improvement
Suggest Changes
Help us improve. Share your suggestions to enhance the article. Contribute your expertise and make a difference in the GeeksforGeeks portal.
geeksforgeeks-suggest-icon
Create Improvement
Enhance the article with your expertise. Contribute to the GeeksforGeeks community and help create better learning resources for all.
geeksforgeeks-improvement-icon
Suggest Changes
min 4 words, max Words Limit:1000

Thank You!

Your suggestions are valuable to us.

'); // $('.spinner-loading-overlay').show(); let script = document.createElement('script'); script.src = 'https://assets.geeksforgeeks.org/v2/editor-prod/static/js/bundle.min.js'; script.defer = true document.head.appendChild(script); script.onload = function() { suggestionModalEditor() //to add editor in suggestion modal if(loginData && loginData.premiumConsent){ personalNoteEditor() //to load editor in personal note } } script.onerror = function() { if($('.editorError').length){ $('.editorError').remove(); } var messageDiv = $('
').text('Editor not loaded due to some issues'); $('#suggestion-section-textarea').append(messageDiv); $('.suggest-bottom-btn').hide(); $('.suggestion-section').hide(); editorLoaded = false; } }); //suggestion modal editor function suggestionModalEditor(){ // editor params const params = { data: undefined, plugins: ["BOLD", "ITALIC", "UNDERLINE", "PREBLOCK"], } // loading editor try { suggestEditorInstance = new GFGEditorWrapper("suggestion-section-textarea", params, { appNode: true }) suggestEditorInstance._createEditor("") $('.spinner-loading-overlay:eq(0)').remove(); editorLoaded = true; } catch (error) { $('.spinner-loading-overlay:eq(0)').remove(); editorLoaded = false; } } //personal note editor function personalNoteEditor(){ // editor params const params = { data: undefined, plugins: ["UNDO", "REDO", "BOLD", "ITALIC", "NUMBERED_LIST", "BULLET_LIST", "TEXTALIGNMENTDROPDOWN"], placeholderText: "Description to be......", } // loading editor try { let notesEditorInstance = new GFGEditorWrapper("pn-editor", params, { appNode: true }) notesEditorInstance._createEditor(loginData&&loginData.user_personal_note?loginData.user_personal_note:"") $('.spinner-loading-overlay:eq(0)').remove(); editorLoaded = true; } catch (error) { $('.spinner-loading-overlay:eq(0)').remove(); editorLoaded = false; } } var lockedCasesHtml = `You can suggest the changes for now and it will be under 'My Suggestions' Tab on Write.

You will be notified via email once the article is available for improvement. Thank you for your valuable feedback!`; var badgesRequiredHtml = `It seems that you do not meet the eligibility criteria to create improvements for this article, as only users who have earned specific badges are permitted to do so.

However, you can still create improvements through the Pick for Improvement section.`; jQuery('.improve-header-sec-child').on('click', function(){ jQuery('.improve-modal--overlay').hide(); $('.improve-modal--suggestion').hide(); jQuery('#suggestion-modal-alert').hide(); }); $('.suggest-change_wrapper, .locked-status--impove-modal .improve-bottom-btn').on('click',function(){ // when suggest changes option is clicked $('.ContentEditable__root').text(""); $('.suggest-bottom-btn').html("Suggest changes"); $('.thank-you-message').css("display","none"); $('.improve-modal--improvement').hide(); $('.improve-modal--suggestion').show(); $('#suggestion-section-textarea').show(); jQuery('#suggestion-modal-alert').hide(); if(suggestEditorInstance !== null){ suggestEditorInstance.setEditorValue(""); } $('.suggestion-section').css('display', 'block'); jQuery('.suggest-bottom-btn').css("display","block"); }); $('.create-improvement_wrapper').on('click',function(){ // when create improvement option clicked then improvement reason will be shown if(loginData && loginData.isLoggedIn) { $('body').append('
'); $('.spinner-loading-overlay').show(); jQuery.ajax({ url: writeApiUrl + 'create-improvement-post/?v=1', type: "POST", contentType: 'application/json; charset=utf-8', dataType: 'json', xhrFields: { withCredentials: true }, data: JSON.stringify({ gfg_id: post_id }), success:function(result) { $('.spinner-loading-overlay:eq(0)').remove(); $('.improve-modal--overlay').hide(); $('.unlocked-status--improve-modal-content').css("display","none"); $('.create-improvement-redirection-to-write').attr('href',writeUrl + 'improve-post/' + `${result.id}` + '/', '_blank'); $('.create-improvement-redirection-to-write')[0].click(); }, error:function(e) { showErrorMessage(e.responseJSON,e.status) }, }); } else { if(loginData && !loginData.isLoggedIn) { $('.improve-modal--overlay').hide(); if ($('.header-main__wrapper').find('.header-main__signup.login-modal-btn').length) { $('.header-main__wrapper').find('.header-main__signup.login-modal-btn').click(); } return; } } }); $('.left-arrow-icon_wrapper').on('click',function(){ if($('.improve-modal--suggestion').is(":visible")) $('.improve-modal--suggestion').hide(); else{ } $('.improve-modal--improvement').show(); }); const showErrorMessage = (result,statusCode) => { if(!result) return; $('.spinner-loading-overlay:eq(0)').remove(); if(statusCode == 403) { $('.improve-modal--improve-content.error-message').html(result.message); jQuery('.improve-modal--overlay').show(); jQuery('.improve-modal--improvement').show(); $('.locked-status--impove-modal').css("display","block"); $('.unlocked-status--improve-modal-content').css("display","none"); $('.improve-modal--improvement').attr("status","locked"); return; } } function suggestionCall() { var editorValue = suggestEditorInstance.getValue(); var suggest_val = $(".ContentEditable__root").find("[data-lexical-text='true']").map(function() { return $(this).text().trim(); }).get().join(' '); suggest_val = suggest_val.replace(/\s+/g, ' ').trim(); var array_String= suggest_val.split(" ") //array of words var gCaptchaToken = $("#g-recaptcha-response-suggestion-form").val(); var error_msg = false; if(suggest_val != "" && array_String.length >=4){ if(editorValue.length <= 2000){ var payload = { "gfg_post_id" : `${post_id}`, "suggestion" : `${editorValue}`, } if(!loginData || !loginData.isLoggedIn) // User is not logged in payload["g-recaptcha-token"] = gCaptchaToken jQuery.ajax({ type:'post', url: "https://apiwrite.geeksforgeeks.org/suggestions/auth/create/", xhrFields: { withCredentials: true }, crossDomain: true, contentType:'application/json', data: JSON.stringify(payload), success:function(data) { if(!loginData || !loginData.isLoggedIn) { grecaptcha.reset(); } jQuery('.spinner-loading-overlay:eq(0)').remove(); jQuery('.suggest-bottom-btn').css("display","none"); $('#suggestion-section-textarea').hide() $('.thank-you-message').css('display', 'flex'); $('.suggestion-section').css('display', 'none'); jQuery('#suggestion-modal-alert').hide(); }, error:function(data) { if(!loginData || !loginData.isLoggedIn) { grecaptcha.reset(); } jQuery('.spinner-loading-overlay:eq(0)').remove(); jQuery('#suggestion-modal-alert').html("Something went wrong."); jQuery('#suggestion-modal-alert').show(); error_msg = true; } }); } else{ jQuery('.spinner-loading-overlay:eq(0)').remove(); jQuery('#suggestion-modal-alert').html("Minimum 4 Words and Maximum Words limit is 1000."); jQuery('#suggestion-modal-alert').show(); jQuery('.ContentEditable__root').focus(); error_msg = true; } } else{ jQuery('.spinner-loading-overlay:eq(0)').remove(); jQuery('#suggestion-modal-alert').html("Enter atleast four words !"); jQuery('#suggestion-modal-alert').show(); jQuery('.ContentEditable__root').focus(); error_msg = true; } if(error_msg){ setTimeout(() => { jQuery('.ContentEditable__root').focus(); jQuery('#suggestion-modal-alert').hide(); }, 3000); } } document.querySelector('.suggest-bottom-btn').addEventListener('click', function(){ jQuery('body').append('
'); jQuery('.spinner-loading-overlay').show(); if(loginData && loginData.isLoggedIn) { suggestionCall(); return; } // script for grecaptcha loaded in loginmodal.html and call function to set the token setGoogleRecaptcha(); }); $('.improvement-bottom-btn.create-improvement-btn').click(function() { //create improvement button is clicked $('body').append('
'); $('.spinner-loading-overlay').show(); // send this option via create-improvement-post api jQuery.ajax({ url: writeApiUrl + 'create-improvement-post/?v=1', type: "POST", contentType: 'application/json; charset=utf-8', dataType: 'json', xhrFields: { withCredentials: true }, data: JSON.stringify({ gfg_id: post_id }), success:function(result) { $('.spinner-loading-overlay:eq(0)').remove(); $('.improve-modal--overlay').hide(); $('.create-improvement-redirection-to-write').attr('href',writeUrl + 'improve-post/' + `${result.id}` + '/', '_blank'); $('.create-improvement-redirection-to-write')[0].click(); }, error:function(e) { showErrorMessage(e.responseJSON,e.status); }, }); });
"For an ad-free experience and exclusive features, subscribe to our Premium Plan!"
Continue without supporting
`; $('body').append(adBlockerModal); $('body').addClass('body-for-ad-blocker'); const modal = document.getElementById("adBlockerModal"); modal.style.display = "block"; } function handleAdBlockerClick(type){ if(type == 'disabled'){ window.location.reload(); } else if(type == 'info'){ document.getElementById("ad-blocker-div").style.display = "none"; document.getElementById("ad-blocker-info-div").style.display = "flex"; handleAdBlockerIconClick(0); } } var lastSelected= null; //Mapping of name and video URL with the index. const adBlockerVideoMap = [ ['Ad Block Plus','https://media.geeksforgeeks.org/auth-dashboard-uploads/abp-blocker-min.mp4'], ['Ad Block','https://media.geeksforgeeks.org/auth-dashboard-uploads/Ad-block-min.mp4'], ['uBlock Origin','https://media.geeksforgeeks.org/auth-dashboard-uploads/ub-blocke-min.mp4'], ['uBlock','https://media.geeksforgeeks.org/auth-dashboard-uploads/U-blocker-min.mp4'], ] function handleAdBlockerIconClick(currSelected){ const videocontainer = document.getElementById('ad-blocker-info-div-gif'); const videosource = document.getElementById('ad-blocker-info-div-gif-src'); if(lastSelected != null){ document.getElementById("ad-blocker-info-div-icons-"+lastSelected).style.backgroundColor = "white"; document.getElementById("ad-blocker-info-div-icons-"+lastSelected).style.borderColor = "#D6D6D6"; } document.getElementById("ad-blocker-info-div-icons-"+currSelected).style.backgroundColor = "#D9D9D9"; document.getElementById("ad-blocker-info-div-icons-"+currSelected).style.borderColor = "#848484"; document.getElementById('ad-blocker-info-div-name-span').innerHTML = adBlockerVideoMap[currSelected][0] videocontainer.pause(); videosource.setAttribute('src', adBlockerVideoMap[currSelected][1]); videocontainer.load(); videocontainer.play(); lastSelected = currSelected; }

What kind of Experience do you want to share?

Interview Experiences
Admission Experiences
Career Journeys
Work Experiences
Campus Experiences
Competitive Exam Experiences