Skip to content

Commit 90639c4

Browse files
Add Dijkstra algorithm implementation and tests
1 parent 1ee7c81 commit 90639c4

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

graph/dijkstra.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Dict, List, Tuple
2+
import heapq
3+
4+
def dijkstra(graph: Dict[int, List[Tuple[int, int]]], start: int) -> Dict[int, int]:
5+
"""
6+
Implements Dijkstra's algorithm for finding the shortest path in a graph.
7+
8+
Args:
9+
graph (Dict[int, List[Tuple[int, int]]]): A dictionary representing the graph.
10+
Keys are nodes, values are lists of (neighbor, weight) tuples.
11+
start (int): The starting node.
12+
13+
Returns:
14+
Dict[int, int]: A dictionary with nodes as keys and shortest distances from start as values.
15+
"""
16+
distances = {node: float('infinity') for node in graph}
17+
distances[start] = 0
18+
pq = [(0, start)]
19+
20+
while pq:
21+
current_distance, current_node = heapq.heappop(pq)
22+
23+
if current_distance > distances[current_node]:
24+
continue
25+
26+
for neighbor, weight in graph[current_node]:
27+
distance = current_distance + weight
28+
if distance < distances[neighbor]:
29+
distances[neighbor] = distance
30+
heapq.heappush(pq, (distance, neighbor))
31+
32+
return distances
33+
34+
# Example usage
35+
if __name__ == "__main__":
36+
# Example graph
37+
graph = {
38+
0: [(1, 4), (2, 1)],
39+
1: [(3, 1)],
40+
2: [(1, 2), (3, 5)],
41+
3: [(4, 3)],
42+
4: []
43+
}
44+
45+
start_node = 0
46+
shortest_paths = dijkstra(graph, start_node)
47+
print(f"Shortest paths from node {start_node}:")
48+
for node, distance in shortest_paths.items():
49+
print(f"To node {node}: {distance}")

graph/test_dijkstra.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
from dijkstra import dijkstra
3+
4+
class TestDijkstra(unittest.TestCase):
5+
def test_simple_graph(self):
6+
graph = {
7+
0: [(1, 4), (2, 1)],
8+
1: [(3, 1)],
9+
2: [(1, 2), (3, 5)],
10+
3: [(4, 3)],
11+
4: []
12+
}
13+
start_node = 0
14+
expected = {0: 0, 1: 3, 2: 1, 3: 4, 4: 7}
15+
self.assertEqual(dijkstra(graph, start_node), expected)
16+
17+
def test_disconnected_graph(self):
18+
graph = {
19+
0: [(1, 1)],
20+
1: [(0, 1)],
21+
2: [(3, 1)],
22+
3: [(2, 1)]
23+
}
24+
start_node = 0
25+
expected = {0: 0, 1: 1, 2: float('infinity'), 3: float('infinity')}
26+
self.assertEqual(dijkstra(graph, start_node), expected)
27+
28+
def test_single_node_graph(self):
29+
graph = {0: []}
30+
start_node = 0
31+
expected = {0: 0}
32+
self.assertEqual(dijkstra(graph, start_node), expected)
33+
34+
def test_complex_graph(self):
35+
graph = {
36+
0: [(1, 4), (2, 2)],
37+
1: [(2, 1), (3, 5)],
38+
2: [(3, 8), (4, 10)],
39+
3: [(4, 2), (5, 6)],
40+
4: [(5, 3)],
41+
5: []
42+
}
43+
start_node = 0
44+
expected = {0: 0, 1: 4, 2: 2, 3: 9, 4: 11, 5: 14}
45+
self.assertEqual(dijkstra(graph, start_node), expected)
46+
47+
def test_start_node_not_in_graph(self):
48+
graph = {0: [(1, 1)], 1: [(0, 1)]}
49+
start_node = 2
50+
with self.assertRaises(KeyError):
51+
dijkstra(graph, start_node)
52+
53+
if __name__ == '__main__':
54+
unittest.main()

0 commit comments

Comments
 (0)