题意
$n$ 点树, $q$ 次询问,给 $m$ 个特殊点。树上每个点由最靠近它的特殊点控制,如果距离相同取编号小的。问每个特殊点控制多少点。
$n,q,\sum m\leq3\times10^5$ 。
题解
虚树。然后考虑树形 DP,从下往上扫一遍从上往下扫一遍。细节很多。
代码
const int N = 3e5 + 5, INF = 0x3f3f3f3f;
inline ll Read() {
ll x = 0, f = 1;
char c = getchar();
while (c != '-' && (c < '0' || c > '9')) c = getchar();
if (c == '-') f = -f, c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - '0', c = getchar();
return x * f;
}
namespace Main {
int n, m;
int head[N], tot;
struct Edge {
int to, nxt;
} e[N << 1];
void add_edge (int u, int v) { e[++tot] = (Edge) {v, head[u]}, head[u] = tot; }
int dfn[N], dep[N], fa[N][30], cnt, idfn[N], siz[N];
void dfs (int u, int fa_) {
dfn[u] = ++cnt; fa[u][0] = fa_;
dep[u] = dep[fa_] + 1;
siz[u] = 1;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to; if (v == fa_) continue;
dfs (v, u);
siz[u] += siz[v];
}
idfn[u] = cnt;
}
int LCA (int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
for (int j = 21; j >= 0; j--)
if ((d >> j) & 1) u = fa[u][j];
if (u == v) return u;
for (int j = 21; j >= 0; j--)
if (fa[u][j] != fa[v][j]) u = fa[u][j], v = fa[v][j];
return fa[u][0];
}
int Jmp (int u, int d) { for (int i = 0; i <= 21; i++) if ((d >> i) & 1) u = fa[u][i]; return u;}
int a[N], b[N];
bool vis[N];
int ans[N], val[N];
namespace VirtualTree {
int stk[N], fa[N];
#define Pair pair<int, int>
#define st first
#define nd second
#define mk make_pair
Pair f[N];
void Build() {
memset (fa, 0, sizeof fa);
sort (a + 1, a + 1 + m, [](int a, int b) { return dfn[a] < dfn[b]; });
for (int i = 1, n = m; i < n; i++) a[++m] = LCA(a[i], a[i + 1]);
sort (a + 1, a + 1 + m, [](int a, int b) { return dfn[a] < dfn[b]; });
m = unique(a + 1, a + 1 + m) - a - 1;
int top = 0; stk[++top] = a[1];
if (!vis[a[1]]) f[a[1]] = mk (INF, 0);
else f[a[1]] = mk(0, a[1]);
for (int i = 2; i <= m; i++) {
for (; top && idfn[stk[top]] < dfn[a[i]]; top--);
if (top) fa[a[i]] = stk[top];
if (!vis[a[i]]) f[a[i]] = mk (INF, 0);
else f[a[i]] = mk(0, a[i]);
ans[a[i]] = 0;
stk[++top] = a[i];
}
}
int dt[N];
void Solve() {
for (int i = m; i >= 2; i--) {
int u = a[i], v = fa[u];
dt[u] = dep[u] - dep[v];
Pair tmp = mk (f[u].st + dt[u], f[u].nd);
if (tmp < f[v]) f[v] = tmp;
}
for (int i = 2; i <= m; i++) {
int u = a[i], v = fa[u];
Pair tmp = mk (f[v].st + dt[u], f[v].nd);
if (tmp < f[u]) f[u] = tmp;
}
for (int i = 1; i <= m; i++) {
int u = a[i], v = fa[u];
val[u] = siz[u];
if (i == 1) {
ans[f[u].nd] += n - siz[u];
continue;
}
int son = Jmp (u, dep[u] - dep[v] - 1);
int calc = siz[son] - siz[u];
val[v] -= siz[son];
if (f[u].nd == f[v].nd) ans[f[u].nd] += calc;
else {
int z = f[u].st - f[v].st + dep[u] + dep[v] + 1 >> 1;
if (f[v].nd < f[u].nd && f[v].st + z - dep[v] == f[u].st + dep[u] - z) ++z;
z = siz[Jmp(u, dep[u] - z)] - siz[u];
ans[f[u].nd] += z;
ans[f[v].nd] += calc - z;
}
}
for (int i = 1; i <= m; i++) ans[f[a[i]].nd] += val[a[i]];
}
}
int main () {
n = Read();
for (int i = 1; i < n; i++) {
int u = Read(), v = Read();
add_edge(u, v);
add_edge(v, u);
}
dfs (1, 0);
for (int j = 1; j <= 21; j++)
for (int i = 1; i <= n; i++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
for (int q = Read(); q--; ) {
int Om = m = Read();
memset (vis, 0, sizeof vis);
memset (ans, 0, sizeof ans);
for (int i = 1; i <= m; i++) vis[b[i] = a[i] = Read()] = 1;
VirtualTree::Build();
VirtualTree::Solve();
for (int i = 1; i <= Om; i++) printf ("%d ", ans[b[i]]); putchar(10);
}
return 0;
}
}
int main () {
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
Main::main();
return 0;
}