Post

MITM 알고리즘

개요

Meet int the middle (중간에서 만나기), 줄여서 MITM 알고리즘이란 부르트포스 알고리즘을 사용하여 문제를 해결할 때 시간복잡도를 줄이기 위해서 사용되는 테크닉으로, 쉽게 설명하면 문제를 반으로 나누어 각각 해결한 후 중간에서 만나 최종적인 문제를 해결하는 알고리즘이다.

BOJ 2295 (세 수의 합)

https://www.acmicpc.net/problem/2295

보통 쉽게 생각하는 방식은 x, y, z번째 수를 합한 집합을 만든 후 입력된 수들 중 최댓값을 찾는 방식을 생각할 것이다. 이는 아무리 짧게 잡아도 최소 $O(n^3)$의 시간 복잡도가 나오게 될 것이다.

여기서 문제를 약간만 변형해보자
x번째 수 + y번째 수 = k번째 수 - z번째 수 를 만족하는 k번째 수 중 최댓값을 구하시오”

x번째 수 + y번째 수인 집합을 생성하고, k번째 수 - z번째 수인 집합을 생성하는 방식을 생각하게 되지 않는가?

두 집합을 생성하는데에 $O(n^2)$, 이분탐색 또는 해시를 사용하여 탐색을 수행하면 $O(n\cdot logn)$, 또는 $O(n^2)$의 시간 복잡도가 걸리므로 총합 $O(n^2)$의 시간복잡도를 가지게 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <algorithm>
#include <vector>
#include <unordered_set>

using namespace std;

int main() {
    int n, result;
    cin >> n;
    
    vector <int> nums(n);
    for (auto &el : nums) {
        cin >> el;
    }

    sort(nums.begin(), nums.end());
    
    unordered_set <int> setA;
    for (int x = 0; x < n; x++) {
        for (int y = x; y < n; y++) {
            setA.insert(nums[x] + nums[y]);
        }    
    }
    
    result = 0;
    for (int z = 0; z < n; z++) {
        for (int k = z; k < n; k++) {
            if (setA.find(nums[k] - nums[z]) != setA.end()) {
                result = max(nums[k], result);
            }
        }    
    }
    
    cout << result;
    
    return 0;
}

BOJ 1208 (부분수열의 합 2)

https://www.acmicpc.net/problem/1208

해당 문제는 정석적으로 생각한다면 입력받은 숫자들을 경우의 수에 따라 구해진 부분수열의 합들을 s와 비교해 볼 것이다.

그러나 해당 방식은 부분수열의 합이 저장된 배열을 구하는 시간복잡도가 $O(2^n)$일 뿐더러, 배열의 최대 크기 또한 $O(2^n)$으로, 만약 n40으로 입력된다면 대략 1조번에 이르는 연산이 수행되어야 한다.

그렇다면 입력받은 배열을 반으로 갈라 나누어진 두 개의 배열을 각각 집합 A집합 B로 나누어 부분수열의 합을 구한 다음,
집합 A의 원소 + 집합 B의 원소 = s를 만족하는 경우의 수를 구하시오” 라는 문제로 변형하여 풀어보자.

각 집합을 구하는 데에는 $O(2^{\frac{n}{2}})$이고, 위의 식을 만족하는 경우의 수를 구하는 데에 정렬과 투포인터 또는 이분탐색 등을 사용하면 최종적으로는 $O(n\cdot 2^{\frac{n}{2}})$ 또는 $O(2^{\frac{n}{2}})$의 시간 복잡도로 해결할 수 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

vector <int> nums;

void dfs (int start, int end, int sum, vector <int> &vec) {
    if (start == end) {
        vec.push_back(sum);
        return;
    }
    
    dfs(start + 1, end, sum, vec);
    dfs(start + 1, end, sum + nums[start], vec);
}

int main() {
    int n, target;
    cin >> n >> target;
    
    nums.resize(n);
    for (int &el : nums) {
        cin >> el;
    }
    
    int mid = n / 2;
    vector <int> setA;
    vector <int> setB;
    
    dfs(0, mid, 0, setA);
    dfs(mid, n, 0, setB);
    
    sort(setA.begin(), setA.end());
    sort(setB.begin(), setB.end());
    
    long long result = 0;
    for (int &x : setA) {
        result += upper_bound(setB.begin(), setB.end(), target - x)
                - lower_bound(setB.begin(), setB.end(), target - x);
    }
    
    if (target == 0) result--; 
    cout << result;
    
    return 0;
}

참조

MITM, sqrt Decomposition - 안즈의 소소한 취미생활

This post is licensed under CC BY 4.0 by the author.