Multiprocessing with Pandas dataframe

Updated: Mar 27



Recently, I was handling large datasets stored in Pandas dataframes. Here are some learning points:


[1] The native method df.iterrows() is slow so if there was ever a need to iterate through the rows, consider instead the vectorisation methods, if not df.apply(). These methods are covered in many Medium articles, for example this one.


[2] We can speed up any data manipulation with multiprocessing (speed up is roughly equivalent to the number of cores you have, but in The Algorithm Design Manual, Skiena made a great point that algorithmic efficiency can be so many folds faster than what computational power can offer. In short, use the best algorithms provided in APIs, if not design our own).


[3] Python does not have multithreading due to the Global Interpreter Lock. Therefore we rely on multiprocessing.


[4] Queue methods in the multiprocessing module are "process-safe" so we don't have to worry about race conditions. It is, however, prudent to run tests and check the output to ensure that we obtain the expected output without any loss of data.


[5] The canonical Queue object (multiprocess.Queue()) has the following warning: "if a child process has put items on a queue... process will not terminate until all buffered items have been flushed to the pipe." This means that if you store a lot of things in the queue, you have to dequeue it before joining the processes, otherwise the process will not terminate and the program will never stop running. To solve this, use multiprocessing.Manager().Queue() instead.


Here is a template for using multiprocessing in manipulating a Pandas dataframe per row. The goal is to use contig_name, start and end position stored in the dataframe to derive the read depth in this region (using pysam's count_coverage() function). Instead of using df.iterrows on a single core, I used df.apply with all the cores in the computer, which will give roughly 4 * number of cores speed up.


Let's start by importing modules, initiating empty lists and loading our Pandas dataframe by reading a CSV file.

from multiprocessing import cpu_count, Process, Manager
import pandas as pd
import pysam
from time import sleep

processes = []
output_list = []
bam_file_name = 'Sample1_S1_L001.sorted.bam'

# Load dataframe
df = pd.read_csv('data.csv')
print(df.head())

Here are the first few rows of the dataframe.

  contig_name     start       end
0        chr1  94000000  94000100
1        chr1  94000200  94000300
2        chr1  94000001  94000101
3        chr1  94000201  94000301
4        chr1  94000002  94000102

With the name of contig (or chromosome), start and end positions, we can find the number of reads mapped to the chromosomal region using pysam.count_coverage(). As count_coverage returns four arrays, each of the them representing the count of A, T, C or G at every position, we need to add up all the counts stored in these arrays, which I did with list comprehension. Here's the function:

def get_depth(contig, start, end, bam_file_name):
    '''
    This function extracts read depth from bam file at every position from 
    start to end for the specified contig
    
    Dataframe columns 
    [1] contig_name, type str: Name of chromosome/ contig
    [2] start, type int: Start position to extract read depth
    [3] end, type int: End position to extract read depth
    
    Input
    [1] df_chunk, type Pandas dataframe
    [2] bam_file_name, type str: Directory and file name of bam
    [3] queue, type Queue: Multiprocessing.Manager().Queue() object

    Output is df_chunk updated with depth column loaded into queue
    '''
    depth_list = pysam.AlignmentFile(bam_file_name, 'rb')
    depth_list = depth_list.count_coverage(str(contig), 
                                           start = int(start), 
                                           stop = int(end) + 1, 
                                           quality_threshold = 10,
                                           read_callback = 'nofilter')
    depth_list = [sum(x) for x in zip(*depth_list)]
    total_depth = sum(depth_list)
    return total_depth

Next, we write the function that will be processed by each core. This function wraps get_depth() with a function that uses df.apply(), and add the output to the multiprocessing manager's queue.

def process_func(df, index_range, queue, bam_file_name):
    '''
    Target function for each multiprocessing process.
    
    Input
    [1] df, type Pandas dataframe
    [2] index_range, type Range object: Range of dataframe rows to process
    [3] queue, type Queue: Multiprocessing.Manager().Queue() object
    [4] bam_file_name, type str: Directory and file name of bam

    Output
    List of tuple: [(df index, total depth), (df index, total depth)... ]
    where each tuple represents a row in the dataframe, df index points to 
    the exact row and total depth is the total depth of mapped reads of
    the chromosomal region in the same row
    '''
    depth_list = list(df.apply(lambda row: get_depth(row['contig_name'], row['start'], row['end'], bam_file_name), axis = 1))
    depth_list = list(zip(index_range, depth_list))
    queue.put(depth_list)

If you're wondering what's the queue about, it's a data structure that stores the output from each core in a thread-safe manner. We use multiprocessing.Manager().Queue instead of multiprocessing.Queue() because of point [5] discussed above.

# Create queue
queue = Manager().Queue()

# Automatically retrieve the max number of cores in your computer
cores = cpu_count()

Now we calculate the size of each chunk of data we pass to each core for processing. If there are 101 rows in the dataframe and we have four cores, then each core should take 25 rows. Instead of passing actual dataframe rows into the core, we can just pass the index ranges of the rows (that's why I used range) and store them in the index_ranges list.

# Calculate the size of each chunk (remainder is excluded for now)
total_len    = int(len(df))
chunk_size   = int(total_len/ cores)
index_ranges = [range(i, i + chunk_size) 
                for i in range(0, total_len, chunk_size)
                if (i + chunk_size <= total_len)]

If you noticed, 101 divided by 25 gives a remainder of 1 row unaccounted. Here we add this one row into the index_ranges list.

# Account for remainder chunks
if total_len % cores != 0:
    index_ranges.append(range(total_len - total_len % cores, total_len))

Now we can start a process per chunk of data. Once each chunk is processed, they have to wait for the other processes to finish before the remaining code is run. Each process will run process_func() we wrote earlier, and recall that the output is put into the queue.

# Create a process for every chunk and run process
print('Multiprocessing start')
for ranges in index_ranges:
    proc = Process(target = process_func, 
                   args   = (df, ranges, queue, bam_file_name)) 
    proc.Daemon = True
    proc.start()
    processes.append(proc)

# Let processes that finished first wait for other processes to finish
for proc in processes:
    proc.join()
print('Multiprocessing end')

We can now extract the output from the queue. To do so safely, we added a sentinel to mark the end of the queue and iteratively pull each output into output_list until we hit the sentinel (which is None). Sleep is added at each iteration to ensure there is enough time for output extraction (not sure if this is necessary but better safe than sorry).

# Insert sentinel to clearly delineate the end of the queue
queue.put(None)

# Retrieve items from queue and put into output_list
for item in iter(queue.get, None):
    output_list += item
    sleep(0.1)

The output looks like this: [(df index, total_depth), (df index, total_depth)... ] so we can use it to create a new dataframe with index mapped to the respective total_depth. We concatenate this new dataframe to the original dataframe with the index as reference point. The final product is the original dataframe with a new column called 'total_depth' updated with the sum of mapped reads at the row's chromosomal region.

# Create new df with the output list and merge with original df
depth_df = pd.DataFrame([x[1] for x in output_list], 
                        index = [x[0] for x in output_list],
                        columns = ['total_depth'])
df = pd.concat([df, depth_df], axis = 1)

# Output dataframe
df.to_csv('output.csv', index = False)

Full script:

from multiprocessing import cpu_count, Process, Manager
import pandas as pd
import pysam
from time import sleep

processes = []
output_list = []
bam_file_name = 'Sample1_S1_L001.sorted.bam'

# Load dataframe
df = pd.read_csv('data.csv')

# Create queue
queue = Manager().Queue()

# Automatically retrieve the max number of cores in your computer
cores = cpu_count()

# Calculate the size of each chunk (remainder is excluded for now)
total_len    = int(len(df))
chunk_size   = int(total_len/ cores)
index_ranges = [range(i, i + chunk_size) 
                for i in range(0, total_len, chunk_size)
                if (i + chunk_size <= total_len)]

# Account for remainder chunks
if total_len % cores != 0:
    index_ranges.append(range(total_len - total_len % cores, total_len))

# Create a process for every chunk and run process
print('Multiprocessing start')
for ranges in index_ranges:
    proc = Process(target = process_func, 
                   args   = (df, ranges, queue, bam_file_name)) 
    proc.Daemon = True
    proc.start()
    processes.append(proc)

# Let processes that finished first wait for other processes to finish
for proc in processes:
    proc.join()
print('Multiprocessing end')

# Insert sentinel to clearly delineate the end of the queue
queue.put(None)

# Retrieve items from queue and put into output_list
for item in iter(queue.get, None):
    output_list += item
    sleep(0.1)

# Create new df with the output list and merge with original df
depth_df = pd.DataFrame([x[1] for x in output_list], 
                        index = [x[0] for x in output_list],
                        columns = ['total_depth'])
df = pd.concat([df, depth_df], axis = 1)

# Output dataframe
df.to_csv('output.csv', index = False)

# Functions
def get_depth(contig, start, end, bam_file_name):
    '''
    This function extracts read depth from bam file at every position from 
    start to end for the specified contig
    
    Dataframe columns 
    [1] contig_name, type str: Name of chromosome/ contig
    [2] start, type int: Start position to extract read depth
    [3] end, type int: End position to extract read depth
    
    Input
    [1] df_chunk, type Pandas dataframe
    [2] bam_file_name, type str: Directory and file name of bam
    [3] queue, type Queue: Multiprocessing.Manager().Queue() object

    Output is df_chunk updated with depth column loaded into queue
    '''
    depth_list = pysam.AlignmentFile(bam_file_name, 'rb')
    depth_list = depth_list.count_coverage(str(contig), 
                                           start = int(start), 
                                           stop = int(end) + 1, 
                                           quality_threshold = 10,
                                           read_callback = 'nofilter')
    depth_list = [sum(x) for x in zip(*depth_list)]
    total_depth = sum(depth_list)
    return total_depth

def process_func(df, index_range, queue, bam_file_name):
    '''
    Target function for each multiprocessing process.
    
    Input
    [1] df, type Pandas dataframe
    [2] index_range, type Range object: Range of dataframe rows to process
    [3] queue, type Queue: Multiprocessing.Manager().Queue() object
    [4] bam_file_name, type str: Directory and file name of bam

    Output
    List of tuple: [(df index, total depth), (df index, total depth)... ]
    where each tuple represents a row in the dataframe, df index points to 
    the exact row and total depth is the total depth of mapped reads of
    the chromosomal region in the same row
    '''
    depth_list = list(df.apply(lambda row: get_depth(row['contig_name'], row['start'], row['end'], bam_file_name), axis = 1))
    depth_list = list(zip(index_range, depth_list))
    queue.put(depth_list)


86 views0 comments

Recent Posts

See All