问题

已有三个整数xx, yy, zz,每个整数都大于等于1,并小于$ n < 10^6 $。

同时满足n=x+y+zn = x+y+z,求它们的余弦值之和最大为多少?

仍然存在的问题:
nn 不同时,输出结果不同,例如 n=1000n=1000n=10000n=10000时,目前并不太确定这是为什么?

方法

首先我们可以想到最简单的方法就是暴力搜索,三次遍历,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from math import cos
import time
import numpy as np

def calc(n):
total = 0
for x in range(n):
for y in range(n):
for z in range(n):
if x + y + z == n:
temp = cos(x) + cos(y) + cos(z)
if temp > total: total = temp
return total

tic = time.clock()
total = calc(1000)
print(time.clock()-tic)

print (total)
73.1499421
2.843394328325828

因为速度非常慢,这里只用了1000来演示,之后使用10000来演示

这样的情况下,时间复杂度为O(N3)O(N^3)

也能很容易的想到将其改为两次循环,代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from math import cos
import time
import numpy as np

def calc(n):
total = 0
for x in range(n):
for y in range(n):
z = n-x-y
temp = cos(x) + cos(y) + cos(z)
if temp > total: total = temp
return total

tic = time.clock()
total = calc(10000)
print(time.clock()-tic)

print (total)
39.927119000000005
1.7604512743833716

这下速度一下就降低了非常多了,输入nn的大小增长了10倍,运行时间减少了2倍。

进步一步的,我们可以注意到三个变量中的至少有一个小于或等于N/3N/3,不妨认为这个数是xx, 且对于剩下的两个数,必有一个数在11(N - x)//2 + 1之间,不妨认为这个数是zz,就可以得到如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from math import cos
import time
import numpy as np

def calc(n):
total = 0
for x in range(n, int((n/3 - 1)),-1):
for y in range(max(int(((n-x)/2))-1,1),min(int(n-x),int(n/3))):
z = n-x-y
temp = cos(x) + cos(y) + cos(z)
if temp > total: total = temp
return total

tic = time.clock()
total = calc(10000)
print(time.clock()-tic)

print (total)
2.2560140000000075
1.7604476202495472

可以看到运行时间进一步降低

使用Numba

使用Numba.jit进行加速,当然这个就不是算法层次的加速了,而是在语言本身和编译的层次进行加速了,详见 Numba

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from math import cos
import time
import numpy as np
from numba import jit

@jit(nopython=True)
def calc(n):
total = 0
for x in range(n, int((n/3 - 1)),-1):
for y in range(max(int(((n-x)/2))-1,1),min(int(n-x),int(n/3))):
z = n-x-y
temp = cos(x) + cos(y) + cos(z)
if temp > total: total = temp
return total

tic = time.clock()
total = calc(10000)
print(time.clock()-tic)

print (total)
0.3251690000000025
1.7604476202495472

再次加速!