조화 수열의 특징을 이용해 원하는 합을 빠르게 구해보자.

1부터 n까지 ceil(1/1)+ceil(1/2)+(1/3)+…+ceil(1/n)의 합을 어떻게 빨리 구할 수 있을까?(ceil은 내림)

1부터 10까지를 생각해 보자. 쭉 전개하면 10 5 3 2 2 1 1 1 1 1으로 2와 1처럼 겹치는 부분이 발생한다. 이 겹치는 부분의 길이를 알 수 있다면 겹치는 부분의 길이 * (n/i)를 다 더해주면 값을 빠르게 구할 수 있다. j는 n/i와 n/j가 같은 최대 j로 정의하자. 그럼 i가 3이면 j는 3이고, i가 4이면 j는 5이다. 이 j를 안다면 sum+=(n/i)*(j-i+1)를 해주면 답을 구할 수 있다. 그러면 j를 어떻게 알 수 있을까?

ceil(n/i)=k라고 하자. ceil(n/x)=k인 실수 x의 최댓값은 n/k이다. 마찬가지로 정수 j의 최댓값은 ceil(n/k)임을 알 수 있다. k가 ceil(n/i)이므로 j는 ceil(n/ceil(n/i))일 것이다.

이를 바탕으로 1부터 n까지 n/i를 빠르게 구하는 코드를 작성해 보자.

int ans=0, j;
for(i=1;i<=n;i=j+1)
{
	j=n/(n/i);
	ans+=n/i*(j-i+1);
}

그럼 이 알고리즘의 시간복잡도는 어떻게 될까? 총 ceil(n/i)의 개수만큼 연산을 하므로 시간복잡도는 서로 다른 n/i들의 총 개수이다.(말하자면 n이 10이면 수열에는 10, 5, 3, 2, 1이 나오므로 5번 for문을 돌게 된다)

1부터 sqrt(n)까지 서로 다른 ceil(n/i)는 최대 sqrt(n)개이다. i의 개수가 sqrt(n)개이기 때문이다. sqrt(n)+1부터 n까지 서로 다른 ceil(n/i)는 최대 sqrt(n)개이다. 모든 i에 대해 ceil(n/i)가 sqrt(n)이하이기 때문이다. 따라서 ceil(n/i)의 개수가 2*sqrt(n)에 바운드되고, 시간복잡도는 sqrt(n)이다.

아래는 관련 문제와 코드이다.

https://www.acmicpc.net/problem/15897

그대로 구현하면 된다.

#include <bits/stdc++.h>
#define int long long
#define float long double
#define pp pair<int, int>
#define tp tuple<int, int, int>
#define mtp make_tuple
#define mp make_pair
using namespace std;
using lf = __float128;
using ld = __int128;
float pi = 3.141592653589793238462643383279502884197169399375105Q;
int dx[4] = {0, 0, 1, -1}, dy[4] = {1, -1, 0, 0};
int ddx[8] = {0, 0, 1, -1, 1, 1, -1, 1}, ddy[8] = {1, -1, 0, 0, 1, -1, 1, -1};

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr);

    int n;
    cin >> n;
    int ans = 0;

    int j = 1;
    n--;
    for (int i = 1; i <= n; i = j + 1) {
        j = n / (n / i);
        ans += (j - i + 1) * (n / i + 1);
    }
    cout << ans + 1;

    return 0;
}

https://www.acmicpc.net/problem/17417 Q번 N,S,E가 주어진다.

Q가 10^5이하이다. Q가 작은 경우는 N이 달라지고, Q의 제한이 없는 경우는 N이 동일하게 나온다. 따라서 두 가지 경우를 다르게 구현하면 된다. (1) Q가 2000이하고 N이 다르게 나오는 경우: 위의 내용을 이용해 O(Qsqrt(N))으로 구현하면 된다.

(2) Q의 제한이 없고 N이 같게 나오는 경우: 누적합과 이분탐색을 이용해 O(Qlog(N))으로 구현하면 된다.

둘 중 하나의 방법으로만 코드를 짜면 풀 수 없는 독특한 문제이다.

#include <bits/stdc++.h>
#define int long long
#define float long double
#define pp pair<int, int>
#define tp tuple<int, int, int>
#define mtp make_tuple
#define mp make_pair
using namespace std;
using lf = __float128;
using ld = __int128;
float pi = 3.141592653589793238462643383279502884197169399375105Q;
int dx[4] = {0, 0, 1, -1}, dy[4] = {1, -1, 0, 0};
int ddx[8] = {0, 0, 1, -1, 1, 1, -1, 1}, ddy[8] = {1, -1, 0, 0, 1, -1, 1, -1};

int E=-1, N;

int bsearch(int val, int sum, vector<pp>& sums) {
    if (val == 0)
        return 0;
    else if (val >= N)
        return sum;
    else {
        pp t = {val, 0};
        vector<pp>::iterator its = lower_bound(sums.begin(), sums.end(), t);

        if ((*its).first == val)
            return (*its).second;
        else {
            its--;
            return (*its).second + (val - (*its).first) * (N / val);
        }
    }
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr);

    int q;
    cin >> q;
    vector<array<int, 3>> arr(101010);
    bool flag = 1;
    for (int t = 0; t < q;t++){
        int n, s, e;
        cin >> n >> s >> e;
        arr[t] = {n, s, e};
        E = max(e, E);
        N = n;
        if(t>=1 && arr[t-1][0]!=arr[t][0]){
            flag = 0;
        }
    }
    if(flag==0){
        for (int t = 0; t < q;t++){
            int j;
            int sum = 0;
            auto [n,s,e] = arr[t];
            for (int i = s; i <= min(n,e);i=j+1){
                j = min(n/(n / i), e);
                sum += (n / i) * (j - i + 1);
            }
            cout << sum << '\\n';
        }
    }
    if(flag==1){
        int j;
        int sum = 0;
        vector<pp> sums;
        for (int i = 1; i <= min(N, E);i=j+1){
            j = N / (N / i);
            sum += (N / i) * (j - i + 1);
            sums.push_back({j, sum});
        }
        for (int t = 0; t < q; t++) {
            auto [n, s, e] = arr[t];
            if (s > e)
                cout << 0 << '\\n';
            else{
                cout << bsearch(e,sum,sums) - bsearch(s-1,sum,sums)<<'\\n';
            }
        }
    }
        return 0;
}