Segment tree with lazy propagation for multiple of 3 - algorithm

Abridged problem: You're given an array of n elements, initially they are all 0.
You will receive two types of query: 0 index1 index2, in this case you have to increase by one all elements in range index1 index2(included).
Second type: 1 index1 index2, in this case you have to print a number rapresenting how many elements between index1 and index2(included) are divisible by 3.
Of course, as n is very large(10^6) the good approach is to use segment tree to store intervals, and also to use lazy propagation to update the tree in log n.
But I actually really don't know how to apply lazy propagation here, because you have to keep into account three possible states for every number( may be 3k,3k+1,3k+2), and not just two as the flipping coins problem.
If I put a flag on some interval that is included in the interval of my query, I have to update it looking at the original array and at its value, but when I have to update the son of this interval I have to do the same again and this is a wasteful of time....
Any better idea? I search on the net but found nothing ...
EDIT: I follow your suggestions and I code this( C++), and works for some base cases, but when I submit it I get just 10/100 points, what is wrong with it ? (I know it's a bit long and there are no much comments but it's a simple Segment Tree with lazy propagation, if you don't understand something, please tell me!
NOTE: st[p].zero contains elements that are 0 mod 3 in interval stored in index p, st[p].one elements 1 mod 3, and st[p].two elements 2 mod 3; When I update I shift of one position these elements(0->1, 1->2, 2->0) and I use lazy. On updating, I return a pair < int , pair< int, int > >, just a simple way to store a triple of numbers, In this way a can return the difference of numbers 0,1,2 mod 3.
int sol;
struct mod{
mod(){ zero=0; one=0;two=0;}
int zero;
int one;
int two;
};
class SegmentTree {
public: int lazy[MAX_N];
mod st[MAX_N];
int n;
int left (int p) { return p << 1; }
int right(int p) { return (p << 1) + 1; }
void build(int p, int L, int R){
if(L == R)
st[p].zero=1;
else{
st[p].zero = R - L + 1;
build(left(p), L, (L + R) / 2);
build(right(p), ((L + R) / 2) + 1, R);
}
return;
}
void query(int p, int L, int R, int i, int j) {
if (L > R || i > R || j < L) return;
if(lazy[p]!=0){ // Check if this no has to be updated
for(int k=0;k<lazy[p];k++){
swap(st[p].zero,st[p].two);
swap(st[p].one, st[p].two);
}
if(L != R){
lazy[left(p)] = (lazy[left(p)] + lazy[p]) % 3;
lazy[right(p)] = (lazy[right(p)] + lazy[p]) % 3;
}
lazy[p] = 0;
}
if (L >= i && R <= j) { sol += st[p].zero; return; }
query(left(p) , L , (L+R) / 2, i, j);
query(right(p), (L+R) / 2 + 1, R , i, j);
return;
}
pair < int, ii > update_tree(int p, int L, int R, int i, int j) {
if (L > R || i > R || j < L){
pair< int, pair< int, int > > PP; PP.first=PP.second.first=PP.second.second=INF;
return PP;
}
if(lazy[p]!=0){ // Check if this no has to be updated
for(int k=0;k<lazy[p];k++){
swap(st[p].zero,st[p].two);
swap(st[p].one, st[p].two);
}
if(L != R){
lazy[left(p)] = (lazy[left(p)] + lazy[p]) % 3;
lazy[right(p)] = (lazy[right(p)] + lazy[p]) % 3;
}
lazy[p] = 0;
}
if(L>=i && R<=j){
swap(st[p].zero, st[p].two);
swap(st[p].one, st[p].two);
if(L != R){
lazy[left(p)] = (lazy[left(p)] + 1) % 3;
lazy[right(p)] = (lazy[right(p)] + 1) % 3;
}
pair< int, pair< int, int > > t; t.first = st[p].zero-st[p].one; t.second.first = st[p].one-st[p].two; t.second.second = st[p].two-st[p].zero;
return t;
}
pair< int, pair< int, int > > s = update_tree(left(p), L, (L+R)/2, i, j); // Updating left child
pair< int, pair< int, int > > s2 = update_tree(right(p), 1+(L+R)/2, R, i, j); // Updating right child
pair< int, pair< int, int > > d2;
d2.first = ( (s.first!=INF ? s.first : 0) + (s2.first!=INF ? s2.first : 0) ); // Calculating difference from the ones given by the children
d2.second.first = ( (s.second.first!=INF ? s.second.first : 0) + (s2.second.first!=INF ? s2.second.first : 0) );
d2.second.second = ( (s.second.second!=INF ? s.second.second : 0) + (s2.second.second!=INF ? s2.second.second : 0) );
st[p].zero += d2.first; st[p].one += d2.second.first; st[p].two += d2.second.second; // Updating root
return d2; // Return difference
}
public:
SegmentTree(const vi &_A) {
n = (int)_A.size();
build(1, 0, n - 1);
}
void query(int i, int j) { return query(1, 0, n - 1, i, j); }
pair< int, pair< int, int > > update_tree(int i, int j) {
return update_tree(1, 0, n - 1, i, j); }
};
int N,Q;
int main() {
FILE * in; FILE * out;
in = fopen("input.txt","r"); out = fopen("output.txt","w");
fscanf(in, "%d %d" , &N, &Q);
//cin>>N>>Q;
int arr[N];
vi A(arr,arr+N);
SegmentTree *st = new SegmentTree(A);
for(int i=0;i<Q;i++){
int t,q,q2;
fscanf(in, "%d %d %d " , &t, &q, &q2);
//cin>>t>>q>>q2;
if(q > q2) swap(q, q2);
if(t){
sol=0;
st->query(q,q2);
fprintf(out, "%d\n", sol);
//cout<<sol<<endl;
}
else{
pair<int, pair< int, int > > t = st->update_tree(q,q2);
}
}
fclose(in); fclose(out);
return 0;
}

You can store two values in each node:
1)int count[3] - how many there are 0, 1 and 2 in this node's segment.
2)int shift - shift value(initially zero).
The operations are performed in the following way(I use pseudo code):
add_one(node v)
v.shift += 1
v.shift %= 3
propagate(node v)
v.left_child.shift += v.shift
v.left_child.shift %= 3
v.right_child.shift += v.shift
v.right_child.shift %= 3
v.shift = 0
for i = 0..2:
v.count[i] = get_count(v.left, i) + get_count(v.right, i)
get_count(node v, int remainder)
return v.count[(remainder + v.shift) % 3]
The number of elements divisible by 3 for a node v is get_count(v, 0).
Update for a node is add_one operation. In general, it can be used as an ordinary segment tree(to answer range queries).
The entire tree update looks like that:
update(node v, int left, int right)
if v is fully covered by [left; right]
add_one(v)
else:
propagate(v)
if [left; right] intersects with the left child:
update(v.left, left, right)
if[left; right] intersects with the right child:
update(v.right, left, right)
for i = 0..2:
v.count[i] = get_count(v.left, i) + get_count(v.right, i)
Getting the number of elements divisible by 3 is done in similar manner.

It seems that you never have to care about the values of the elements, only their values modulo 3.
Keep a segment tree, using lazy updates as you suggest. Each node knows the number of things that are 0, 1, and 2 modulo 3 (memoization).
Each update hits log(n) nodes. When an update hits a node, you remember that you have to update the descendants (lazy update) and you cycle the memoized number of things in the subtree that are 0, 1, and 2 modulo 3.
Each query hits log(n) nodes; they're the same nodes an update of the same interval would hit. Whenever a query comes across a lazy update that hasn't been done, it pushes the update down to the descendants before recursing. Apart from that, all it does is it adds up the number of elements that are 0 modulo 3 in each maximal subtree completely contained in the query interval.

Related

How do I convert this recursion to dp

I am trying to solve this problem. I believe my solution does work, but it takes more time. This is my solution - at each step, I calculate the minimum sum if I choose i or i+1 index.
class Solution
{
public int minimumTotal(List<List<Integer>> triangle)
{
return minSum( triangle, 0, 0 );
}
public int minSum( List<List<Integer>> triangle, int row, int index )
{
if( row >= triangle.size() )
return 0;
int valueAtThisRow = triangle.get(row).get(index);
return Math.min( valueAtThisRow + minSum(triangle, row+1, index),
valueAtThisRow + minSum(triangle, row+1, index+1));
}
}
I think more appropriate way is to use DP. Please share any suggestions on how I can convert this to a DP.
I think that bottom-up solution is simpler here and does not require additional memory, we can store intermediate results in the same list/array cells
Walk through levels of triangle starting from the level before last, choosing the best result from two possible for every element, then move to upper level and so on. After that triangle[0][0] will contain minimum sum
for (row = n - 2; row >= 0; row--)
for (i = 0; i <= row; i++)
triangle[row][i] += min(triangle[row+1][i], triangle[row+1][i+1])
(tried python version, accepted)
DP bottom up approach:
class Solution {
public:
int minimumTotal(vector<vector<int> > &triangle)
{
vector<int> mini = triangle[triangle.size()-1];
for ( int i = triangle.size() - 2; i>= 0 ; --i )
for ( int j = 0; j < triangle[i].size() ; ++ j )
mini[j] = triangle[i][j] + min(mini[j],mini[j+1]);
return mini[0];
}
};
Python version:
def minimumTotal(self, triangle: List[List[int]]) -> int:
if not triangle: return
size = len(triangle)
res = triangle[-1] # last row
for r in range(size-2, -1, -1): # bottom up
for c in range(len(triangle[r])):
res[c] = min(res[c], res[c+1]) + triangle[r][c]
return res[0]
Top down solution: we compute shortest paths (agenda) to each cells row by row starting from the triangle's top.
C# code
public int MinimumTotal(IList<IList<int>> triangle) {
int[] agenda = new int[] {triangle[0][0]};
for (int r = 1; r < triangle.Count; ++r) {
int[] next = triangle[r].ToArray();
for (int c = 0; c < next.Length; ++c)
if (c == 0)
next[c] += agenda[c];
else if (c == next.Length - 1)
next[c] += agenda[c - 1];
else
next[c] += Math.Min(agenda[c - 1], agenda[c]);
agenda = next;
}
return agenda.Min();
}

Find the number of players cannot win the game?

We are given n players, each player has 3 values assigned A, B and C.
A player i cannot win if there exists another player j with all 3 values A[j] > A[i], B[j] > B[i] and C[j] > C[i]. We are asked to find number of players cannot win.
I tried this problem using brute force, which is a linear search over players array. But it's showing TLE.
For each player i, I am traversing the complete array to find if there exists any other player j for which the above condition holds true.
Code :
int count_players_cannot_win(vector<vector<int>> values) {
int c = 0;
int n = values.size();
for(int i = 0; i < n; i++) {
for(int j = 0; j!= i && j < n; j++) {
if(values[i][0] < values[j][0] && values[i][1] < values[j][1] && values[i][2] < values[j][2]) {
c += 1;
break;
}
}
}
return c;
}
And this approach is O(n^2), as for every player we are traversing the complete array. Thus it is giving the TLE.
Sample testcase :
Sample Input
3(number of players)
A B C
1 4 2
4 3 2
2 5 3
Sample Output :
1
Explanation :
Only player1 cannot win as there exists player3 whose all 3 values(A, B and C) are greater than that of player1.
Contraints :
n(number of players) <= 10^5
What would be optimal way to solve this problem?
Solution:
int n;
const int N = 4e5 + 1;
int tree[N];
int get_max(int i, int l, int r, int L) { // range query of max in range v[B+1: n]
if(r < L || n <= l)
return numeric_limits<int>::min();
else if(L <= l)
return tree[i];
int m = (l + r)/2;
return max(get_max(2*i+1, l, m, L), get_max(2*i+2, m+1, r, L));
}
void update(int i, int l, int r, int on, int v) { // point update in tree[on]
if(r < on || on < l)
return;
else if(l == r) {
tree[i] = max(tree[i], v);
return;
}
int m = (l + r)/2;
update(2*i+1, l, m, on, v);
update(2*i+2, m + 1, r, on, v);
tree[i] = max(tree[2*i+1], tree[2*i+2]);
}
bool comp(vector<int> a, vector<int> b) {
return a[0] != b[0] ? a[0] > b[0] : a[1] < b[1];
}
int solve(vector<vector<int>> &v) {
n = v.size();
vector<int> b(n, 0); // reduce the scale of range from [0,10^9] to [0,10^5]
for(int i = 0; i < n; i++) {
b[i] = v[i][1];
}
for(int i = 0; i < n; i++) {
cin >> v[i][2];
}
// sort on 0th col in reverse order
sort(v.begin(), v.end(), comp);
sort(b.begin(), b.end());
int ans = 0;
for(int i = 0; i < n;) {
int j = i;
while(j < n && v[j][0] == v[i][0]) {
int B = v[j][1];
int pos = lower_bound(b.begin(), b.end(), B) - b.begin(); // position of B in b[]
int mx = get_max(0, 0, n - 1, pos + 1);
if(mx > v[j][2])
ans += 1;
j++;
}
while(i < j) {
int B = v[i][1];
int C = v[i][2];
int pos = lower_bound(b.begin(), b.end(), B) - b.begin(); // position of B in b[]
update(0, 0, n - 1, pos, C);
i++;
}
}
return ans;
}
This solution uses segment tree, and thus solves the problem in
time O(n*log(n)) and space O(n).
Approach is explained in the accepted answer by #Primusa.
First lets assume that our input comes in the form of a list of tuples T = [(A[0], B[0], C[0]), (A[1], B[1], C[1]) ... (A[N - 1], B[N - 1], C[N - 1])]
The first observation we can make is that we can sort on T[0] (in reverse order). Then for each tuple (a, b, c), to determine if it cannot win, we ask if we've already seen a tuple (d, e, f) such that e > b && f > c. We don't need to check the first element because we are given that d > a* since T is sorted in reverse.
Okay, so now how do we check this second criteria?
We can reframe it like so: out of all tuples (d, e, f), that we've already seen with e > b, what is the maximum value of f? If the max value is greater than c, then we know that this tuple cannot win.
To handle this part we can use a segment tree with max updates and max range queries. When we encounter a tuple (d, e, f), we can set tree[e] = max(tree[e], f). tree[i] will represent the third element with i being the second element.
To answer a query like "what is the maximum value of f such that e > b", we do max(tree[b+1...]), to get the largest third element over a range of possible second elements.
Since we are only doing suffix queries, you can get away with using a modified fenwick tree, but it is easier to explain with a segment tree.
This will give us an O(NlogN) solution, for sorting T and doing O(logN) work with our segment tree for every tuple.
*Note: this should actually be d >= a. However it is easier to explain the algorithm when we pretend everything is unique. The only modification you need to make to accommodate duplicate values of the first element is to process your queries and updates in buckets of tuples of the same value. This means that we will perform our check for all tuples with the same first element, and only then do we update tree[e] = max(tree[e], f) for all of those tuples we performed the check on. This ensures that no tuple with the same first value has updated the tree already when another tuple is querying the tree.

Counting inversions in a segment with updates

I'm trying to solve a problem which goes like this:
Problem
Given an array of integers "arr" of size "n", process two types of queries. There are "q" queries you need to answer.
Query type 1
input: l r
result: output number of inversions in [l, r]
Query type 2
input: x y
result: update the value at arr [x] to y
Inversion
For every index j < i, if arr [j] > arr [i], the pair (j, i) is one inversion.
Input
n = 5
q = 3
arr = {1, 4, 3, 5, 2}
queries:
type = 1, l = 1, r = 5
type = 2, x = 1, y = 4
type = 1, l = 1, r = 5
Output
4
6
Constraints
Time: 4 secs
1 <= n, q <= 100000
1 <= arr [i] <= 40
1 <= l, r, x <= n
1 <= y <= 40
I know how to solve a simpler version of this problem without updates, i.e. to simply count the number of inversions for each position using a segment tree or fenwick tree in O(N*log(N)). The only solution I have to this problem is O(q*N*log(N)) (I think) with segment tree other than the O(q*N2) trivial algorithm. This however does not fit within the time constraints of the problem. I would like to have hints towards a better algorithm to solve the problem in O(N*log(N)) (if it's possible) or maybe O(N*log2(N)).
I first came across this problem two days ago and have been spending a few hours here and there to try and solve it. However, I'm finding it non-trivial to do so and would like to have some help/hints regarding the same. Thanks for your time and patience.
Updates
Solution
With the suggestion, answer and help by Tanvir Wahid, I've implemented the source code for the problem in C++ and would like to share it here for anyone who might stumble across this problem and not have an intuitive idea on how to solve it. Thank you!
Let's build a segment tree with each node containing information about how many inversions exist and the frequency count of elements present in its segment of authority.
node {
integer inversion_count : 0
array [40] frequency : {0...0}
}
Building the segment tree and handling updates
For each leaf node, initialise inversion count to 0 and increase frequency of the represented element from the input array to 1. The frequency of the parent nodes can be calculated by summing up frequencies of the left and right childrens. The inversion count of parent nodes can be calculated by summing up the inversion counts of left and right children nodes added with the new inversions created upon merging the two segments of their authority which can be calculated using the frequencies of elements in each child. This calculation basically finds out the product of frequencies of bigger elements in the left child and frequencies of smaller elements in the right child.
parent.inversion_count = left.inversion_count + right.inversion_count
for i in [39, 0]
for j in [0, i)
parent.inversion_count += left.frequency [i] * right.frequency [j]
Updates are handled similarly.
Answering range queries on inversion counts
To answer the query for the number of inversions in the range [l, r], we calculate the inversions using the source code attached below.
Time Complexity: O(q*log(n))
Note
The source code attached does break some good programming habits. The sole purpose of the code is to "solve" the given problem and not to accomplish anything else.
Source Code
/**
Lost Arrow (Aryan V S)
Saturday 2020-10-10
**/
#include "bits/stdc++.h"
using namespace std;
struct node {
int64_t inv = 0;
vector <int> freq = vector <int> (40, 0);
void combine (const node& l, const node& r) {
inv = l.inv + r.inv;
for (int i = 39; i >= 0; --i) {
for (int j = 0; j < i; ++j) {
// frequency of bigger numbers in the left * frequency of smaller numbers on the right
inv += 1LL * l.freq [i] * r.freq [j];
}
freq [i] = l.freq [i] + r.freq [i];
}
}
};
void build (vector <node>& tree, vector <int>& a, int v, int tl, int tr) {
if (tl == tr) {
tree [v].inv = 0;
tree [v].freq [a [tl]] = 1;
}
else {
int tm = (tl + tr) / 2;
build(tree, a, 2 * v + 1, tl, tm);
build(tree, a, 2 * v + 2, tm + 1, tr);
tree [v].combine(tree [2 * v + 1], tree [2 * v + 2]);
}
}
void update (vector <node>& tree, int v, int tl, int tr, int pos, int val) {
if (tl == tr) {
tree [v].inv = 0;
tree [v].freq = vector <int> (40, 0);
tree [v].freq [val] = 1;
}
else {
int tm = (tl + tr) / 2;
if (pos <= tm)
update(tree, 2 * v + 1, tl, tm, pos, val);
else
update(tree, 2 * v + 2, tm + 1, tr, pos, val);
tree [v].combine(tree [2 * v + 1], tree [2 * v + 2]);
}
}
node inv_cnt (vector <node>& tree, int v, int tl, int tr, int l, int r) {
if (l > r)
return node();
if (tl == l && tr == r)
return tree [v];
int tm = (tl + tr) / 2;
node result;
result.combine(inv_cnt(tree, 2 * v + 1, tl, tm, l, min(r, tm)), inv_cnt(tree, 2 * v + 2, tm + 1, tr, max(l, tm + 1), r));
return result;
}
void solve () {
int n, q;
cin >> n >> q;
vector <int> a (n);
for (int i = 0; i < n; ++i) {
cin >> a [i];
--a [i];
}
vector <node> tree (4 * n);
build(tree, a, 0, 0, n - 1);
while (q--) {
int type, x, y;
cin >> type >> x >> y;
--x; --y;
if (type == 1) {
node result = inv_cnt(tree, 0, 0, n - 1, x, y);
cout << result.inv << '\n';
}
else if (type == 2) {
update(tree, 0, 0, n - 1, x, y);
}
else
assert(false);
}
}
int main () {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.precision(10);
std::cout << std::fixed << std::boolalpha;
int t = 1;
// std::cin >> t;
while (t--)
solve();
return 0;
}
arr[i] can be at most 40. We can use this to our advantage. What we need is a segment tree. Each node will hold 41 values (A long long int which represents inversions for this range and a array of size 40 for count of each numbers. A struct will do). How do we merge two children of a node. We know inversions for left child and right child. Also know frequency of each numbers in both of them. Inversion of parent node will be summation of inversions of both children plus number of inversions between left and right child. We can easily find inversions between two children from frequency of numbers. Query can be done in similar way. Complexity O(40*qlog(n))

finding minimum number in a integer

Given an n-digit Number and a number 'k'. You have to remove ‘k’ digits from the number and give the shortest number from the remaining ‘n-k’ digits such that the sequence of digits remains same. For example, if the number is 637824 and k = 3. So you have to remove 3 digits from the given number. The number formed from the remaining digits should be the smallest possible and the sequence of digits must not be changed. So the output should be 324.
Approach which i have used same as inclusion exclusion logic:
Input 63119 and K=2:
Select 9 + minimum of(6311) = 19 or
don't select 9, select 1 + minimum of(631) = 11 and finally take minimum of all.
For input 4567813 and k=3:
Select 3 and 1 along with minimum of(45678) = 413
I am using recursion logic to get this right but i am not able to implement this with code and i am spent force now. I need help with this recursion. I am not after better solution.
#define min(a, b) ((a)>(b)?(b):(a));
int minimum(char *s, int i, int j)
{
if (i == j)
return s[i] - '0';
return min(s[j]-'0', minimum(s, i, j-1));
}
int add_up(char *s, int i, int j)
{
int sum = 0, mul = 1;
while(i < j) {
sum = sum + (s[j] - '0')*mul;
j--;mul *= 10;
}
return sum;
}
int foo(char *s, int size, int j, int k)
{
int sum = 0, i, mul = 1;
if (k < 0 || j > size || j < 0)
return 0;
if ((k == 0) && (j != 0))
return add_up(s, 0, j);
if ((k == 1) && (j != 0))
return minimum(s, 0, j);
if (k-1 == j)
return add_up(s, 0, j);
for (i=k;i>=0;i--) {
sum += min((s[j]-'0')+10*foo(s, size, j-1, k-1), foo(s, size, j-1, k));
}
return sum;
}
int main()
{
char s[] = {"4567813"};
printf("%d\n", foo(s, strlen(s)-1, strlen(s)-1, 2));
return 0;
}
You have said finally take minimum of all but you have taken minimum of all and then added them. You need to take minimum for each i and save minimum over all i in sum. See the code for clarification. Also there is bug in your add_up function, it should add until i<=j.
int add_up(char *s, int i, int j)
{
int sum=0, mul=1;
while(i <= j) { // modification here
sum=sum + (s[j] - '0')*mul;
j--; mul*=10;
}
return sum;
}
int foo(char *s, int size, int j, int k)
{
int sum=INT_MAX, i, mul=1;
if(k < 0 || j > size || j < 0)
return 0;
if((k == 0) && (j != 0))
return add_up(s, 0, j);
if((k == 1) && (j != 0))
return minimum(s, 0, j);
if(k-1 == j)
return add_up(s, 0, j);
for(i=k; i>=0; i--) {
int res=min((s[j]-'0')+10*foo(s, size, j-1, k-1), foo(s, size, j-1, k));
sum=min(sum,res); // minimum over all possible i
}
return sum;
}

Finding multiple entries with binary search

I use standard binary search to quickly return a single object in a sorted list (with respect to a sortable property).
Now I need to modify the search so that ALL matching list entries are returned. How should I best do this?
Well, as the list is sorted, all the entries you are interested in are contiguous. This means you need to find the first item equal to the found item, looking backwards from the index which was produced by the binary search. And the same about last item.
You can simply go backwards from the found index, but this way the solution may be as slow as O(n) if there are a lot of items equal to the found one. So you should better use exponential search: double your jumps as you find more equal items. This way your whole search is still O(log n).
First let's recall the naive binary search code snippet:
int bin_search(int arr[], int key, int low, int high)
{
if (low > high)
return -1;
int mid = low + ((high - low) >> 1);
if (arr[mid] == key) return mid;
if (arr[mid] > key)
return bin_search(arr, key, low, mid - 1);
else
return bin_search(arr, key, mid + 1, high);
}
Quoted from Prof.Skiena:
Suppose we delete the equality test if (s[middle] == key)
return(middle); from the implementation above and return the index low
instead of −1 on each unsuccessful search. All searches will now be
unsuccessful, since there is no equality test. The search will proceed
to the right half whenever the key is compared to an identical array
element, eventually terminating at the right boundary. Repeating the
search after reversing the direction of the binary comparison will
lead us to the left boundary. Each search takes O(lgn) time, so we can
count the occurrences in logarithmic time regardless of the size of
the block.
So, we need two rounds of binary_search to find the lower_bound (find the first number no less than the KEY) and the upper_bound (find the first number bigger than the KEY).
int lower_bound(int arr[], int key, int low, int high)
{
if (low > high)
//return -1;
return low;
int mid = low + ((high - low) >> 1);
//if (arr[mid] == key) return mid;
//Attention here, we go left for lower_bound when meeting equal values
if (arr[mid] >= key)
return lower_bound(arr, key, low, mid - 1);
else
return lower_bound(arr, key, mid + 1, high);
}
int upper_bound(int arr[], int key, int low, int high)
{
if (low > high)
//return -1;
return low;
int mid = low + ((high - low) >> 1);
//if (arr[mid] == key) return mid;
//Attention here, we go right for upper_bound when meeting equal values
if (arr[mid] > key)
return upper_bound(arr, key, low, mid - 1);
else
return upper_bound(arr, key, mid + 1, high);
}
Hope it's helpful :)
If I'm following your question, you have a list of objects which, for the purpose of comparison, look like {1,2,2,3,4,5,5,5,6,7,8,8,9}. A normal search for 5 will hit one of objects that compare as 5, but you want to get them all, is that right?
In that case, I'd suggest a standard binary search which, upon landing on a matching element, starts looking left until it stops matching, and then right (from the first match) again until it stops matching.
Be careful that whatever data structure you are using is not overwriting elements that compare to the same!
Alternatively, consider using a structure that stores elements that compare to the same as a bucket in that position.
I would do two binary searches, one looking for the first element comparing >= the value (in C++ terms, lower_bound) and then one searching for the first element comparing > the value (in C++ terms, upper_bound). The elements from lower_bound to just before upper bound are what you are looking for (in terms of java.util.SortedSet, subset(key, key)).
So you need two different slight modifications to the standard binary search: you still probe and use the comparison at the probe to narrow down the area in which the value you are looking for must lie, but now e.g. for lower_bound if you hit equality, all you know is that the element you are looking for (the first equal value) is somewhere between the first element of the range so far and the value you have just probed - you can't return immediately.
Once you found a match with the bsearch, just recursive bsearch both side until no more match
pseudo code :
range search (type *array) {
int index = bsearch(array, 0, array.length-1);
// left
int upperBound = index -1;
int i = upperBound;
do {
upperBound = i;
i = bsearch(array, 0, upperBound);
} while (i != -1)
// right
int lowerBound = index + 1;
int i = lowerBound;
do {
lowerBound = i;
i = bsearch(array, lowerBound, array.length);
} while (i != -1)
return range(lowerBound, UpperBound);
}
No corner cases are covered though. I think this will keep ur complexity to (O(logN)).
This depends on which implementation of the binary search you use:
In Java and .NET, the binary search will give you an arbitrary element; you must search both ways to get the range that you are looking for.
In C++ you can use equal_range method to produce the result that you want in a single call.
To speed up searches in Java and .NET for cases when the equal range is too long for iterating linearly, you can look for a predecessor element and for the successor element, and take values in the middle of the range that you find, exclusive of the ends.
Should this be too slow because of a second binary search, consider writing your own search that looks for both ends at the same time. This may be a bit tedious, but it should run faster.
I'd start by finding the index of a single element given the sortable property (using "normal" binary search) and then start looking to both left-and-right of the element in the list, adding all elements found to meet the search criterion, stopping at one end when an element doesn't meet the criterion or there are no more elements to traverse, and stopping altogether when both the left-and-right ends meet the stop conditions mentioned before.
This code in Java is counting occurences of target value in a sorted array in O(logN) time in one pass. It's easy to modify it to return list of found indexes, just pass in ArrayList.
Idea is to recursively refine e and b bounds until they become lower and upper boundary for contiguous block having target values;
static int countMatching(int[] arr, int b, int e, int target){
int m = (b+e)/2;
if(e-b<2){
int count = 0;
if(arr[b] == target){
count++;
}
if(arr[e] == target && b!=e){
count++;
}
return count;
}
else if(arr[m] > target){
return countMatching(arr,b,m-1, target);
}
else if(arr[m] < target){
return countMatching(arr, m+1, e, target);
}
else {
return countMatching(arr, b, m-1, target) + 1
+ countMatching(arr, m+1, e, target);
}
}
does your binary search return the element, or the index the element is at? Can you get the index?
Since the list is sorted, all matching elements should appear adjacent. If you can get the index of the item returned in the standard search, you just need to search in both directions from that index until you find non-matches.
Here is the solution by Deril Raju (in the answer above), ported to swift:
func bin_search(_ A: inout [Int], first: Int, last: Int, key: Int, searchLow: Bool) -> Int {
var result = -1
var low = first
var high = last
while low <= high {
let mid = (low + high) / 2
if A[mid] < key {
low = mid + 1
} else if A[mid] > key {
high = mid - 1
} else {
result = mid
if searchLow {
high = mid - 1 // go on searching towards left (lower indices)
} else {
low = mid + 1 // go on searching towards right (higher indices)
}
}
}
return result
}
func bin_search_range(_ A: inout [Int], first: Int, last: Int, key: Int) -> (Int, Int) {
let low = bin_search(&A, first: first, last: last, key: key, searchLow: true)
let high = bin_search(&A, first: first, last: last, key: key, searchLow: false)
return (low, high)
}
func test() {
var A = [1, 2, 3, 3, 3, 4, 4, 4, 4, 5]
assert(bin_search(&A, first: 0, last: A.count - 1, key: 3, searchLow: true) == 2)
assert(bin_search(&A, first: 0, last: A.count - 1, key: 3, searchLow: false) == 4)
assert(bin_search_range(&A, first: 0, last: A.count - 1, key: 3) == (2, 4))
assert(bin_search(&A, first: 0, last: A.count - 1, key: 4, searchLow: true) == 5)
assert(bin_search(&A, first: 0, last: A.count - 1, key: 4, searchLow: false) == 8)
assert(bin_search_range(&A, first: 0, last: A.count - 1, key: 4) == (5, 8))
assert(bin_search_range(&A, first: 0, last: A.count - 1, key: 5) == (9, 9))
assert(bin_search_range(&A, first: 0, last: A.count - 1, key: 0) == (-1, -1))
}
test()
Try this. It works amazingly.
working example, Click here
var arr = [1, 1, 2, 3, "a", "a", "a", "b", "c"]; // It should be sorted array.
// if it arr contain more than one keys than it will return an array indexes.
binarySearch(arr, "a", false);
function binarySearch(array, key, caseInsensitive) {
var keyArr = [];
var len = array.length;
var ub = (len - 1);
var p = 0;
var mid = 0;
var lb = p;
key = caseInsensitive && key && typeof key == "string" ? key.toLowerCase() : key;
function isCaseInsensitive(caseInsensitive, element) {
return caseInsensitive && element && typeof element == "string" ? element.toLowerCase() : element;
}
while (lb <= ub) {
mid = parseInt(lb + (ub - lb) / 2, 10);
if (key === isCaseInsensitive(caseInsensitive, array[mid])) {
keyArr.push(mid);
if (keyArr.length > len) {
return keyArr;
} else if (key == isCaseInsensitive(caseInsensitive, array[mid + 1])) {
for (var i = 1; i < len; i++) {
if (key != isCaseInsensitive(caseInsensitive, array[mid + i])) {
break;
} else {
keyArr.push(mid + i);
}
}
}
if (keyArr.length > len) {
return keyArr;
} else if (key == isCaseInsensitive(caseInsensitive, array[mid - 1])) {
for (var i = 1; i < len; i++) {
if (key != isCaseInsensitive(caseInsensitive, array[mid - i])) {
break;
} else {
keyArr.push(mid - i);
}
}
}
return keyArr;
} else if (key > isCaseInsensitive(caseInsensitive, array[mid])) {
lb = mid + 1;
} else {
ub = mid - 1;
}
}
return -1;
}
Very efficient algorithm for this was found recently.
The algorithm has logarithmic time complexity considering both variables (size of input and amount of searched keys). However searched keys has to be sorted as well.
#define MIDDLE(left, right) ((left) + (((right) - (left)) >> 1))
int bs (const int *arr, int left, int right, int key, bool *found)
{
int middle = MIDDLE(left, right);
while (left <= right)
{
if (key < arr[middle])
right = middle - 1;
else if (key == arr[middle]) {
*found = true;
return middle;
}
else
left = middle + 1;
middle = MIDDLE(left, right);
}
*found = false;
/* left points to the position of first bigger element */
return left;
}
static void _mkbs (const int *arr, int arr_l, int arr_r,
const int *keys, int keys_l, int keys_r, int *results)
{
/* end condition */
if (keys_r - keys_l < 0)
return;
int keys_middle = MIDDLE(keys_l, keys_r);
/* throw away half of keys, if the key on keys_middle is out */
if (keys[keys_middle] < arr[arr_l]) {
_mkbs(arr, arr_l, arr_r, keys, keys_middle + 1, keys_r, results);
return;
}
if (keys[keys_middle] > arr[arr_r]) {
_mkbs(arr, arr_l, arr_r, keys, keys_l, keys_middle - 1, results);
return;
}
bool found;
int pos = bs(arr, arr_l, arr_r, keys[keys_middle], &found);
if (found)
results[keys_middle] = pos;
_mkbs(arr, arr_l, pos - 1, keys, keys_l, keys_middle - 1, results);
_mkbs(arr, (found) ? pos + 1 : pos, arr_r, keys, keys_middle + 1, keys_r, results);
}
void mkbs (const int *arr, int N, const int *keys, int M, int *results)
{ _mkbs(arr, 0, N - 1, keys, 0, M - 1, results); }
Here is the implementation in C and a draft paper intended for publication:
https://github.com/juliusmilan/multi_value_binary_search
Can you please share a use case?
You can use the below code for your problem. The main aim here is first to find the lower bound of the key and then to find the upper bound of the same. Later we get the difference of the indices and we get our answer. Rather than having two different functions, we can use a flag which can be used to find the upper bound and lower bound in the same function.
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
int bin_search(int a[], int low, int high, int key, bool flag){
long long int mid,result=-1;
while(low<=high){
mid = (low+high)/2;
if(a[mid]<key)
low = mid + 1;
else if(a[mid]>key)
high = mid - 1;
else{
result = mid;
if(flag)
high=mid-1;//Go on searching towards left (lower indices)
else
low=mid+1;//Go on searching towards right (higher indices)
}
}
return result;
}
int main() {
int n,k,ctr,lowind,highind;
cin>>n>>k;
//k being the required number to find for
int a[n];
for(i=0;i<n;i++){
cin>>a[i];
}
sort(a,a+n);
lowind = bin_search(a,0,n-1,k,true);
if(lowind==-1)
ctr=0;
else{
highind = bin_search(a,0,n-1,k,false);
ctr= highind - lowind +1;
}
cout<<ctr<<endl;
return 0;
}
You can do two searches: one for index before the range and one for index after. Because the before and after can repeat itself - use float as "unique" key"
static int[] findFromTo(int[] arr, int key) {
float beforeKey = (float) ((float) key - 0.2);
float afterKey = (float) ((float) key + 0.2);
int left = 0;
int right = arr.length - 1;
for (; left <= right;) {
int mid = left + (right - left) / 2;
float cur = (float) arr[mid];
if (beforeKey < cur)
right = mid - 1;
else
left = mid + 1;
}
leftAfter = 0;
right = arr.length - 1;
for (; leftAfter <= right;) {
int mid = left + (right - leftAfter) / 2;
float cur = (float) arr[mid];
if (afterKey < cur)
right = mid - 1;
else
left = mid + 1;
}
return new int[] { left, leftAfter };
}
You can use recursion to solve the problem. first, call binary_search on the list, if the matching element is found then the left side of the list is called, and the right side of the list is reached. The matching element index is stored in the vector and returns the vector.
vector<int> binary_search(int arr[], int n, int key) {
vector<int> ans;
int st = 0;
int last = n - 1;
helper(arr, st, last, key, ans);
return ans;
}
Recursive function:
void helper(int arr[], int st, int last, int key, vector<int> &ans) {
if (st > last) {
return;
}
while (st <= last) {
int mid = (st + last) / 2;
if (arr[mid] == key) {
ans.push_back(mid);
helper(arr, st, mid - 1, key, ans); // left side call
helper(arr, mid + 1, last, key, ans); // right side call
return;
} else if (arr[mid] > key) {
last = mid - 1;
} else {
st = mid + 1;
}
}
}

Resources