## Google Interview Question

SDE1s**Country:**United States

Solution for the first part.

Follow-up: Yes, it will be simpler to solve, since searching for the node in the forest will be faster. In the case of the BST, only the find_node function needs to be changed in the below implementation, everything else remains the same.

```
class TreeNode():
def __init__(self, x, left = None, right = None):
self.val = x
self.left = left
self.right = right
def print_tree(self, level=0):
print(level*" " + str(self.val))
if self.left:
self.left.print_tree(level+1)
if self.right:
self.right.print_tree(level+1)
def T(x, l=None, r=None):
"""Convenience function for TreeNodes"""
return TreeNode(x, l, r)
def main():
print("TEST 1: Erasing [1, 5]")
test1()
print("TEST 2: Erasing [2, 3]")
test2()
def test1():
# 1
# 2 3
# 4 5 6 7
root = T(1, T(2, T(4), T(5)), T(3, T(6), T(7)))
to_be_erased = [1, 5]
forest = break_tree(root, to_be_erased)
print_forest(forest)
def test2():
# 1
# 2 3
# 4 5 6 7
root = T(1, T(2, T(4), T(5)), T(3, T(6), T(7)))
to_be_erased = [2, 3]
forest = break_tree(root, to_be_erased)
print_forest(forest)
def print_forest(forest):
for i, f in enumerate(forest):
print("Forest #" + str(i + 1))
f.print_tree()
print()
def break_tree(root, to_be_erased):
forest = [root]
for num in to_be_erased:
for i, rt in enumerate(forest):
n, parent, which = find_node(rt, num)
if n:
forest.append(n.left)
forest.append(n.right)
# If this node has no parent, it means it was a root in the forest
# and we need to remove it from our list
if not parent:
forest[i] = None
else:
if which == "left":
parent.left = None
elif which == "right":
parent.right = None
return [f for f in forest if f]
def find_node(root, num):
def find_node_and_parent(root, num, parent, which):
if root:
if root.val == num:
return root, parent, which
else:
n, p, w = find_node_and_parent(root.left, num, root, "left")
if n:
return n, p, w
n, p, w = find_node_and_parent(root.right, num, root, "right")
if n:
return n, p, w
return None, None, None
return find_node_and_parent(root, num, None, None)
if __name__ == "__main__":
main()
```

Output:

```
TEST 1: Erasing [1, 5]
Forest #1
2
4
Forest #2
3
6
7
TEST 2: Erasing [2, 3]
Forest #1
1
Forest #2
4
Forest #3
5
Forest #4
6
Forest #5
7
```

```
def break_tree(parent, node, to_be_erased, forest):
if not node:
return
# root node or child of erased node
if (parent is None or parent.val in to_be_erased) and node.val not in to_be_erased:
forest.append(node)
break_tree(node, node.left, to_be_erased, forest)
break_tree(node, node.right, to_be_erased, forest)
if node.left and node.left.val in to_be_erased:
node.left = None
if node.right and node.right.val in to_be_erased:
node.right = None
```

Traverse the tree in pre order, so any node is erased before it's children. Here is the high level algorithm.

```
isRootErased := false
forest := []
eraseNodes(node, parent):
if node is null
return
if shouldBeErased(node):
if parent is null:
isRootErased := true
else if parent.left is node:
parent.left := null
else:
parent.right := null
if node.left is not null:
forest.push(node.left)
if node.right is not null:
forest.push(node.right)
eraseNodes(node.left, node)
eraseNodes(node.right, node)
eraseNodes(root, null)
if not isRootErased:
forest.push(root)
return forest
```

Processing each node take O(1), so time complexity is O(N), where N is the number of nodes.

Regarding the follow up, it depends. Suppose you are given a list of K nodes to erase. Finding each of them would take O(log N) if the tree is balanced. So overall complexity in time would be O(K log N), which is better than O(N) if K is small. In the worst case, K = N and we get O(N log N), which is worse.

```
class node
{
public:
int val;
node* left;
node* right;
node( int v): val(v), left( nullptr), right(nullptr)
{}
};
node* eraseNodes(node* root, vector<node*>& result, function< bool(node*) >& shouldBeErased )
{
if(!root) return nullptr;
// check if leaf node
if( !root->left && !root->right )
{
if( shouldBeErased(root) )
{
delete root;
return nullptr;
}
else
{
return root;
}
}
root->left = eraseNodes( root->left, result, shouldBeErased );;
root->right = eraseNodes(root->right, result, shouldBeErased);;
if( shouldBeErased(root) )
{
if( root->left ) result.push_back(root->left);
if( root->right ) result.push_back( root->right);
return nullptr;
}
return root;
}
```

- Alex December 08, 2017