Skip to content

Spark: Find pairs having at least n common attributes?

An answer to this question on Stack Overflow.

Question

I have a dataset consisting of (sensor_id, timestamp, data) (the sensor_id are ids of IoT devices, timestamp is UNIX time and data is an MD5 hash of their output at that time). There is no primary key on the table but each row is unique.

I need to find all pairs of sensor_ids s1 and s2 such that these two sensors have at least n (n=50) entries (timestamp, data) in common between them i.e. on n different occasions they emitted same data at same timestamp.

For a sense of magnitudes of the data, I have 10B rows and ~50M distinct sensor_ids and I believe that there are around ~5M pairs of sensor-ids that emitted same data at same timestamp at least 50 times.

What's the best way to do this in Spark? I tried various approaches (group-by (timestamp, data) and/or self-joining) but they are prohibitively expensive in complexity.

Answer

Here's how I'd do it.

First, generate some fake data:

#!/usr/bin/env python3
import random
fout = open('test_data.csv','w')
i=0
for x in range(100000):
  if i>=1000000:
    break
  for y in range(random.randint(0,100)):
    i         = i + 1
    timestamp = x
    sensor_id = random.randint(0,50)
    data      = random.randint(0,1000)
    fout.write("{} {} {}\n".format(timestamp,sensor_id,data))

Now, you can process the data as follows.

If you let the number of lines be N, the number of unique timestamps be T, and the expected number of sensors per timestamp be S, then the complexity of each operation is as in the comments

import itertools
#Turn a set into a list of all unique unordered pairs in the set, without
#including self-pairs
def Pairs(x):
  temp = []
  x    = list(x)
  for i in range(len(x)):
    for j in range(i+1,len(x)):
      temp.append((x[i],x[j]))
  return temp
#Load data
#O(N) time to load data
fin        = sc.textFile("file:///z/test_data.csv")
#Split data at spaces, keep only the timestamp and sensorid portions
#O(N) time to split each line of data
lines      = fin.map(lambda line: line.split(" ")[0:2])
#Convert each line into a timestamp-set pair, where the set contains the sensor
#O(N) time to make each line into a timestamp-hashset pair
monosets   = lines.map(lambda line: (line[0],set(line[1])))
#Combine sets by timestamp to produce a list of timestamps and all sensors at
#each timestamp
#O(TS) time to place each line into a hash table of size O(T) where each 
#entry in the hashtable is a hashset combining 
timegroups = sets.reduceByKey(lambda a,b: a | b)
#Convert sets at each timestamp into a list of all pairs of sensors that took
#data at that timestamp
#O(T S^2) time to do all pairs for each timestamp
shared     = timegroups.flatMap(lambda tg: PairsWithoutSelf(tg[1]))
#Associate each sensor pair with a value one
#O(T S^2) time
monoshared = shared.map(lambda x: (x,1))
#Sum by sensor pair
#O(T S^2) time
paircounts = monoshared.reduceByKey(lambda a,b: a+b)
#Filter by high hitters
#O(<S^2) time
good       = paircounts.filter(lambda x: x[1]>5)
#Display results
good.count()

The time complexities are a little hand-wavy as I'm working on this answer kind of late, but the bottlenecks should be visible, at least.