[백준] 1517 – 버블 소트

문제

#1517: 버블 정렬(acmicpc.net)

#1517: 버블 정렬

첫 번째 줄은 N(1 ≤ N ≤ 500,000)을 지정합니다. 다음 행에는 A(N)이 주어졌을 때 N개의 정수 A(1), A(2), …가 포함됩니다. 모든 A(i) 0 ≤ |A(i)| ≤ 1,000,000,000.

www.acmicpc.net

설명

아래 코드는 제가 먼저 보낸 코드인데 시간복잡도가 O(N^{2})여서 만료되었습니다.

from sys import stdin
input = lambda : stdin.readline().strip()

N = int(input())
A = list(map(int, input().split()))

count = 0
for i in range(N) :
    for j in range(N-1) :
        if A(j) > A(j+1) :
            A(j), A(j+1) = A(j+1), A(j)
            count += 1

print(count)

해결책을 찾지 못해 다른 사람들이 풀었던 코드를 살펴보니 버블 정렬 대신 병합 정렬을 사용했습니다.

병합, 정렬

이 문제를 시간 초과하지 않으려면 버블 정렬 대신 병합 정렬을 사용하는 것이 좋습니다.

버블정렬 과정을 보면 왼쪽에서 오른쪽으로 정렬을 하고 생각해보고 머지정렬 과정에서 오른쪽 숫자가 왼쪽으로 몇 자리 이동하는지 계산해보면 답을 찾을 수 있을 것이다.

예제에서 (3, 2, 1) 정렬을 병합해 봅시다.

먼저 merge_sort 함수에 대한 인수로 첫 번째 인덱스(0)와 마지막 인덱스(N – 1)를 입력합니다.

이것을 (3), (2, 1)(가운데)로 나누고 나눌 수 있으면 다시 나눕니다. (재귀 merge_sort)

분할 배열에서 첫 번째 인덱스를 front_index로, 마지막 인덱스를 back_index로 정의한 다음 new_A라는 배열을 만들어 새 배열을 저장합니다.

이제 front_index가 mid보다 작거나 같고 back_index가 end보다 작거나 같을 때까지 반복합니다.

A의 front_index가 back_index보다 작으면 정렬할 필요가 없으므로 new_A에 바로 front_index를 추가하고 개수를 증가시키지 않습니다.

위의 경우가 아닌 경우 더 작은 값이 뒤에 있으므로 카운터를 1씩 증가시킵니다.

from sys import stdin
input = lambda : stdin.readline().strip()

def merge_sort(start, end):
    global count

    if start < end :
        mid = (start + end) // 2
        merge_sort(start, mid)
        merge_sort(mid + 1, end)

        front_index = start
        back_index = mid + 1
        new_A = ()

        while front_index <= mid and back_index <= end :
            if A(front_index) <= A(back_index) :
                new_A.append(A(front_index))
                front_index += 1
            else :
                new_A.append(A(back_index))
                back_index += 1
                count += mid - front_index + 1

        if front_index <= mid :
            new_A += A(front_index : mid + 1)
        if back_index <= end :
            new_A += A(back_index : end + 1)

        for i in range(len(new_A)) :
            A(start + i) = new_A(i)

count = 0
N = int(input())
A = list(map(int, input().split()))

merge_sort(0, N - 1)
print(count)


즉, Python 3과 PyPy3 사이의 시간 차이는 중요합니다.

제출용으로 PyPy3을 권장합니다.