☻Blog("Laziji")

System.out.print("辣子鸡的博客");

简介

线段树算法是一种快速查询一段区间内的信息的算法, 由于其实现简单, 所以广泛应用于程序设计竞赛中。
线段树是一棵完美二叉树, 即所有的叶子节点的深度均相同, 并且所有的非叶子节点都有两个子节点。每个节点维护一个区间, 这个区间为父节点二分后的子区间, 根节点维护整个区间, 叶子节点维护单个元素, 当元素个数为n时, 对区间的操作都可以在O(log n)的时间内完成, 因为此时树的深度为log2 n + 1, 每次操作只需从叶子节点开始, 往上更新至根节点, 每层只需更新相关的一个区间即可, 操作次数log2 n + 1, 即在O(log n)的时间内可完成。

可实现的功能

线段树可以提供不同的功能, 例如最常见的求区间内的最大最小值和求区间内的和, 还有其他类似的功能, 实现思路基本相同

求区间最小值(最小值)

给定任意数列[a0, a1,...,an-1], 在O(log n)的时间内完成下列的两种操作

  • query(s, t)[as,as+1,...,at-1] 内的最小值(最小值)
  • update(i, x)ai 的值改为 x

求区间的和

给定初始值全为0的数列[a0, a1,...,an-1], 在O(log n)的时间内完成下列的两种操作

  • query(s, t)[as,as+1,...,at-1] 内的和
  • add(i, x) 执行 ai += x

代码实现

这里我们以求区间最小值内的最小值为例, 用Python来实现原始的一棵线段树

初始化

这里创建一个数组dat[]并赋予初始最大值, 为了让其成为一棵完美的二叉树, 便于计算, 我们把n扩大到2的幂, 由于我们在数组中填充了int32的最大整数2147483647, 所以多余出来的的元素总是最大值, 不会影响原来区间的结果

1
2
3
4
5
6
7
8
def init(self, n):
self.INT_MAX = 2147483647
self.n = 1

while self.n < n:
self.n *= 2

self.dat = [self.INT_MAX for i in range(2 * self.n - 1)]

更新元素

我们把一棵完美二叉树压成一个数组, 下标为i的子节点为i*2+1 和 i*2+2, a0为根节点, 每次更新时, 首先更新叶子节点, 之后一层层往上更新, 节点a[k] = min(a[k * 2 + 1],a[k * 2 + 2]), 操作在O(log n)的时间内完成

1
2
3
4
5
6
def update(self, k, a):
k += self.n - 1
self.dat[k] = a
while k > 0:
k = (k - 1) // 2
self.dat[k] = min(self.dat[k * 2 + 1],self.dat[k * 2 + 2])

查询元素

query的功能为查询[a, b)区间内的最小值, 参数k, l, r是辅助参数

  • k 当前计算的节点
  • l, r 当前节点区间的范围

[a,b), 不在k节点管理的区间[l, r)内时, 直接返回INT_MAX
[a,b), 重合于k节点管理的区间[l, r)时, 直接返回k节点的值
否则, 递归k的两个子节点, 返回其中的最小值

1
2
3
4
5
6
7
8
9
10
def query(self, a, b, k, l, r):
if r <= a or b <= l:
return self.INT_MAX

if a <= l and r <= b:
return self.dat[k]
else:
vl = self.query(a, b, k * 2 + 1, l, (l + r) // 2)
vr = self.query(a, b, k * 2 + 2, (l + r) // 2, r)
return min(vl, vr)

结尾

至此我们就简单地实现了一棵线段树, 这只是线段树的其中一种形式, 线段树还有其他的变体。线段树的使用实例可以看我的另一篇文章https://laboo.top/2018/11/02/acm-lc-45/#more

源码

digit-recognizer

demo

https://github-laziji.github.io/digit-recognizer/
演示开始时需要加载大概100M的训练数据, 稍等片刻

调整训练集的大小, 观察测试结果的准确性

数据来源

数据来源与 https://www.kaggle.com 中的一道题目 digit-recognizer
题目给出42000条训练数据(包含图片和标签)以及28000条测试数据(只包含图片)
要求给这些测试数据打上标签[0-9]描述该图像显示的是哪个数字, 要尽可能的准确

网站中还有许多其他的机器学习的题目以及数据, 是个很好的练手的地方

实现

TensorFlow是一个开源的机器学习库, 利用这个库我们可以快速地构建机器学习项目
这里我们使用TensorFlow.js来实现识别手写数字

创建模型

卷积神经网络的第一层有两种作用, 它既是输入层也是执行层, 接收IMAGE_H * IMAGE_W大小的黑白像素
最后一层是输出层, 有10个输出单元, 代表着0-9这十个值的概率分布, 例如 Label=2 , 输出为[0.02,0.01,0.9,...,0.01]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
function createConvModel() {
const model = tf.sequential();

model.add(tf.layers.conv2d({
inputShape: [IMAGE_H, IMAGE_W, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));

model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
model.add(tf.layers.flatten({}));

model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

return model;
}

训练模型

我们选择适当的优化器和损失函数, 来编译模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
async function train() {

ui.trainLog('Create model...');
model = createConvModel();

ui.trainLog('Compile model...');
const optimizer = 'rmsprop';
model.compile({
optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
const trainData = Data.getTrainData(ui.getTrainNum());

ui.trainLog('Training model...');
await model.fit(trainData.xs, trainData.labels, {});

ui.trainLog('Completed!');
ui.trainCompleted();
}

测试

这里测试一组测试数据, 返回对应的标签, 即十个输出单元中概率最高的下标

1
2
3
4
5
6
7
8
9
10
11
12
13
function testOne(xs){
if(!model){
ui.viewLog('Need to train the model first');
return;
}
ui.viewLog('Testing...');
let output = model.predict(xs);
ui.viewLog('Completed!');
output.print();
const axis = 1;
const predictions = output.argMax(axis).dataSync();
return predictions[0];
}
0%