Newman Fast 算法Python实现

rick    2018-11-19 22:09

 

 

之前写的newman fast算法有错误,更正过来了并上传一下代码,希望能有帮助,如果有错误也请指出。 算法原理: 先将每个节点看作一个社区,然后选择模块度增值最大的进行合并,直到所有社团变成一个社团为止。

所需包安装

pip install networkx==1.11

# networkx目前版本2.0,遍历邻居函数G.neighbors(node)返回不再是一个list,读者可自行调整

import networkx as nx
# newman快速算法


# 合并社团函数
def cluAssemble(self, other, currentCluList):
    currentCluList.remove(self)
    currentCluList.remove(other)
    for node in self:
        if node not in other:
            other.append(node)
    currentCluList.append(other)
    cluAfterAssemble = currentCluList
    return cluAfterAssemble


# 判断两个社团之间是否有边相连
def cluHasEdge(clu1, clu2, graph):
    for p1 in clu1:
        for p2 in clu2:
            if graph.has_edge(p1, p2):
                return True
    return False


# 从path中读取图
def load_graph(path):
    G = nx.Graph()
    with open(path) as text:
        for line in text:
            vertices = line.strip().split(" ")
            sourcePoint = int(vertices[0])
            targetPoint = int(vertices[1])
            G.add_edge(sourcePoint, targetPoint)
    return G


# 计时器函数,该函数是一个装饰器decorator
def fn_timer(function):
    @wraps(function)
    def function_timer(*args, **kwargs):
        t0 = time.time()
        result = function(*args, **kwargs)
        t1 = time.time()
        print("time:%s s" % (str(t1 - t0)))
        return result
    return function_timer
    
    
# 写入所属社团标号
def writeClu(G, currentCluList):
    cluID = 0
    for item in currentCluList:
        cluID += 1
        for item2 in item:
            G.node[item2]['groupID'] = str(cluID)
    # print(G.nodes(data=True))
    
    
# Q = 社团内部边数/网络总边数 - (社团内所有点度数之和/2倍总边数)平方
def cal_Q(partition, G):
    m = len(G.edges())
    a = 0.0
    e = 0.0

    # 计算eii
    for community in partition:
        eii = 0.0
        for i in community:
            for j in community:
                if G.has_edge(i, j):
                    eii += 1
        e += eii / (2 * m)
    # 计算aij的平方
    for community in partition:
        aij = 0.0
        for node in community:
            aij += len(G.neighbors(node))
        a += (aij / (2 * m)) ** 2

    q = e - a
    return q


# 主类
class newmanFast:
    def __init__(self, graph):
        self.G = graph
        self.nodeList = self.G.nodes()
        self.cluList = []
        for i in self.nodeList:
            self.cluList.append([i])
        self.finalClu = []

    @fn_timer
    def execute(self):
        iterTime = 1
        # 先将每个节点看作一个社区,然后选择模块度增值最大的进行合并,直到所有社团变成一个社团为止
        print(self.cluList)
        maxQ = -float('Inf')
        # 只要社团列表长度不为1
        while len(self.cluList) is not 1:
            Q = -float('Inf')
            for cluFrom in self.cluList:
                thisClu = self.cluList.copy()
                thisClu.remove(cluFrom)
                for cluTo in thisClu:
                    if cluHasEdge(cluFrom, cluTo, self.G):
                        partition = cluAssemble(cluFrom.copy(), cluTo.copy(), self.cluList.copy())
                        thisQ = cal_Q(partition, self.G)
                        # print("社团" + str(cluTo) + "q值" + str(thisQ))
                        # 记录该轮最大Q和目标社团
                        if thisQ >= Q:
                            Q = thisQ
                            finalCluFrom = cluFrom
                            finalCluTo = cluTo
            cluAssemble(finalCluFrom.copy(), finalCluTo.copy(), self.cluList)
            print("该轮结束的划分结果:" + str(self.cluList))
            print("该轮结束时的模块度Q:" + str(Q))
            print("################第" + str(iterTime) + "轮结束##################")
            file_object = open('resver2/%sres.txt' % str(graphName), 'a')
            file_object.write(str(iterTime) + '|Q:' + str(Q) + '|Result:' + str(self.cluList) + '\n')
            file_object.close()

            iterTime += 1
            if Q > maxQ:
                maxQ = Q
                iterRound = iterTime - 1
                self.finalClu = self.cluList.copy()
                writeClu(self.G, self.cluList)

        print("最大Q值出现在:第" + str(iterRound) + "轮。最大Q值为:" + str(maxQ))
        for clu in self.finalClu:
            print(sorted(clu))


if __name__ == "__main__":
    graphNameClu = ['club', 'dolphin', 'football', 'power', 'sciencenet', 'test', 'test2', 'test3', 'dept3', 'facebook']
    graphName = input("请输入数据集名('club', 'dolphin', 'football', 'power', 'sciencenet', 'test', 'test2', 'test3', "
                      "'dept3', 'facebook'):")
    if graphName not in graphNameClu:
        print("error")
    else:
        G = load_graph('your/path/to/dataset/%s.txt' % str(graphName))
        this = newmanFast(G)
        this.execute()
        nx.write_gml(G, '%sres.gml' % str(graphName))

 

Views: 5.3K

[[total]] comments

Post your comment
  1. [[item.time]]
    [[item.user.username]] [[item.floor]]Floor
  2. Click to load more...
  3. Post your comment