问题

在Stackoverflow上看到了这个问题:Recursively combine dictionaries

就是说,对于这样的两个字典:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
groups = {
'servers': ['unix_servers', 'windows_servers'],
'unix_servers': ['server_a', 'server_b', 'server_group'],
'windows_servers': ['server_c', 'server_d'],
'server_group': ['server_e', 'server_f']
}

hosts = {
'server_a': '10.0.0.1',
'server_b': '10.0.0.2',
'server_c': '10.0.0.3',
'server_d': '10.0.0.4',
'server_e': '10.0.0.5',
'server_f': '10.0.0.6'
}

如何递归组合成如下形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
{
'servers': {
'unix_servers': {
'server_a': '10.0.0.1',
'server_b': '10.0.0.2',
'server_group': {
'server_e': '10.0.0.5',
'server_f': '10.0.0.6'
}
},
'windows_servers': {
'server_c': '10.0.0.3',
'server_d': '10.0.0.4'
}
}
}

思路与代码

思路在问题中实际已经提及了,就是用递归来实现,但是需要注意一些细节:会有一些非根结点也出现在字典中,所以要记录在servers中出现的那些非根的结点,然后在输出前将其去除。

1
2
3
4
5
6
7
8
9
10
11
12
13
def resolve(groups, hosts):
# Groups that have already been solved
resolved_groups = {}
# Group names that are not root
non_root = set()
# Make dict with resolution of each group
result = {}
for name in groups:
result[name] = _resolve_rec(name, groups, hosts, resolved_groups, non_root)
for name in groups:
if name in non_root:
del result[name]
return result

这部分很简单,下面是递归部分:

实际上,就是对group里的每一层,都进行查找,若出现在hosts中,则完成组合;如不在hosts中,说明group中递归还未结束。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def _resolve_rec(name, groups, hosts, resolved_groups, non_root):
# If group has already been resolved, finish
if name in resolved_groups:
return resolved_groups[name]
# If it is a host, finish
if name in hosts:
return hosts[name]
# new group resolution
resolved = {}
for child in groups[name]:
# Resolve each child
resolved[child] = _resolve_rec(child, groups, hosts, resolved_groups, non_root)
# Mark child as non_root
non_root.add(child)
# Save to resolved groups
resolved_groups[name] = resolved
return resolved

看看输出的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
groups = {
'servers': ['unix_servers', 'windows_servers'],
'unix_servers': ['server_a', 'server_b', 'server_group'],
'windows_servers': ['server_c', 'server_d'],
'server_group': ['server_e', 'server_f']
}

hosts = {
'server_a': '10.0.0.1',
'server_b': '10.0.0.2',
'server_c': '10.0.0.3',
'server_d': '10.0.0.4',
'server_e': '10.0.0.5',
'server_f': '10.0.0.6'
}

print(resolve(groups, hosts))

和预想的一致:

1
{'servers': {'unix_servers': {'server_a': '10.0.0.1', 'server_b': '10.0.0.2', 'server_group': {'server_e': '10.0.0.5', 'server_f': '10.0.0.6'}}, 'windows_servers': {'server_c': '10.0.0.3', 'server_d': '10.0.0.4'}}}

不过需要注意的是,若group A中包含group B,且group B也包含group A的话可能会出现无限递归

另一个更简单直观的代码

1
2
3
4
def build(val): 
return {i:build(i) for i in groups[val]} if val in groups else hosts[val]

{"servers":build('servers')}