TensorFlow.js

JavaScript machine learning library running in browser and Node.js environments. Enables browser inference execution, WebGL/WebGPU acceleration, and privacy-preserving AI implementation. Standard solution for client-side ML.

JavaScriptTypeScriptMachine LearningDeep LearningBrowserNode.jsWebGLWebGPU

GitHub Overview

tensorflow/tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.

Stars18,929
Watchers325
Forks1,985
Created:March 5, 2018
Language:TypeScript
License:Apache License 2.0

Topics

deep-learningdeep-neural-networkgpu-accelerationjavascriptmachine-learningneural-networktypescriptwasmweb-assemblywebgl

Star History

tensorflow/tfjs Star History
Data as of: 8/13/2025, 01:43 AM

Framework

TensorFlow.js

Overview

TensorFlow.js is a JavaScript library that enables machine learning model training and inference in browsers and Node.js environments.

Details

TensorFlow.js is an open-source library developed by Google as part of the TensorFlow ecosystem, designed to run machine learning models in JavaScript environments. Released in 2018, it enables real-time machine learning in browsers, server-side inference in Node.js, and ML capabilities in mobile applications. Key features include GPU acceleration through WebGL and WebGPU, a rich library of pre-trained models, conversion capabilities from Python TensorFlow models, and in-browser model training functionality. It's widely used for image recognition, audio processing, natural language processing, real-time object detection, and pose estimation. The advantages of privacy protection (data doesn't leave the device), low latency (no network communication required), and cross-platform support (browsers, Node.js, React Native) have made it the standard for implementing machine learning in web applications.

Pros and Cons

Pros

  • Client-side Execution: Machine learning runs directly in browsers without servers
  • Privacy Protection: User data never leaves the device for external transmission
  • Low Latency: Real-time processing without network communication requirements
  • GPU Acceleration: High-speed processing through WebGL/WebGPU
  • Rich Pre-trained Models: High-quality model library ready for immediate use
  • Python TensorFlow Compatibility: Easy conversion and use of existing TensorFlow models
  • Transfer Learning Support: Model fine-tuning capabilities directly in browsers

Cons

  • Model Size Limitations: Large models present challenges in download time and memory usage
  • Browser Dependencies: Performance depends on browser GPU capabilities and WebGL/WebGPU support
  • Debugging Complexity: Model debugging in browser environments can be challenging
  • Performance Constraints: Processing speed inferior compared to server-side GPUs
  • Model Conversion Limitations: Some TensorFlow operations are not supported during conversion
  • Memory Management: Requires both JavaScript garbage collection and manual tensor management

Key Links

Code Examples

Hello World

// Basic TensorFlow.js usage
import * as tf from '@tensorflow/tfjs';

console.log("TensorFlow.js version:", tf.version.tfjs);

// Basic tensor operations
const a = tf.tensor([[1, 2], [3, 4]]);
const b = tf.tensor([[5, 6], [7, 8]]);

console.log("Tensor a:");
a.print();

// Mathematical operations
const c = a.add(b);
console.log("a + b:");
c.print();

// Matrix multiplication
const d = tf.matMul(a, b);
console.log("Matrix multiplication a × b:");
d.print();

// Resource cleanup to prevent memory leaks
a.dispose();
b.dispose();
c.dispose();
d.dispose();

Browser Model Loading and Inference

import * as tf from '@tensorflow/tfjs';

// Loading and using pre-trained models
async function loadAndUseModel() {
  try {
    // Load MobileNet model
    const model = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2');
    
    // Get image element
    const img = document.getElementById('image');
    
    // Preprocess image
    const preprocessed = tf.browser.fromPixels(img)
      .resizeBilinear([224, 224])  // Resize to 224x224
      .expandDims(0)               // Add batch dimension
      .toFloat()
      .div(127.5)                  // Normalize
      .sub(1);
    
    // Run inference
    const predictions = model.predict(preprocessed);
    
    // Display results
    console.log("Prediction results:");
    predictions.print();
    
    // Get top 5 predictions
    const topK = await tf.topk(predictions, 5);
    const indices = await topK.indices.data();
    const values = await topK.values.data();
    
    for (let i = 0; i < indices.length; i++) {
      console.log(`Class ${indices[i]}: Probability ${values[i].toFixed(4)}`);
    }
    
    // Memory cleanup
    preprocessed.dispose();
    predictions.dispose();
    topK.indices.dispose();
    topK.values.dispose();
    
  } catch (error) {
    console.error("Model loading error:", error);
  }
}

loadAndUseModel();

Real-time Webcam Image Recognition

import * as tf from '@tensorflow/tfjs';

class RealTimeClassifier {
  constructor() {
    this.model = null;
    this.video = null;
    this.canvas = null;
    this.ctx = null;
  }
  
  async initialize() {
    // Load model
    this.model = await tf.loadLayersModel('/models/my_model.json');
    
    // Setup webcam
    this.video = document.getElementById('video');
    this.canvas = document.getElementById('canvas');
    this.ctx = this.canvas.getContext('2d');
    
    try {
      const stream = await navigator.mediaDevices.getUserMedia({ 
        video: { width: 640, height: 480 } 
      });
      this.video.srcObject = stream;
      
      this.video.addEventListener('loadeddata', () => {
        this.canvas.width = this.video.videoWidth;
        this.canvas.height = this.video.videoHeight;
        this.startClassification();
      });
      
    } catch (error) {
      console.error("Camera access error:", error);
    }
  }
  
  async startClassification() {
    const classify = async () => {
      if (this.video.readyState === 4) {
        // Draw camera feed to canvas
        this.ctx.drawImage(this.video, 0, 0);
        
        // Preprocess image
        const imageTensor = tf.browser.fromPixels(this.canvas)
          .resizeBilinear([224, 224])
          .expandDims(0)
          .toFloat()
          .div(255.0);
        
        // Run inference
        const predictions = this.model.predict(imageTensor);
        const probabilities = await predictions.data();
        
        // Display results
        this.displayResults(probabilities);
        
        // Memory cleanup
        imageTensor.dispose();
        predictions.dispose();
      }
      
      // Continue on next frame
      requestAnimationFrame(classify);
    };
    
    classify();
  }
  
  displayResults(probabilities) {
    const classes = ['Class1', 'Class2', 'Class3']; // Class names
    const resultsDiv = document.getElementById('results');
    
    let html = '<h3>Prediction Results:</h3>';
    probabilities.forEach((prob, index) => {
      if (index < classes.length) {
        const percentage = (prob * 100).toFixed(1);
        html += `<p>${classes[index]}: ${percentage}%</p>`;
      }
    });
    
    resultsDiv.innerHTML = html;
  }
}

// Usage example
const classifier = new RealTimeClassifier();
classifier.initialize();

Transfer Learning for Custom Model Creation

import * as tf from '@tensorflow/tfjs';

class TransferLearning {
  constructor() {
    this.baseModel = null;
    this.model = null;
    this.xs = null;
    this.ys = null;
  }
  
  async loadBaseModel() {
    // Load pre-trained MobileNet
    const mobilenet = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2');
    
    // Use as feature extractor without final layer
    this.baseModel = tf.model({
      inputs: mobilenet.input,
      outputs: mobilenet.layers[mobilenet.layers.length - 2].output
    });
    
    // Freeze base model weights
    this.baseModel.trainable = false;
  }
  
  createCustomModel(numClasses) {
    // Add custom classification head
    const input = tf.input({ shape: [224, 224, 3] });
    const features = this.baseModel.apply(input);
    const flatten = tf.layers.flatten().apply(features);
    const dense1 = tf.layers.dense({ units: 128, activation: 'relu' }).apply(flatten);
    const dropout = tf.layers.dropout({ rate: 0.5 }).apply(dense1);
    const output = tf.layers.dense({ units: numClasses, activation: 'softmax' }).apply(dropout);
    
    this.model = tf.model({ inputs: input, outputs: output });
    
    // Compile model
    this.model.compile({
      optimizer: tf.train.adam(0.001),
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy']
    });
  }
  
  async addTrainingData(imageElement, label) {
    // Preprocess image
    const img = tf.browser.fromPixels(imageElement)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    // One-hot encode label
    const labelTensor = tf.oneHot(label, 3).expandDims(0);
    
    // Accumulate data
    if (this.xs == null) {
      this.xs = img;
      this.ys = labelTensor;
    } else {
      this.xs = tf.concat([this.xs, img], 0);
      this.ys = tf.concat([this.ys, labelTensor], 0);
    }
    
    img.dispose();
    labelTensor.dispose();
  }
  
  async train() {
    if (this.xs == null || this.ys == null) {
      throw new Error("Training data not set");
    }
    
    console.log("Starting model training...");
    
    const history = await this.model.fit(this.xs, this.ys, {
      epochs: 20,
      batchSize: 32,
      validationSplit: 0.2,
      shuffle: true,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          console.log(`Epoch ${epoch + 1}: Loss=${logs.loss.toFixed(4)}, Accuracy=${logs.acc.toFixed(4)}`);
        }
      }
    });
    
    return history;
  }
  
  async predict(imageElement) {
    if (this.model == null) {
      throw new Error("Model not trained");
    }
    
    const img = tf.browser.fromPixels(imageElement)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    const prediction = this.model.predict(img);
    const probabilities = await prediction.data();
    
    img.dispose();
    prediction.dispose();
    
    return probabilities;
  }
  
  async saveModel(path) {
    if (this.model) {
      await this.model.save(path);
      console.log(`Model saved to ${path}`);
    }
  }
}

// Usage example
async function runTransferLearning() {
  const tl = new TransferLearning();
  
  // Load base model
  await tl.loadBaseModel();
  
  // Create custom model (3-class classification)
  tl.createCustomModel(3);
  
  // Add training data (example)
  const img1 = document.getElementById('sample1');
  await tl.addTrainingData(img1, 0); // Class 0
  
  const img2 = document.getElementById('sample2');
  await tl.addTrainingData(img2, 1); // Class 1
  
  // Execute training
  await tl.train();
  
  // Save model
  await tl.saveModel('localstorage://my-custom-model');
  
  console.log("Transfer learning completed!");
}

Node.js Model Inference

import * as tf from '@tensorflow/tfjs-node';
import fs from 'fs';
import path from 'path';

class NodeMLPredictor {
  constructor() {
    this.model = null;
  }
  
  async loadModel(modelPath) {
    try {
      // Load model from local file
      this.model = await tf.loadLayersModel(`file://${modelPath}`);
      console.log("Model loading completed");
      
      // Display model information
      console.log("Input shape:", this.model.inputs[0].shape);
      console.log("Output shape:", this.model.outputs[0].shape);
      
    } catch (error) {
      console.error("Model loading error:", error);
      throw error;
    }
  }
  
  async preprocessImage(imagePath) {
    // Read image file
    const imageBuffer = fs.readFileSync(imagePath);
    
    // Decode image with TensorFlow.js
    const imageTensor = tf.node.decodeImage(imageBuffer, 3) // 3 channels (RGB)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    return imageTensor;
  }
  
  async predict(imagePath) {
    if (!this.model) {
      throw new Error("Model not loaded");
    }
    
    try {
      // Preprocess image
      const imageTensor = await this.preprocessImage(imagePath);
      
      // Run inference
      const startTime = Date.now();
      const predictions = this.model.predict(imageTensor);
      const inferenceTime = Date.now() - startTime;
      
      // Get results
      const probabilities = await predictions.data();
      const maxProbIndex = predictions.argMax(1);
      const maxIndex = await maxProbIndex.data();
      
      console.log(`Inference time: ${inferenceTime}ms`);
      console.log(`Predicted class: ${maxIndex[0]}`);
      console.log("Probability distribution:", Array.from(probabilities));
      
      // Memory cleanup
      imageTensor.dispose();
      predictions.dispose();
      maxProbIndex.dispose();
      
      return {
        predictedClass: maxIndex[0],
        probabilities: Array.from(probabilities),
        inferenceTime: inferenceTime
      };
      
    } catch (error) {
      console.error("Prediction error:", error);
      throw error;
    }
  }
  
  async batchPredict(imageDirectory) {
    const imageFiles = fs.readdirSync(imageDirectory)
      .filter(file => /\.(jpg|jpeg|png)$/i.test(file));
    
    const results = [];
    
    for (const imageFile of imageFiles) {
      const imagePath = path.join(imageDirectory, imageFile);
      console.log(`Processing: ${imageFile}`);
      
      try {
        const result = await this.predict(imagePath);
        results.push({
          filename: imageFile,
          ...result
        });
      } catch (error) {
        console.error(`Error processing ${imageFile}:`, error.message);
      }
    }
    
    return results;
  }
}

// Usage example
async function runBatchInference() {
  const predictor = new NodeMLPredictor();
  
  // Load model
  await predictor.loadModel('./models/my_model.json');
  
  // Single image prediction
  const singleResult = await predictor.predict('./images/sample.jpg');
  console.log("Single prediction result:", singleResult);
  
  // Batch prediction
  const batchResults = await predictor.batchPredict('./images/batch/');
  console.log("Batch prediction results:", batchResults);
  
  // Save results to JSON file
  fs.writeFileSync('./results.json', JSON.stringify(batchResults, null, 2));
  console.log("Results saved to results.json");
}

runBatchInference().catch(console.error);

TensorFlow Python Model Conversion and Usage

// 1. Convert model on Python side
/*
Python commands:
pip install tensorflowjs

# Convert Keras H5 model
tensorflowjs_converter --input_format=keras \
                       --output_format=tfjs_layers_model \
                       ./my_model.h5 \
                       ./tfjs_model/

# Convert SavedModel
tensorflowjs_converter --input_format=tf_saved_model \
                       --output_format=tfjs_graph_model \
                       ./saved_model/ \
                       ./tfjs_model/
*/

// 2. Use converted model on JavaScript side
import * as tf from '@tensorflow/tfjs';

class ConvertedModelHandler {
  constructor() {
    this.model = null;
  }
  
  async loadConvertedModel(modelUrl) {
    try {
      // Load converted model
      this.model = await tf.loadLayersModel(modelUrl);
      console.log("Converted model loading completed");
      
      // Display model structure
      console.log("Number of layers:", this.model.layers.length);
      this.model.summary();
      
    } catch (error) {
      console.error("Converted model loading error:", error);
      throw error;
    }
  }
  
  async warmUp() {
    if (!this.model) return;
    
    // Warm up model with dummy input (accelerate first inference)
    const inputShape = this.model.inputs[0].shape.slice(1); // Exclude batch dimension
    const dummyInput = tf.randomNormal([1, ...inputShape]);
    
    console.log("Warming up model...");
    const startTime = Date.now();
    
    const warmupPrediction = this.model.predict(dummyInput);
    await warmupPrediction.data(); // Get data and wait for completion
    
    const warmupTime = Date.now() - startTime;
    console.log(`Warmup completed: ${warmupTime}ms`);
    
    // Memory cleanup
    dummyInput.dispose();
    warmupPrediction.dispose();
  }
  
  async predictWithPreprocessing(inputData) {
    if (!this.model) {
      throw new Error("Model not loaded");
    }
    
    // Input data preprocessing (example: normalization)
    const preprocessed = tf.tidy(() => {
      let processed = inputData;
      
      // Type conversion
      if (processed.dtype !== 'float32') {
        processed = processed.toFloat();
      }
      
      // Normalization (0-255 → 0-1)
      if (processed.max().dataSync()[0] > 1) {
        processed = processed.div(255.0);
      }
      
      // Add batch dimension
      if (processed.shape.length === 3) {
        processed = processed.expandDims(0);
      }
      
      return processed;
    });
    
    try {
      // Run inference
      const predictions = this.model.predict(preprocessed);
      
      // Post-processing (example: softmax -> probabilities)
      const postprocessed = tf.tidy(() => {
        if (predictions.shape[1] > 1) {
          // Multi-class classification
          return tf.softmax(predictions);
        } else {
          // Binary classification
          return tf.sigmoid(predictions);
        }
      });
      
      // Get results
      const result = await postprocessed.data();
      
      // Memory cleanup
      preprocessed.dispose();
      predictions.dispose();
      postprocessed.dispose();
      
      return Array.from(result);
      
    } catch (error) {
      preprocessed.dispose();
      throw error;
    }
  }
  
  async benchmarkPerformance(numRuns = 100) {
    if (!this.model) {
      throw new Error("Model not loaded");
    }
    
    const inputShape = this.model.inputs[0].shape.slice(1);
    const testInput = tf.randomNormal([1, ...inputShape]);
    
    // Warmup
    await this.warmUp();
    
    const times = [];
    
    console.log(`Starting performance test (${numRuns} runs)...`);
    
    for (let i = 0; i < numRuns; i++) {
      const startTime = performance.now();
      
      const prediction = this.model.predict(testInput);
      await prediction.data();
      
      const endTime = performance.now();
      times.push(endTime - startTime);
      
      prediction.dispose();
      
      if ((i + 1) % 20 === 0) {
        console.log(`Progress: ${i + 1}/${numRuns} completed`);
      }
    }
    
    testInput.dispose();
    
    // Calculate statistics
    const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
    const minTime = Math.min(...times);
    const maxTime = Math.max(...times);
    const medianTime = times.sort()[Math.floor(times.length / 2)];
    
    const stats = {
      average: avgTime.toFixed(2),
      minimum: minTime.toFixed(2),
      maximum: maxTime.toFixed(2),
      median: medianTime.toFixed(2),
      totalRuns: numRuns
    };
    
    console.log("Performance test results:", stats);
    return stats;
  }
}

// Usage example
async function runConvertedModelDemo() {
  const handler = new ConvertedModelHandler();
  
  // Load converted model
  await handler.loadConvertedModel('./models/converted_model/model.json');
  
  // Warmup
  await handler.warmUp();
  
  // Predict with sample image
  const sampleImage = tf.randomNormal([224, 224, 3]);
  const predictions = await handler.predictWithPreprocessing(sampleImage);
  
  console.log("Prediction results:", predictions);
  
  // Performance test
  await handler.benchmarkPerformance(50);
  
  // Memory cleanup
  sampleImage.dispose();
}

runConvertedModelDemo().catch(console.error);