MITM 알고리즘
개요
Meet int the middle (중간에서 만나기), 줄여서 MITM 알고리즘이란 부르트포스 알고리즘을 사용하여 문제를 해결할 때 시간복잡도를 줄이기 위해서 사용되는 테크닉으로, 쉽게 설명하면 문제를 반으로 나누어 각각 해결한 후 중간에서 만나 최종적인 문제를 해결하는 알고리즘이다.
BOJ 2295 (세 수의 합)
https://www.acmicpc.net/problem/2295
보통 쉽게 생각하는 방식은 x, y, z번째 수
를 합한 집합을 만든 후 입력된 수들 중 최댓값을 찾는 방식을 생각할 것이다. 이는 아무리 짧게 잡아도 최소 $O(n^3)$의 시간 복잡도가 나오게 될 것이다.
여기서 문제를 약간만 변형해보자
“x번째 수 + y번째 수 = k번째 수 - z번째 수
를 만족하는 k번째 수
중 최댓값을 구하시오”
x번째 수 + y번째 수
인 집합을 생성하고, k번째 수 - z번째 수
인 집합을 생성하는 방식을 생각하게 되지 않는가?
두 집합을 생성하는데에 $O(n^2)$, 이분탐색 또는 해시를 사용하여 탐색을 수행하면 $O(n\cdot logn)$, 또는 $O(n^2)$의 시간 복잡도가 걸리므로 총합 $O(n^2)$의 시간복잡도를 가지게 된다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <algorithm>
#include <vector>
#include <unordered_set>
using namespace std;
int main() {
int n, result;
cin >> n;
vector <int> nums(n);
for (auto &el : nums) {
cin >> el;
}
sort(nums.begin(), nums.end());
unordered_set <int> setA;
for (int x = 0; x < n; x++) {
for (int y = x; y < n; y++) {
setA.insert(nums[x] + nums[y]);
}
}
result = 0;
for (int z = 0; z < n; z++) {
for (int k = z; k < n; k++) {
if (setA.find(nums[k] - nums[z]) != setA.end()) {
result = max(nums[k], result);
}
}
}
cout << result;
return 0;
}
BOJ 1208 (부분수열의 합 2)
https://www.acmicpc.net/problem/1208
해당 문제는 정석적으로 생각한다면 입력받은 숫자들을 경우의 수에 따라 구해진 부분수열의 합들을 s
와 비교해 볼 것이다.
그러나 해당 방식은 부분수열의 합이 저장된 배열을 구하는 시간복잡도가 $O(2^n)$일 뿐더러, 배열의 최대 크기 또한 $O(2^n)$으로, 만약 n
이 40
으로 입력된다면 대략 1조번에 이르는 연산이 수행되어야 한다.
그렇다면 입력받은 배열을 반으로 갈라 나누어진 두 개의 배열을 각각 집합 A
와 집합 B
로 나누어 부분수열의 합을 구한 다음,
“집합 A의 원소 + 집합 B의 원소 = s
를 만족하는 경우의 수를 구하시오” 라는 문제로 변형하여 풀어보자.
각 집합을 구하는 데에는 $O(2^{\frac{n}{2}})$이고, 위의 식을 만족하는 경우의 수를 구하는 데에 정렬과 투포인터 또는 이분탐색 등을 사용하면 최종적으로는 $O(n\cdot 2^{\frac{n}{2}})$ 또는 $O(2^{\frac{n}{2}})$의 시간 복잡도로 해결할 수 있다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
vector <int> nums;
void dfs (int start, int end, int sum, vector <int> &vec) {
if (start == end) {
vec.push_back(sum);
return;
}
dfs(start + 1, end, sum, vec);
dfs(start + 1, end, sum + nums[start], vec);
}
int main() {
int n, target;
cin >> n >> target;
nums.resize(n);
for (int &el : nums) {
cin >> el;
}
int mid = n / 2;
vector <int> setA;
vector <int> setB;
dfs(0, mid, 0, setA);
dfs(mid, n, 0, setB);
sort(setA.begin(), setA.end());
sort(setB.begin(), setB.end());
long long result = 0;
for (int &x : setA) {
result += upper_bound(setB.begin(), setB.end(), target - x)
- lower_bound(setB.begin(), setB.end(), target - x);
}
if (target == 0) result--;
cout << result;
return 0;
}