在Stackoverflow上看到了这个问题:Identifying root parents and all their children in trees——如何在树中识别根节点和所有的子节点

问题

如下,有这样的一个树:(原问题是一个pandas dataframe)

1
2
3
4
5
6
7
parent   child   parent_level   child_level
A B 0 1
B C 1 2
B D 1 2
X Y 0 2
X D 0 2
Y Z 2 3

它表示的树是这样的:

1
2
3
4
5
6
7
    A  X
/ / \
B / \
/\ / \
C D Y
|
Z

如何将这个树表示为:

1
2
3
root    children
A [B,C,D]
X [D,Y,Z]

解决办法

方便起见,我们把上面的数据表示为一个数组

1
2
3
4
5
6
7
# parent   child   parent_level   child_level
data = [['A', 'B', 0, 1],
['B', 'C', 1, 2],
['B', 'D', 1, 2],
['X', 'Y', 0, 2],
['X', 'D', 0, 2],
['Y', 'Z', 2, 3]]

递归

一个简单的思路就是我们先找到所有子节点的父节点,可以用一个字典来表示,因此,data可以表示为{'B': {'A'}, 'C': {'B'}, 'D': {'B', 'X'}, 'Y': {'X'}, 'Z': {'Y'}}

然后再通过递归的方式,查看每个子节点的父节点是否在这个字典里

  • 若在,说明该父节点还有父节点,则继续递归查找
  • 若不再,说明该父节点已经是根节点,结束

Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
data = [['A', 'B', 0, 1],
['B', 'C', 1, 2],
['B', 'D', 1, 2],
['X', 'Y', 0, 2],
['X', 'D', 0, 2],
['Y', 'Z', 2, 3]]

def process(data):
tree = {}
for l in data:
parent, child = l[0], l[1]
tree.setdefault(child, set()).add(parent)
# tree: {'B': {'A'}, 'C': {'B'}, 'D': {'B', 'X'}, 'Y': {'X'}, 'Z': {'Y'}}
descendents = {}
for child in tree:
for parent in find_root(tree, child):
descendents.setdefault(parent, set()).add(child)
return descendents

def find_root(tree, child):
# 返回根节点列表,注意可能会有多个根节点
if child in tree:
return {p for parent in tree[child] for p in find_root(tree, parent)}
return {child}

if __name__ == "__main__":
print(process(data))

# {'A': {'B', 'C', 'D'}, 'X': {'Y', 'Z', 'D'}}

当然,我们也可以把find_root改写为生成器

1
2
3
4
5
6
def find_root(tree, child):
if child in tree:
for x in tree[child]:
yield from find_root(tree, x)
else:
yield child

也可以把递归改为栈,来避免递归深度的问题,可以使用 “stack of iterators” pattern

1
2
3
4
5
6
7
8
9
10
11
12
def find_root(tree, child):
stack = [iter([child])]
while stack:
for node in stack[-1]:
if node in tree:
stack.append(iter(tree[node]))
else:
yield node
break
# yes! that is an `else` clause on a for loop
else:
stack.pop()

关于for…else见此

利用第三方库:networkx

因为这是一个图问题,所以也可以使用networkx来解决这个问题,特别是 descendants(G, source)函数,可以返回有向无环图GG中的所有可以达到sourcesource的节点,对于这个问题,就是可以获得所有可以到达根节点的子节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import networkx as nx
import pandas as pd

data = [['A', 'B', 0, 1],
['B', 'C', 1, 2],
['B', 'D', 1, 2],
['X', 'Y', 0, 2],
['X', 'D', 0, 2],
['Y', 'Z', 2, 3]]

df = pd.DataFrame(data=data, columns=['parent', 'child', 'parent_level', 'child_level'])

roots = df.parent[df.parent_level.eq(0)].unique()
dg = nx.from_pandas_edgelist(df, source='parent', target='child', create_using=nx.DiGraph)

result = pd.DataFrame(data=[[root, nx.descendants(dg, root)] for root in roots], columns=['root', 'children'])
print(result)

Reference

Identifying root parents and all their children in trees