给你一个 m x n
的二元矩阵 matrix
,且所有值被初始化为 0
。请你设计一个算法,随机选取一个满足 matrix[i][j] == 0
的下标 (i, j)
,并将它的值变为 1
。所有满足 matrix[i][j] == 0
的下标 (i, j)
被选取的概率应当均等。
尽量最少调用内置的随机函数,并且优化时间和空间复杂度。
实现 Solution
类:
Solution(int m, int n)
使用二元矩阵的大小m
和n
初始化该对象int[] flip()
返回一个满足matrix[i][j] == 0
的随机下标[i, j]
,并将其对应格子中的值变为1
void reset()
将矩阵中所有的值重置为0
示例:
输入 ["Solution", "flip", "flip", "flip", "reset", "flip"] [[3, 1], [], [], [], [], []] 输出 [null, [1, 0], [2, 0], [0, 0], null, [2, 0]] 解释 Solution solution = new Solution(3, 1); solution.flip(); // 返回 [1, 0],此时返回 [0,0]、[1,0] 和 [2,0] 的概率应当相同 solution.flip(); // 返回 [2, 0],因为 [1,0] 已经返回过了,此时返回 [2,0] 和 [0,0] 的概率应当相同 solution.flip(); // 返回 [0, 0],根据前面已经返回过的下标,此时只能返回 [0,0] solution.reset(); // 所有值都重置为 0 ,并可以再次选择下标返回 solution.flip(); // 返回 [2, 0],此时返回 [0,0]、[1,0] 和 [2,0] 的概率应当相同
提示:
1 <= m, n <= 104
- 每次调用
flip
时,矩阵中至少存在一个值为 0 的格子。 - 最多调用
1000
次flip
和reset
方法。
分析,实现,【掉坑】,改进
直接模拟,思路
- 声明一个
m * n
的二维数组,随机翻转一个 - 但是接下来的问题就是,如何随机保证横轴和纵轴的坐标的随机概率相等,所以我们可以换个思路,直接声明一个
m * n
长度的一维数组,在这个长度内随机,而每一个随机数都可以映射回原来的矩阵中,从而保证每一个格子的随机概率是相等的 - 下一个问题,之前已经翻转到过的位置不能再随机到。
- 那么我们就需要记录一下剩余多少位置可以随机,从而在这个范围内进行随机操作
- 如果在这个范围内随机到的数字是之前原数组中已经翻转过的,那么就从这个位置开始一直往后找到未被翻转的过的那个位置
- 最终将随机到的位置坐标重新映射回矩阵坐标中返回即可
代码
class Solution {
int[] matrix;
int m;
int n;
Random random;
int total;
public Solution(int m, int n) {
random = new Random();
this.m = m;
this.n = n;
reset();
}
public int[] flip() {
int idx = random.nextInt(total);
while (matrix[idx]==1){
idx++;
}
total--;
matrix[idx] = 1;
return new int[]{idx/n,idx%n};
}
public void reset() {
matrix = new int[m * n];
total = matrix.length;
}
}
提交,跑到第19个测试用例,报OOM了,看下入参是10000 * 10000
,按照道理java中数组的最大长度根据具体数据类型和JVM配置实际情况应该是一个接近Integer.MAX_VALU - 8
的值,最大就是这么多,此时OOM的话,必然应该是限制了内存大小了
那么我们不妨再换个思路,题面中给出了 最多调用 1000 次 flip 和 reset 方法
,我们就可以不用记录整个数组的情况,而是换一种角度,只记录哪些位置被翻转过,使用一个HashSet来存储这些位置信息,那么修改后代码如下
代码
class Solution {
int m;
int n;
Random random;
int total;
HashSet<Integer> hashSet;
public Solution(int m, int n) {
random = new Random();
this.m = m;
this.n = n;
reset();
}
public int[] flip() {
int idx = random.nextInt(total);
while (hashSet.contains(idx)){
idx++;
}
total--;
hashSet.add(idx);
return new int[]{idx/n,idx%n};
}
public void reset() {
hashSet = new HashSet<>();
total = m * n;
}
}