본문 바로가기
백준(Python) 풀이

백준 1238번. 파티 (Python / 파이썬)

by yewonnie 2022. 4. 6.

문제

N개의 숫자로 구분된 각각의 마을에 한 명의 학생이 살고 있다.
어느 날 이 N명의 학생이 X (1 ≤ X ≤ N)번 마을에 모여서 파티를 벌이기로 했다. 이 마을 사이에는 총 M개의 단방향 도로들이 있고 i번째 길을 지나는데 Ti(1 ≤ Ti ≤ 100)의 시간을 소비한다.
각각의 학생들은 파티에 참석하기 위해 걸어가서 다시 그들의 마을로 돌아와야 한다. 하지만 이 학생들은 워낙 게을러서 최단 시간에 오고 가기를 원한다.
이 도로들은 단방향이기 때문에 아마 그들이 오고 가는 길이 다를지도 모른다. N명의 학생들 중 오고 가는데 가장 많은 시간을 소비하는 학생은 누구일지 구하여라.

입력

첫째 줄에 N(1 ≤ N ≤ 1,000), M(1 ≤ M ≤ 10,000), X가 공백으로 구분되어 입력된다. 두 번째 줄부터 M+1번째 줄까지 i번째 도로의 시작점, 끝점, 그리고 이 도로를 지나는데 필요한 소요시간 Ti가 들어온다. 시작점과 끝점이 같은 도로는 없으며, 시작점과 한 도시 A에서 다른 도시 B로 가는 도로의 개수는 최대 1개이다.
모든 학생들은 집에서 X에 갈수 있고, X에서 집으로 돌아올 수 있는 데이터만 입력으로 주어진다.

출력

첫 번째 줄에 N명의 학생들 중 오고 가는데 가장 오래 걸리는 학생의 소요시간을 출력한다.

문제 풀이

파티 문제는 파티가 열리는 X번 마을까지 갔다가, 다시 되돌아오는 최소비용 경로에서

가장 많은 시간을 소비하는 학생을 구하는 문제입니다.

특정 마을에서 X까지, X마을에서 다시 특정 마을까지 되돌아가는 최소비용을 각각 구해

그 중 최대 비용을 더해주면 됩니다.

따라서 이 문제는 최소 비용을 구하는 Dijkstra 알고리즘을 이용하면 됩니다.


My Code

import heapq
import sys
input = sys.stdin.readline
INF = int(1e9)

n, m, x = map(int,input().split()) # 마을 수, 도로 수, 파티가 열리는 마을

graph = [[] for _ in range(n + 1)]
for _ in range(m):
    a, b, cost = map(int,input().split()) # 시작점, 끝점, 소요시간
    graph[a].append((b, cost))  # a에서 b까지 가는 비용이 cost라는 뜻

# dijkstra Algorithm
def dijkstra(start):
    distance = [INF] * (n + 1) # 최소 비용을 저장할 list
    q = []
    heapq.heappush(q, (0, start)) # 비용과 노드를 heap에 저장
    distance[start] = 0 # 출발 마을의 비용은 0
    while q: # 큐가 빌 때까지 반복
        dist, now = heapq.heappop(q) 
        if distance[now] < dist: # 이미 최소비용이라면 continue
            continue
        for i in graph[now]:
            cost = dist + i[1]   
            if cost < distance[i[0]]: # 계산한 비용이 더 작다면 갱신
                distance[i[0]] = cost
                heapq.heappush(q, (cost, i[0])) # heap에 삽입
    return distance

max_value = -1e9
for i in range(1, n + 1):
    dist1 = dijkstra(i)  # 각 마을에서 출발 했을 때 distance를 저장
    dist2 = dijkstra(x)  # x에서 출발 했을 때 distance를 저장

    # 각 마을에서 출발 했을 때 x까지의 최소비용 + x에서 출발 했을 때 각 마을 까지의 최소 비용
    # 더한 값중 max 값을 저장
    max_value = max(max_value, dist1[x] + dist2[i])

print(max_value)

댓글