출처 : 종만북 - 07 분할정복 예제 : 병합 정렬과 퀵 정렬

코드설명

주어진 수열을 크기순서대로 정렬하는 문제 중, 분할정복 패러다임을 사용하는 병합 정렬(Merge Sort)와 퀵 정렬(Quick Sort)에 대해 알아봅니다.

 

병합 정렬 알고리즘은 주어진 수열을 가운데에서 쪼개 비슷한 크기의 수열 두개로 만든 뒤 이들을 재귀 호출을 이용해 각각 정렬합니다. 그후 정렬된 배열을 하나로 합침으로써 정렬된 전체 수열을 얻습니다.

즉, 각 수열의 크기가 1이 될떄까지 절반씩 쪼개 나간 뒤, 정렬된 부분 배열들을 합쳐 나가는 것을 볼 수 있습니다. 병합 정렬의 분할 방식은 아주 단순하고 효율적입니다. 주어진 배열을 가운데서 절반으로 그냥 나누는 것 입니다. 따라서 이 과정을 상수 시간인 O(1)만에 수행할 수 있습니다. 그러나 각각 나눠서 정렬한 배열들을 하나의 배열로 합치기 위해 별도의 병합 과정을 실행해야 합니다. 여기에 O(n)의 시간이 걸립니다. 

 

반대로, 퀵 정렬 알고리즘은 배열을 단순하게 가운데에서 쪼개는 대신, 병합과정이 필요없도록 한쪽의 배열에 포함된 수가 다른쪽 배열의 수보다 항상 작도록 배열을 분할합니다.

이를 위해 퀵 정렬은 파티션(parition)이라고 부르는 단계를 도입하는데, 이는 배열에 있는 수 중 임의의 '기준 수(pivot)'를 지정한 후 기준보다 작거나 같은 숫자를 왼쪽, 더 큰 숫자를 오른쪽으로 보내는 과정입니다.

퀵 정렬은 각 부분 수열의 맨 처음에 있는 수를 기준으로 삼고, 이들보다 작은 수를 왼쪽으로, 큰 것을 오른쪽으로 가게끔 문제를 분해합니다. 이렇게 이전 수를 분할하는데 사용된 기준을 'pivot'이라하며 이 과정이 파티션(parition) 단계입니다. 이 분할은 O(n)의 시간이 걸리는 복잡한 작업인데다, 우리가 어떤 기준을 선택하느냐에 따라서 (27, 9, 3, 10)이 (9, 3, 10)과 (27)로 나누어지는 것처럼 비효율적인 분할이 발생할 수 있지만(퀵 정렬의 동작 효율을 위해서 항상 분할이 정확히 절반으로 이루어져야 합니다. 재귀함수의 호출 횟수가 최소한(lg N)으로 발생하기 위함입니다), 그 과정에서 이미 부분배열이 이미 정렬한 상태가 되어 별도의 병합작업이 필요없습니다.

 

이 두 알고리즘은 같은 아이디어로 정렬을 수행하지만 시간이 많이 걸리는 작업을 분할단계(퀵정렬)에서 하느냐, 병합단계에서 하느냐(병합정렬)가 다릅니다. 

 

 

병합정렬을 구현하다보면, 가끔 인터넷에 array[]에 데이터를 중첩으로 정렬하는 것처럼 보이는 경우가 있습니다.

사실, 합병하는 마지막 레벨까지 해당 array[]는 필요없습니다.

모든 병합이 이루어지고, 합병까지 마지막 한번 빼고 모두 이루어지고, 마지막 한번의 합병 때 array[]에 데이터 값이 들어가는 경우가 정렬된 경우입니다.

즉, left와 right의 길이 합이 array.length와 같아야합니다.

이러한 이유는, left와 right는 새로 합병될때 피차일반 다시 정렬되어야 하기 때문입니다.

해당 사항을 알게된다면, 코드이해가 훨씬 쉬워집니다.

 

입력예시 1

8
8 2 6 4 7 3 9 5

결과예시 1

2 3 4 5 6 7 8 9

 

입력예시 2

7
38 27 43 9 3 82 10

결과예시 2

3 9 10 27 38 43 82

코드

병합정렬을 구현한 경우입니다. (void 타입으로 구현하다보면, array[]에 데이터를 중첩으로 정렬하는것으로 보입니다. void타입의 경우에선 사실, 합병하는 마지막 레벨까지 해당 array[]는 필요없습니다. left는 참조로 들어가서 곧바로 변환되기에 상관 없습니다. 아래의 코드에서 모든 병합이 이루어지고, 합병까지 마지막 한번 빼고 모두 이루어지고, 마지막 한번의 합병 때 array[]에 데이터 값이 들어가는 경우가 정답이 됩니다. 즉, 그 마지막 합병 외에는 단순히 mergeSort는 의미없이 실행됩니다.)

하지만, 반환타입을 void가 아닌 int[]로 하게된다면, 정렬된 배열값이 left가 되고 right값이 반환한 정렬값을 반환하며 훨씬 직관적으로 작성됩니다. 아래에 해당 사항의 경우도 작성했습니다.

package algorhythm;
 
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
 
public class Main {
	public static int C, N, M;
	public static int INF = 0;
	public static int[] input;
	public static int answer = 0;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		N = Integer.parseInt(st.nextToken());
		st = new StringTokenizer(br.readLine());
		input = new int[N];
		for(int i=0;i<N;i++) {
			input[i] = Integer.parseInt(st.nextToken());
		}
		mergeSort(input);
		for(int a : input) {
			System.out.print(" "+a);
		}
		
	}
	
	public static void mergeSort(int[] array) {
		if(array.length == 1) {
			return ;
		}
		
		int mid = array.length / 2;
		int[] left = new int[mid];
		int[] right = new int[array.length - mid];
		
		for(int i=0;i<mid;i++) {
			left[i] = array[i];
		}
		for(int i=mid; i<array.length;i++) {
			right[i - mid] = array[i];
		}
		
		mergeSort(left);
		mergeSort(right);
		
		merge(array, left, right);
	}
	
	public static void merge(int[] array, int[] left, int[] right) {
		int leftIdx = 0, rightIdx = 0, arrayIdx = 0;
		while( leftIdx < left.length && rightIdx < right.length) {
			if(left[leftIdx] <= right[rightIdx]) {
				array[arrayIdx++] = left[leftIdx++];
			} else if(left[leftIdx] > right[rightIdx]) {
				array[arrayIdx++] = right[rightIdx++];
			}
		}
		
		while(leftIdx < left.length) {
			array[arrayIdx++] = left[leftIdx++];
		}
		while(rightIdx < right.length) {
			array[arrayIdx++] = right[rightIdx++];
		}
		
	}
}

 

합병정렬의 2번쨰 코드입니다. 반환형식을 명확하게 사용함으로써 모든 변수들이 의미있게 사용됩니다. 이 코드와 3번쨰 코드와의 차이점은 이 코드는 재귀 Tree의 레벨이 한칸 더 깊습니다. 즉, logN + 1 의 레벨을 가지고 있고, (이유는 array.length에서 만들어서 들어가기 떄문) 3번째 코드는 1일떄 새로 반환되기에 차이점이 존재합니다.

package algorhythm;
 
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
 
public class Main {
	public static int C, N, M;
	public static int INF = 0;
	public static int[] input;
	public static int answer = 0;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		N = Integer.parseInt(st.nextToken());
		st = new StringTokenizer(br.readLine());
		input = new int[N];
		for(int i=0;i<N;i++) {
			input[i] = Integer.parseInt(st.nextToken());
		}
		input = mergeSort(input);
		for(int a : input) {
			System.out.print(" "+a);
		}
		
	}
	
	public static int[] mergeSort(int[] array) {
		if(array.length == 1) {
			return array;
		}
		
		int mid = array.length / 2;
		int[] left = new int[mid];
		int[] right = new int[array.length - mid];
		
		for(int i=0;i<mid;i++) {
			left[i] = array[i];
		}
		for(int i=mid; i<array.length;i++) {
			right[i - mid] = array[i];
		}
		
		left = mergeSort(left);
		right = mergeSort(right);
		
		return merge(left, right);
	}
	
	public static int[] merge(int[] left, int[] right) {
		int[] array = new int[left.length + right.length];
		int leftIdx = 0, rightIdx = 0, arrayIdx = 0;
		while( leftIdx < left.length && rightIdx < right.length) {
			if(left[leftIdx] <= right[rightIdx]) {
				array[arrayIdx++] = left[leftIdx++];
			} else if(left[leftIdx] > right[rightIdx]) {
				array[arrayIdx++] = right[rightIdx++];
			}
		}
		
		while(leftIdx < left.length) {
			array[arrayIdx++] = left[leftIdx++];
		}
		while(rightIdx < right.length) {
			array[arrayIdx++] = right[rightIdx++];
		}
		
		return array;
	}
}

 

3번쨰 병합정렬 코드입니다. 2번쨰 병합정렬 코드와 거의 비슷하지만, lo와 hi를 통해 연산이 더욱 편해졌습니다.

구현하면서 실수했었던 점은,

mergeSort의 기저사례에서 A[hi] 새로운 배열 1의 크기를 반환하지 않았었습니다. 잘못해서 그대로 A를 반환했습니다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
 
public class Main {
	public static int C, N, K; 
	public static int answer = 0;
	public static int[] arr;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		N = Integer.parseInt(st.nextToken());
		arr = new int[N];
		st = new StringTokenizer(br.readLine());
		for(int i=0; i<N;i++) {
			arr[i] = Integer.parseInt(st.nextToken());
		}
		int[] temp = mergeSort(arr, 0, N - 1);
		for(int v : temp) {
			System.out.println(v+" ");
		}
		
	}
	
	public static int[] mergeSort(int[] A, int lo, int hi) {
		//기저사례 1 : 길이가 1개가 되었다면, 그 배열을 반환한다.
		if(lo == hi) {
			int[] K = new int[] { A[hi] };
			return K;
		}
		
		int mid = (lo + hi) / 2;
		int[] left = mergeSort(A, lo, mid);
		int[] right = mergeSort(A, mid + 1, hi);
		
		//분할들이 모두 끝난 후에 합쳐지는 과정이다.
		int[] merged = merge(left, right);
		return merged;
	}
	
	public static int[] merge(int[] A, int[] B) {
		
		int left = 0, right = 0;
		int[] mergedArr = new int[A.length + B.length];
		int mergedArrPos = 0;
		while(left < A.length && right < B.length ) {
			if(A[left] < B[right]) {
				mergedArr[mergedArrPos++] = A[left++];
			}
			else {
				mergedArr[mergedArrPos++] = B[right++];
			}
		}
		
		while(left < A.length) {
			mergedArr[mergedArrPos++] = A[left++];
		}
		
		while(right < B.length) {
			mergedArr[mergedArrPos++] = B[right++];
		}
		
		return mergedArr;
	}
	
	
}

 

퀵정렬코드입니다.

package Main;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
 
public class Main {
	public static int C, N, M;
	public static int INF = 0;
	public static int[] input;
	public static int answer = 0;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		N = Integer.parseInt(st.nextToken());
		st = new StringTokenizer(br.readLine());
		input = new int[N];
		for(int i=0;i<N;i++) {
			input[i] = Integer.parseInt(st.nextToken());
		}
		quickSort(input, 0, input.length - 1);
		for(int a : input) {
			System.out.print(" "+a);
		}
		
	}
	
	public static void quickSort(int[] A, int start, int end) {
		if(start >= end) return ;
		
		int partitionIdx = partition(A, start, end);
		quickSort(A, start, partitionIdx - 1);
		quickSort(A, partitionIdx + 1, end);
		
	}
	
	public static int partition(int[] A, int start, int end) {
		int pivot = A[start];
		int left = start + 1;
		int right = end;
		
		while(left <= right) {
			
			//왼쪽에서 오른쪽 작업, pivot값이 더 클동안만 진행
			while(left <= end && pivot >= A[left]) {
				left++;
			}
			
			//오른쪽->왼쪽 작업, pivot값이 더 작을동안 진행
			while(right >= start && pivot < A[right]) {
				right--;
			}
			
			if(left < right) {
				swap(A, left, right);
			}
		}
		swap(A, start, right);
		return right;
	}
	
	public static void swap(int[] A, int left, int right) {
		int temp = A[left];
		A[left] = A[right];
		A[right] = temp;
	}
	
}

 

퀵정렬 코드 2. 항상 이런 코드 작성할시 lo, hi로 사용할지 아니면 left 와 right로 사용할지 고민됩니다. start와 end도 포함.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
 
public class Main {
	public static int C, N, K; 
	public static int answer = 0;
	public static int[] arr;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		N = Integer.parseInt(st.nextToken());
		arr = new int[N];
		for(int i=0; i<N;i++) {
			st = new StringTokenizer(br.readLine());
			arr[i] = Integer.parseInt(st.nextToken());
		}
		
		quickSort(arr, 0, N-1);
		for(int v : arr) {
			System.out.println(v+" ");
		}
		
	}
	
	public static void quickSort(int[] A, int lo, int hi) {
		if(lo >= hi) {
			return ;
		}
		int partitionIdx = partition(A, lo, hi);
		quickSort(A, lo, partitionIdx - 1);
		quickSort(A, partitionIdx + 1, hi);
		
		return ;
	}
	
	public static int partition(int[] A, int lo, int hi) {
		int pivot = A[lo];
		int left = lo + 1;
		int right = hi;
		
		while(left <= right) {
			
			//pivot의 왼쪽에는 pivot보다 더 작은값만 냅두어야 하므로, 더 작을동안만 움직입니다.
			while(left <= hi && pivot >= A[left]) {
				left++;
			}
			
			//pivot의 오른쪽에는 pivot보다 더 큰값만 냅두어야 하므로, 더 클동안만 움직입니다.
			while(right > lo && pivot < A[right]) { //이 부분에서 >= lo여도 상관은없지만, 엄연히는 > lo로도 작동한다. (이유는 오름차순 정렬이기에 그렇다?? 좀 더 명확한 이유를 아직 표현못하겠습니다) 직접 예제 테스트 시 명확히 알 수 있을 것 입니다.
				right--;
			}
			
			if(left < right) {
				swap(A, left, right);
			}
			
		}
		swap(A, lo, right);
		return right;
	}
	
	public static void swap(int[] A, int left, int right) {
		int temp = A[left];
		A[left] = A[right];
		A[right] = temp; 
	}
	
}

+ Recent posts