不难发现,第一个限制条件可以把n个点构成一颗不超过k叉的树,父节点的取值需要大于子节点的取值
令表示第i个点儿子状态已经确定时取j的方案数
令
考虑到父节点取的数需要大于子节点,所以可以等到转移方程
注意到叶子节点
所以是关于j的一阶多项式
因为k阶多项式是前缀和是k+1阶多项式(可以严格证明,有兴趣可以自己证明)
又因为k阶多项式和p阶多项式的积是k+p阶多项式
所以可以证明对于任意一个是一个关于j的m阶多项式
可以严格证明对于任意i,m不超过n+2
所以我们可以用dp算出前每个的每个点j<=n+3的值
最后我们要求的答案是
因为它是一个阶数小于n+2的多项式
我们知道,有n+3个点可以确定一个n+2阶多项式
所以我们用刚才知道的n+2个值,可以用拉格朗日差值求这个答案在多项式在m处的值
也就是答案
以上就是这题大概的题解
想要了解更具体的情况或有什么问题的,可以联系我的QQ:碳钾钨
以下是代码(来自zyx学长)
#include <bits/stdc++.h>
using ll = long long;
using ull = unsigned long long;
using namespace std;
template <class T>
constexpr T power(T a, ll b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
constexpr ll mul(ll a, ll b, ll p) {
ll res = a * b - ll(1.L * a * b / p) * p;
res %= p;
if (res < 0) {
res += p;
}
return res;
}
template <ll P>
struct MLong {
ll x;
constexpr MLong() : x{} {}
constexpr MLong(ll x) : x{ norm(x % getMod()) } {}
static ll Mod;
constexpr static ll getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(ll Mod_) { Mod = Mod_; }
constexpr ll norm(ll x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr ll val() const { return x; }
explicit constexpr operator ll() const { return x; }
constexpr MLong operator-() const {
MLong res;
res.x = norm(getMod() - x);
return res;
}
constexpr MLong inv() const {
assert(x != 0);
return power(*this, getMod() - 2);
}
constexpr MLong &operator*=(MLong rhs) & {
x = mul(x, rhs.x, getMod());
return *this;
}
constexpr MLong &operator+=(MLong rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MLong &operator-=(MLong rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MLong &operator/=(MLong rhs) & { return *this *= rhs.inv(); }
friend constexpr MLong operator*(MLong lhs, MLong rhs) {
MLong res = lhs;
res *= rhs;
return res;
}
friend constexpr MLong operator+(MLong lhs, MLong rhs) {
MLong res = lhs;
res += rhs;
return res;
}
friend constexpr MLong operator-(MLong lhs, MLong rhs) {
MLong res = lhs;
res -= rhs;
return res;
}
friend constexpr MLong operator/(MLong lhs, MLong rhs) {
MLong res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MLong &a) {
ll v;
is >> v;
a = MLong(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) { return os << a.val(); }
friend constexpr bool operator==(MLong lhs, MLong rhs) { return lhs.val() == rhs.val(); }
friend constexpr bool operator!=(MLong lhs, MLong rhs) { return lhs.val() != rhs.val(); }
};
template <>
ll MLong<0LL>::Mod = ll(1E18) + 9;
template <int P>
struct MInt {
int x;
constexpr MInt() : x{} {}
constexpr MInt(ll x) : x{ norm(x % getMod()) } {}
static int Mod;
constexpr static int getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(int Mod_) { Mod = Mod_; }
constexpr int norm(int x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr int val() const { return x; }
explicit constexpr operator int() const { return x; }
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
assert(x != 0);
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) & {
x = 1LL * x * rhs.x % getMod();
return *this;
}
constexpr MInt &operator+=(MInt rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) & { return *this *= rhs.inv(); }
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
ll v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) { return os << a.val(); }
friend constexpr bool operator==(MInt lhs, MInt rhs) { return lhs.val() == rhs.val(); }
friend constexpr bool operator!=(MInt lhs, MInt rhs) { return lhs.val() != rhs.val(); }
};
template <>
int MInt<0>::Mod = 998244353;
template <int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
constexpr int P = 1e9 + 7;
using Z = MInt<P>;
void solve() {
int n, m, k;
cin >> n >> m >> k;
vector<ll> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
if (k == 1) {
map<int, int> mp;
for (int i = 0; i < n; i++) {
mp[a[i]] += 1;
}
cout << power(Z(m), mp.size()) << "\n";
return;
}
vector<vector<int>> adj(n);
vector<int> d(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (a[i] == a[j] / k) {
adj[i].push_back(j);
d[j] += 1;
}
}
}
vector<Z> invfac(1e5);
invfac[0] = 1;
for (int i = 1; i < 1e5; i++) {
invfac[i] = invfac[i - 1] / Z(i);
}
auto lagrange = [&](vector<Z> &f, int x) -> Z {
int k = f.size() - 1;
if (0 <= x && x <= k) {
return f[x];
}
vector<Z> pre(k + 1), suf(k + 1);
pre[0] = 1;
suf[k] = 1;
for (int i = 0; i < k; i++) {
pre[i + 1] = pre[i] * (x - i);
suf[k - i - 1] = suf[k - i] * (x - (k - i));
}
Z res = 0;
for (int i = 0, d = k % 2 ? -1 : 1; i <= k; i++, d *= -1) {
Z num = f[i] * pre[i] * suf[i];
Z den = invfac[i] * invfac[k - i] * d;
res += num * den;
}
return res;
};
const int N = n + 1;
vector<vector<Z>> dp(n, vector<Z>(N));
Z ans = 1;
for (int i = 0; i < n; i++)
if (d[i] == 0) {
auto dfs = [&](auto self, int x) -> void {
if (adj[x].size() == 0) {
iota(dp[x].begin(), dp[x].end(), 1);
return;
}
dp[x].assign(N, 1);
dp[x][0] = 0;
for (auto y : adj[x]) {
self(self, y);
for (int i = 1; i < N; i++) {
dp[x][i] *= dp[y][i - 1];
}
}
// if(x == 0) {
// for(int i = 0; i < n; i++) {
// cerr << dp[x][i] << " \n"[i == n - 1];
// }
// }
for (int i = 1; i < N; i++) {
dp[x][i] += dp[x][i - 1];
}
};
dfs(dfs, i);
ans *= lagrange(dp[i], m - 1);
// cerr << i << " " << ans << "\n";
}
cout << ans << "\n";
}
signed main() {
std::ios::sync_with_stdio(0);
std::cin.tie(0);
std::cout.tie(0);
int tt = 1;
// std::cin >> tt;
while (tt--) solve();
}