显示原始代码
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <map>
#include <unordered_map>
#include <set>
#include <iostream>
#include <cstring>
#include <bitset>
#include <ctime>
#include <functional>
#include <numeric>
#include <random>
#include <chrono>
#include <array>
#include <utility>
typedef int64_t LL;
typedef uint64_t uLL;
typedef __int128_t sLL;
#define fir first
#define sec second
#define eb emplace_back
#define em emplace
#define pb push_back
#define ppb pop_back
#define pii std::pair<int, int>
#define mkp(a, b) std::make_pair(a, b)
#define bitcount(x) __builtin_popcount(x)
#define bitcountll(x) __builtin_popcountll(x)
#define bitparity(x) __builtin_parity(x)
#define bitparityll(x) __builtin_parityll(x)
int read() {
int s = 0, w = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
w = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * w;
}
LL readl() {
LL s = 0, w = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
w = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * w;
}
template <typename T>
void debugv(const T &vec, const char *s = NULL) {
if (s == NULL) {
s = "";
}
printf("%s = [", s);
for (const auto &val : vec) {
std::cout << val << ", ";
}
printf("]\n");
}
template <typename key, typename val>
void debugmp(const std::map<key, val> &mp, const char *s = NULL) {
if (s == NULL) {
s = "";
}
printf("%s = [", s);
for (const auto &[k, v] : mp) {
std::cout << "< " << k << ", " << v << " >, ";
}
printf("]\n");
}
template <typename T>
void debuga(T *arr, int l, int r, const char *s = NULL) {
if (s == NULL) {
s = "";
}
printf("%s = [", s);
for (int i = l; i <= r; i++) {
std::cout << arr[i] << ", ";
}
printf("]\n");
}
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = std::chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
const int maxn = 3e5 + 1;
const int mod = 998244353, g = 3;
#define int int64_t
int ksm(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = x * x % mod)
if (y & 1)
ret = ret * x % mod;
return ret;
}
int rev[maxn];
void NTT(int *a, int n, int typ) {
for (int i = 0; i < n; i++)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
int gn = ksm(g, (mod - 1) / (i << 1));
for (int j = 0, g0 = 1, x, y; j < n; j += (i << 1), g0 = 1)
for (int k = 0; k < i; k++, g0 = gn * g0 % mod) {
x = a[j + k], y = a[i + j + k] * g0 % mod;
a[j + k] = (x + y) % mod;
a[i + j + k] = (x - y + mod) % mod;
}
}
if (typ == 1)
return;
int inv = ksm(n, mod - 2);
std::reverse(a + 1, a + n);
for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
return;
}
void INV(int *b, int *a, int n) {
if (n == 1)
return b[0] = ksm(a[0], mod - 2), void();
INV(b, a, (n + 1) >> 1);
static int c[maxn];
int len = 1, p = -1;
while (len < (n << 1)) len <<= 1, p++;
for (int i = 1; i <= len - 1; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << p);
std::copy(a, a + n, c);
std::fill(c + n, c + len, 0);
NTT(c, len, 1), NTT(b, len, 1);
for (int i = 0; i <= len - 1; i++) b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
NTT(b, len, 0), std::fill(b + n, b + len, 0);
return;
}
#undef int
int n, k;
LL f[maxn], gg[maxn], frac[maxn], ifrac[maxn];
int a[maxn];
void init() {
frac[0] = 1;
for (int i = 1; i < maxn; i++) {
frac[i] = frac[i - 1] * i % mod;
}
ifrac[maxn - 1] = ksm(frac[maxn - 1], mod - 2);
for (int i = maxn - 2; i >= 0; i--) {
ifrac[i] = ifrac[i + 1] * (i + 1) % mod;
}
}
void add(LL &x, int val) {
x += val;
if (x >= mod)
x -= mod;
if (x < 0)
x += mod;
}
int32_t main() {
init();
n = read();
k = read();
for (int i = 1; i <= k; i++) {
a[i] = read();
if (a[i] == 2) {
printf("0\n");
return 0;
}
}
int mx = *std::max_element(a + 1, a + 1 + k);
for (int i = 0; i <= n; i++) {
gg[i] = ksm(2, (LL)i * (i - 1) / 2) * ifrac[i] % mod;
}
INV(f, gg, n + 1);
for (int i = 0; i <= n; i++) {
f[i] = mod - f[i];
if (f[i] == mod)
f[i] = 0;
}
add(f[0], 1);
for (int i = mx; i <= n; i++) {
f[i] = 0;
}
for (int i = 0; i <= n; i++) {
f[i] = mod - f[i];
if (f[i] == mod)
f[i] = 0;
}
add(f[0], 1);
memset(gg, 0, sizeof(gg));
INV(gg, f, n + 1);
LL ans = ksm(2, (LL)n * (n - 1) / 2);
add(ans, -gg[n] * frac[n] % mod);
printf("%lld\n", ans);
return 0;
}