使

使用 TensorFlow.js (ReactJS) 在浏览器中运行机器学习模型

2025-06-10

使用 TensorFlow.js (ReactJS) 在浏览器中运行机器学习模型

TensorFlow.js(简称 tfjs)是一个库,可让您使用 JavaScript 创建、训练和使用已训练好的机器学习模型!
它的主要目标是让 JavaScript 开发人员能够通过创建酷炫且智能的 Web 应用程序进入机器学习和深度学习的世界,这些应用程序可以使用 JavaScript 在任何主流浏览器或Node.js服务器上运行。

介绍

TensorFlow.js 几乎可以在任何地方运行,包括所有主流浏览器、服务器、手机,甚至物联网设备。这充分展现了这个库的巨大潜力。TensorFlow.js 后端可以通过WebGLAPI在设备 GPU 上运行,WebGLAPI 允许 JavaScript 代码在 GPU 上运行,这意味着即使在浏览器上运行,TensorFlow.js 也能拥有出色的性能。

阅读完本文后,您将:

  • 了解 TensorFlow.js 及其使用方法。
  • 了解如何将机器学习模型加载到您的 Javascript 项目中并开始使用它。
  • 获得自己创建此类项目的技能
  • 最后,获得更多有关机器学习的知识。

那么,它是如何工作的呢?

我们可以选择以下几种方案:
tensorflowjs-模型

1.运行现有模型:

TensorFlow.js 提供了一些优秀的预训练模型,我们可以将它们导入到项目中,提供输入,并根据需求使用输出。您可以在这里探索它们为我们提供的模型:TensorFlow.js 模型,并且随着时间的推移,它们会不断添加更多模型。
除此之外,您还可以在整个网络上找到由 TensorFlow.js 社区开发的许多优秀的预训练模型。

2.重新训练现有模型:

此选项使我们能够针对特定用例改进现有模型。我们可以通过使用一种称为“迁移学习”的方法来实现这一点。
迁移学习是通过迁移已学习的相关任务中的知识来改进新任务的学习。
例如,在现实世界中,骑自行车时学到的平衡逻辑可以迁移到学习驾驶其他两轮车辆。同样,在机器学习中,迁移学习可用于将算法逻辑从一个机器学习模型迁移到另一个模型。

3. 使用 JavaScript 开发 ML:

第三个选项将用于开发人员想要从头开始创建新的机器学习模型的情况,使用 TensorFlow.js API,就像常规 TensorFlow 版本一样。

现在让我们开始用 JavaScript 进行机器学习

在本文中,我们主要关注如何在 JavaScript 项目中添加并运行预先训练的机器学习模型。您将看到安装、加载和运行机器学习模型的预测是多么简单。

那么让我们开始吧!

我构建了一个应用程序,演示了如何使用由 Tensorflow.js 团队创建的预训练图像标签分类模型。该模型名为 MobileNet,您可以在此处找到有关它的更多信息。

该演示应用程序使用React.js构建,并使用 Ant Design作为 UI 组件。

React 是一个用于构建用户界面或 UI 组件的开源前端 JavaScript 库。

让我们一起了解一下该应用程序的主要部分:

首先,依赖关系

设置好 React 应用程序后,我们需要通过运行以下命令来安装 tfjs 和图像分类模型 — mobilenet:

$ npm i @tensorflow-models/mobilenet
$ npm i @tensorflow/tfjs
Enter fullscreen mode Exit fullscreen mode

现在,安装完软件包后,我们可以将它们导入到我们的App.js文件中:

import "@tensorflow/tfjs";
import * as mobileNet from "@tensorflow-models/mobilenet
view raw import_tfjs.js hosted with ❤ by GitHub
import "@tensorflow/tfjs";
import * as mobileNet from "@tensorflow-models/mobilenet
view raw import_tfjs.js hosted with ❤ by GitHub
import "@tensorflow/tfjs";
import * as mobileNet from "@tensorflow-models/mobilenet
view raw import_tfjs.js hosted with ❤ by GitHub

我们导入了图像分类模型和 TensorFlow.js 引擎,每次调用模型时,它都会在后台运行机器学习模型。

接下来,我们需要将模型加载到组件中以供将来使用。请注意,model.load() 函数是一个异步函数,因此我们需要等待它完成。

const [model, setModel] = useState(null);
useEffect(() => {
const loadModel = async () => {
const model = await mobileNet.load();
setModel(model);
};
loadModel();
}, []);
view raw load_mode.js hosted with ❤ by GitHub
const [model, setModel] = useState(null);
useEffect(() => {
const loadModel = async () => {
const model = await mobileNet.load();
setModel(model);
};
loadModel();
}, []);
view raw load_mode.js hosted with ❤ by GitHub
const [model, setModel] = useState(null);
useEffect(() => {
const loadModel = async () => {
const model = await mobileNet.load();
setModel(model);
};
loadModel();
}, []);
view raw load_mode.js hosted with ❤ by GitHub

mobileNet模型有一个方法叫classify,我们加载模型之后可以调用这个方法:

model.classify(
img: tf.Tensor3D | ImageData | HTMLImageElement |
HTMLCanvasElement | HTMLVideoElement,
topk?: number
)
model.classify(
img: tf.Tensor3D | ImageData | HTMLImageElement |
HTMLCanvasElement | HTMLVideoElement,
topk?: number
)
model.classify(
img: tf.Tensor3D | ImageData | HTMLImageElement |
HTMLCanvasElement | HTMLVideoElement,
topk?: number
)

此方法接受 2 个参数:

  • img:用于进行分类的张量或图像元素。
  • topk:返回最高概率的数量。默认为 3。

下一步,我们要读取用户输入的图像,并将上传的文件加载到 HTMLCanvasElement 类型的画布元素中

const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
}
view raw draw_image.js hosted with ❤ by GitHub
const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
}
view raw draw_image.js hosted with ❤ by GitHub
const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
}
view raw draw_image.js hosted with ❤ by GitHub

将图像加载到画布后,我们就可以运行分类方法。

const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
// Classify the image
const predictions = await model.classify(canvas, 5);
// Set the results to the componenet's state
setPredictions(predictions);
};
view raw classify.js hosted with ❤ by GitHub
const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
// Classify the image
const predictions = await model.classify(canvas, 5);
// Set the results to the componenet's state
setPredictions(predictions);
};
view raw classify.js hosted with ❤ by GitHub
const onImageChange = async ({ target }) => {
// Load the image into a canvas element
const canvas = canvasRef.current;
const ctx = canvas.getContext("2d");
drawImageOnCanvas(target, canvas, ctx);
// Classify the image
const predictions = await model.classify(canvas, 5);
// Set the results to the componenet's state
setPredictions(predictions);
};
view raw classify.js hosted with ❤ by GitHub

model.classify 方法的输出是一个包含分类标签及其预测分数的数组。输出如下所示:

[{
className: "tiger, Panthera tigris",
probability: 0.6370824575424194
}, {
className: "tiger cat",
probability: 0.3609316051006317
}, {
className: "jaguar, panther, Panthera onca, Felis onca",
probability: 0.0009806138696148992
}]
view raw result.js hosted with ❤ by GitHub
[{
className: "tiger, Panthera tigris",
probability: 0.6370824575424194
}, {
className: "tiger cat",
probability: 0.3609316051006317
}, {
className: "jaguar, panther, Panthera onca, Felis onca",
probability: 0.0009806138696148992
}]
view raw result.js hosted with ❤ by GitHub
[{
className: "tiger, Panthera tigris",
probability: 0.6370824575424194
}, {
className: "tiger cat",
probability: 0.3609316051006317
}, {
className: "jaguar, panther, Panthera onca, Felis onca",
probability: 0.0009806138696148992
}]
view raw result.js hosted with ❤ by GitHub

一旦我们在组件中保存了预测数组,我们就可以循环遍历该数组并将它们渲染到屏幕上:

<div className="tags-container">
{predictions.map(
({ className, probability }) =>
probability.toFixed(3) > 0 && (
<Tag className="tag" key={className} color="geekblue">
{className.split(",")[0]} {probability.toFixed(3)}
</Tag>
)
)}
</div>
<div className="tags-container">
{predictions.map(
({ className, probability }) =>
probability.toFixed(3) > 0 && (
<Tag className="tag" key={className} color="geekblue">
{className.split(",")[0]} {probability.toFixed(3)}
</Tag>
)
)}
</div>
<div className="tags-container">
{predictions.map(
({ className, probability }) =>
probability.toFixed(3) > 0 && (
<Tag className="tag" key={className} color="geekblue">
{className.split(",")[0]} {probability.toFixed(3)}
</Tag>
)
)}
</div>

就是这样,我们的浏览器中有一个活生生的机器学习模型,恭喜!

欢迎您访问以下链接:

您可以上传自己的图像,获得预测,甚至可以发挥创造力,尝试自己添加新功能!

结论

毫无疑问,机器学习的应用正在持续增长。随着 JavaScript 开发日益普及,TensorFlow.js 社区也将不断发展壮大。我认为,我们将看到越来越多的生产级应用程序在浏览器或 Node.js 服务器上运行 TensorFlow.js,用于执行机器学习模型能够解决的简单、轻量级任务。

在大家了解了将 TensorFlow.js 集成到 Javascript 应用程序是多么快速和容易之后,我邀请大家亲自尝试一下,创建一些很酷的项目并与社区分享。

鏂囩珷鏉ユ簮锛�https://dev.to/omrigm/run-machine-learning-models-in-your-browser-with-tensorflow-js-reactjs-48pe
PREV
像绝地武士一样掌握 Web 性能
NEXT
给初学者的 5 个 git 技巧