Python网络关系可视化工具Networkx

Tags: Python, Networkx, matplotlib

Networkx是用Python语言开发的图论与复杂网络建模工具,内置了常用的图与复杂网络分析算法,可以方便的进行复杂网络数据分析、仿真建模等工作。

本文主要实现用networkx画有向图,检测是否有回环,每个节点的前节点、后节点。这里已经封装好了相关的实现类。

# -*- coding:utf-8 -*-

import networkx as nx
import matplotlib.pyplot as plt
import copy
from networkx.algorithms.cycles import *

class GetGraph:

    def __init__(self):
        pass

    @staticmethod
    def create_directed_graph(data_dict):
        my_graph = nx.DiGraph()
        my_graph.clear()
        for front_node, back_node_list in data_dict.items():
            if back_node_list:
                for back_node in back_node_list:
                    my_graph.add_edge(front_node, back_node)
            else:
                my_graph.add_node(front_node)
        return my_graph

    @staticmethod
    def draw_directed_graph(my_graph, name='out'):
        nx.draw_networkx(my_graph, pos=nx.circular_layout(my_graph), vmin=10,
                         vmax=20, width=2, font_size=8, edge_color='black')
        picture_name = name + ".png"
        plt.savefig(picture_name)
        # print('save success: ', picture_name)
        # plt.show()

    @staticmethod
    def get_next_node(my_graph):
        nodes = my_graph.nodes
        next_node_dict = {}
        for n in nodes:
            value_list = list(my_graph.successors(n))
            next_node_dict[n] = value_list
        return copy.deepcopy(next_node_dict)

    @staticmethod
    def get_front_node(my_graph):
        nodes = my_graph.nodes
        front_node_dict = {}
        for n in nodes:
            value_list = list(my_graph.predecessors(n))
            front_node_dict[n] = value_list
        return copy.deepcopy(front_node_dict)

    @staticmethod
    def get_loop_node(my_graph):
        loop = (list(simple_cycles(my_graph)))
        return copy.deepcopy(loop)

if __name__ == '__main__':
    comp_graph_object = GetGraph()
    comp_statement = {'CT_UPDATE_POS': ['CT_STATE'], 'CT_STATE': ['CT_MOVE'], 'CT_MOVE': ['CT_STATE', 'CT_FLUSH_VISUAL'],
      'CT_VISUAL': [], 'CT_FLUSH_VISUAL': ['CT_MOVE'], 'CT_INPUT': []}

    print('self.comp_statement_ct_map:', comp_statement)
    graph = comp_graph_object.create_directed_graph(comp_statement)
    comp_next_node = comp_graph_object.get_next_node(graph)
    comp_front_node = comp_graph_object.get_front_node(graph)
    comp_loop_list = comp_graph_object.get_loop_node(graph)
    comp_graph_object.draw_directed_graph(graph)

如果有警告提示:

FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  if self._edgecolors == str('face'):

无需理会。效果图为:

12522897-08b43b83ea0e76e0

注意点

节点的位置排列官方给了几种办法,选用合适的即可。

bipartite_layout(G, nodes[, align, scale, …]) Position nodes in two straight lines.
circular_layout(G[, scale, center, dim]) Position nodes on a circle.
kamada_kawai_layout(G[, dist, pos, weight, …]) Position nodes using Kamada-Kawai path-length cost-function.
planar_layout(G[, scale, center, dim]) Position nodes without edge intersections.
random_layout(G[, center, dim, seed]) Position nodes uniformly at random in the unit square.
rescale_layout(pos[, scale]) Returns scaled position array to (-scale, scale) in all axes.
shell_layout(G[, nlist, scale, center, dim]) Position nodes in concentric circles.
spring_layout(G[, k, pos, fixed, …]) Position nodes using Fruchterman-Reingold force-directed algorithm.
spectral_layout(G[, weight, scale, center, dim]) Position nodes using the eigenvectors of the graph Laplacian.
spiral_layout(G[, scale, center, dim, …]) Position nodes in a spiral layout.

本文这里使用的是将节点画在了同心圆上,这样节点之间交叉较少。

查阅源码可知,

nx.circular_layout(my_graph)

虽然有设置高维(3维、4维...)的参数,但是还是只能实现画在2维平面。思路主要是将1等距离平分n份,然后变换成2π的角度,圆半径已知,求得圆周上各节点之间的位置。

但是源码的所有节点只能分布在同一圆上,如果节点很多,便不再适用,因此本文在此基础上改为分布在同心圆上。

import  matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt

def get_node_pos(node_list, radius=1, step=1, step_num=8, center=(0, 0), dim=2):
    if dim < 2:
        raise ValueError('cannot handle dimensions < 2')
    paddims = max(0, (dim - 2))
    odd_all_num = len(node_list)
    node_pos_list = []
    while odd_all_num > 0:
        cur_lever_num = radius * step_num
        if odd_all_num < cur_lever_num:
            cur_lever_num = odd_all_num
        odd_all_num -= cur_lever_num

        theta = np.linspace(0, 1, cur_lever_num + 1)[:-1] * 2 * np.pi
        theta = theta.astype(np.float32)
        pos = np.column_stack([np.cos(theta) * radius, np.sin(theta) * radius,
                               np.zeros((cur_lever_num, paddims))])
        pos = pos.tolist()
        node_pos_list.extend(pos)
        radius += 1
    all_pos = dict(zip(node_list, node_pos_list))
    return all_pos

if __name__ == '__main__':
    node =range(1,30,1)
    print('node:', node)
    pos = get_node_pos(node)

    # fig, ax = plt.subplots(figsize=(10,10))
    for name, pos in pos.items():
        plt.scatter(pos[0], pos[1])

思路主要是:圆的半径是radius ,圆上最大节点数为step_num,当节点数超过step_num, 求出新圆的半径radius+=step,该圆的最大节点数为radius*step_num。也就是说,圆的半径和最大节点数成正比。默认参数为,从里到外圆上最大节点数依次为8个、16个、8n个。

这里用matplotlib来显示效果:

12522897-387dd540afab19f8

但是发现,明明等分的圆,这些点却明显是椭圆。纠结了好久,才发现原来matplotlib默认得到的宽、高是不等距离的。因此设置图像的宽和高相等即可。

fig, ax = plt.subplots(figsize=(10,10))

最终效果图为:

12522897-404806d5b3286c29

可以看出,我们将节点等分在了同心圆上。

参考