题目描述:
给你两个下标从 0 开始的数组 nums
和 cost
,分别包含 n
个 正 整数。
你可以执行下面操作 任意 次:
对第 i
个元素执行一次操作的开销是 cost[i]
。
请你返回使 nums
中所有元素 相等 的 最少 总开销。
数据范围:
$1\le n \le 10^5, 1\le nums[i], cost[i] \le 10^6$
题解:
直接枚举:
考虑对 $nums[i]$ 排序,然后依次枚举每一个 $nums[i]$ 作为最终的数字。
需要考虑的是最终数字从 $nums[i]$ 变为 $nums[i + 1]$ 时代价是如何变化的。
变为 $nums[i + 1]$ 时, $p = nums[i + 1] - nums[i]$ ,则 $[0, i]$ 每个数都需要增加 $cost \times p$ , $[i + 1, n - 1]$ 每个数都需要减少 $cost \times p$ ,即增加 $\sum_{j = 0}^i cost_j \times p$ ,减少 $\sum_{j = i + 1}^{n - 1} cost_j \times p$ 。需要维护区间和,可以使用前缀和。
中位数贪心:
考虑转化,可以把 $nums[i], cost[i]$ 的代价看做 $cost[i]$ 个 $nums[i]$ ,这样的话就可以转化为一般的中位数贪心。可以统计所有的 $\sum cost[i] = sum$ ,然后累加到 $\sum_1 cost[j] \ge sum / 2$ .
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| auto optimize_cpp_stdio = []() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); return 0; }(); class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const static long long mod = 1e9 + 7; const long long INF_LL = 0x3f3f3f3f3f3f3f3f; const int INF = 0x3f3f3f3f; long long minCost(vector<int> &nums, vector<int> &cost) { int n = nums.size(); vector<int> id(n); vector<long long> preSum(n + 1); iota(id.begin(), id.end(), 0); sort(id.begin(), id.end(), [&](const int &x, const int &y) { return nums[x] < nums[y]; }); long long ans = 0; for (int i = 0; i < n; ++i) { preSum[i + 1] = preSum[i] + cost[id[i]]; ans += 1ll * (nums[id[i]] - nums[id[0]]) * cost[id[i]]; } long long tmp = 0; for (int i = 2; i <= n; ++i) { int index = id[i - 1]; long long p = nums[index] - nums[id[i - 1 - 1]]; tmp = ans + p * preSum[i - 1] - p * (preSum[n] - preSum[i - 1]); ans = min(ans, tmp); } return ans; } };
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| auto optimize_cpp_stdio = []() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); return 0; }(); class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const static long long mod = 1e9 + 7; const long long INF_LL = 0x3f3f3f3f3f3f3f3f; const int INF = 0x3f3f3f3f; long long minCost(vector<int> &nums, vector<int> &cost) { int n = nums.size(); vector<int> id(n); iota(id.begin(), id.end(), 0); sort(id.begin(), id.end(), [&](const int &x, const int &y) { return nums[x] < nums[y]; }); long long sum = accumulate(cost.begin(), cost.end(), 0ll); long long cnt = 0; int index; for (int i = 0; i < n; ++i) { cnt += cost[id[i]]; if(cnt >= (sum + 1) / 2) { index = i; break; } } int aim = nums[id[index]]; cout << aim << endl; long long ans = 0; for (int i = 0; i < n; ++i) { ans = ans + 1ll * abs(aim - nums[i]) * 1ll * cost[i]; } return ans; } };
|