[ICPC 2024 Yokohama R] Tree Generators
题意
规定表达式生成树的操作。根据以下过程,从一个表达式生成一棵树。
- 表达式 $\texttt{1}$ 生成一棵仅包含一个标号为 $1$ 的节点的树。
- 对于两个表达式 $E_1$ 和 $E_2$ ,表达式 $(E_1E_2)$ 按如下方式生成一棵树:
- 从 $E_1$ 生成一棵拥有 $n_1$ 个节点的树 $T_1$ ,并从 $E_2$ 生成一棵拥有 $n_2$ 个节点的树 $T_2$ 。
- 将 $T_2$ 中所有节点的标号都加上 $n_1$ 。
- 随机地各从 $T_1$ 和 $T_2$ 中选取一个节点,连接它们形成一条边,从而构造出一棵标号为 $1$ 到 $(n_1 + n_2)$ 的树,该树即为 $(E_1E_2)$ 生成的树。
现在给定两条表达式,求它们都能生成的树个数。答案对 $998244353$ 取模。
$|S|\leq 7\times 10^5$ 。
思路
两棵树要完全相同必然需要一棵树的每个边另一边也有可能连上,因此一棵树的 $\texttt{(11)}$ 限制了另一棵树相同位置的连边。
比如两个表达式 $\texttt{((1(11))1)}$ 和 $\texttt{((11)(11))}$ ,前者的 $\texttt{(11)}$ 就限定了 2 和 3 连在一起,对生成的树来说限定了第一棵树中两棵子树 $\{2\}$ 和 $\{3\}$ 连接方式、第二棵树两棵子树 $\{12\}$ 和 $\{34\}$ 的连接方式。
考虑在此基础上拓展。 $(E_1E_2)$ 意味着两棵树之间必须有条边,有 $|E_1||E_2|$ 种可能。需要考虑这一个表达式是怎么限制另一棵树连边的。假设 $E_1$ 的范围是 $[l_1,i]$ , $E_2$ 的范围是 $[i+1,r_1]$ ,则 $(E_1E_2)$ 只能限制的另一条表达式 $(E_3E_4)$ 满足 $E_3$ 的范围是 $[l_2,i]$ , $E_4$ 的范围是 $[i+1,r_2]$ 。换句话说,两个分割位置相同的表达式互相影响,它们生成的边对应,有 $(i-\max(l_1,l_2)+1)(\min(r_1,r_2)-i)$ 种可能。
那么预处理出 $l_1,l_2,r_1,r_2$ 即可。
代码
const int N = 7e5 + 5;
const int mod = 998244353;
inline ll read() {
ll x = 0, f = 1;
char c = getchar();
while (c != '-' && (c < '0' || c > '9')) c = getchar();
if (c == '-') f = -f, c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - '0', c = getchar();
return x * f;
}
namespace Main {
int n, m, len;
char s[N];
int l[N], r[N];
int st[N], top;
int ch[N][2], leaf[N];
void mxe (int &x, int y) { if (x < y) x = y; }
void mne (int &x, int y) { if (x > y) x = y; }
pair<int, int> dfs (int p) {
if (leaf[p]) return make_pair(leaf[p], leaf[p]);
auto nd1 = dfs(ch[p][0]), nd2 = dfs(ch[p][1]);
mxe(l[nd1.second], nd1.first);
mne(r[nd1.second], nd2.second);
return make_pair(nd1.first, nd2.second);
}
void solve (bool op = 1) {
scanf ("%s", s + 1);
len = strlen(s + 1);
n = m = top = 0;
memset(ch, 0, sizeof ch);
memset(leaf, 0, sizeof leaf);
for (int i = 1; i <= len; ++i) {
if (s[i] == '(') {
st[++top] = ++n;
ch[st[top - 1]][ch[st[top - 1]][0] != 0] = n;
}
else if (s[i] == ')') st[top--] = 0;
else leaf[ch[st[top]][ch[st[top]][0] != 0] = ++n] = ++m;
}
if (op)
for (int i = 1; i < m; ++i) l[i] = 1, r[i] = m;
dfs(1);
}
int main () {
solve();
solve(0);
ll ans = 1;
for (int i = 1; i < m; ++i)
ans = ans * (i - l[i] + 1) % mod * (r[i] - i) % mod;
printf ("%lld\n", ans);
return 0;
}
}
int main () {
string str = "";
// freopen((str + ".in").c_str(), "r", stdin);
// freopen((str + ".out").c_str(), "w", stdout);
Main::main();
return 0;
}
