TensorFlow.js
JavaScript machine learning library running in browser and Node.js environments. Enables browser inference execution, WebGL/WebGPU acceleration, and privacy-preserving AI implementation. Standard solution for client-side ML.
GitHub Overview
tensorflow/tfjs
A WebGL accelerated JavaScript library for training and deploying ML models.
Topics
Star History
Framework
TensorFlow.js
Overview
TensorFlow.js is a JavaScript library that enables machine learning model training and inference in browsers and Node.js environments.
Details
TensorFlow.js is an open-source library developed by Google as part of the TensorFlow ecosystem, designed to run machine learning models in JavaScript environments. Released in 2018, it enables real-time machine learning in browsers, server-side inference in Node.js, and ML capabilities in mobile applications. Key features include GPU acceleration through WebGL and WebGPU, a rich library of pre-trained models, conversion capabilities from Python TensorFlow models, and in-browser model training functionality. It's widely used for image recognition, audio processing, natural language processing, real-time object detection, and pose estimation. The advantages of privacy protection (data doesn't leave the device), low latency (no network communication required), and cross-platform support (browsers, Node.js, React Native) have made it the standard for implementing machine learning in web applications.
Pros and Cons
Pros
- Client-side Execution: Machine learning runs directly in browsers without servers
- Privacy Protection: User data never leaves the device for external transmission
- Low Latency: Real-time processing without network communication requirements
- GPU Acceleration: High-speed processing through WebGL/WebGPU
- Rich Pre-trained Models: High-quality model library ready for immediate use
- Python TensorFlow Compatibility: Easy conversion and use of existing TensorFlow models
- Transfer Learning Support: Model fine-tuning capabilities directly in browsers
Cons
- Model Size Limitations: Large models present challenges in download time and memory usage
- Browser Dependencies: Performance depends on browser GPU capabilities and WebGL/WebGPU support
- Debugging Complexity: Model debugging in browser environments can be challenging
- Performance Constraints: Processing speed inferior compared to server-side GPUs
- Model Conversion Limitations: Some TensorFlow operations are not supported during conversion
- Memory Management: Requires both JavaScript garbage collection and manual tensor management
Key Links
- TensorFlow.js Official Site
- TensorFlow.js Official Documentation
- TensorFlow.js GitHub Repository
- TensorFlow.js Tutorials
- TensorFlow.js Model Hub
- TensorFlow.js Demos
Code Examples
Hello World
// Basic TensorFlow.js usage
import * as tf from '@tensorflow/tfjs';
console.log("TensorFlow.js version:", tf.version.tfjs);
// Basic tensor operations
const a = tf.tensor([[1, 2], [3, 4]]);
const b = tf.tensor([[5, 6], [7, 8]]);
console.log("Tensor a:");
a.print();
// Mathematical operations
const c = a.add(b);
console.log("a + b:");
c.print();
// Matrix multiplication
const d = tf.matMul(a, b);
console.log("Matrix multiplication a × b:");
d.print();
// Resource cleanup to prevent memory leaks
a.dispose();
b.dispose();
c.dispose();
d.dispose();
Browser Model Loading and Inference
import * as tf from '@tensorflow/tfjs';
// Loading and using pre-trained models
async function loadAndUseModel() {
try {
// Load MobileNet model
const model = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2');
// Get image element
const img = document.getElementById('image');
// Preprocess image
const preprocessed = tf.browser.fromPixels(img)
.resizeBilinear([224, 224]) // Resize to 224x224
.expandDims(0) // Add batch dimension
.toFloat()
.div(127.5) // Normalize
.sub(1);
// Run inference
const predictions = model.predict(preprocessed);
// Display results
console.log("Prediction results:");
predictions.print();
// Get top 5 predictions
const topK = await tf.topk(predictions, 5);
const indices = await topK.indices.data();
const values = await topK.values.data();
for (let i = 0; i < indices.length; i++) {
console.log(`Class ${indices[i]}: Probability ${values[i].toFixed(4)}`);
}
// Memory cleanup
preprocessed.dispose();
predictions.dispose();
topK.indices.dispose();
topK.values.dispose();
} catch (error) {
console.error("Model loading error:", error);
}
}
loadAndUseModel();
Real-time Webcam Image Recognition
import * as tf from '@tensorflow/tfjs';
class RealTimeClassifier {
constructor() {
this.model = null;
this.video = null;
this.canvas = null;
this.ctx = null;
}
async initialize() {
// Load model
this.model = await tf.loadLayersModel('/models/my_model.json');
// Setup webcam
this.video = document.getElementById('video');
this.canvas = document.getElementById('canvas');
this.ctx = this.canvas.getContext('2d');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 }
});
this.video.srcObject = stream;
this.video.addEventListener('loadeddata', () => {
this.canvas.width = this.video.videoWidth;
this.canvas.height = this.video.videoHeight;
this.startClassification();
});
} catch (error) {
console.error("Camera access error:", error);
}
}
async startClassification() {
const classify = async () => {
if (this.video.readyState === 4) {
// Draw camera feed to canvas
this.ctx.drawImage(this.video, 0, 0);
// Preprocess image
const imageTensor = tf.browser.fromPixels(this.canvas)
.resizeBilinear([224, 224])
.expandDims(0)
.toFloat()
.div(255.0);
// Run inference
const predictions = this.model.predict(imageTensor);
const probabilities = await predictions.data();
// Display results
this.displayResults(probabilities);
// Memory cleanup
imageTensor.dispose();
predictions.dispose();
}
// Continue on next frame
requestAnimationFrame(classify);
};
classify();
}
displayResults(probabilities) {
const classes = ['Class1', 'Class2', 'Class3']; // Class names
const resultsDiv = document.getElementById('results');
let html = '<h3>Prediction Results:</h3>';
probabilities.forEach((prob, index) => {
if (index < classes.length) {
const percentage = (prob * 100).toFixed(1);
html += `<p>${classes[index]}: ${percentage}%</p>`;
}
});
resultsDiv.innerHTML = html;
}
}
// Usage example
const classifier = new RealTimeClassifier();
classifier.initialize();
Transfer Learning for Custom Model Creation
import * as tf from '@tensorflow/tfjs';
class TransferLearning {
constructor() {
this.baseModel = null;
this.model = null;
this.xs = null;
this.ys = null;
}
async loadBaseModel() {
// Load pre-trained MobileNet
const mobilenet = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2');
// Use as feature extractor without final layer
this.baseModel = tf.model({
inputs: mobilenet.input,
outputs: mobilenet.layers[mobilenet.layers.length - 2].output
});
// Freeze base model weights
this.baseModel.trainable = false;
}
createCustomModel(numClasses) {
// Add custom classification head
const input = tf.input({ shape: [224, 224, 3] });
const features = this.baseModel.apply(input);
const flatten = tf.layers.flatten().apply(features);
const dense1 = tf.layers.dense({ units: 128, activation: 'relu' }).apply(flatten);
const dropout = tf.layers.dropout({ rate: 0.5 }).apply(dense1);
const output = tf.layers.dense({ units: numClasses, activation: 'softmax' }).apply(dropout);
this.model = tf.model({ inputs: input, outputs: output });
// Compile model
this.model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
}
async addTrainingData(imageElement, label) {
// Preprocess image
const img = tf.browser.fromPixels(imageElement)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
// One-hot encode label
const labelTensor = tf.oneHot(label, 3).expandDims(0);
// Accumulate data
if (this.xs == null) {
this.xs = img;
this.ys = labelTensor;
} else {
this.xs = tf.concat([this.xs, img], 0);
this.ys = tf.concat([this.ys, labelTensor], 0);
}
img.dispose();
labelTensor.dispose();
}
async train() {
if (this.xs == null || this.ys == null) {
throw new Error("Training data not set");
}
console.log("Starting model training...");
const history = await this.model.fit(this.xs, this.ys, {
epochs: 20,
batchSize: 32,
validationSplit: 0.2,
shuffle: true,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch + 1}: Loss=${logs.loss.toFixed(4)}, Accuracy=${logs.acc.toFixed(4)}`);
}
}
});
return history;
}
async predict(imageElement) {
if (this.model == null) {
throw new Error("Model not trained");
}
const img = tf.browser.fromPixels(imageElement)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
const prediction = this.model.predict(img);
const probabilities = await prediction.data();
img.dispose();
prediction.dispose();
return probabilities;
}
async saveModel(path) {
if (this.model) {
await this.model.save(path);
console.log(`Model saved to ${path}`);
}
}
}
// Usage example
async function runTransferLearning() {
const tl = new TransferLearning();
// Load base model
await tl.loadBaseModel();
// Create custom model (3-class classification)
tl.createCustomModel(3);
// Add training data (example)
const img1 = document.getElementById('sample1');
await tl.addTrainingData(img1, 0); // Class 0
const img2 = document.getElementById('sample2');
await tl.addTrainingData(img2, 1); // Class 1
// Execute training
await tl.train();
// Save model
await tl.saveModel('localstorage://my-custom-model');
console.log("Transfer learning completed!");
}
Node.js Model Inference
import * as tf from '@tensorflow/tfjs-node';
import fs from 'fs';
import path from 'path';
class NodeMLPredictor {
constructor() {
this.model = null;
}
async loadModel(modelPath) {
try {
// Load model from local file
this.model = await tf.loadLayersModel(`file://${modelPath}`);
console.log("Model loading completed");
// Display model information
console.log("Input shape:", this.model.inputs[0].shape);
console.log("Output shape:", this.model.outputs[0].shape);
} catch (error) {
console.error("Model loading error:", error);
throw error;
}
}
async preprocessImage(imagePath) {
// Read image file
const imageBuffer = fs.readFileSync(imagePath);
// Decode image with TensorFlow.js
const imageTensor = tf.node.decodeImage(imageBuffer, 3) // 3 channels (RGB)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
return imageTensor;
}
async predict(imagePath) {
if (!this.model) {
throw new Error("Model not loaded");
}
try {
// Preprocess image
const imageTensor = await this.preprocessImage(imagePath);
// Run inference
const startTime = Date.now();
const predictions = this.model.predict(imageTensor);
const inferenceTime = Date.now() - startTime;
// Get results
const probabilities = await predictions.data();
const maxProbIndex = predictions.argMax(1);
const maxIndex = await maxProbIndex.data();
console.log(`Inference time: ${inferenceTime}ms`);
console.log(`Predicted class: ${maxIndex[0]}`);
console.log("Probability distribution:", Array.from(probabilities));
// Memory cleanup
imageTensor.dispose();
predictions.dispose();
maxProbIndex.dispose();
return {
predictedClass: maxIndex[0],
probabilities: Array.from(probabilities),
inferenceTime: inferenceTime
};
} catch (error) {
console.error("Prediction error:", error);
throw error;
}
}
async batchPredict(imageDirectory) {
const imageFiles = fs.readdirSync(imageDirectory)
.filter(file => /\.(jpg|jpeg|png)$/i.test(file));
const results = [];
for (const imageFile of imageFiles) {
const imagePath = path.join(imageDirectory, imageFile);
console.log(`Processing: ${imageFile}`);
try {
const result = await this.predict(imagePath);
results.push({
filename: imageFile,
...result
});
} catch (error) {
console.error(`Error processing ${imageFile}:`, error.message);
}
}
return results;
}
}
// Usage example
async function runBatchInference() {
const predictor = new NodeMLPredictor();
// Load model
await predictor.loadModel('./models/my_model.json');
// Single image prediction
const singleResult = await predictor.predict('./images/sample.jpg');
console.log("Single prediction result:", singleResult);
// Batch prediction
const batchResults = await predictor.batchPredict('./images/batch/');
console.log("Batch prediction results:", batchResults);
// Save results to JSON file
fs.writeFileSync('./results.json', JSON.stringify(batchResults, null, 2));
console.log("Results saved to results.json");
}
runBatchInference().catch(console.error);
TensorFlow Python Model Conversion and Usage
// 1. Convert model on Python side
/*
Python commands:
pip install tensorflowjs
# Convert Keras H5 model
tensorflowjs_converter --input_format=keras \
--output_format=tfjs_layers_model \
./my_model.h5 \
./tfjs_model/
# Convert SavedModel
tensorflowjs_converter --input_format=tf_saved_model \
--output_format=tfjs_graph_model \
./saved_model/ \
./tfjs_model/
*/
// 2. Use converted model on JavaScript side
import * as tf from '@tensorflow/tfjs';
class ConvertedModelHandler {
constructor() {
this.model = null;
}
async loadConvertedModel(modelUrl) {
try {
// Load converted model
this.model = await tf.loadLayersModel(modelUrl);
console.log("Converted model loading completed");
// Display model structure
console.log("Number of layers:", this.model.layers.length);
this.model.summary();
} catch (error) {
console.error("Converted model loading error:", error);
throw error;
}
}
async warmUp() {
if (!this.model) return;
// Warm up model with dummy input (accelerate first inference)
const inputShape = this.model.inputs[0].shape.slice(1); // Exclude batch dimension
const dummyInput = tf.randomNormal([1, ...inputShape]);
console.log("Warming up model...");
const startTime = Date.now();
const warmupPrediction = this.model.predict(dummyInput);
await warmupPrediction.data(); // Get data and wait for completion
const warmupTime = Date.now() - startTime;
console.log(`Warmup completed: ${warmupTime}ms`);
// Memory cleanup
dummyInput.dispose();
warmupPrediction.dispose();
}
async predictWithPreprocessing(inputData) {
if (!this.model) {
throw new Error("Model not loaded");
}
// Input data preprocessing (example: normalization)
const preprocessed = tf.tidy(() => {
let processed = inputData;
// Type conversion
if (processed.dtype !== 'float32') {
processed = processed.toFloat();
}
// Normalization (0-255 → 0-1)
if (processed.max().dataSync()[0] > 1) {
processed = processed.div(255.0);
}
// Add batch dimension
if (processed.shape.length === 3) {
processed = processed.expandDims(0);
}
return processed;
});
try {
// Run inference
const predictions = this.model.predict(preprocessed);
// Post-processing (example: softmax -> probabilities)
const postprocessed = tf.tidy(() => {
if (predictions.shape[1] > 1) {
// Multi-class classification
return tf.softmax(predictions);
} else {
// Binary classification
return tf.sigmoid(predictions);
}
});
// Get results
const result = await postprocessed.data();
// Memory cleanup
preprocessed.dispose();
predictions.dispose();
postprocessed.dispose();
return Array.from(result);
} catch (error) {
preprocessed.dispose();
throw error;
}
}
async benchmarkPerformance(numRuns = 100) {
if (!this.model) {
throw new Error("Model not loaded");
}
const inputShape = this.model.inputs[0].shape.slice(1);
const testInput = tf.randomNormal([1, ...inputShape]);
// Warmup
await this.warmUp();
const times = [];
console.log(`Starting performance test (${numRuns} runs)...`);
for (let i = 0; i < numRuns; i++) {
const startTime = performance.now();
const prediction = this.model.predict(testInput);
await prediction.data();
const endTime = performance.now();
times.push(endTime - startTime);
prediction.dispose();
if ((i + 1) % 20 === 0) {
console.log(`Progress: ${i + 1}/${numRuns} completed`);
}
}
testInput.dispose();
// Calculate statistics
const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
const minTime = Math.min(...times);
const maxTime = Math.max(...times);
const medianTime = times.sort()[Math.floor(times.length / 2)];
const stats = {
average: avgTime.toFixed(2),
minimum: minTime.toFixed(2),
maximum: maxTime.toFixed(2),
median: medianTime.toFixed(2),
totalRuns: numRuns
};
console.log("Performance test results:", stats);
return stats;
}
}
// Usage example
async function runConvertedModelDemo() {
const handler = new ConvertedModelHandler();
// Load converted model
await handler.loadConvertedModel('./models/converted_model/model.json');
// Warmup
await handler.warmUp();
// Predict with sample image
const sampleImage = tf.randomNormal([224, 224, 3]);
const predictions = await handler.predictWithPreprocessing(sampleImage);
console.log("Prediction results:", predictions);
// Performance test
await handler.benchmarkPerformance(50);
// Memory cleanup
sampleImage.dispose();
}
runConvertedModelDemo().catch(console.error);