Brain.js 初学者指南
Brain.js
Brain.js 是一个用于创建神经网络的 JavaScript 库,它非常易于学习和实现。神经网络是一种受人脑启发并以人脑命名的数据结构。神经网络由多层节点组成,每层节点都具有相应的权重阈值。如果某个节点的输出高于阈值,则该神经元被激活,并将数据发送到下一层的连接节点。在 Brain.js 中,初始化此数据结构的过程非常简单,只需加载模块并调用构造函数即可。
const brain = require('brain.js');
//initialize neural network
const net = new brain.NeuralNetwork();
训练
这个新构建的对象需要传入一个数据集才能进行预测。Brain.js 简化了这个过程,您只需调用该.train
方法并传入一个对象数组即可——每个对象都包含输入和输出的键值对。基础网络构造函数仅支持 0 和 1 的数组,或 0 和 1 的哈希值;然而,还有其他更复杂的模型,允许将其他数据类型输入到数据集中。
//input training dataset
net.train([
{ input: [0, 0], output: [0] },
{ input: [0, 1], output: [1] },
{ input: [1, 0], output: [1] },
{ input: [1, 1], output: [0] },
]);
跑步
现在你已经输入了数据集并训练了你的人工大脑,你需要.run
在神经网络上调用该方法。这将返回一个包含十进制值的数组,该值对应于预期输出等于 1 的概率。小数点的值越高,神经网络就越有可能认为输出应该为 1。
//outputs a probability
const output1 = net.run([1, 0]);// [ 0.9327429533004761 ]
const output2 = net.run([0, 0]);//) [ 0.05645085498690605 ]
const output3 = net.run([1, 1]);// [ 0.08839469403028488 ]
console.log('output 1:', output1);
console.log('output 2:', output2);
console.log('output 3:', output3);
console.log(Math.round(output1[0]));// => 1
console.log(Math.round(output2[0]));// => 0
console.log(Math.round(output3[0]));// => 0
保存已训练的网络
通过提供更大的数据集,网络预测预期输出的能力会更加准确;但是,如果数据集更大,或者您正在实现更复杂的网络模型(这些模型需要对数据集进行一定次数的迭代才能进行训练),代码执行所需的时间也会增加。可以通过使用 brain.js.toJSON
并将.fromJSON
预训练网络存储为 JSON 格式来解决这个问题。以下是如何使用 node.js 的文件系统模块保存和加载网络的示例。
const brain = require('brain.js');
const data = require('./data.json');
const fs = require('fs');
//init neural network
const net = new brain.recurrent.LSTM();
const trainingData = data.map(item => ({
input: item.message,
output: item.response
}));
//input training data and configuration object
net.train(trainingData, {
log: (err) => console.log(err),
iterations: 500
});
//save trained network to json
const networkState = net.toJSON();
fs.writeFileSync('network_state.json',
JSON.stringify(networkState),
'utf-8');
const brain = require('brain.js');
const fs = require('fs');
//init neural network
const net = new brain.recurrent.LSTM();
//load trained network to json
const networkState = JSON.parse(
fs.readFileSync('network_state.json',
'utf-8'));
net.fromJSON(networkState);
结论
Brain.js 是一种用 JavaScript 学习和实现神经网络的有趣且简单的方法。我希望它能激励更多人用 JavaScript 实现机器学习算法。
文章来源:https://dev.to/gfish94/brainjs-for-beginners-1g77