우선 출처를 밝힌다.
https://programmer.group/cdq-divide-and-conquer-learning-notes.html
중국인분께서 작성하신 글이다. 예제가 친절해서 이해하기 좋았다. 아마 다른 용어로 다들 알고있겠지만.. CDQ 알고리즘이라는 이름으로 번역된 버전은 없는 것 같아서 국문으로 포스팅한다.
CDQ 알고리즘은 Chen Danqi라는 사람이 고안한 것이라고 한다. 그 분의 이름을 본따 지은 알고리즘이다.
기본적으로 분할 정복을 이용하여 오프라인 쿼리를 처리하는데 사용하는 것 같다.
1.
분할 정복에서 가장 대표적인 예시라고 하면 병합정렬(merge sort)이다.
좌구간, 우구간으로 분리하여 각각을 정렬시키고 합치는 과정이다.
좌우 구간을 합치는 과정에서 필요한 시간복잡도가 $O(n)$이라 총괄 복잡도는 $O(n \log n)$이다.
좌우 구간을 합치는 과정에서 추가적인 처리를 할 수 있다. 대표적으로 역전 개수 세기(counting inversion)가 있다.
배열 $A[1:n]$에서 역전(inversion)이란, $i < j$이면서 $A[i] > A[j]$를 만족하는 $(i, j)$ 튜플을 의미한다.
그 개수를 2중 for 문을 통하여 완전 탐색하면 $O(n^2)$의 시간복잡도가 걸린다.
하지만 좌우 구간을 분할하여 각각을 정렬시키고나서 합치면 다음과 같은 최적화가 가능하다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAX_N = 1e5 + 20;
int N, A[MAX_N];
int T[MAX_N]; // 임시 저장공간
ll CountInversion;
void merge_sort(int l, int r) {
if (l == r) return;
int m = (l + r) / 2;
merge_sort(l, m);
merge_sort(m+1, r);
int p1 = l, p2 = m+1, p = l;
while (p1 <= m && p2 <= r) {
if (A[p1] <= A[p2]) {
T[p++] = A[p1++];
} else {
CountInversion += m + 1 - p1;
T[p++] = A[p2++];
}
}
while (p1 <= m) T[p++] = A[p1++];
while (p2 <= r) T[p++] = A[p2++];
for (int i=l; i<=r; i++) A[i] = T[i];
}
int main() {
scanf("%d", &N);
for (int i=1; i<=N; i++) scanf("%d", &A[i]);
merge_sort(1, N);
printf("%lld\n", CountInversion);
}
$A$ 배열이 정렬되는 것은 그냥 side effect라고 생각하면 될 것 같다.
너무 well-known이라 간단히 코드만 첨부하고 넘어가겠다.
2.
사실 역전 개수 세기 문제는 굳이 병합정렬을 사용하지 않아도 괜찮다.
펜윅트리를 이용하여 $i = 1, 2, \dots, n$ 순서를 차례대로 읽어가면서, $A[1:i-1]$ 원소 중에 $A[i]$보다 큰 것들의 개수를 세면 되기 때문이다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define ALL(x) (x).begin(), (x).end()
#define UNIQUE(x) (x).erase(unique(ALL(x)), (x).end())
#define NTH(v, x) (lower_bound(ALL(v), (x)) - (v).begin())
const int MAX_N = 1e5 + 20;
namespace FW {
int F[MAX_N];
void add(int p, int v) {
for (; p<MAX_N; p+=p&-p) F[p] += v;
}
int sum(int p) {
int s = 0;
for (; p; p&=p-1) s += F[p];
return s;
}
int sum(int l, int r) {
return sum(r) - sum(l-1);
}
};
int N, A[MAX_N];
ll CountInversion;
int main() {
scanf("%d", &N);
vector<int> cA; // 좌표압축을 위해
for (int i=1; i<=N; i++) {
scanf("%d", &A[i]);
cA.push_back(A[i]);
}
sort(ALL(cA));
UNIQUE(cA);
for (int i=1; i<=N; i++) {
A[i] = 1 + NTH(cA, A[i]);
CountInversion += FW::sum(A[i]+1, MAX_N-1);
FW::add(A[i], 1);
}
printf("%lld\n", CountInversion);
}
3.
다음과 같은 문제를 생각해보자.
2차원 좌표 평면에 $n$개의 점 $(A[i], B[i])$가 있다. (편의상 점들의 좌표가 모두 다르다고 하자.)
$i \neq j$이면서 $A[i] \leq A[j]$, $B[i] \leq B[j]$를 만족하는 튜플 $(i, j)$의 개수를 구하여라.
$A$, $B$ 배열을 한 데 모은 다음 사전순으로 정렬한다.
왼쪽에서 오른쪽으로 원소를 차례대로 읽으면 $A$ 값이 비감소함은 보장되어있다.
나머지는 펜윅트리를 이용하면 어렵지 않다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define ALL(x) (x).begin(), (x).end()
#define UNIQUE(x) (x).erase(unique(ALL(x)), (x).end())
#define NTH(v, x) (lower_bound(ALL(v), (x)) - (v).begin())
const int MAX_N = 1e5 + 20;
namespace FW {
int F[MAX_N];
void add(int p, int v) {
for (; p<MAX_N; p+=p&-p) F[p] += v;
}
int sum(int p) {
int s = 0;
for (; p; p&=p-1) s += F[p];
return s;
}
int sum(int l, int r) {
return sum(r) - sum(l-1);
}
};
int N, A[MAX_N], B[MAX_N];
pii AB[MAX_N];
int main() {
scanf("%d", &N);
vector<int> cA, cB; // 좌표압축을 위해
for (int i=1; i<=N; i++) {
scanf("%d %d", &A[i], &B[i]);
cA.push_back(A[i]);
cB.push_back(B[i]);
}
sort(ALL(cA));
sort(ALL(cB));
UNIQUE(cA);
UNIQUE(cB);
ll cnt = 0;
for (int i=1; i<=N; i++) {
A[i] = 1 + NTH(cA, A[i]);
B[i] = 1 + NTH(cB, B[i]);
AB[i] = {A[i], B[i]};
}
sort(AB+1, AB+1+N);
for (int i=1; i<=N; i++) {
int a, b; tie(a, b) = AB[i];
cnt += FW::sum(b);
FW::add(b, 1);
}
printf("%lld\n", cnt);
}
4.
한 차원 더 높여보자.
3차원 좌표 평면에 $n$개의 점 $(A[i], B[i], C[i])$가 있다. (편의상 점들의 좌표가 모두 다르다고 하자.)
$i \neq j$이면서 $A[i] \leq A[j]$, $B[i] \leq B[j]$, $C[i] \leq C[j]$를 만족하는 튜플 $(i, j)$의 개수를 구하여라.
이를 사전순으로 정렬하면 $A$ 배열의 원소 값은 우선적으로 정렬이 된다.
분할 정복을 사용해보자. 병합정렬이 stable sort라는 점을 상기하자.
이미 $A$를 기준으로 정렬이 완료된 상태에서 병합정렬을 통해 $B$를 기준으로 정렬을 시도할 것이다.
$A$에 대한 순서는 헝클어지지만, 문제를 전략적으로 풀 수 있다.
1.에 나온 역전 개수 세기와 같은 전략을 사용하여 최적화를 시켜보자.
아래 코드를 보면 잘 이해될 것이다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef tuple<int,int,int> piii;
#define ALL(x) (x).begin(), (x).end()
#define UNIQUE(x) (x).erase(unique(ALL(x)), (x).end())
#define NTH(v, x) (lower_bound(ALL(v), (x)) - (v).begin())
const int MAX_N = 1e5 + 20;
namespace FW {
int F[MAX_N];
void add(int p, int v) {
for (; p<MAX_N; p+=p&-p) F[p] += v;
}
int sum(int p) {
int s = 0;
for (; p; p&=p-1) s += F[p];
return s;
}
int sum(int l, int r) {
return sum(r) - sum(l-1);
}
};
int N, A[MAX_N], B[MAX_N], C[MAX_N];
piii ABC[MAX_N], T[MAX_N];
ll Count;
void CDQ(int l, int r) {
// ABC[l:r] 부분배열을 B 값을 기준으로 병합정렬한다.
if (l == r) return;
int m = (l + r) / 2;
CDQ(l, m);
CDQ(m+1, r);
int p1 = l, p2 = m+1, p = l;
while (p1 <= m && p2 <= r) {
if (get<1>(ABC[p1]) <= get<1>(ABC[p2])) { // B 값을 비교
FW::add(get<2>(ABC[p1]), 1);
T[p++] = ABC[p1++];
} else {
Count += FW::sum(get<2>(ABC[p2])); // !!!!
T[p++] = ABC[p2++];
}
}
int p1_backup = p1;
while (p1 <= m) T[p++] = ABC[p1++];
while (p2 <= r) {
Count += FW::sum(get<2>(ABC[p2]));
T[p++] = ABC[p2++];
}
for (int i=l; i<p1_backup; i++) {
FW::add(get<2>(ABC[i]), -1); // 펜윅트리 초기화
}
for (int i=l; i<=r; i++) {
ABC[i] = T[i];
}
}
int main() {
scanf("%d", &N);
vector<int> cA, cB, cC; // 좌표압축을 위해
for (int i=1; i<=N; i++) {
scanf("%d %d %d", &A[i], &B[i], &C[i]);
cA.push_back(A[i]);
cB.push_back(B[i]);
cC.push_back(C[i]);
}
sort(ALL(cA));
sort(ALL(cB));
sort(ALL(cC));
UNIQUE(cA);
UNIQUE(cB);
UNIQUE(cC);
ll cnt = 0;
for (int i=1; i<=N; i++) {
A[i] = 1 + NTH(cA, A[i]);
B[i] = 1 + NTH(cB, B[i]);
C[i] = 1 + NTH(cC, C[i]);
ABC[i] = {A[i], B[i], C[i]};
}
sort(ABC+1, ABC+1+N);
CDQ(1, N);
printf("%lld\n", Count);
}
주석으로 !!!! 표시된 줄을 보면 원리를 알 수 있다.
우선 초기에 사전순으로 정렬했고, 분할 정복의 결과이므로 병합 이전 $ABC[l:m]$ 원소의 $A$ 값은 $ABC[m+1:r]$ 원소의 $A$ 값보다 항상 작거나 같다.
CDQ 함수자체가 $B$ 값을 기준으로 병합정렬을 취하는 것이다. 그러므로 $ABC[l:m]$과 $ABC[m+1:r]$의 $B$ 값의 원소를 비교하면서 작은 것을 우선적으로 임시배열 $T$에 저장한다. 즉 먼저 $T$ 배열에 저장된 원소의 $B$ 값은 나중에 삽입되는 원소의 것보다 항상 작거나 같다.
마지막으로 $C$ 배열값을 펜윅트리로 처리하였다. 그러므로 $T$ 배열에 먼저 삽입된 원소 중 $C$ 값이 더 작거나 같은 원소들의 개수도 바로 구할 수 있다.
펜윅트리 연산에 소요되는 시간 복잡도가 $O(\log n)$ 이므로, 전체는 $O(n \log^2 n)$ 이다.
이게 뭐지 싶은데 응용을 보면 아주 신기하다. 생각보다 기초 개념과의 거리가 멀어서.. 꼭 연습문제를 같이 풀어봐야하는 것 같다.
사실 이를 이용한 아주 재밌는 문제를 오늘 처음 풀었다. 글을 나누어 포스팅해야겠다.
https://www.acmicpc.net/problem/16336
'Problem Solving' 카테고리의 다른 글
Dijkstra 알고리즘에 관하여 (0) | 2020.03.16 |
---|---|
백준 온라인 저지 - 16336. Points and Rectangles (0) | 2020.03.16 |
백준 온라인 저지 - 8232. Leveling Ground (0) | 2020.03.05 |
Codeforces Round #625 - D. Reachable Strings (0) | 2020.03.04 |
백준 온라인 저지 - 18297. Pixels (2) | 2020.03.01 |