import networkx as nx

# load file contents from ../res/10/input
with open("../res/10/input") as f:
    lines = [[int(c) for c in line.strip()] for line in f]

# create a directed graph
G = nx.DiGraph()

starts = []
ends = []

for i, line in enumerate(lines):
    for j, c in enumerate(line):
        G.add_node((i, j), pos=(i, j))

# add all edges to the graph
for i, line in enumerate(lines):
    for j, c in enumerate(line):
        if i > 0 and lines[i-1][j] == c+1:
            G.add_edge((i, j), (i-1, j), weight=c)
        if j > 0 and lines[i][j-1] == c+1:
            G.add_edge((i, j), (i, j-1), weight=c)
        if i < len(lines)-1 and lines[i+1][j] == c+1:
            G.add_edge((i, j), (i+1, j), weight=c)
        if j < len(line)-1 and lines[i][j+1] == c+1:
            G.add_edge((i, j), (i, j+1), weight=c)
        if c == 0:
            starts.append((i, j))
        if c == 9:
            ends.append((i, j))



#pos = nx.get_node_attributes(G, 'pos')
#nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=16)
#plt.show()

# check if ends are reachable for starts
result = 0
for start in starts:
    for end in ends:
        if nx.has_path(G, start, end):
            result += 1

print("Result 1:",result)

result = 0
for start in starts:
    for end in ends:
        result += len(list(nx.all_simple_paths(G, start, end)))

print("Result 2:",result)