TensorFlow.js

ブラウザとNode.js環境で動作するJavaScript機械学習ライブラリ。ブラウザ内での推論実行、WebGL/WebGPU高速化、プライバシー保護型AI実装が可能。クライアントサイドMLの標準ソリューション。

JavaScriptTypeScript機械学習ディープラーニングブラウザNode.jsWebGLWebGPU

GitHub概要

tensorflow/tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.

スター18,929
ウォッチ325
フォーク1,985
作成日:2018年3月5日
言語:TypeScript
ライセンス:Apache License 2.0

トピックス

deep-learningdeep-neural-networkgpu-accelerationjavascriptmachine-learningneural-networktypescriptwasmweb-assemblywebgl

スター履歴

tensorflow/tfjs Star History
データ取得日時: 2025/8/13 01:43

フレームワーク

TensorFlow.js

概要

TensorFlow.jsは、ブラウザやNode.js環境でマシンラーニングモデルの学習と推論を可能にするJavaScriptライブラリです。

詳細

TensorFlow.js(テンソルフロージェーエス)は、Googleが開発したTensorFlowエコシステムの一部として、JavaScript環境でマシンラーニングモデルを実行するためのオープンソースライブラリです。2018年に公開され、ブラウザ内でのリアルタイム機械学習、Node.jsでのサーバーサイド推論、モバイルアプリでの機械学習機能などを可能にします。WebGLやWebGPUを活用したGPUアクセラレーション、事前学習済みモデルの豊富なライブラリ、PythonのTensorFlowモデルからの変換機能、ブラウザ内でのモデル学習機能などが特徴です。画像認識、音声処理、自然言語処理、リアルタイム物体検出、姿勢推定など幅広い用途で活用されています。プライバシー保護(データがサーバーに送信されない)、低レイテンシ(ネットワーク通信不要)、クロスプラットフォーム対応(ブラウザ、Node.js、React Native)などの利点により、Webアプリケーションでの機械学習実装のスタンダードとなっています。

メリット・デメリット

メリット

  • クライアントサイド実行: サーバーなしでブラウザ内で機械学習が実行可能
  • プライバシー保護: ユーザーデータがデバイスから外部に送信されない
  • 低レイテンシ: ネットワーク通信が不要でリアルタイム処理が可能
  • GPU アクセラレーション: WebGL/WebGPUによる高速処理
  • 豊富な事前学習済みモデル: すぐに使える高品質なモデルライブラリ
  • Python TensorFlow互換: 既存のTensorFlowモデルを簡単に変換・利用可能
  • 転移学習対応: ブラウザ内でのモデルのファインチューニングが可能

デメリット

  • モデルサイズ制限: 大きなモデルはダウンロード時間とメモリ使用量が課題
  • ブラウザ依存: GPU性能やWebGL/WebGPUサポートがブラウザに依存
  • デバッグの困難さ: ブラウザ環境でのモデルデバッグが複雑
  • パフォーマンス制限: サーバーサイドのGPUと比較して処理速度が劣る
  • モデル変換の制約: 一部のTensorFlow操作は変換時にサポートされない
  • メモリ管理: JavaScriptのガベージコレクションとTensorの手動管理が必要

主要リンク

書き方の例

Hello World

// TensorFlow.jsの基本的な使用方法
import * as tf from '@tensorflow/tfjs';

console.log("TensorFlow.js version:", tf.version.tfjs);

// 基本的なテンソル操作
const a = tf.tensor([[1, 2], [3, 4]]);
const b = tf.tensor([[5, 6], [7, 8]]);

console.log("テンソル a:");
a.print();

// 数学的操作
const c = a.add(b);
console.log("a + b:");
c.print();

// 行列積
const d = tf.matMul(a, b);
console.log("行列積 a × b:");
d.print();

// メモリリークを防ぐためのリソース解放
a.dispose();
b.dispose();
c.dispose();
d.dispose();

ブラウザでのモデル読み込みと推論

import * as tf from '@tensorflow/tfjs';

// 事前学習済みモデルの読み込み
async function loadAndUseModel() {
  try {
    // MobileNetモデルを読み込み
    const model = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2');
    
    // 画像要素を取得
    const img = document.getElementById('image');
    
    // 画像を前処理
    const preprocessed = tf.browser.fromPixels(img)
      .resizeBilinear([224, 224])  // 224x224にリサイズ
      .expandDims(0)               // バッチ次元を追加
      .toFloat()
      .div(127.5)                  // 正規化
      .sub(1);
    
    // 推論実行
    const predictions = model.predict(preprocessed);
    
    // 結果を表示
    console.log("予測結果:");
    predictions.print();
    
    // トップ5の予測を取得
    const topK = await tf.topk(predictions, 5);
    const indices = await topK.indices.data();
    const values = await topK.values.data();
    
    for (let i = 0; i < indices.length; i++) {
      console.log(`クラス ${indices[i]}: 確率 ${values[i].toFixed(4)}`);
    }
    
    // メモリ解放
    preprocessed.dispose();
    predictions.dispose();
    topK.indices.dispose();
    topK.values.dispose();
    
  } catch (error) {
    console.error("モデル読み込みエラー:", error);
  }
}

loadAndUseModel();

リアルタイム Webカメラ画像認識

import * as tf from '@tensorflow/tfjs';

class RealTimeClassifier {
  constructor() {
    this.model = null;
    this.video = null;
    this.canvas = null;
    this.ctx = null;
  }
  
  async initialize() {
    // モデルを読み込み
    this.model = await tf.loadLayersModel('/models/my_model.json');
    
    // Webカメラの設定
    this.video = document.getElementById('video');
    this.canvas = document.getElementById('canvas');
    this.ctx = this.canvas.getContext('2d');
    
    try {
      const stream = await navigator.mediaDevices.getUserMedia({ 
        video: { width: 640, height: 480 } 
      });
      this.video.srcObject = stream;
      
      this.video.addEventListener('loadeddata', () => {
        this.canvas.width = this.video.videoWidth;
        this.canvas.height = this.video.videoHeight;
        this.startClassification();
      });
      
    } catch (error) {
      console.error("カメラアクセスエラー:", error);
    }
  }
  
  async startClassification() {
    const classify = async () => {
      if (this.video.readyState === 4) {
        // カメラ映像をキャンバスに描画
        this.ctx.drawImage(this.video, 0, 0);
        
        // 画像を前処理
        const imageTensor = tf.browser.fromPixels(this.canvas)
          .resizeBilinear([224, 224])
          .expandDims(0)
          .toFloat()
          .div(255.0);
        
        // 推論実行
        const predictions = this.model.predict(imageTensor);
        const probabilities = await predictions.data();
        
        // 結果を表示
        this.displayResults(probabilities);
        
        // メモリ解放
        imageTensor.dispose();
        predictions.dispose();
      }
      
      // 次のフレームで再実行
      requestAnimationFrame(classify);
    };
    
    classify();
  }
  
  displayResults(probabilities) {
    const classes = ['クラス1', 'クラス2', 'クラス3']; // クラス名
    const resultsDiv = document.getElementById('results');
    
    let html = '<h3>予測結果:</h3>';
    probabilities.forEach((prob, index) => {
      if (index < classes.length) {
        const percentage = (prob * 100).toFixed(1);
        html += `<p>${classes[index]}: ${percentage}%</p>`;
      }
    });
    
    resultsDiv.innerHTML = html;
  }
}

// 使用例
const classifier = new RealTimeClassifier();
classifier.initialize();

転移学習によるカスタムモデル作成

import * as tf from '@tensorflow/tfjs';

class TransferLearning {
  constructor() {
    this.baseModel = null;
    this.model = null;
    this.xs = null;
    this.ys = null;
  }
  
  async loadBaseModel() {
    // 事前学習済みMobileNetを読み込み
    const mobilenet = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2');
    
    // 最後の層を除いた特徴抽出器として使用
    this.baseModel = tf.model({
      inputs: mobilenet.input,
      outputs: mobilenet.layers[mobilenet.layers.length - 2].output
    });
    
    // 基底モデルの重みを凍結
    this.baseModel.trainable = false;
  }
  
  createCustomModel(numClasses) {
    // カスタム分類ヘッドを追加
    const input = tf.input({ shape: [224, 224, 3] });
    const features = this.baseModel.apply(input);
    const flatten = tf.layers.flatten().apply(features);
    const dense1 = tf.layers.dense({ units: 128, activation: 'relu' }).apply(flatten);
    const dropout = tf.layers.dropout({ rate: 0.5 }).apply(dense1);
    const output = tf.layers.dense({ units: numClasses, activation: 'softmax' }).apply(dropout);
    
    this.model = tf.model({ inputs: input, outputs: output });
    
    // モデルをコンパイル
    this.model.compile({
      optimizer: tf.train.adam(0.001),
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy']
    });
  }
  
  async addTrainingData(imageElement, label) {
    // 画像を前処理
    const img = tf.browser.fromPixels(imageElement)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    // ラベルをワンホットエンコーディング
    const labelTensor = tf.oneHot(label, 3).expandDims(0);
    
    // データを蓄積
    if (this.xs == null) {
      this.xs = img;
      this.ys = labelTensor;
    } else {
      this.xs = tf.concat([this.xs, img], 0);
      this.ys = tf.concat([this.ys, labelTensor], 0);
    }
    
    img.dispose();
    labelTensor.dispose();
  }
  
  async train() {
    if (this.xs == null || this.ys == null) {
      throw new Error("学習データが設定されていません");
    }
    
    console.log("モデル学習を開始します...");
    
    const history = await this.model.fit(this.xs, this.ys, {
      epochs: 20,
      batchSize: 32,
      validationSplit: 0.2,
      shuffle: true,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          console.log(`エポック ${epoch + 1}: 損失=${logs.loss.toFixed(4)}, 精度=${logs.acc.toFixed(4)}`);
        }
      }
    });
    
    return history;
  }
  
  async predict(imageElement) {
    if (this.model == null) {
      throw new Error("モデルが学習されていません");
    }
    
    const img = tf.browser.fromPixels(imageElement)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    const prediction = this.model.predict(img);
    const probabilities = await prediction.data();
    
    img.dispose();
    prediction.dispose();
    
    return probabilities;
  }
  
  async saveModel(path) {
    if (this.model) {
      await this.model.save(path);
      console.log(`モデルを ${path} に保存しました`);
    }
  }
}

// 使用例
async function runTransferLearning() {
  const tl = new TransferLearning();
  
  // ベースモデルを読み込み
  await tl.loadBaseModel();
  
  // カスタムモデルを作成(3クラス分類)
  tl.createCustomModel(3);
  
  // 学習データを追加(例)
  const img1 = document.getElementById('sample1');
  await tl.addTrainingData(img1, 0); // クラス0
  
  const img2 = document.getElementById('sample2');
  await tl.addTrainingData(img2, 1); // クラス1
  
  // 学習実行
  await tl.train();
  
  // モデル保存
  await tl.saveModel('localstorage://my-custom-model');
  
  console.log("転移学習完了!");
}

Node.js でのモデル推論

import * as tf from '@tensorflow/tfjs-node';
import fs from 'fs';
import path from 'path';

class NodeMLPredictor {
  constructor() {
    this.model = null;
  }
  
  async loadModel(modelPath) {
    try {
      // ローカルファイルからモデルを読み込み
      this.model = await tf.loadLayersModel(`file://${modelPath}`);
      console.log("モデルの読み込み完了");
      
      // モデル情報を表示
      console.log("入力形状:", this.model.inputs[0].shape);
      console.log("出力形状:", this.model.outputs[0].shape);
      
    } catch (error) {
      console.error("モデル読み込みエラー:", error);
      throw error;
    }
  }
  
  async preprocessImage(imagePath) {
    // 画像ファイルを読み込み
    const imageBuffer = fs.readFileSync(imagePath);
    
    // TensorFlow.jsで画像をデコード
    const imageTensor = tf.node.decodeImage(imageBuffer, 3) // 3チャンネル(RGB)
      .resizeBilinear([224, 224])
      .toFloat()
      .div(255.0)
      .expandDims(0);
    
    return imageTensor;
  }
  
  async predict(imagePath) {
    if (!this.model) {
      throw new Error("モデルが読み込まれていません");
    }
    
    try {
      // 画像を前処理
      const imageTensor = await this.preprocessImage(imagePath);
      
      // 推論実行
      const startTime = Date.now();
      const predictions = this.model.predict(imageTensor);
      const inferenceTime = Date.now() - startTime;
      
      // 結果を取得
      const probabilities = await predictions.data();
      const maxProbIndex = predictions.argMax(1);
      const maxIndex = await maxProbIndex.data();
      
      console.log(`推論時間: ${inferenceTime}ms`);
      console.log(`予測クラス: ${maxIndex[0]}`);
      console.log("確率分布:", Array.from(probabilities));
      
      // メモリ解放
      imageTensor.dispose();
      predictions.dispose();
      maxProbIndex.dispose();
      
      return {
        predictedClass: maxIndex[0],
        probabilities: Array.from(probabilities),
        inferenceTime: inferenceTime
      };
      
    } catch (error) {
      console.error("予測エラー:", error);
      throw error;
    }
  }
  
  async batchPredict(imageDirectory) {
    const imageFiles = fs.readdirSync(imageDirectory)
      .filter(file => /\.(jpg|jpeg|png)$/i.test(file));
    
    const results = [];
    
    for (const imageFile of imageFiles) {
      const imagePath = path.join(imageDirectory, imageFile);
      console.log(`処理中: ${imageFile}`);
      
      try {
        const result = await this.predict(imagePath);
        results.push({
          filename: imageFile,
          ...result
        });
      } catch (error) {
        console.error(`${imageFile} の処理でエラー:`, error.message);
      }
    }
    
    return results;
  }
}

// 使用例
async function runBatchInference() {
  const predictor = new NodeMLPredictor();
  
  // モデルを読み込み
  await predictor.loadModel('./models/my_model.json');
  
  // 単一画像の予測
  const singleResult = await predictor.predict('./images/sample.jpg');
  console.log("単一予測結果:", singleResult);
  
  // バッチ予測
  const batchResults = await predictor.batchPredict('./images/batch/');
  console.log("バッチ予測結果:", batchResults);
  
  // 結果をJSONファイルに保存
  fs.writeFileSync('./results.json', JSON.stringify(batchResults, null, 2));
  console.log("結果を results.json に保存しました");
}

runBatchInference().catch(console.error);

TensorFlow Python モデルの変換と利用

// 1. Python側でモデルを変換
/*
Python コマンド:
pip install tensorflowjs

# Keras H5 モデルの変換
tensorflowjs_converter --input_format=keras \
                       --output_format=tfjs_layers_model \
                       ./my_model.h5 \
                       ./tfjs_model/

# SavedModel の変換
tensorflowjs_converter --input_format=tf_saved_model \
                       --output_format=tfjs_graph_model \
                       ./saved_model/ \
                       ./tfjs_model/
*/

// 2. JavaScript側で変換されたモデルを利用
import * as tf from '@tensorflow/tfjs';

class ConvertedModelHandler {
  constructor() {
    this.model = null;
  }
  
  async loadConvertedModel(modelUrl) {
    try {
      // 変換されたモデルを読み込み
      this.model = await tf.loadLayersModel(modelUrl);
      console.log("変換モデルの読み込み完了");
      
      // モデル構造を表示
      console.log("レイヤー数:", this.model.layers.length);
      this.model.summary();
      
    } catch (error) {
      console.error("変換モデル読み込みエラー:", error);
      throw error;
    }
  }
  
  async warmUp() {
    if (!this.model) return;
    
    // ダミー入力でモデルをウォームアップ(初回推論の高速化)
    const inputShape = this.model.inputs[0].shape.slice(1); // バッチ次元を除く
    const dummyInput = tf.randomNormal([1, ...inputShape]);
    
    console.log("モデルウォームアップ中...");
    const startTime = Date.now();
    
    const warmupPrediction = this.model.predict(dummyInput);
    await warmupPrediction.data(); // データを取得して処理完了を待つ
    
    const warmupTime = Date.now() - startTime;
    console.log(`ウォームアップ完了: ${warmupTime}ms`);
    
    // メモリ解放
    dummyInput.dispose();
    warmupPrediction.dispose();
  }
  
  async predictWithPreprocessing(inputData) {
    if (!this.model) {
      throw new Error("モデルが読み込まれていません");
    }
    
    // 入力データの前処理(例:正規化)
    const preprocessed = tf.tidy(() => {
      let processed = inputData;
      
      // 型変換
      if (processed.dtype !== 'float32') {
        processed = processed.toFloat();
      }
      
      // 正規化(0-255 → 0-1)
      if (processed.max().dataSync()[0] > 1) {
        processed = processed.div(255.0);
      }
      
      // バッチ次元の追加
      if (processed.shape.length === 3) {
        processed = processed.expandDims(0);
      }
      
      return processed;
    });
    
    try {
      // 推論実行
      const predictions = this.model.predict(preprocessed);
      
      // 後処理(例:softmax -> 確率)
      const postprocessed = tf.tidy(() => {
        if (predictions.shape[1] > 1) {
          // 多クラス分類の場合
          return tf.softmax(predictions);
        } else {
          // 二値分類の場合
          return tf.sigmoid(predictions);
        }
      });
      
      // 結果を取得
      const result = await postprocessed.data();
      
      // メモリ解放
      preprocessed.dispose();
      predictions.dispose();
      postprocessed.dispose();
      
      return Array.from(result);
      
    } catch (error) {
      preprocessed.dispose();
      throw error;
    }
  }
  
  async benchmarkPerformance(numRuns = 100) {
    if (!this.model) {
      throw new Error("モデルが読み込まれていません");
    }
    
    const inputShape = this.model.inputs[0].shape.slice(1);
    const testInput = tf.randomNormal([1, ...inputShape]);
    
    // ウォームアップ
    await this.warmUp();
    
    const times = [];
    
    console.log(`パフォーマンステスト開始(${numRuns}回実行)...`);
    
    for (let i = 0; i < numRuns; i++) {
      const startTime = performance.now();
      
      const prediction = this.model.predict(testInput);
      await prediction.data();
      
      const endTime = performance.now();
      times.push(endTime - startTime);
      
      prediction.dispose();
      
      if ((i + 1) % 20 === 0) {
        console.log(`進捗: ${i + 1}/${numRuns} 完了`);
      }
    }
    
    testInput.dispose();
    
    // 統計計算
    const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
    const minTime = Math.min(...times);
    const maxTime = Math.max(...times);
    const medianTime = times.sort()[Math.floor(times.length / 2)];
    
    const stats = {
      average: avgTime.toFixed(2),
      minimum: minTime.toFixed(2),
      maximum: maxTime.toFixed(2),
      median: medianTime.toFixed(2),
      totalRuns: numRuns
    };
    
    console.log("パフォーマンステスト結果:", stats);
    return stats;
  }
}

// 使用例
async function runConvertedModelDemo() {
  const handler = new ConvertedModelHandler();
  
  // 変換されたモデルを読み込み
  await handler.loadConvertedModel('./models/converted_model/model.json');
  
  // ウォームアップ
  await handler.warmUp();
  
  // サンプル画像で予測
  const sampleImage = tf.randomNormal([224, 224, 3]);
  const predictions = await handler.predictWithPreprocessing(sampleImage);
  
  console.log("予測結果:", predictions);
  
  // パフォーマンステスト
  await handler.benchmarkPerformance(50);
  
  // メモリ解放
  sampleImage.dispose();
}

runConvertedModelDemo().catch(console.error);