TensorFlow.js
ブラウザとNode.js環境で動作するJavaScript機械学習ライブラリ。ブラウザ内での推論実行、WebGL/WebGPU高速化、プライバシー保護型AI実装が可能。クライアントサイドMLの標準ソリューション。
GitHub概要
tensorflow/tfjs
A WebGL accelerated JavaScript library for training and deploying ML models.
スター18,929
ウォッチ325
フォーク1,985
作成日:2018年3月5日
言語:TypeScript
ライセンス:Apache License 2.0
トピックス
deep-learningdeep-neural-networkgpu-accelerationjavascriptmachine-learningneural-networktypescriptwasmweb-assemblywebgl
スター履歴
データ取得日時: 2025/8/13 01:43
フレームワーク
TensorFlow.js
概要
TensorFlow.jsは、ブラウザやNode.js環境でマシンラーニングモデルの学習と推論を可能にするJavaScriptライブラリです。
詳細
TensorFlow.js(テンソルフロージェーエス)は、Googleが開発したTensorFlowエコシステムの一部として、JavaScript環境でマシンラーニングモデルを実行するためのオープンソースライブラリです。2018年に公開され、ブラウザ内でのリアルタイム機械学習、Node.jsでのサーバーサイド推論、モバイルアプリでの機械学習機能などを可能にします。WebGLやWebGPUを活用したGPUアクセラレーション、事前学習済みモデルの豊富なライブラリ、PythonのTensorFlowモデルからの変換機能、ブラウザ内でのモデル学習機能などが特徴です。画像認識、音声処理、自然言語処理、リアルタイム物体検出、姿勢推定など幅広い用途で活用されています。プライバシー保護(データがサーバーに送信されない)、低レイテンシ(ネットワーク通信不要)、クロスプラットフォーム対応(ブラウザ、Node.js、React Native)などの利点により、Webアプリケーションでの機械学習実装のスタンダードとなっています。
メリット・デメリット
メリット
- クライアントサイド実行: サーバーなしでブラウザ内で機械学習が実行可能
- プライバシー保護: ユーザーデータがデバイスから外部に送信されない
- 低レイテンシ: ネットワーク通信が不要でリアルタイム処理が可能
- GPU アクセラレーション: WebGL/WebGPUによる高速処理
- 豊富な事前学習済みモデル: すぐに使える高品質なモデルライブラリ
- Python TensorFlow互換: 既存のTensorFlowモデルを簡単に変換・利用可能
- 転移学習対応: ブラウザ内でのモデルのファインチューニングが可能
デメリット
- モデルサイズ制限: 大きなモデルはダウンロード時間とメモリ使用量が課題
- ブラウザ依存: GPU性能やWebGL/WebGPUサポートがブラウザに依存
- デバッグの困難さ: ブラウザ環境でのモデルデバッグが複雑
- パフォーマンス制限: サーバーサイドのGPUと比較して処理速度が劣る
- モデル変換の制約: 一部のTensorFlow操作は変換時にサポートされない
- メモリ管理: JavaScriptのガベージコレクションとTensorの手動管理が必要
主要リンク
- TensorFlow.js公式サイト
- TensorFlow.js公式ドキュメント
- TensorFlow.js GitHub リポジトリ
- TensorFlow.js チュートリアル
- TensorFlow.js モデルハブ
- TensorFlow.js デモ
書き方の例
Hello World
// TensorFlow.jsの基本的な使用方法
import * as tf from '@tensorflow/tfjs';
console.log("TensorFlow.js version:", tf.version.tfjs);
// 基本的なテンソル操作
const a = tf.tensor([[1, 2], [3, 4]]);
const b = tf.tensor([[5, 6], [7, 8]]);
console.log("テンソル a:");
a.print();
// 数学的操作
const c = a.add(b);
console.log("a + b:");
c.print();
// 行列積
const d = tf.matMul(a, b);
console.log("行列積 a × b:");
d.print();
// メモリリークを防ぐためのリソース解放
a.dispose();
b.dispose();
c.dispose();
d.dispose();
ブラウザでのモデル読み込みと推論
import * as tf from '@tensorflow/tfjs';
// 事前学習済みモデルの読み込み
async function loadAndUseModel() {
try {
// MobileNetモデルを読み込み
const model = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2');
// 画像要素を取得
const img = document.getElementById('image');
// 画像を前処理
const preprocessed = tf.browser.fromPixels(img)
.resizeBilinear([224, 224]) // 224x224にリサイズ
.expandDims(0) // バッチ次元を追加
.toFloat()
.div(127.5) // 正規化
.sub(1);
// 推論実行
const predictions = model.predict(preprocessed);
// 結果を表示
console.log("予測結果:");
predictions.print();
// トップ5の予測を取得
const topK = await tf.topk(predictions, 5);
const indices = await topK.indices.data();
const values = await topK.values.data();
for (let i = 0; i < indices.length; i++) {
console.log(`クラス ${indices[i]}: 確率 ${values[i].toFixed(4)}`);
}
// メモリ解放
preprocessed.dispose();
predictions.dispose();
topK.indices.dispose();
topK.values.dispose();
} catch (error) {
console.error("モデル読み込みエラー:", error);
}
}
loadAndUseModel();
リアルタイム Webカメラ画像認識
import * as tf from '@tensorflow/tfjs';
class RealTimeClassifier {
constructor() {
this.model = null;
this.video = null;
this.canvas = null;
this.ctx = null;
}
async initialize() {
// モデルを読み込み
this.model = await tf.loadLayersModel('/models/my_model.json');
// Webカメラの設定
this.video = document.getElementById('video');
this.canvas = document.getElementById('canvas');
this.ctx = this.canvas.getContext('2d');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 }
});
this.video.srcObject = stream;
this.video.addEventListener('loadeddata', () => {
this.canvas.width = this.video.videoWidth;
this.canvas.height = this.video.videoHeight;
this.startClassification();
});
} catch (error) {
console.error("カメラアクセスエラー:", error);
}
}
async startClassification() {
const classify = async () => {
if (this.video.readyState === 4) {
// カメラ映像をキャンバスに描画
this.ctx.drawImage(this.video, 0, 0);
// 画像を前処理
const imageTensor = tf.browser.fromPixels(this.canvas)
.resizeBilinear([224, 224])
.expandDims(0)
.toFloat()
.div(255.0);
// 推論実行
const predictions = this.model.predict(imageTensor);
const probabilities = await predictions.data();
// 結果を表示
this.displayResults(probabilities);
// メモリ解放
imageTensor.dispose();
predictions.dispose();
}
// 次のフレームで再実行
requestAnimationFrame(classify);
};
classify();
}
displayResults(probabilities) {
const classes = ['クラス1', 'クラス2', 'クラス3']; // クラス名
const resultsDiv = document.getElementById('results');
let html = '<h3>予測結果:</h3>';
probabilities.forEach((prob, index) => {
if (index < classes.length) {
const percentage = (prob * 100).toFixed(1);
html += `<p>${classes[index]}: ${percentage}%</p>`;
}
});
resultsDiv.innerHTML = html;
}
}
// 使用例
const classifier = new RealTimeClassifier();
classifier.initialize();
転移学習によるカスタムモデル作成
import * as tf from '@tensorflow/tfjs';
class TransferLearning {
constructor() {
this.baseModel = null;
this.model = null;
this.xs = null;
this.ys = null;
}
async loadBaseModel() {
// 事前学習済みMobileNetを読み込み
const mobilenet = await tf.loadLayersModel('https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2');
// 最後の層を除いた特徴抽出器として使用
this.baseModel = tf.model({
inputs: mobilenet.input,
outputs: mobilenet.layers[mobilenet.layers.length - 2].output
});
// 基底モデルの重みを凍結
this.baseModel.trainable = false;
}
createCustomModel(numClasses) {
// カスタム分類ヘッドを追加
const input = tf.input({ shape: [224, 224, 3] });
const features = this.baseModel.apply(input);
const flatten = tf.layers.flatten().apply(features);
const dense1 = tf.layers.dense({ units: 128, activation: 'relu' }).apply(flatten);
const dropout = tf.layers.dropout({ rate: 0.5 }).apply(dense1);
const output = tf.layers.dense({ units: numClasses, activation: 'softmax' }).apply(dropout);
this.model = tf.model({ inputs: input, outputs: output });
// モデルをコンパイル
this.model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
}
async addTrainingData(imageElement, label) {
// 画像を前処理
const img = tf.browser.fromPixels(imageElement)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
// ラベルをワンホットエンコーディング
const labelTensor = tf.oneHot(label, 3).expandDims(0);
// データを蓄積
if (this.xs == null) {
this.xs = img;
this.ys = labelTensor;
} else {
this.xs = tf.concat([this.xs, img], 0);
this.ys = tf.concat([this.ys, labelTensor], 0);
}
img.dispose();
labelTensor.dispose();
}
async train() {
if (this.xs == null || this.ys == null) {
throw new Error("学習データが設定されていません");
}
console.log("モデル学習を開始します...");
const history = await this.model.fit(this.xs, this.ys, {
epochs: 20,
batchSize: 32,
validationSplit: 0.2,
shuffle: true,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`エポック ${epoch + 1}: 損失=${logs.loss.toFixed(4)}, 精度=${logs.acc.toFixed(4)}`);
}
}
});
return history;
}
async predict(imageElement) {
if (this.model == null) {
throw new Error("モデルが学習されていません");
}
const img = tf.browser.fromPixels(imageElement)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
const prediction = this.model.predict(img);
const probabilities = await prediction.data();
img.dispose();
prediction.dispose();
return probabilities;
}
async saveModel(path) {
if (this.model) {
await this.model.save(path);
console.log(`モデルを ${path} に保存しました`);
}
}
}
// 使用例
async function runTransferLearning() {
const tl = new TransferLearning();
// ベースモデルを読み込み
await tl.loadBaseModel();
// カスタムモデルを作成(3クラス分類)
tl.createCustomModel(3);
// 学習データを追加(例)
const img1 = document.getElementById('sample1');
await tl.addTrainingData(img1, 0); // クラス0
const img2 = document.getElementById('sample2');
await tl.addTrainingData(img2, 1); // クラス1
// 学習実行
await tl.train();
// モデル保存
await tl.saveModel('localstorage://my-custom-model');
console.log("転移学習完了!");
}
Node.js でのモデル推論
import * as tf from '@tensorflow/tfjs-node';
import fs from 'fs';
import path from 'path';
class NodeMLPredictor {
constructor() {
this.model = null;
}
async loadModel(modelPath) {
try {
// ローカルファイルからモデルを読み込み
this.model = await tf.loadLayersModel(`file://${modelPath}`);
console.log("モデルの読み込み完了");
// モデル情報を表示
console.log("入力形状:", this.model.inputs[0].shape);
console.log("出力形状:", this.model.outputs[0].shape);
} catch (error) {
console.error("モデル読み込みエラー:", error);
throw error;
}
}
async preprocessImage(imagePath) {
// 画像ファイルを読み込み
const imageBuffer = fs.readFileSync(imagePath);
// TensorFlow.jsで画像をデコード
const imageTensor = tf.node.decodeImage(imageBuffer, 3) // 3チャンネル(RGB)
.resizeBilinear([224, 224])
.toFloat()
.div(255.0)
.expandDims(0);
return imageTensor;
}
async predict(imagePath) {
if (!this.model) {
throw new Error("モデルが読み込まれていません");
}
try {
// 画像を前処理
const imageTensor = await this.preprocessImage(imagePath);
// 推論実行
const startTime = Date.now();
const predictions = this.model.predict(imageTensor);
const inferenceTime = Date.now() - startTime;
// 結果を取得
const probabilities = await predictions.data();
const maxProbIndex = predictions.argMax(1);
const maxIndex = await maxProbIndex.data();
console.log(`推論時間: ${inferenceTime}ms`);
console.log(`予測クラス: ${maxIndex[0]}`);
console.log("確率分布:", Array.from(probabilities));
// メモリ解放
imageTensor.dispose();
predictions.dispose();
maxProbIndex.dispose();
return {
predictedClass: maxIndex[0],
probabilities: Array.from(probabilities),
inferenceTime: inferenceTime
};
} catch (error) {
console.error("予測エラー:", error);
throw error;
}
}
async batchPredict(imageDirectory) {
const imageFiles = fs.readdirSync(imageDirectory)
.filter(file => /\.(jpg|jpeg|png)$/i.test(file));
const results = [];
for (const imageFile of imageFiles) {
const imagePath = path.join(imageDirectory, imageFile);
console.log(`処理中: ${imageFile}`);
try {
const result = await this.predict(imagePath);
results.push({
filename: imageFile,
...result
});
} catch (error) {
console.error(`${imageFile} の処理でエラー:`, error.message);
}
}
return results;
}
}
// 使用例
async function runBatchInference() {
const predictor = new NodeMLPredictor();
// モデルを読み込み
await predictor.loadModel('./models/my_model.json');
// 単一画像の予測
const singleResult = await predictor.predict('./images/sample.jpg');
console.log("単一予測結果:", singleResult);
// バッチ予測
const batchResults = await predictor.batchPredict('./images/batch/');
console.log("バッチ予測結果:", batchResults);
// 結果をJSONファイルに保存
fs.writeFileSync('./results.json', JSON.stringify(batchResults, null, 2));
console.log("結果を results.json に保存しました");
}
runBatchInference().catch(console.error);
TensorFlow Python モデルの変換と利用
// 1. Python側でモデルを変換
/*
Python コマンド:
pip install tensorflowjs
# Keras H5 モデルの変換
tensorflowjs_converter --input_format=keras \
--output_format=tfjs_layers_model \
./my_model.h5 \
./tfjs_model/
# SavedModel の変換
tensorflowjs_converter --input_format=tf_saved_model \
--output_format=tfjs_graph_model \
./saved_model/ \
./tfjs_model/
*/
// 2. JavaScript側で変換されたモデルを利用
import * as tf from '@tensorflow/tfjs';
class ConvertedModelHandler {
constructor() {
this.model = null;
}
async loadConvertedModel(modelUrl) {
try {
// 変換されたモデルを読み込み
this.model = await tf.loadLayersModel(modelUrl);
console.log("変換モデルの読み込み完了");
// モデル構造を表示
console.log("レイヤー数:", this.model.layers.length);
this.model.summary();
} catch (error) {
console.error("変換モデル読み込みエラー:", error);
throw error;
}
}
async warmUp() {
if (!this.model) return;
// ダミー入力でモデルをウォームアップ(初回推論の高速化)
const inputShape = this.model.inputs[0].shape.slice(1); // バッチ次元を除く
const dummyInput = tf.randomNormal([1, ...inputShape]);
console.log("モデルウォームアップ中...");
const startTime = Date.now();
const warmupPrediction = this.model.predict(dummyInput);
await warmupPrediction.data(); // データを取得して処理完了を待つ
const warmupTime = Date.now() - startTime;
console.log(`ウォームアップ完了: ${warmupTime}ms`);
// メモリ解放
dummyInput.dispose();
warmupPrediction.dispose();
}
async predictWithPreprocessing(inputData) {
if (!this.model) {
throw new Error("モデルが読み込まれていません");
}
// 入力データの前処理(例:正規化)
const preprocessed = tf.tidy(() => {
let processed = inputData;
// 型変換
if (processed.dtype !== 'float32') {
processed = processed.toFloat();
}
// 正規化(0-255 → 0-1)
if (processed.max().dataSync()[0] > 1) {
processed = processed.div(255.0);
}
// バッチ次元の追加
if (processed.shape.length === 3) {
processed = processed.expandDims(0);
}
return processed;
});
try {
// 推論実行
const predictions = this.model.predict(preprocessed);
// 後処理(例:softmax -> 確率)
const postprocessed = tf.tidy(() => {
if (predictions.shape[1] > 1) {
// 多クラス分類の場合
return tf.softmax(predictions);
} else {
// 二値分類の場合
return tf.sigmoid(predictions);
}
});
// 結果を取得
const result = await postprocessed.data();
// メモリ解放
preprocessed.dispose();
predictions.dispose();
postprocessed.dispose();
return Array.from(result);
} catch (error) {
preprocessed.dispose();
throw error;
}
}
async benchmarkPerformance(numRuns = 100) {
if (!this.model) {
throw new Error("モデルが読み込まれていません");
}
const inputShape = this.model.inputs[0].shape.slice(1);
const testInput = tf.randomNormal([1, ...inputShape]);
// ウォームアップ
await this.warmUp();
const times = [];
console.log(`パフォーマンステスト開始(${numRuns}回実行)...`);
for (let i = 0; i < numRuns; i++) {
const startTime = performance.now();
const prediction = this.model.predict(testInput);
await prediction.data();
const endTime = performance.now();
times.push(endTime - startTime);
prediction.dispose();
if ((i + 1) % 20 === 0) {
console.log(`進捗: ${i + 1}/${numRuns} 完了`);
}
}
testInput.dispose();
// 統計計算
const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
const minTime = Math.min(...times);
const maxTime = Math.max(...times);
const medianTime = times.sort()[Math.floor(times.length / 2)];
const stats = {
average: avgTime.toFixed(2),
minimum: minTime.toFixed(2),
maximum: maxTime.toFixed(2),
median: medianTime.toFixed(2),
totalRuns: numRuns
};
console.log("パフォーマンステスト結果:", stats);
return stats;
}
}
// 使用例
async function runConvertedModelDemo() {
const handler = new ConvertedModelHandler();
// 変換されたモデルを読み込み
await handler.loadConvertedModel('./models/converted_model/model.json');
// ウォームアップ
await handler.warmUp();
// サンプル画像で予測
const sampleImage = tf.randomNormal([224, 224, 3]);
const predictions = await handler.predictWithPreprocessing(sampleImage);
console.log("予測結果:", predictions);
// パフォーマンステスト
await handler.benchmarkPerformance(50);
// メモリ解放
sampleImage.dispose();
}
runConvertedModelDemo().catch(console.error);