#!/usr/bin/env python

import nsysstats

class CUDAGPUStarvation(nsysstats.Report):

    THRESHOLD = 500
    ROW_LIMIT = 50

    usage = f"""{{SCRIPT}}[:rows=<limit>][:gap=<threshold>] -- GPU Starvation (kernels + memory operations)

    Options:
        rows=<limit> - Maximum number of rows returned by the query.
            Default is {ROW_LIMIT}.

        gap=<threshold> - Minimum duration of GPU gaps in milliseconds.
            Default is {THRESHOLD}ms.

    Output: All time values default to nanoseconds
        Row# : Row number of GPU gap
        Duration : Duration of GPU gap
        Start : Start time of GPU gap
        PID : Process identifier
        Device ID : GPU device identifier

    This rule identifies time regions where a GPU is idle for longer than a set
    threshold. For each process, each GPU device is examined, and gaps are
    found within the time range that starts with the beginning of the first GPU
    operation on that device and ends with the end of the last GPU operation on
    that device.

    GPU gaps that cannot be addressed by the user are excluded. This includes:
        1. CUDA profiling overhead in the middle of a GPU gap.
        2. The initial gap starting before the first GPU operation, which
           corresponds to the CUDA initialization overhead.
        3. The final gap starting after the last GPU operation, which
           corresponds to the flushing of CUPTI buffers when stopping the
           collection.
"""

    query_format_columns = """
    SELECT
        ROW_NUMBER() OVER(ORDER BY duration DESC, gapStart) AS "Row#",
        duration AS "Duration:dur_ns",
        gapStart AS "Start:ts_ns",
        pid AS "PID",
        deviceId AS "Device ID",
        globalPid AS "_Global PID"
    FROM
        ({GAPS})
"""

# Find gaps.
# "ops" is the table containing kernel + memory + CUDA profiling overhead.
# 1. CTE "starts": Give a rowNum, SRi, to each start (ordered by start time).
# 2. CTE "ends": Give a rowNum, ERj, to each end (ordered by end time).
# 3. Reconstruct intervals [ERj, SRj+1] by putting together an end ERj from ends
#    with the next start SRj+1 from starts i.e. start_rowNum - 1 = end_rowNum.
#    For example, if an end has 2 as rowNum it will be joined with
#    a start that has 3 as rowNum.
# 4. Keep only those intervals [ERj, SRj+1] that are valid i.e. ERj < SRj+1.
# Assume that we have the following intervals:
#
# SR1                          ER2
#  |--------------a-------------|
#      SR2                ER1
#       |---------b--------|
#                                         SR3              ER3
#                                          |--------c-------|
# With step 3, we get:
# 1. ER1 joined with SR2.
# 2. ER2 joined with SR3.
#
#      SR2                 ER1
#       |---------a'--------|
#                               ER2        SR3
#                                |----b'----|
#
# However, only the second interval (b') meets the condition end < start of step 4 and
# will be considered as a gap.
# The first one (a') will be discared and the query will return:
#
#                               ER2        SR3
#                                |----b'----|
#
# ER2 will be the start and SR3 will be the end of the gap.

    query_gap = """
    WITH
        ops AS (
            {OPS_ALL}
        ),
        starts AS (
            SELECT
                ROW_NUMBER() OVER(ORDER BY pid, deviceId, start) AS rowNum,
                start,
                pid,
                globalPid,
                deviceId
            FROM
                ops
        ),
        ends AS (
            SELECT
                ROW_NUMBER() OVER(ORDER BY pid, deviceId, end) AS rowNum,
                end,
                pid,
                globalPid,
                deviceId
            FROM
                ops
        )
    SELECT
        start - end AS duration,
        end AS gapStart,
        start AS gapEnd,
        starts.pid,
        starts.globalPid,
        starts.deviceId
    FROM
        starts
    JOIN
        ends
        ON
                starts.rowNum - 1 = ends.rowNum
            AND starts.deviceId = ends.deviceId
            AND starts.pid = ends.pid
    WHERE
            duration > {THRESHOLD} * 1000000
        AND gapStart < gapEnd
    LIMIT {ROW_LIMIT}
"""

# Select columns of kernel/memory operations.
    query_select = """
    SELECT
        start,
        end,
        (globalPid >> 24) & 0x00FFFFFF AS pid,
        globalPid,
        deviceId
    FROM
        {GPU_OPERATION}
"""

# Combine kernel/memory operations.
    query_union = """
        UNION ALL
"""

# Add the profiler overhead to the GPU operation table returned by
# "query_union".
# 1. CTE "range": Get [min(start), max(end)] for each deviceId/PID. It will be
#    used as the clipping range for overheads.
# 2. CTE "cudaoverhead": Select CUDA profiling overhead that we want to take
#    into account.
# 3. Duplicate overhead rows for each deviceId/PID. This will create a deviceId
#    column that is not initially in the PROFILER_OVERHEAD table. i.e., a CUDA
#    profiling overhead on one thread affects all GPUs of the same process.
# 4. The overhead rows are combined with GPU operation rows.
    query_overhead = """
    WITH
        gpuops AS (
            {GPU_OPS_ALL}
        ),
        range AS (
            SELECT
                min(start) AS start,
                max(end) AS end,
                pid,
                globalPid,
                deviceId
            FROM
                gpuops
            GROUP BY deviceId, pid
        ),
        cudaoverheadID AS (
            SELECT
                id
            FROM
                StringIds
            WHERE
                value = 'CUDA profiling data flush overhead'
                OR value = 'CUDA profiling stop overhead'
                OR value = 'CUDA profiling overhead'
        ),
        cudaoverhead AS (
            SELECT
                po.start,
                po.end,
                (po.globalTid >> 24) & 0x00FFFFFF AS pid
            FROM
                PROFILER_OVERHEAD AS po
            JOIN
                cudaoverheadID AS co
                ON co.id = po.nameId
        )
    SELECT
        co.start,
        co.end,
        co.pid,
        range.globalPid,
        range.deviceId
    FROM
        cudaoverhead AS co
    JOIN
        range
        ON
                co.pid = range.pid
            AND co.start > range.start
            AND co.end < range.end
    UNION ALL
    SELECT
        *
    FROM
        gpuops
"""

    def setup(self):
        err = super().setup()
        if err != None:
            return err

        row_limit = self.ROW_LIMIT
        threshold = self.THRESHOLD
        for arg in self.args:
            s = arg.split('=')
            if len(s) == 2 and s[1].isdigit():
                if s[0] == 'rows':
                    row_limit = s[1]
                    continue
                if s[0] == 'gap':
                    threshold = s[1]
                    continue
            exit(self.EXIT_INVALID_ARG)

        sub_queries = []

        kernel = 'CUPTI_ACTIVITY_KIND_KERNEL'
        memcpy = 'CUPTI_ACTIVITY_KIND_MEMCPY'
        memset = 'CUPTI_ACTIVITY_KIND_MEMSET'
        overhead = 'PROFILER_OVERHEAD'

        if self.table_exists(kernel):
            sub_queries.append(self.query_select.format(GPU_OPERATION = kernel))

        if self.table_exists(memcpy):
            sub_queries.append(self.query_select.format(GPU_OPERATION = memcpy))

        if self.table_exists(memset):
            sub_queries.append(self.query_select.format(GPU_OPERATION = memset))

        if len(sub_queries) == 0:
            return "{DBFILE} could not be analyzed because it does not contain CUDA trace data."

        union = self.query_union.join(sub_queries)

        if self.table_exists(overhead):
            union = self.query_overhead.format(
                GPU_OPS_ALL = union)

        gaps = self.query_gap.format(
            OPS_ALL = union,
            ROW_LIMIT = row_limit,
            THRESHOLD = threshold)

        self.query = self.query_format_columns.format(
            GAPS = gaps)

if __name__ == "__main__":
    CUDAGPUStarvation.Main()
