原文见此:https://johnnylee-sde.github.io/Fast-numeric-string-to-int/

利用位运算和64位CPU的优势,实现快速的转换string到int,并使用Go来验证。

最简单的版本

  • 常见的字符串转数字的代码如下所示:
1
2
3
4
5
6
// given num[] - ASCII chars containing decimal digits 0-9
int sum = 0;
for (int i = 0; i <= 7; i++)
{
sum = (sum * 10) + (num[i] - '0');
}
  • 一种非常直接优化的方式是将循环展开
1
2
3
4
5
6
7
8
9
int sum;
sum = (num[0] - '0') * 10000000 +
(num[1] - '0') * 1000000 +
(num[2] - '0') * 100000 +
(num[3] - '0') * 10000 +
(num[4] - '0') * 1000 +
(num[5] - '0') * 100 +
(num[6] - '0') * 10 +
(num[7] - '0');
  • 用Golang测试下,循环展开的版本是否有优化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package string2int

import (
"strconv"
"testing"
)

func loop(str string) int {
num := 0
for i := 0; i < 8; i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func loop2(str string) int {
num := 0
for i := 0; i < len(str); i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func unrollLoop(str string) int {
num := int(str[0]-'0')*10000000 +
int(str[1]-'0')*1000000 +
int(str[2]-'0')*100000 +
int(str[3]-'0')*10000 +
int(str[4]-'0')*1000 +
int(str[5]-'0')*100 +
int(str[6]-'0')*10 +
int(str[7]-'0')
return num
}

func Test_String2Int(t *testing.T) {
str := "12345678"
n, _ := strconv.Atoi(str)

if loop(str) != n {
t.Errorf("loop error, %v != %v\n", loop(str), n)
}
if loop2(str) != n {
t.Errorf("loop2 error, %v != %v\n", loop2(str), n)
}
if unrollLoop(str) != n {
t.Errorf("unroll loop error, %v != %v\n", unrollLoop(str), n)
}
}

func Benchmark_String2Int(b *testing.B) {
str := "12345678"
b.Run("strconv.Atoi", func(b *testing.B) {
for i := 0; i < b.N; i++ {
strconv.Atoi(str)
}
})
b.Run("loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
loop(str)
}
})
b.Run("loop2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
loop2(str)
}
})
b.Run("unroll loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
unrollLoop(str)
}
})
}

--------------------------------------------------------------------
goos: darwin
goarch: amd64
pkg: mine/mock/benchmark/string2int
cpu: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Benchmark_String2Int
Benchmark_String2Int/strconv.Atoi
Benchmark_String2Int/strconv.Atoi-12 138586400 8.996 ns/op
Benchmark_String2Int/loop
Benchmark_String2Int/loop-12 338002887 3.550 ns/op
Benchmark_String2Int/loop2
Benchmark_String2Int/loop2-12 410646234 3.205 ns/op
Benchmark_String2Int/unroll_loop
Benchmark_String2Int/unroll_loop-12 724936510 1.436 ns/op
PASS

// 循环展开的版本优势是明显的,并且比strconv要快很多,因为strconv要做一些检查。
// 有意思的地方:loop2要比loop快,还不太清楚为什么
  • 两种方法实际上都是O(n)的,n为字符串长度
  • 可以通过计算loads、adds/subtracts、shifts和的multiplies来估计代码的开销:
    • 对于展开循环的版本:

      • 8 loads
      • 8 subtracts, 7 adds, 8 adds to index into num array
      • 7 multiplies
    • 这里假设除了multiplies外每个操作的开销一致,一般multiplies开销会更大一些。

    • 因此是,31次操作 + 7 multiplies

更快的转换

现在 64 位 CPU 和操作系统很常见,我们可以在这个问题上释放 64 位 CPU 寄存器的全部力量。

概念1

  • 在 ASCII 字符集中,数字字符('0' ~ '9')的范围是 0x30 ~ 0x39 (48-57)
  • 如果将每个数字 ASCII 字符与 0x0F 按位与,会将数字 ASCII 字符转换为该数字字符对应的十进制数。
1
2
'0' ~ '9'             =>   0x30 ~ 0x39
(0x30 ~ 0x39) & 0x0F => 0x0 ~ 0x9
  • 将 8 位数字字符串加载到 64 位 CPU 寄存器中,在 Intel CPU(little-endian)上表现如下
1
2
3
// given the string "12345678", 
// on little-endian Intel CPUs we see the reversed:
sum = 0x3837363534333231
  • 0x0F0F0F0F0F0F0F0F 按位与后
1
2
3
4
// given the string "12345678",
// bitwise-AND with 0x0F0F0F0F0F0F0F0F
sum = *((long long *)num) & 0x0F0F0F0F0F0F0F0F;
sum == 0x0807060504030201

概念2

  • 由于Intel CPU为little-endian,加载进来的数字实际上低位在前,高位在后。
1
2
3
// given the string "12345678", 
// on little-endian Intel CPUs we see the reversed:
sum = 0x3837363534333231
  • 因此我们需要做一些调整,将个位的低位和十位的高位合并为一个数
    • 按位与所有高位数字并将乘以10
    • 右移低位数字到与高位数字相同的位置
    • 取上面两个数的和
1
2
3
4
5
6
7
8
9
10
// isolate the high digit, multiply by 10,
// shift over the low digit and add in
sum = ((sum & 0x000F000F000F000F) * 10) +
((sum >> 8) & 0x000F000F000F000F);

// sum = 0x0807060504030201
// (sum & 0x000F000F000F000F) * 10 = 0x0007000500030001 * 10
// (sum >> 8) & 0x000F000F000F000F = 0x0008000600040002
// 取两数之和后: [78]代表10进制数字对应的16进制数字
// sum = 0x00[78]00[56]00[34]00[12]
  • 扩展到更大的范围
1
2
3
4
5
6
7
8
9
10
11
12
13
// numbers are in range 0-99 (0x0-0x63) now
// - isolate high number (use 0x7F which encompasses number range)
// - multiply by 100 to move high number into
// thousands & hundreds position
// - shift low number over to tens and ones position
// - add the two numbers together
sum = ((sum & 0x0000007F0000007F) * 100) + ((sum >> 16) & 0x0000007F0000007F);

// sum = 0x00[78]00[56]00[34]00[12]
// (sum & 0x0000007F0000007F) * 100 = 0x00000000[56]000000[12] * 100
// (sum >> 16) & 0x0000007F0000007F) = 0x000000[78]000000[34]
// 取两数之和后:
// sum = 0x0000[5678]0000[1234]
  • 继续扩展
1
2
3
4
5
6
7
8
9
10
11
12
// numbers are in range 0-9,999 (0x0-0x270F) now
// isolate high number (use 0x3FFF which covers number range)
// then multiply by 10000 to move high number into position
// shift low number over and isolate
// add the two numbers together
sum = ((sum & 0x3FFF) * 10000) + ((sum >> 32) & 0x3FFF);

// sum = 0x0000[5678]0000[1234]
// (sum & 0x3FFF) * 10000 = 0x000000000000[1234] * 10000
// (sum >> 32) & 0x3FFF = 0x000000000000[5678]
// 取两数之和后:
// sum = 0x00000000[12345678]

最终的算法

1
2
3
4
5
6
// given num[] - ASCII chars containing decimal digits 0-9
long long sum;
sum = *((long long*)num) & 0xFFFFFFFF;
sum = ((sum & 0x0F0F0F0F) * 10 ) + ((sum >> 8) & 0x0F0F0F0F);
sum = ((sum & 0x0007F0007F) * 100) + ((sum >> 16) & 0x0007F0007F);
sum = ((sum & 0x3FFF) * 10000) + ((sum >> 32) & 0x3FFF);
  • 现在的时间复杂度为O(lg n), 会执行
    • 1 load
    • 7 bitwise ANDs
    • 3 right shifts
    • 3 adds
    • 3 multiplies
  • 和循环展开的版本的比较一下
1
2
3
Algorithm       |  Ops   | Multiplies
Unrolled loop | 31 | 7
SIMD | 14 | 3

再快一点

1
2
3
4
sum = *(long long *)num;
sum = (sum & 0x0F0F0F0F0F0F0F0F) * 2561 >> 8;
sum = (sum & 0x00FF00FF00FF00FF) * 6553601 >> 16;
sum = (sum & 0x0000FFFF0000FFFF) * 42949672960001 >> 32;

这些魔法数字是哪来的?

魔法数字

  • 上面的方法是通过右移不断的累加高位数字
  • 这里其实也是类似的,只是通过左移来处理高位数字的
  • 最后右移,去掉引入的部分
1
2
3
4
5
6
7
8
9
10
11
12
// 2561

sum = (((256 * 10) * sum) + 1 * sum);
// multiply by 256 is the same as left shift by 8
== ((10 * sum) << 8) + (1 * sum);
sum = sum >> 8;

// sum = 0x0807060504030201
// (256 * 10) * sum = (sum << 8) * 10 = 0x0706050403020100 * 10
// (256 * 10) * sum + 1 * sum = 0x0706050403020100 * 10 + 0x0807060504030201
// = 0x[78][67][56][45][34][23][12][01]
// sum >> 8 = 0x00[78][67][56][45][34][23][12]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 6553601

// number groups are in range 0-99 now
sum = (sum & 0x00FF00FF00FF00FF) * 6553601 >> 16;

// sum = 0x00[78][67][56][45][34][23][12]
// sum & 0x00FF00FF00FF00FF = 0x00[78]00[56]00[34]00[12]

// sum = 0x00[78]00[56]00[34]00[12]
// (65536 * 100) * sum = (sum << 16) * 100 = 0x00[56]00[34]00[12]0000 * 100
// (65536 * 100) * sum + 1 * sum = 0x00[56]00[34]00[12]0000 * 100 + 0x00[78]00[56]00[34]00[12]
// = 0x[5678][3456][1234][0012]
// sum >> 16 = 0x0000[5678][3456][1234]

// 42949672960001

// number groups are in range 0-9,999 now
sum = (sum & 0x0000FFFF0000FFFF) * 42949672960001 >> 32;

// sum = 0x0000[5678][3456][1234]
// sum & 0x0000FFFF0000FFFF = 0x0000[5678]0000[1234]

// sum = 0x0000[5678]0000[1234]
// (4294967296 * 10000) * sum) = (sum << 32) * 10000 = 0x0000[1234]00000000 * 10000
// (4294967296 * 10000) * sum) + 1 * sum = 0x0000[1234]00000000 * 10000 + 0x0000[5678]0000[1234]
// = 0x[12345678][00001234]
// sum >> 32 = 0x00000000[12345678]
  • 这样进一步减少了操作的步骤
    • 1 load
    • 3 bitwise ANDs
    • 3 right shifts
    • 3 multiplies
  • 和之前的版本比较一下
1
2
3
4
Algorithm       |  Ops  | Multiplies
Unrolled loop | 31 | 7
SIMD | 14 | 3
SIMD 2 | 7 | 3

用Golang模拟

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
goos: darwin
goarch: amd64
pkg: mine/mock/benchmark/string2int
cpu: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Benchmark_String2Int
Benchmark_String2Int/strconv.Aoti
Benchmark_String2Int/strconv.Aoti-12 127020 8282 ns/op
Benchmark_String2Int/loop
Benchmark_String2Int/loop-12 327012 3719 ns/op
Benchmark_String2Int/loop2
Benchmark_String2Int/loop2-12 406800 2974 ns/op
Benchmark_String2Int/unroll_loop
Benchmark_String2Int/unroll_loop-12 802428 1455 ns/op
Benchmark_String2Int/simd
Benchmark_String2Int/simd-12 4299562 330.2 ns/op
Benchmark_String2Int/simd2
Benchmark_String2Int/simd2-12 3641539 301.2 ns/op
PASS
  • SIMD 和 SIMD2的优势是非常明显的
  • 发现SIMD2 相比于SIMD并没有明显优势(benchmark比较的是1000次的转换开销),还不确定是什么原因
    • 毕竟SIMD 2减少了常规操作,但是引入了相对大的整数乘法
    • 也有可能是常规操作确实开销太小了,不明显
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package string2int

import (
"strconv"
"testing"
)

func loop(str string) int {
num := 0
for i := 0; i < 8; i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func loop2(str string) int {
num := 0
for i := 0; i < len(str); i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func unrollLoop(str string) int {
num := int(str[0]-'0')*10000000 +
int(str[1]-'0')*1000000 +
int(str[2]-'0')*100000 +
int(str[3]-'0')*10000 +
int(str[4]-'0')*1000 +
int(str[5]-'0')*100 +
int(str[6]-'0')*10 +
int(str[7]-'0')
return num
}

func simd() int {
str := 0x3837363534333231
num := str & 0x0f0f0f0f0f0f0f0f
num = (num&0x000f000f000f000f)*10 + (num>>8)&0x000f000f000f000f
num = (num&0x0000007f0000007f)*100 + (num>>16)&0x0000007f0000007f
num = (num&0x3fff)*10000 + (num>>32)&0x3fff
return num
}

func simd2() int {
str := 0x3837363534333231
num := (str & 0x0f0f0f0f0f0f0f0f) * 2561 >> 8
num = (num & 0x00ff00ff00ff00ff) * 6553601 >> 16
num = (num & 0x0000ffff0000ffff) * 42949672960001 >> 32
return num
}

func Test_String2Int(t *testing.T) {
str := "12345678"
n, _ := strconv.Atoi(str)

if loop(str) != n {
t.Errorf("loop error, %v != %v\n", loop(str), n)
}
if loop2(str) != n {
t.Errorf("loop2 error, %v != %v\n", loop2(str), n)
}
if unrollLoop(str) != n {
t.Errorf("unroll loop error, %v != %v\n", unrollLoop(str), n)
}
if simd() != n {
t.Errorf("simd error, %v != %v\n", simd(), n)
}
if simd2() != n {
t.Errorf("simd2 error, %v != %v\n", simd2(), n)
}
}

func Benchmark_String2Int(b *testing.B) {
str := "12345678"
b.Run("strconv.Aoti", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
strconv.Atoi(str)
}
}
})
b.Run("loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
loop(str)
}
}
})
b.Run("loop2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
loop2(str)
}
}
})
b.Run("unroll loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
unrollLoop(str)
}
}
})
b.Run("simd", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
simd()
}
}
})
b.Run("simd2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
simd2()
}
}
})
}