diff --git a/graphs/page_rank.py b/graphs/page_rank.py index c0ce3a94c76b..7452ad8df2db 100644 --- a/graphs/page_rank.py +++ b/graphs/page_rank.py @@ -31,22 +31,35 @@ def __repr__(self): return f"" -def page_rank(nodes, limit=3, d=0.85): - ranks = {} - for node in nodes: - ranks[node.name] = 1 +def page_rank(nodes, limit=None, d=0.85, tol=1e-8, max_iter=100): + if not nodes: + return {} + + if limit is not None: + max_iter = limit + + n = len(nodes) + ranks = {node.name: 1.0 / n for node in nodes} + outbounds = {node.name: len(node.outbound) for node in nodes} + + for _ in range(max_iter): + new_ranks = {} + dangling_sum = sum( + ranks[node.name] for node in nodes if outbounds[node.name] == 0 + ) + + for node in nodes: + inbound_rank = sum( + ranks[inbound_node] / outbounds[inbound_node] + for inbound_node in node.inbound + ) + new_ranks[node.name] = (1 - d) / n + d * (inbound_rank + dangling_sum / n) - outbounds = {} - for node in nodes: - outbounds[node.name] = len(node.outbound) + if sum(abs(new_ranks[name] - ranks[name]) for name in ranks) < tol: + return new_ranks + ranks = new_ranks - for i in range(limit): - print(f"======= Iteration {i + 1} =======") - for _, node in enumerate(nodes): - ranks[node.name] = (1 - d) + d * sum( - ranks[ib] / outbounds[ib] for ib in node.inbound - ) - print(ranks) + return ranks def main(): @@ -64,7 +77,8 @@ def main(): for node in nodes: print(node) - page_rank(nodes) + print("======= Page Rank =======") + print(page_rank(nodes)) if __name__ == "__main__": diff --git a/graphs/tests/test_page_rank.py b/graphs/tests/test_page_rank.py new file mode 100644 index 000000000000..174df8738d90 --- /dev/null +++ b/graphs/tests/test_page_rank.py @@ -0,0 +1,32 @@ +import math + +from graphs.page_rank import Node, page_rank + + +def add_edge(nodes, source, destination): + nodes[destination].add_inbound(nodes[source].name) + nodes[source].add_outbound(nodes[destination].name) + + +def test_page_rank_scores_are_normalized(): + nodes = [Node("A"), Node("B"), Node("C")] + add_edge(nodes, 0, 1) + add_edge(nodes, 0, 2) + add_edge(nodes, 1, 2) + add_edge(nodes, 2, 0) + + ranks = page_rank(nodes, max_iter=100) + + assert math.isclose(sum(ranks.values()), 1.0, abs_tol=1e-8) + + +def test_page_rank_handles_dangling_nodes(): + nodes = [Node("A"), Node("B"), Node("C")] + add_edge(nodes, 0, 1) + add_edge(nodes, 1, 2) + + ranks = page_rank(nodes, max_iter=100) + + assert math.isclose(sum(ranks.values()), 1.0, abs_tol=1e-8) + assert math.isclose(ranks["C"], 0.474412, rel_tol=1e-5) + assert ranks["C"] > ranks["B"] > ranks["A"]