SRM 654 Div1 Medium SuccessiveSubtraction2
区間を 2 つ選んで符号を反転させることができるので、区間和の最小値が計算できれば、あとは簡単に計算できる。
これは、O(N) でやってよいので、配列が更新される度に最初から計算し直せば良い。(解いているときは何も考えずに、StarrySky Tree を貼ってしまい、計算量が余計に log 倍されてしまった。)
どちらの解法でも TLE しているが、不具合が無ければ通るはず。
配列版 O(QN)
#define N 2010 int dp[N], dp2[N]; int s[N], s2[N]; class SuccessiveSubtraction2 { public: vector <int> calc(vector <int> a, vector <int> p, vector <int> v) { int n = a.size(), m = p.size(); vector<int> res; mset(dp, 0); mset(dp2, 0); mset(s, 0); mset(s2, 0); rep(i, m){ a[p[i]] = v[i]; ll sum = a[0]; for(int j = 1; j < n; j++) sum -= a[j]; if(n<=2){ res.pb(sum); continue; } for(int j = 2; j < n; j++) s[j+1] = s[j]+a[j]; for(int j = n-1; j >= 2; j--) s2[j] = s2[j+1]+a[j]; for(int j = 2, mn = 0; j <= n; j++){ chmin(mn, s[j]); dp[j] = max(dp[j-1], s[j]-mn); } for(int j = n-1, mn = 0; j >= 1; j--){ chmin(mn, s2[j]); dp2[j] = max(dp2[j+1], s2[j]-mn); } int r = -INF; for(int j = 2; j < n; j++) chmax(r, dp[j]+dp2[j]); res.pb(sum+2*r); } return res; } };
StarrySky Tree 版 O(QNlogN)
struct starrysky { int seg_n; vector<ll> data, datb; void init(int n, ll initial_value = 0){ seg_n = 1; while(seg_n<n) seg_n<<=1; data = vector<ll>(seg_n*2-1, 0); datb = vector<ll>(seg_n*2-1, 0); } // min(v[a], v[a+1],..., v[b-1]) // to use: query(a, b); ll rmin(int a, int b, int k, int l, int r){ if(r <= a || b <= l) return INF; if(a <= l && r <= b) return data[k]+datb[k]; return min(rmin(a, b, k*2+1, l, (l+r)/2), rmin(a, b, k*2+2, (l+r)/2, r))+data[k]; } ll rmin(int a, int b){ return rmin(a, b, 0, 0, seg_n); } void radd(int a, int b, ll x, int k, int l, int r){ if(r <= a || b <= l) return;//return INF; if(a <= l && r <= b){ data[k] += x; return; } else { radd(a, b, x, k*2+1, l, (l+r)/2); radd(a, b, x, k*2+2, (l+r)/2, r); datb[k] = min(datb[k*2+1]+data[k*2+1], datb[k*2+2]+data[k*2+2]); } } void radd(int a, int b, ll x){ radd(a, b, x, 0, 0, seg_n); } }; #define N 2010 int dp[N], dp2[N]; class SuccessiveSubtraction2 { public: vector <int> calc(vector <int> a, vector <int> p, vector <int> v) { int n = a.size(), m = p.size(); vector<int> res; starrysky x, y; x.init(n+2); y.init(n+2); ll sum = a[0]; for(int i = 1; i < n; i++) sum -= a[i]; for(int i = 2; i < n; i++){ x.radd(i, n+2, a[i]); y.radd(0, i, a[i]); } rep(i, m){ if(p[i]==0) sum += v[i]-a[0]; else sum -= v[i]-a[p[i]]; if(p[i]<2) a[p[i]] = v[i]; else { int d = v[i]-a[p[i]]; x.radd(p[i], n+2, d); y.radd(0, p[i], d); a[p[i]] = v[i]; } int r = -INF; mset(dp, 0); mset(dp2, 0); for(int j = 2; j < n; j++){ int mn = x.rmin(j, j+1)-x.rmin(0, j+1); dp[j] = max(dp[j-1], mn); } for(int j = n-1; j >= 2; j--){ int mn = y.rmin(j, j+1)-y.rmin(j, n); dp2[j] = max(dp2[j+1], mn); } for(int j = 2; j < n; j++) chmax(r, dp[j]+dp2[j+1]); if(n<=2) r = 0; res.pb(sum+r*2); } return res; } };