题解

admin 2024-12-21 21:21:41

不难发现,第一个限制条件可以把n个点构成一颗不超过k叉的树,父节点的取值需要大于子节点的取值

表示第i个点儿子状态已经确定时取j的方案数

考虑到父节点取的数需要大于子节点,所以可以等到转移方程

注意到叶子节点

所以是关于j的一阶多项式

因为k阶多项式是前缀和是k+1阶多项式(可以严格证明,有兴趣可以自己证明)

又因为k阶多项式和p阶多项式的积是k+p阶多项式

所以可以证明对于任意一个是一个关于j的m阶多项式

可以严格证明对于任意i,m不超过n+2

所以我们可以用dp算出前每个的每个点j<=n+3的值

最后我们要求的答案是

因为它是一个阶数小于n+2的多项式

我们知道,有n+3个点可以确定一个n+2阶多项式

所以我们用刚才知道的n+2个值,可以用拉格朗日差值求这个答案在多项式在m处的值

也就是答案

以上就是这题大概的题解

想要了解更具体的情况或有什么问题的,可以联系我的QQ:碳钾钨

以下是代码(来自zyx学长)

#include <bits/stdc++.h>
using ll = long long;
using ull = unsigned long long;
using namespace std;
template <class T>
constexpr T power(T a, ll b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}

constexpr ll mul(ll a, ll b, ll p) {
    ll res = a * b - ll(1.L * a * b / p) * p;
    res %= p;
    if (res < 0) {
        res += p;
    }
    return res;
}
template <ll P>
struct MLong {
    ll x;
    constexpr MLong() : x{} {}
    constexpr MLong(ll x) : x{ norm(x % getMod()) } {}

    static ll Mod;
    constexpr static ll getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(ll Mod_) { Mod = Mod_; }
    constexpr ll norm(ll x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr ll val() const { return x; }
    explicit constexpr operator ll() const { return x; }
    constexpr MLong operator-() const {
        MLong res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MLong inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MLong &operator*=(MLong rhs) & {
        x = mul(x, rhs.x, getMod());
        return *this;
    }
    constexpr MLong &operator+=(MLong rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MLong &operator-=(MLong rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MLong &operator/=(MLong rhs) & { return *this *= rhs.inv(); }
    friend constexpr MLong operator*(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MLong operator+(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MLong operator-(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MLong operator/(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MLong &a) {
        ll v;
        is >> v;
        a = MLong(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) { return os << a.val(); }
    friend constexpr bool operator==(MLong lhs, MLong rhs) { return lhs.val() == rhs.val(); }
    friend constexpr bool operator!=(MLong lhs, MLong rhs) { return lhs.val() != rhs.val(); }
};

template <>
ll MLong<0LL>::Mod = ll(1E18) + 9;

template <int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(ll x) : x{ norm(x % getMod()) } {}

    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) { Mod = Mod_; }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const { return x; }
    explicit constexpr operator int() const { return x; }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & { return *this *= rhs.inv(); }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
        ll v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) { return os << a.val(); }
    friend constexpr bool operator==(MInt lhs, MInt rhs) { return lhs.val() == rhs.val(); }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) { return lhs.val() != rhs.val(); }
};

template <>
int MInt<0>::Mod = 998244353;

template <int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();

constexpr int P = 1e9 + 7;
using Z = MInt<P>;
void solve() {
    int n, m, k;
    cin >> n >> m >> k;

    vector<ll> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }

    if (k == 1) {
        map<int, int> mp;
        for (int i = 0; i < n; i++) {
            mp[a[i]] += 1;
        }
        cout << power(Z(m), mp.size()) << "\n";
        return;
    }

    vector<vector<int>> adj(n);
    vector<int> d(n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            if (a[i] == a[j] / k) {
                adj[i].push_back(j);
                d[j] += 1;
            }
        }
    }

    vector<Z> invfac(1e5);
    invfac[0] = 1;
    for (int i = 1; i < 1e5; i++) {
        invfac[i] = invfac[i - 1] / Z(i);
    }

    auto lagrange = [&](vector<Z> &f, int x) -> Z {
        int k = f.size() - 1;
        if (0 <= x && x <= k) {
            return f[x];
        }
        vector<Z> pre(k + 1), suf(k + 1);
        pre[0] = 1;
        suf[k] = 1;
        for (int i = 0; i < k; i++) {
            pre[i + 1] = pre[i] * (x - i);
            suf[k - i - 1] = suf[k - i] * (x - (k - i));
        }
        Z res = 0;
        for (int i = 0, d = k % 2 ? -1 : 1; i <= k; i++, d *= -1) {
            Z num = f[i] * pre[i] * suf[i];
            Z den = invfac[i] * invfac[k - i] * d;
            res += num * den;
        }
        return res;
    };

    const int N = n + 1;
    vector<vector<Z>> dp(n, vector<Z>(N));
    Z ans = 1;
    for (int i = 0; i < n; i++)
        if (d[i] == 0) {
            auto dfs = [&](auto self, int x) -> void {
                if (adj[x].size() == 0) {
                    iota(dp[x].begin(), dp[x].end(), 1);
                    return;
                }
                dp[x].assign(N, 1);
                dp[x][0] = 0;
                for (auto y : adj[x]) {
                    self(self, y);
                    for (int i = 1; i < N; i++) {
                        dp[x][i] *= dp[y][i - 1];
                    }
                }
                // if(x == 0) {
                //     for(int i = 0; i < n; i++) {
                //         cerr << dp[x][i] << " \n"[i == n - 1];
                //     }
                // }
                for (int i = 1; i < N; i++) {
                    dp[x][i] += dp[x][i - 1];
                }
            };
            dfs(dfs, i);
            ans *= lagrange(dp[i], m - 1);
            // cerr << i << " " << ans << "\n";
        }

    cout << ans << "\n";
}
signed main() {
    std::ios::sync_with_stdio(0);
    std::cin.tie(0);
    std::cout.tie(0);
    int tt = 1;
    // std::cin >> tt;
    while (tt--) solve();
}