
🔗刷题网址: bishipass.com
华子-2026.03.18-算法岗
1. 选择题
、
2. 训练图显存回收规划
显存回收题的本质并不在于 swap 和重计算两个名词,而在于每个张量只会贡献一次显存释放量,代价只取两者较小值。压平以后就是一个“总体积至少达到阈值、总代价最小”的 背包,核心难点其实是建模去噪,而不是状态本身。
3. 终端语音指令近邻识别
这题就是标准三维 KNN 分类,没有隐藏算法难点。关键是把欧式距离写对,并意识到排序时完全可以只比较平方距离;其余就是取前 个邻居做多数表决,属于非常直接的机器学习基础题。
第二题: 训练图显存回收规划
问题描述
LYA 正在做大模型训练图裁剪。
当前模型已经放在 NPU 上运行,但还差至少 的显存空间才能继续完成后续计算。现在一共有 个候选张量可供处理,第 个张量占用的显存大小为 。
对每个候选张量,都可以选择下面两种优化方式之一:
- >
执行 swap:先把张量搬到CPU,需要时再搬回NPU。 - >
执行重计算:释放该张量后,在后续需要时重新算出它。
无论采用哪种方式,只要处理了这个张量,就都可以释放出它本身大小的显存;不同之处在于代价不同。
请从这 个候选张量中选出若干个进行优化,使得释放出的总显存大小不少于 ,并且总代价最小。
如果不存在满足条件的方案,则输出 error。
输入格式
第 行输入一个整数 ,表示至少还需要释放多少显存。
第 行输入一个整数 ,表示候选张量个数。
第 行输入 个整数 ,表示每个张量占用的显存大小。
第 行输入 个整数 ,表示每个张量执行 swap 的代价。
第 行输入 个整数 ,表示每个张量执行重计算的代价。
输出格式
如果不存在合法方案,输出 error。
否则输出一个整数,表示最小总代价。
样例输入
9
4
2 4 5 6
3 2 6 4
1 5 2 3
样例输出
4
数据范围
- >
。 - >
。 - >
。
题解
这道题表面上每个张量有两种操作方式,但真正关键只有一句:
- >
只要决定“处理这个张量”,释放出来的显存都是 。
因此,对于第 个张量,我们根本不需要把 swap 和重计算当成两个不同决策,只需要先把它压成一个真实代价:
这样原问题就变成了:
- >
有 个物品。 - >
第 个物品的“体积”是 。 - >
第 个物品的“费用”是 。 - >
选若干个物品,使总体积至少达到 。 - >
费用和最小。
这就是标准的“至少装满”的 背包。
#状态设计
设:
表示“当前已经处理完若干个张量后,释放显存达到状态 时的最小总代价”。
为了避免状态开到很大,我们直接把所有超过 的状态都压到 :
- >
状态范围只保留 。 - >
若加入某个张量后释放显存超过了 ,统一记成 。
初始时:
其余状态初始化为无穷大。
#状态转移
枚举每个张量,设它的大小为 ,最小代价为 。
对每个当前状态 ,若它可达,则新状态为:
转移方程:
因为每个张量最多只能选一次,所以必须倒序枚举状态,这样才是 背包,而不会把同一个张量重复使用多次。
#为什么这样做是对的
先把两种操作方式压成一个最小代价以后,每个张量只剩下“选”或“不选”两种结果。
题目要求的是:
- >
总释放显存至少为 。 - >
在满足条件的所有方案里总代价最小。
而 恰好记录了“释放到某个显存水平时的最优代价”,并且转移时完整覆盖了“当前张量选还是不选”这两种可能,因此最终的 就是答案。
#边界与坑点
最容易绕弯的地方有两个:
不能把一个张量拆成两个物品。 因为
swap和重计算不是可以同时选的两种贡献,而是同一个张量的两种候选处理方式,真实代价只取较小值。不能把状态开到所有显存和。 题目只关心是否达到 ,所以把超过 的状态统一压成 就够了。
另外,如果所有张量大小之和都还小于 ,那么一定无解,直接输出 error。
#复杂度分析
设目标显存为 ,张量数为 。
- >
时间复杂度:。 - >
空间复杂度:。
由于题目中 ,一维背包可以通过。
参考代码
- >
Python
import sys
input = lambda:sys.stdin.readline().strip()
defsolve():
m_line = input()
ifnot m_line:
return
m = int(m_line)
n = int(input())
siz = list(map(int, input().split()))
swp = list(map(int, input().split()))
rec = list(map(int, input().split()))
# 每个张量只会被处理一次,真实代价取两种方式里的较小值。
cst = [swp[i] if swp[i] < rec[i] else rec[i] for i in range(n)]
# 如果总可释放显存都不够,直接无解。
if sum(siz) < m:
print("error")
return
inf = 10**30
dp = [inf] * (m + 1)
dp[0] = 0
for i in range(n):
w = siz[i]
c = cst[i]
# 倒序枚举,保证每个张量只用一次。
for j in range(m, -1, -1):
if dp[j] >= inf:
continue
nj = j + w
if nj > m:
nj = m
val = dp[j] + c
if val < dp[nj]:
dp[nj] = val
print("error"if dp[m] >= inf else dp[m])
if __name__ == "__main__":
solve()
- >
Cpp
#include<bits/stdc++.h>
usingnamespacestd;
intmain(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int m, n;
if (!(cin >> m)) {
return0;
}
cin >> n;
vector<int> siz(n), swp(n), rec(n);
for (int i = 0; i < n; ++i) {
cin >> siz[i];
}
for (int i = 0; i < n; ++i) {
cin >> swp[i];
}
for (int i = 0; i < n; ++i) {
cin >> rec[i];
}
longlong tot = 0;
vector<int> cst(n);
for (int i = 0; i < n; ++i) {
tot += siz[i];
// 同一个张量的两种操作只能二选一,取更小代价即可。
cst[i] = min(swp[i], rec[i]);
}
if (tot < m) {
cout << "error\n";
return0;
}
constlonglong INF = (longlong)4e18;
vector<longlong> dp(m + 1, INF);
dp[0] = 0;
for (int i = 0; i < n; ++i) {
int w = siz[i];
int c = cst[i];
// 倒序转移,表示 0/1 背包。
for (int j = m; j >= 0; --j) {
if (dp[j] == INF) {
continue;
}
int nj = min(m, j + w);
dp[nj] = min(dp[nj], dp[j] + c);
}
}
if (dp[m] == INF) {
cout << "error\n";
} else {
cout << dp[m] << '\n';
}
return0;
}
- >
Java
import java.io.*;
import java.util.*;
publicclassMain{
staticclassFastScanner{
privatefinal InputStream in = System.in;
privatefinalbyte[] buf = newbyte[1 << 16];
privateint ptr = 0;
privateint len = 0;
privateintread()throws IOException {
if (ptr >= len) {
len = in.read(buf);
ptr = 0;
if (len <= 0) {
return -1;
}
}
return buf[ptr++];
}
intnextInt()throws IOException {
int c;
do {
c = read();
} while (c <= ' ' && c >= 0);
int sign = 1;
if (c == '-') {
sign = -1;
c = read();
}
int val = 0;
while (c > ' ') {
val = val * 10 + c - '0';
c = read();
}
return val * sign;
}
}
publicstaticvoidmain(String[] args)throws Exception {
FastScanner fs = new FastScanner();
int m;
try {
m = fs.nextInt();
} catch (Exception e) {
return;
}
int n = fs.nextInt();
int[] siz = newint[n];
int[] swp = newint[n];
int[] rec = newint[n];
for (int i = 0; i < n; i++) {
siz[i] = fs.nextInt();
}
for (int i = 0; i < n; i++) {
swp[i] = fs.nextInt();
}
for (int i = 0; i < n; i++) {
rec[i] = fs.nextInt();
}
long sum = 0;
int[] cst = newint[n];
for (int i = 0; i < n; i++) {
sum += siz[i];
// 每个张量的真实代价取两种方式里的较小值。
cst[i] = Math.min(swp[i], rec[i]);
}
if (sum < m) {
System.out.println("error");
return;
}
long INF = Long.MAX_VALUE / 4;
long[] dp = newlong[m + 1];
Arrays.fill(dp, INF);
dp[0] = 0;
for (int i = 0; i < n; i++) {
int w = siz[i];
int c = cst[i];
// 倒序转移,避免同一张量被重复使用。
for (int j = m; j >= 0; j--) {
if (dp[j] >= INF) {
continue;
}
int nj = Math.min(m, j + w);
long val = dp[j] + c;
if (val < dp[nj]) {
dp[nj] = val;
}
}
}
if (dp[m] >= INF) {
System.out.println("error");
} else {
System.out.println(dp[m]);
}
}
}
第三题: 终端语音指令近邻识别
问题描述
卢小姐正在做终端语音交互模块的离线验证。
一段语音经过特征提取后,会被表示成一个三维向量。现在已经收集到了若干个带标签的历史样本,需要用 KNN 对一个新的语音特征向量做分类。
对于两个三维向量:
它们之间的欧式距离定义为:
KNN 的分类规则如下:
计算待分类向量与所有已知样本之间的距离。 取距离最近的 个样本作为邻居。 统计这 个邻居中出现次数最多的类别。 输出该类别。
保证最终多数表决结果唯一,不会出现并列情况。
输入格式
第 行输入两个正整数 和 ,分别表示已知样本数量与邻居数量。
接下来 行,每行输入 个数:
- >
前 个数表示一个三维语音特征向量。 - >
最后 个正整数表示该样本所属类别 label。
第 行输入 个数,表示待分类的语音特征向量。
输出格式
输出一个正整数,表示该待分类语音向量的类别。
样例输入
6 3
1.0 1.1 0.9 1
0.9 1.0 1.2 1
1.2 0.8 1.0 1
3.0 3.1 2.9 2
2.9 3.2 3.0 2
3.1 2.8 3.2 2
1.1 0.9 1.1
样例输出
1
数据范围
- >
。 - >
特征值为实数。 - >
类别为正整数。 - >
保证最终分类结果唯一。
题解
这道题就是标准的三维 KNN 模拟。
#核心做法
对每个已知样本,计算它与待分类向量之间的距离,然后按距离从小到大排序,取前 个做多数表决即可。
由于欧式距离里有平方根:
但平方根函数是单调的,所以比较大小时其实没必要真的开方。直接比较平方距离:
结果完全一样,而且还能少一点浮点误差。
#具体步骤
把每个训练样本与查询点之间的平方距离算出来。 把所有样本按平方距离升序排序。 扫前 个样本,统计每个标签出现次数。 取出现次数最多的标签作为答案。
题目已经保证多数表决结果唯一,因此不需要再单独设计并列时的额外判定规则。
#为什么这样做是对的
KNN 的定义就是:
- >
先找最近的 个邻居。 - >
再由这 个邻居做投票。
我们完整执行了这两个步骤,因此得到的结果就是题目要求的分类结果。
#边界与细节
这题最容易忽略的地方有两个:
K可能等于 。 这时就等于让所有训练样本一起投票,代码里不能默认只会取很小的一段前缀。浮点距离不需要真的开方。 只比较平方距离即可,既更快,也更稳。
#复杂度分析
设训练样本数为 。
- >
计算所有距离的复杂度是 。 - >
排序复杂度是 。 - >
统计前 个邻居复杂度是 。
总时间复杂度为:
空间复杂度为:
参考代码
- >
Python
import sys
input = lambda:sys.stdin.readline().strip()
defsolve():
n, k = map(int, input().split())
arr = []
pts = []
for _ in range(n):
x, y, z, lab = input().split()
pts.append((float(x), float(y), float(z), int(lab)))
qx, qy, qz = map(float, input().split())
for x, y, z, lab in pts:
# 只比较平方距离即可,不需要真的开方。
d = (x - qx) * (x - qx) + (y - qy) * (y - qy) + (z - qz) * (z - qz)
arr.append((d, lab))
arr.sort()
cnt = {}
ans = -1
bst = -1
for i in range(k):
lab = arr[i][1]
cnt[lab] = cnt.get(lab, 0) + 1
# 题目保证结果唯一,因此只需维护最大出现次数。
if cnt[lab] > bst:
bst = cnt[lab]
ans = lab
print(ans)
if __name__ == "__main__":
solve()
- >
Cpp
#include<bits/stdc++.h>
usingnamespacestd;
intmain(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, k;
cin >> n >> k;
vector<pair<double, int>> arr;
arr.reserve(n);
vector<array<double, 3>> pts(n);
vector<int> lab(n);
for (int i = 0; i < n; ++i) {
cin >> pts[i][0] >> pts[i][1] >> pts[i][2] >> lab[i];
}
double qx, qy, qz;
cin >> qx >> qy >> qz;
for (int i = 0; i < n; ++i) {
double dx = pts[i][0] - qx;
double dy = pts[i][1] - qy;
double dz = pts[i][2] - qz;
// 比较平方距离即可,顺序与欧式距离一致。
double d2 = dx * dx + dy * dy + dz * dz;
arr.push_back({d2, lab[i]});
}
sort(arr.begin(), arr.end());
unordered_map<int, int> cnt;
int ans = -1;
int bst = -1;
for (int i = 0; i < k; ++i) {
int x = arr[i].second;
int now = ++cnt[x];
// 题目保证最终答案唯一,这里直接维护最高频次即可。
if (now > bst) {
bst = now;
ans = x;
}
}
cout << ans << '\n';
return0;
}
- >
Java
import java.io.*;
import java.util.*;
publicclassMain{
staticclassFastScanner{
privatefinal InputStream in = System.in;
privatefinalbyte[] buf = newbyte[1 << 16];
privateint ptr = 0;
privateint len = 0;
privateintread()throws IOException {
if (ptr >= len) {
len = in.read(buf);
ptr = 0;
if (len <= 0) {
return -1;
}
}
return buf[ptr++];
}
String next()throws IOException {
int c;
do {
c = read();
} while (c <= ' ' && c >= 0);
if (c < 0) {
returnnull;
}
StringBuilder sb = new StringBuilder();
while (c > ' ') {
sb.append((char) c);
c = read();
}
return sb.toString();
}
}
staticclassNode{
double d2;
int label;
Node(double d2, int label) {
this.d2 = d2;
this.label = label;
}
}
publicstaticvoidmain(String[] args)throws Exception {
FastScanner fs = new FastScanner();
int n = Integer.parseInt(fs.next());
int k = Integer.parseInt(fs.next());
double[][] pts = newdouble[n][3];
int[] labels = newint[n];
for (int i = 0; i < n; i++) {
pts[i][0] = Double.parseDouble(fs.next());
pts[i][1] = Double.parseDouble(fs.next());
pts[i][2] = Double.parseDouble(fs.next());
labels[i] = Integer.parseInt(fs.next());
}
double qx = Double.parseDouble(fs.next());
double qy = Double.parseDouble(fs.next());
double qz = Double.parseDouble(fs.next());
ArrayList<Node> arr = new ArrayList<>();
for (int i = 0; i < n; i++) {
double dx = pts[i][0] - qx;
double dy = pts[i][1] - qy;
double dz = pts[i][2] - qz;
// 只保留平方距离,比较大小已经足够。
double d2 = dx * dx + dy * dy + dz * dz;
arr.add(new Node(d2, labels[i]));
}
arr.sort(Comparator.comparingDouble(a -> a.d2));
HashMap<Integer, Integer> cnt = new HashMap<>();
int ans = -1;
int best = -1;
for (int i = 0; i < k; i++) {
int lab = arr.get(i).label;
int now = cnt.getOrDefault(lab, 0) + 1;
cnt.put(lab, now);
// 题目保证答案唯一,直接维护最大频次即可。
if (now > best) {
best = now;
ans = lab;
}
}
System.out.println(ans);
}
}
✨ 写在最后:
网站最近上线了八股和选额的功能啦。
最近卡片刷题已经全员开放了,八股和选择题都能直接刷。 我自己还挺喜欢这种刷法的,先看题、自己想,再翻面看答案,会比一直往下看题库更有“在准备面试”的感觉一点。 如果你最近也在刷八股,准备面试、可以来bishipass试试看哦。



