Object Detection in React Native App using TensorFlow.js

In this post, we are going to build a React Native app for detecting objects in an image using TensorFlow.js.

TensorFlow.js is a JavaScript library for training and deploying machine learning models in the browser and in Node.js. It provides many pre-trained models that ease the time-consuming task of training a new machine learning model from scratch.

Solution Architecture:

The captured image or selected from the file system is sent to API Gateway where it triggers the Lambda function which will store the image in S3 Bucket and returns the stored image URL.

Installing Dependencies:

Go to React Native Docs, select React Native CLI Quickstart. As we are building an android application, select Development OS and the Target OS as Android.

Follow the docs for installing dependencies, then create a new React Native Application. Use the command-line interface to generate a new React Native project called ObjectDetection.

npx react-native init ObjectDetection

Preparing the Android device:

We shall need an Android device to run our React Native Android app. If you have a physical Android device, you can use it for development by connecting it to your computer using a USB cable and following the instructions here.

Now go to the command line and run following command inside your React Native app directory.

cd ObjectDetection
react-native run-android

If everything is set up correctly, you should see your new app running on your physical device. Next, we need to install the react-native-image-picker package to capture or select an image. To install the package run the following command inside the project directory.

npm install react-native-image-picker --save

We would also need a few other packages as well. To install them run the following commands inside the project directory.

npm install expo-permissions --save
npm install expo-constants --save
npm install jpeg-js --save

The expo-permissions package, allows us to use prompt for various permissions to access device sensors, device cameras, etc.

The expo-constants package provides system information that remains constant throughout the lifetime of the app.

The jpeg-js package is used to decode the data from the image.

Integrating TensorFlow.js in our React Native App:

Follow this link to integrate TensorFlow.js in our React Native App. After that, we must also install @tensorflow-models/mobilenet. To install these run the following command inside the project directory.

npm install @tensorflow-models/mobilenet --save

We also need to set up an API in the AWS console and also create a Lambda function which will store the image in S3 Bucket and will return the stored image URL.

API Creation in AWS Console:

Before going further, create an API in your AWS console following Working with API Gateway paragraph in the following post:

https://medium.com/zenofai/serverless-web-application-architecture-using-react-with-amplify-part1-5b4d89f384f7

Once you are done with creating API come back to the React Native application. Go to your project directory and replace your App.js file with the following code.

App.js

import React from 'react';
import {
  StyleSheet,
  Text,
  View,
  ScrollView,
  TouchableHighlight,
  Image
} from 'react-native';
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import { fetch } from '@tensorflow/tfjs-react-native';
import Constants from 'expo-constants';
import * as Permissions from 'expo-permissions';
import * as jpeg from 'jpeg-js';
import ImagePicker from "react-native-image-picker";
import Amplify, { API } from "aws-amplify";

Amplify.configure({
  API: {
    endpoints: [
      {
        name: "<Your-API-Name>",
        endpoint: "<Your-API-Endpoint>"
      }
    ]
  }
});

class App extends React.Component {

  state = {
    isTfReady: false,
    isModelReady: false,
    predictions: null,
    image: null,
    base64String: '',
    capturedImage: '',
    imageSubmitted: false,
    s3ImageUrl: ''
  }

  async componentDidMount() {
    // Wait for tf to be ready.
    await tf.ready();
    // Signal to the app that tensorflow.js can now be used.
    this.setState({
      isTfReady: true
    });
    this.model = await mobilenet.load();
    this.setState({ isModelReady: true });
    this.askCameraPermission();
  }

  askCameraPermission = async () => {
    if (Constants.platform.android) {
      const { status } = await Permissions.askAsync(Permissions.CAMERA_ROLL);
      if (status !== 'granted') {
        alert('Please provide camera roll permissions to make this work!');
      }
    }
  }

  imageToTensor(rawImageData) {
    const TO_UINT8ARRAY = true;
    const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY);
    // Drop the alpha channel info for mobilenet
    const buffer = new Uint8Array(width * height * 3);
    let offset = 0 ; // offset into original data
    for (let i = 0; i < buffer.length; i += 3) {
      buffer[i] = data[offset];
      buffer[i + 1] = data[offset + 1];
      buffer[i + 2] = data[offset + 2];

      offset += 4;
    }

    return tf.tensor3d(buffer, [height, width, 3]);
  }

  classifyImage = async () => {
    try {
      const imageAssetPath = this.state.s3ImageUrl;
      const response = await fetch(imageAssetPath, {}, { isBinary: true });
      const rawImageData = await response.arrayBuffer();
      const imageTensor = this.imageToTensor(rawImageData);
      const predictions = await this.model.classify(imageTensor);
      this.setState({ predictions });
    } catch (error) {
      console.log(error);
    }
  }
  
  renderPrediction = prediction => {
    return (
      <Text key={prediction.className} style={styles.text}>
        {prediction.className}
      </Text>
    )
  }

  captureImageButtonHandler = () => {
    this.setState({
      imageSubmitted: false,
      predictions: null
    });
    ImagePicker.showImagePicker({ title: "Pick an Image", maxWidth: 800, maxHeight: 600 }, (response) => {
      if (response.didCancel) {
        console.log('User cancelled image picker');
      } else if (response.error) {
        console.log('ImagePicker Error: ', response.error);
      } else if (response.customButton) {
        console.log('User tapped custom button: ', response.customButton);
      } else {
        // You can also display the image using data:
        const source = { uri: 'data:image/jpeg;base64,' + response.data };
        this.setState({ capturedImage: response.uri, base64String: source.uri });
      }
    });
  }

  submitButtonHandler = () => {
    if (this.state.capturedImage == '' || this.state.capturedImage == undefined || this.state.capturedImage == null) {
      alert("Please Capture the Image");
    } else {
      this.setState({
        imageSubmitted: true
      });
      const apiName = "<Your-API-Name>";
      const path = "<Your-API-Path>";
      const init = {
        headers: {
          'Accept': 'application/json',
          "Content-Type": "application/x-amz-json-1.1"
        },
        body: JSON.stringify({
          Image: this.state.base64String,
          name: "testImage.jpg"
        })
      }

      API.post(apiName, path, init).then(response => {
        this.setState({
          s3ImageUrl: response
        });
        { this.state.s3ImageUrl !== '' ? this.classifyImage() : '' };
      });
    }
  }

  render() {
    const { isModelReady, predictions } = this.state
    const capturedImageUri = this.state.capturedImage;
    const imageSubmittedCheck = this.state.imageSubmitted;

    return (
      <View style={styles.MainContainer}>
        <ScrollView>
          <Text style={{ fontSize: 20, color: "#000", textAlign: 'center', marginBottom: 15, marginTop: 10 }}>Object Detection</Text>

          {this.state.capturedImage !== "" && <View style={styles.imageholder} >
            <Image source={{ uri: this.state.capturedImage }} style={styles.previewImage} />
          </View>}

          {this.state.capturedImage != '' && imageSubmittedCheck && (
            <View style={styles.predictionWrapper}>
              {isModelReady && capturedImageUri && imageSubmittedCheck && (
                <Text style={styles.text}>
                  Predictions: {predictions ? '' : 'Loading...'}
                </Text>
              )}
              {isModelReady &&
                predictions &&
                predictions.map(p => this.renderPrediction(p))}
            </View>
          )
          }

          <TouchableHighlight style={[styles.buttonContainer, styles.captureButton]} onPress={this.captureImageButtonHandler}>
            <Text style={styles.buttonText}>Capture Image</Text>
          </TouchableHighlight>

          <TouchableHighlight style={[styles.buttonContainer, styles.submitButton]} onPress={this.submitButtonHandler}>
            <Text style={styles.buttonText}>Submit</Text>
          </TouchableHighlight>

        </ScrollView>
      </View>
    );
  }
}

const styles = StyleSheet.create({
  MainContainer: {
    flex: 1,
    backgroundColor: 'white',
  },
  text: {
    color: '#000000',
    fontSize: 16
  },
  predictionWrapper: {
    height: 100,
    width: '100%',
    flexDirection: 'column',
    alignItems: 'center'
  },
  buttonContainer: {
    height: 45,
    flexDirection: 'row',
    alignItems: 'center',
    justifyContent: 'center',
    marginBottom: 20,
    width: "80%",
    borderRadius: 30,
    marginTop: 20,
    marginLeft: 30,
  },
  captureButton: {
    backgroundColor: "#337ab7",
    width: 350,
  },
  buttonText: {
    color: 'white',
    fontWeight: 'bold',
  },
  submitButton: {
    backgroundColor: "#C0C0C0",
    width: 350,
    marginTop: 5,
  },
  imageholder: {
    borderWidth: 1,
    borderColor: "grey",
    backgroundColor: "#eee",
    width: "50%",
    height: 150,
    marginTop: 10,
    marginLeft: 100,
    flexDirection: 'row',
    alignItems: 'center'
  },
  previewImage: {
    width: "100%",
    height: "100%",
  }
})

export default App;

In the above code, we are configuring amplify with the API name and Endpoint URL that you created as shown below.

Amplify.configure({
 API: {
   endpoints: [
     {
       name: '<Your-API-Name>, 
       endpoint: '<Your-API-Endpoint-URL>',
     },
   ],
 },
});

The capture button will trigger the captureImageButtonHandler function. It will then ask the user to take a picture or select an image from the file system. We will store that image in the state as shown below.

captureImageButtonHandler = () => {
    this.setState({
      imageSubmitted: false,
      predictions: null
    });

    ImagePicker.showImagePicker({ title: "Pick an Image", maxWidth: 800, maxHeight: 600 }, (response) => {
      if (response.didCancel) {
        console.log('User cancelled image picker');
      } else if (response.error) {
        console.log('ImagePicker Error: ', response.error);
      } else if (response.customButton) {
        console.log('User tapped custom button: ', response.customButton);
      } else {
        const source = { uri: 'data:image/jpeg;base64,' + response.data };
        this.setState({ capturedImage: response.uri, base64String: source.uri });
      }
    });
  }

After capturing the image we will preview that image. 

By Clicking on the submit button, the submitButtonHandler function will get triggered where we will send the image to the endpoint as shown below.

submitButtonHandler = () => {
    if (this.state.capturedImage == '' || this.state.capturedImage == undefined || this.state.capturedImage == null) {
      alert("Please Capture the Image");
    } else {
      this.setState({
        imageSubmitted: true
      });
      const apiName = "<Your-API-Name>";
      const path = "<Path-to-your-API>";
      const init = {
        headers: {
          'Accept': 'application/json',
          "Content-Type": "application/x-amz-json-1.1"
        },
        body: JSON.stringify({
          Image: this.state.base64String,
          name: "testImage.jpg"
        })
      }

      API.post(apiName, path, init).then(response => {
        this.setState({
          s3ImageUrl: response
        });
        { this.state.s3ImageUrl !== '' ? this.classifyImage() : '' };
      });
    }
  }

After submitting the image, the API gateway triggers the Lambda function. The Lambda function stores the submitted image in the S3 Bucket and returns its URL. Which is then sent back in the response. The received URL is then set to the state variable and classifyImage function is called as shown above.

The classifyImage function will read the raw data from the image and yield results upon classification in the form of Predictions. The image is going to be read from S3, as we stored its URL in the state of the app component we shall use it. Similarly, the results yielded by this asynchronous method must also be saved. We are storing them in the predictions variable.

classifyImage = async () => {
    try {
      const imageAssetPath = this.state.s3ImageUrl;
      const response = await fetch(imageAssetPath, {}, { isBinary: true });
      const rawImageData = await response.arrayBuffer();
      const imageTensor = this.imageToTensor(rawImageData);
      const predictions = await this.model.classify(imageTensor);
      this.setState({ predictions });
    } catch (error) {
      console.log(error);
    }
  }

The package jpeg-js decodes the width, height, and binary data from the image inside the handler method imageToTensor, which accepts a parameter of the raw image data.

imageToTensor(rawImageData) {
    const TO_UINT8ARRAY = true;
    const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY);
    // Drop the alpha channel info for mobilenet
    const buffer = new Uint8Array(width * height * 3);
    let offset = 0 ; // offset into original data
    for (let i = 0; i < buffer.length; i += 3) {
      buffer[i] = data[offset];
      buffer[i + 1] = data[offset + 1];
      buffer[i + 2] = data[offset + 2];

      offset += 4;
    }

    return tf.tensor3d(buffer, [height, width, 3]);
  }

Here the TO_UINT8ARRAY array represents an array of 8-bit unsigned integers.

Lambda Function:

Add the below code to your Lambda function that you created earlier in your AWS Console. The below Lambda function stores the captured image in S3 Bucket and returns the URL of the image.

const AWS = require('aws-sdk');
var s3BucketName = "<Your-S3-BucketName>";
var s3Bucket = new AWS.S3( { params: {Bucket: s3BucketName, Region: "<Your-S3-Bucket-Region>"} } );

exports.handler = (event, context, callback) => {
    let parsedData = JSON.parse(event);
    let encodedImage = parsedData.Image;
    var filePath = parsedData.name;
    let buf = new Buffer(encodedImage.replace(/^data:image\/\w+;base64,/, ""),'base64');
    var data = {
        Key: filePath, 
        Body: buf,
        ContentEncoding: 'base64',
        ContentType: 'image/jpeg'
    };
    s3Bucket.putObject(data, function(err, data){
        if (err) { 
            callback(err, null);
        } else {
            var s3Url = "https://" + s3BucketName + '.' + "s3.amazonaws.com/" + filePath;
            callback(null, s3Url);
        }
    });
};

Running the App:

Run the application by executing the react-native run-android command from the terminal window. Below are the screenshots of the app running on an android device.

That’s all folks! I hope it was helpful. Any queries, please drop them in the comments section.

This story is authored by Dheeraj Kumar. He is a software engineer specializing in React Native and React based frontend development.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.