#!/usr/bin/python3
# @lint-avoid-python-3-compatibility-imports
#
# netsessionaudit  Audit TCP/UDP sessions throughput by host.
#           For Linux, uses BCC, eBPF. Embedded C.
#
# USAGE: netsessionaudit [-h] [interval]
#
# This uses dynamic tracing of kernel functions, and will need to be updated
# to match kernel changes.
#
# WARNING: This traces all send/receives at the TCP level, and while it
# summarizes data in-kernel to reduce overhead, there may still be some
# overhead at high TCP send/receive rates (eg, ~13% of one CPU at 100k TCP
# events/sec. This is not the same as packet rate: funccount can be used to
# count the kprobes below to find out the TCP rate). Test in a lab environment
# first. If your send/receive rate is low (eg, <1k/sec) then the overhead is
# expected to be negligible.
#
# Copyright 2016 Netflix, Inc.
# Licensed under the Apache License, Version 2.0 (the "License")
#
# 02-Sep-2016   Brendan Gregg   Created this.
# Copyright 2022 KylinSec, Inc.
# 30-Oct-2022   Lugang Modified this.

from __future__ import print_function
from bpfcc import BPF
#from bcc.containers import filter_by_containers
import argparse
from socket import inet_ntop, AF_INET, AF_INET6
from struct import pack
import time
from collections import namedtuple, defaultdict
import os
import signal
import threading
import configparser
import logging
from logging.handlers import RotatingFileHandler
import operator


def tcpstate2str(state):
    # from include/net/tcp_states.h:
    tcpstate = {
        1: "ESTABLISHED",
        2: "SYN_SENT",
        3: "SYN_RECV",
        4: "FIN_WAIT1",
        5: "FIN_WAIT2",
        6: "TIME_WAIT",
        7: "CLOSE",
        8: "CLOSE_WAIT",
        9: "LAST_ACK",
        10: "LISTEN",
        11: "CLOSING",
        12: "NEW_SYN_RECV",
    }

    if state in tcpstate:
        return tcpstate[state]
    else:
        return str(state)


'''
#define TCP_EVENT_TYPE_CONNECT     1
#define TCP_EVENT_TYPE_ACCEPT      2
#define TCP_EVENT_TYPE_CLOSE       3
#define TCP_EVENT_TYPE_TRANSFER_S  4
#define TCP_EVENT_TYPE_TRANSFER_R  5

#define UDP_EVENT_TYPE_FIRST_SEND  1
#define UDP_EVENT_TYPE_FIRST_RECV  2
#define UPD_EVENT_TYPE_DESTROY     3
#define UPD_EVENT_TYPE_EXIT        4
'''
def tcptype2str(event_type):
    type2str = {
        1: "会话开始 CONN",
        2: "会话开始 ACCE",
        3: "会话结束 CLOS",
        4: "会话转移 SEND",
        5: "会话转移 RECV",
    }

    if event_type in type2str:
        return type2str[event_type]
    else:
        return str(event_type)


def udptype2str(event_type):
    type2str = {
        1: "会话开始 SEND",
        2: "会话开始 RECV",
        3: "会话结束 DEST",
        4: "会话结束 PROC",
    }

    if event_type in type2str:
        return type2str[event_type]
    else:
        return str(event_type)


# arguments
def range_check(string):
    value = int(string)
    if value < 1:
        msg = "value must be stricly positive, got %d" % (value,)
        raise argparse.ArgumentTypeError(msg)
    return value


def get_conf_file(args):
    DEFAULT_CONF_FILE = "/etc/netsessionaudit.conf"

    if args.conf:
        return args.conf

    conf_dir = os.path.split(os.path.realpath(__file__))[0]
    conf_file = os.path.join(conf_dir, "netsessionaudit.conf")
    if os.path.exists(conf_file):
        return conf_file
    if os.path.exists(DEFAULT_CONF_FILE):
        return DEFAULT_CONF_FILE

    return None


def get_parser_option(config_parser, section, key):
    if config_parser.has_option(section=section, option=key):
        return config_parser.get(section=section, option=key)
    return None


def get_log_file(args, config_parser):
    DEFAULT_LOG_FILE = "/var/log/netsessionaudit.log"
    if args.log:
        return args.log

    value = get_parser_option(config_parser, "log", "log_file")
    if value:
        return value
    return DEFAULT_LOG_FILE


def get_log_max(args, config_parser):
    DEFAULT_LOG_MAX = 50 * 1024 * 1024
    if args.log:
        return args.log

    value = get_parser_option(config_parser, "log", "log_size")
    if value:
        max_size = int(value)
        if 1 <= max_size < 4096:
            return max_size * 1024 * 1024
    return DEFAULT_LOG_MAX


def get_period(args, config_parser):
    # 1 hour, 60 minutes
    DEFAULT_AUDIT_PERIOD = 60
    if args.period:
        return args.period
    value = get_parser_option(config_parser, "global", "audit_period")
    if value:
        return int(value)
    return DEFAULT_AUDIT_PERIOD


examples = """examples:
    ./netsessionaudit                 # audit TCP/UDP sessions by host
    ./netsessionaudit -c              # also print to console
    ./netsessionaudit --conf          # assign conf file, the order: --conf, exepath and then /etc/netsessionaudit.conf
    ./netsessionaudit --log filename  # write audit info to log file assigned
"""

parser = argparse.ArgumentParser(
    description="Audit TCP/UDP sessions throughput by host",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=examples)
parser.add_argument("--conf",
                    help="audit conf file")
parser.add_argument("--period", nargs="?", type=range_check,
                    help="audit period, in minutes")
parser.add_argument("--log",
                    help="save audit records to the log file")
parser.add_argument("-c", "--console", action="store_true",
                    help="output to console as well as logfile")

args = parser.parse_args()
hostname = os.getenv('HOSTNAME')

log_format = "%(asctime)s " + hostname + "  %(message)s"
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
handle_format = logging.Formatter(log_format)

if args.console:
    console = logging.StreamHandler()
    console.setLevel(level=logging.INFO)
    console.setFormatter(handle_format)
    logger.addHandler(console)

conf_file = get_conf_file(args)
config = configparser.ConfigParser()
config.read(conf_file)

# get log file, value in the args first, then in the config file
log_file = get_log_file(args, config)
log_max_bytes = get_log_max(args, config)
# log_file_handle = logging.FileHandler(log_file, encoding="utf-8")
log_file_handle = RotatingFileHandler(filename=log_file, maxBytes=log_max_bytes,
                                      backupCount=5, encoding="utf-8")
log_file_handle.setLevel(logging.INFO)
log_file_handle.setFormatter(handle_format)
logger.addHandler(log_file_handle)

# convert minutes to seconds
audit_period = get_period(args, config) * 60

# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <linux/tcp.h>
#include <net/sock.h>
#include <bcc/proto.h>

// bpf event defines
#define TCP_EVENT_TYPE_CONNECT     1
#define TCP_EVENT_TYPE_ACCEPT      2
#define TCP_EVENT_TYPE_CLOSE       3
#define TCP_EVENT_TYPE_TRANSFER_S  4
#define TCP_EVENT_TYPE_TRANSFER_R  5

#define UDP_EVENT_TYPE_FIRST_SEND  1
#define UDP_EVENT_TYPE_FIRST_RECV  2
#define UPD_EVENT_TYPE_DESTROY     3
#define UPD_EVENT_TYPE_EXIT        4

// tcp and udp use the same ipv4_event_t
struct ipv4_event_t {
    u64 ts_us;
    u64 tx_b;
    u64 rx_b;   
    u32 type;
    u32 pid;
    u32 saddr;
    u32 daddr;
    u16 sport;
    u16 dport;
    char comm[TASK_COMM_LEN];
};

BPF_PERF_OUTPUT(tcp_ipv4_events);
BPF_PERF_OUTPUT(udp_ipv4_events);

// tcp_set_state doesn't run in the context of the process that initiated the
// connection so we need to store a map TUPLE -> PID to send the right PID on
// the event
struct ipv4_tuple_t {
    u32 saddr;
    u32 daddr;
    u16 sport;
    u16 dport;
};

struct pid_comm_t {
    u32 pid;
    char comm[TASK_COMM_LEN];
};

BPF_HASH(tuple2pidcomm, struct ipv4_tuple_t, struct pid_comm_t);
BPF_HASH(tcp_connect_sock, u64, struct sock *);
BPF_HASH(udp_connect_sock, u64, struct sock *);

// udp_pid hash be used to record all udp apps, to reduce the count of process_exit events notifying user-space
BPF_HASH(udp_pid, u32);

// udp_history0,1 for recognizing the first udp communication between two nodes
// udp sent/recv msg operations lookup both udphistory0,1, 
// but only save to udp_history[current_time % audit_period] when the lookup op failed.
// udp_history0,1 will be cleared in turn by user space python code in audit period.   
BPF_HASH(udp_history0, struct ipv4_tuple_t);
BPF_HASH(udp_history1, struct ipv4_tuple_t);

// bpf hashs for period audit
struct ipv4_pidtuple_t {
    u32 saddr;
    u32 daddr;
    u16 sport;
    u16 dport;
    u32 pid;
};

BPF_HASH(tcp_ipv4_send_bytes, struct ipv4_pidtuple_t);
BPF_HASH(tcp_ipv4_recv_bytes, struct ipv4_pidtuple_t);

BPF_HASH(udp_ipv4_send_bytes, struct ipv4_pidtuple_t);
BPF_HASH(udp_ipv4_recv_bytes, struct ipv4_pidtuple_t);

// process audit events

static int read_ipv4_tuple(struct ipv4_tuple_t *tuple, struct sock *skp)
{
  u32 saddr = skp->__sk_common.skc_rcv_saddr;
  u32 daddr = skp->__sk_common.skc_daddr;
  struct inet_sock *sockp = (struct inet_sock *)skp;
  u16 sport = sockp->inet_sport;
  u16 dport = skp->__sk_common.skc_dport;

  tuple->saddr = saddr;
  tuple->daddr = daddr;
  tuple->sport = ntohs(sport);
  tuple->dport = ntohs(dport);

  // if addresses or ports are 0, ignore
  if (saddr == 0 || daddr == 0 || sport == 0 || dport == 0) {
      return 0;
  }

  return 1;
}

static void tuple_to_pidtuple(struct ipv4_pidtuple_t *pp, struct ipv4_tuple_t *tp)
{
    pp->saddr = tp->saddr;
    pp->daddr = tp->daddr;
    pp->sport = tp->sport;
    pp->dport = tp->dport;
}

static void tuple_to_event(struct ipv4_event_t *ep, struct ipv4_tuple_t *tp)
{
    ep->saddr = tp->saddr;
    ep->daddr = tp->daddr;
    ep->sport = tp->sport;
    ep->dport = tp->dport;
}

static bool check_family(struct sock *sk, u16 expected_family)
{
  u64 zero = 0;
  u16 family = sk->__sk_common.skc_family;
  return family == expected_family;
}

static int comm_copy(char *dst, char *src)
{
      int i;
      for (i = 0; i < TASK_COMM_LEN; i++) {
          dst[i] = src[i];
      }
      return TASK_COMM_LEN;
}

static int move_tcp_txrx_to_event(struct ipv4_event_t *ep, struct ipv4_pidtuple_t *ptp)
{
    int find = 0;
    u64 *pbytes = tcp_ipv4_send_bytes.lookup(ptp);
    if (pbytes) {
        ep->tx_b = *pbytes;
        tcp_ipv4_send_bytes.delete(ptp);
        find = 1;
    } else {
        ep->tx_b = 0;
    }
    
    pbytes = tcp_ipv4_recv_bytes.lookup(ptp);
    if (pbytes) {
        ep->rx_b = *pbytes;
        tcp_ipv4_recv_bytes.delete(ptp);
        find = 1;
    } else {
        ep->rx_b = 0;
    }
    
    return find;
}

static int move_udp_txrx_to_event(struct ipv4_event_t *ep, struct ipv4_pidtuple_t *ptp)
{
    int find = 0;
    u64 *pbytes = udp_ipv4_send_bytes.lookup(ptp);
    if (pbytes) {
        ep->tx_b = *pbytes;
        udp_ipv4_send_bytes.delete(ptp);
        find = 1;
    } else {
        ep->tx_b = 0;
    }
    
    pbytes = udp_ipv4_recv_bytes.lookup(ptp);
    if (pbytes) {
        ep->rx_b = *pbytes;
        udp_ipv4_recv_bytes.delete(ptp);
        find = 1;
    } else {
        ep->rx_b = 0;
    }
    
    return find;
}

// ##AUDIT_PERIOD## should be replaced by a number from the python script.
static int find_or_set_history(struct ipv4_tuple_t *t)
{
    u32 find = 0;
    u64 ts_us = bpf_ktime_get_ns() / 1000; 
    u64 ts_minute = ts_us / (60 * 1000000);   
    if ((ts_minute / ##AUDIT_PERIOD## ) % 2 == 0) {
        if (udp_history0.lookup(t)) {
            find = 1;
        } else {
            udp_history0.update(t, &ts_us);
            
            if (udp_history1.lookup(t)) find = 1;
        }
    } else {
        if (udp_history1.lookup(t)) {
            find = 1;
        } else {
            udp_history1.update(t, &ts_us);
            
            if (udp_history0.lookup(t)) find = 1;
        }        
    }
    
    return find;        
}
int kprobe__tcp_v4_connect(struct pt_regs *ctx, struct sock *sk)
{
    u64 pid = bpf_get_current_pid_tgid();

    // stash the sock ptr for lookup on return
    tcp_connect_sock.update(&pid, &sk);

    return 0;
}

int kretprobe__tcp_v4_connect(struct pt_regs *ctx)
{
    int ret = PT_REGS_RC(ctx);
    u64 pid = bpf_get_current_pid_tgid();

    struct sock **skpp;
    skpp = tcp_connect_sock.lookup(&pid);
    if (skpp == 0) {
        return 0;       // missed entry
    }

    tcp_connect_sock.delete(&pid);

    if (ret != 0) {
        // failed to send SYNC packet, may not have populated
        // socket __sk_common.{skc_rcv_saddr, ...}
        return 0;
    }

    // pull in details
    struct sock *skp = *skpp;
    struct ipv4_tuple_t t = { };
    if (!read_ipv4_tuple(&t, skp)) {
        return 0;
    }

    struct pid_comm_t p = { };
    p.pid = pid >> 32;
    bpf_get_current_comm(&p.comm, sizeof(p.comm));

    tuple2pidcomm.update(&t, &p);

    return 0;
}

int kprobe__tcp_set_state(struct pt_regs *ctx, struct sock *sk, int state)
{
    if (state != TCP_ESTABLISHED && state != TCP_CLOSE) {
        return 0;
    }

    struct ipv4_tuple_t t = { 0 };
    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {
        if (state == TCP_CLOSE) {
            // tuple2pidcomm items should keep to tcp_close, for possible transfers
            tuple2pidcomm.delete(&t);
            return 0;
        }

        // process TCP_ESTABLISHED:
        struct pid_comm_t *p;
        p = tuple2pidcomm.lookup(&t);
        if (p == 0) {
            return 0;       // missed entry
        }

        struct ipv4_event_t evt4 = { 0 };
        evt4.ts_us = bpf_ktime_get_ns() / 1000;
        evt4.tx_b = 0;
        evt4.rx_b = 0;
        evt4.type = TCP_EVENT_TYPE_CONNECT;
        evt4.pid = p->pid;
        tuple_to_event(&evt4, &t);
        comm_copy(evt4.comm, p->comm);

        tcp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));
    }

    return 0;
}

int kprobe__tcp_close(struct pt_regs *ctx, struct sock *sk)
{
    u32 pid = bpf_get_current_pid_tgid() >> 32;

    u8 oldstate = sk->sk_state;
    // Don't generate close events for connections that were never
    // established in the first place.
    if (oldstate == TCP_SYN_SENT ||
        oldstate == TCP_SYN_RECV ||
        oldstate == TCP_NEW_SYN_RECV)
        return 0;

    struct ipv4_tuple_t t = { 0 };
    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {
        // check pid to see if it was transferred to others such as a child process.
        struct pid_comm_t *pp = tuple2pidcomm.lookup(&t);
        u32 oldpid = pid;
        if (pp) {
            oldpid = pp->pid;
        }
        
        struct ipv4_pidtuple_t pt = { 0 };
        tuple_to_pidtuple(&pt, &t);
        pt.pid =  oldpid;
        
        struct ipv4_event_t evt4 = { 0 };
        // get tx/rx before close action
        move_tcp_txrx_to_event(&evt4, &pt);
        
        evt4.ts_us = bpf_ktime_get_ns() / 1000;        
        evt4.type = TCP_EVENT_TYPE_CLOSE;
        evt4.pid = pid;
        tuple_to_event(&evt4, &t);
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));

        tcp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));
    }
   
    return 0;
};

int kretprobe__inet_csk_accept(struct pt_regs *ctx)
{
    struct sock *newsk = (struct sock *)PT_REGS_RC(ctx);
    u32 pid = bpf_get_current_pid_tgid() >> 32;

    if (newsk == NULL) {
        return 0;
    }

    // pull in details
    struct ipv4_tuple_t t = {};
    if (check_family(newsk, AF_INET) && read_ipv4_tuple(&t, newsk)) {
        struct ipv4_event_t evt4 = { 0 };
        evt4.ts_us = bpf_ktime_get_ns() / 1000;
        evt4.rx_b = 0;
        evt4.tx_b = 0;
        evt4.type = TCP_EVENT_TYPE_ACCEPT;
        evt4.pid = pid;
        tuple_to_event(&evt4, &t);
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));
        
        tcp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));

        // record pid for tracing, the sock may be transferred to others such as a child process.
        struct pid_comm_t p = { };
        p.pid = pid;
        comm_copy(p.comm, evt4.comm);

        tuple2pidcomm.update(&t, &p);          
    } 

  return 0;
}

// next process period audit
int kprobe__tcp_sendmsg(struct pt_regs *ctx, struct sock *sk,
    struct msghdr *msg, size_t size)
{
    u32 pid = bpf_get_current_pid_tgid() >> 32;
    struct ipv4_tuple_t t = {};

    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {
        struct ipv4_pidtuple_t key = {.pid = pid};
        tuple_to_pidtuple(&key, &t);
        tcp_ipv4_send_bytes.increment(key, size);
        
        // check pid to see if it was transferred to others such as a child process.
        struct pid_comm_t *pp = tuple2pidcomm.lookup(&t);
        if (!pp) return 0;
        
        u32 oldpid = pp->pid;
        if (oldpid == pid) return 0;
        

        // find whether the connection is transferred to new pid, then generate a event
        struct ipv4_pidtuple_t oldkey = { };
        oldkey.pid = oldpid;
        tuple_to_pidtuple(&oldkey, &t);
        
        struct ipv4_event_t evt4 = { 0 };
        move_tcp_txrx_to_event(&evt4, &oldkey);
  
        evt4.ts_us = bpf_ktime_get_ns() / 1000;        
        evt4.type = TCP_EVENT_TYPE_TRANSFER_S;
        evt4.pid = pid;
        tuple_to_event(&evt4, &t);
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));
        
        tcp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));
        
        // refresh tuple track 
        struct pid_comm_t newpc = {0};
        newpc.pid = pid;
        comm_copy(newpc.comm, evt4.comm);
        tuple2pidcomm.update(&t, &newpc);
    }
    
    return 0;
}

/*
 * tcp_recvmsg() would be obvious to trace, but is less suitable because:
 * - we'd need to trace both entry and return, to have both sock and size
 * - misses tcp_read_sock() traffic
 * we'd much prefer tracepoints once they are available.
 */
int kprobe__tcp_cleanup_rbuf(struct pt_regs *ctx, struct sock *sk, int copied)
{
    u32 pid = bpf_get_current_pid_tgid() >> 32;

    u16 dport = 0, family = sk->__sk_common.skc_family;
    u64 *val, zero = 0;

    if (copied <= 0)
        return 0;

    struct ipv4_tuple_t t = {};
    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {
        struct ipv4_pidtuple_t key = {.pid = pid};
        tuple_to_pidtuple(&key, &t);
        tcp_ipv4_recv_bytes.increment(key, copied);
        
        // check pid to see if it was transferred to others such as a child process.
        struct pid_comm_t *pp = tuple2pidcomm.lookup(&t);
        if (!pp) return 0;
        
        u32 oldpid = pp->pid;
        if (oldpid == pid) return 0;

        // find whether the connection is transferred to new pid, then generate a event with old tx/rx
        struct ipv4_pidtuple_t oldkey = { };
        oldkey.pid = oldpid;
        tuple_to_pidtuple(&oldkey, &t);
        
        struct ipv4_event_t evt4 = { 0 };
        move_tcp_txrx_to_event(&evt4, &oldkey);
  
        evt4.ts_us = bpf_ktime_get_ns() / 1000;        
        evt4.type = TCP_EVENT_TYPE_TRANSFER_R;
        evt4.pid = pid;
        tuple_to_event(&evt4, &t);
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));
        
        tcp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));
        
        // refresh tuple track 
        struct pid_comm_t newpc = {0};
        newpc.pid = pid;
        comm_copy(newpc.comm, evt4.comm);
        tuple2pidcomm.update(&t, &newpc);
    } 
    
    return 0;
}

// for udp
int kprobe__udp_sendmsg(struct pt_regs *ctx, struct sock *sk, struct msghdr *msg, size_t len)
{
    u64 tid = bpf_get_current_pid_tgid();
    u32 pid = tid >> 32;

    struct ipv4_tuple_t t = {};
    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {

        if (!udp_pid.lookup(&pid)) udp_pid.update(&pid, &tid);   

        if (find_or_set_history(&t)) {
            struct ipv4_pidtuple_t key = {.pid = pid};
            tuple_to_pidtuple(&key, &t);
            udp_ipv4_send_bytes.increment(key, len);
        } else {
            struct ipv4_event_t evt4 = { 0 };
            evt4.ts_us = bpf_ktime_get_ns() / 1000;
            evt4.rx_b = 0;
            evt4.tx_b = len;
            evt4.type = UDP_EVENT_TYPE_FIRST_SEND;
            evt4.pid = pid;
            tuple_to_event(&evt4, &t);
            bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));

            udp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));   
        }         
    }  
    
    return 0;
}

int kprobe__udp_recvmsg(struct pt_regs *ctx, struct sock *sk, struct msghdr *msg, size_t len, int noblock,
        int flags, int *addr_len)
{
    u64 tid = bpf_get_current_pid_tgid();      
    //stash the sock ptr for lookup on return 
    udp_connect_sock.update(&tid, &sk);
    u32 pid = tid >> 32;
    if (!udp_pid.lookup(&pid)) udp_pid.update(&pid, &tid);
    
    return 0;
}

int kretprobe__udp_recvmsg(struct pt_regs *ctx)
{
    u64 tid = bpf_get_current_pid_tgid();
    struct sock **skpp;
    skpp = udp_connect_sock.lookup(&tid);
    if (skpp == 0) return 0;
    
    struct sock *sk = *skpp;
    udp_connect_sock.delete(&tid);

    int copied = PT_REGS_RC(ctx);
    if (copied <= 0)
        return 0;  
        
    u32 pid = tid >> 32;

    struct ipv4_tuple_t t = {};
    if (check_family(sk, AF_INET) && read_ipv4_tuple(&t, sk)) {
        if (find_or_set_history(&t)) {
            struct ipv4_pidtuple_t key = {.pid = pid};
            tuple_to_pidtuple(&key, &t);
            udp_ipv4_recv_bytes.increment(key, copied);
        } else {
            struct ipv4_event_t evt4 = { 0 };
            evt4.ts_us = bpf_ktime_get_ns() / 1000;
            evt4.rx_b = copied;
            evt4.tx_b = 0;
            evt4.type = UDP_EVENT_TYPE_FIRST_RECV;
            evt4.pid = pid;
            tuple_to_event(&evt4, &t);
            bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));

            udp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));
        }           
    }      
    
    return 0;
}

// should notify the python script that the sock is gone, 
// but there are concurrency problems between map.increment in recv_msg and map.lookup in destroy
// if we can read the right map.increment in user-space event process,
// we can let python codes care about it.
int kprobe__udp_destroy_sock(struct pt_regs *ctx, struct sock *sk)
{
    u64 pid = bpf_get_current_pid_tgid();      
    //stash the sock ptr for lookup on return 
    udp_connect_sock.update(&pid, &sk);
    u32 pid32 = pid >> 32;
    u64 *p = udp_pid.lookup(&pid32);
    if ( p == 0)
        udp_pid.insert(&pid32, &pid);
    return 0;
}

int kretprobe__udp_destroy_sock(struct pt_regs *ctx) //, struct sock *sk)
{
    u64 pid = bpf_get_current_pid_tgid();
    struct sock **skpp;
    skpp = udp_connect_sock.lookup(&pid);
    if (skpp == 0) return 0;
    
    struct sock *sk = *skpp;
    udp_connect_sock.delete(&pid);
    
    pid = pid >> 32;

    struct ipv4_tuple_t t = {};
    //if (check_family(sk, AF_INET) && read_ipv4_tuple(&t,sk)) {
    if (0) {
        struct ipv4_pidtuple_t key = {.pid = pid};
        struct ipv4_event_t evt4 = { 0 };
        evt4.ts_us = bpf_ktime_get_ns() / 1000;
        evt4.rx_b = 0;
        evt4.tx_b = 0;
        evt4.type = UPD_EVENT_TYPE_DESTROY;
        evt4.pid = pid;
        tuple_to_event(&evt4, &t);
        // move_udp_txrx_to_event(&evt4, &key);  let python care about it.
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));

        udp_ipv4_events.perf_submit(ctx, &evt4, sizeof(evt4));            
    }         

    return 0;
}

// when the process exit, it is better to audit it immediately
TRACEPOINT_PROBE(sched, sched_process_exit)
{
    struct task_struct *task = (typeof(task))bpf_get_current_task();
    // to see if it is main thread
    if (task->tgid != task->pid) return 0;
    u32 pid = task->tgid;

    if (udp_pid.lookup(&pid)) {
        udp_pid.delete(&pid);
        
        struct ipv4_event_t evt4 = { 0 };
        evt4.ts_us = bpf_ktime_get_ns() / 1000;
        evt4.type = UPD_EVENT_TYPE_EXIT;
        evt4.saddr = task->pid;
        evt4.pid = pid;
        bpf_get_current_comm(&evt4.comm, sizeof(evt4.comm));

        udp_ipv4_events.perf_submit(args, &evt4, sizeof(evt4));     
        return 0;
    }
    
    return 0;
}

"""

# TCPSessionKey = namedtuple('TCPSession', ['pid', 'laddr', 'lport', 'daddr', 'dport'])
# TCPSessionKey = namedtuple('TCPSession', ['laddr', 'lport', 'daddr', 'dport'])
SessionKey = namedtuple('SessionKey', ['laddr', 'lport', 'daddr', 'dport'])

# repace AUDIT_PERIOD string in the bpf_text
bpf_text = bpf_text.replace('##AUDIT_PERIOD##', str(audit_period))

# initialize BPF
b = BPF(text=bpf_text)

tcp_ipv4_send_bytes = b["tcp_ipv4_send_bytes"]
tcp_ipv4_recv_bytes = b["tcp_ipv4_recv_bytes"]
udp_ipv4_send_bytes = b["udp_ipv4_send_bytes"]
udp_ipv4_recv_bytes = b["udp_ipv4_recv_bytes"]
udp_history0 = b["udp_history0"]
udp_history1 = b["udp_history1"]

# ipv6_send_bytes = b["ipv6_send_bytes"]
# ipv6_recv_bytes = b["ipv6_recv_bytes"]

# 【系统时间  主机名  审计类型(extented)  进程名  PID  协议类型  本地IP地址:本地端口  远端IP地址:远端端口  接收字节数  发送字节数】
#  系统时间 and 主机名 provided by logger
audit_format_string = "%-10s %-15s %5d %-3s %-21s %-21s %12d %12d"
# udp_destroy_format_string = "%-10s %-15s %5d %-3s %-21s %-21s"

# period audit should print accumulated tx/rx
# close events also need them
g_tcp_ipv4_throughput = defaultdict(lambda: [0, 0, 0])

# todo: global udp store should be cleaned in a certain way
g_udp_ipv4_throughput = defaultdict(lambda: [0, 0, 0])

# go through the udp hash map when we get a sock destroy event
# those items should not be processed again at next period
g_udp_key_processed = set()


class PidCommHelper(object):
    def __init__(self):
        # store = {pid: [comm, access_time]} for that period audit cannot find the comm of pid in /proc/pid/comm
        # when the process is already gone.
        # items will be expired after 2 audit periods
        self.store = {}

    def get_comm(self, pid):
        # if find it in the store, update access time and return the comm
        access_time = sys_up_time()
        if pid in self.store:
            self.store[pid][1] = access_time
            return self.store[pid][0], True

        try:
            comm = open("/proc/%d/comm" % pid, "r").read().rstrip()
            self.store[pid] = [comm, access_time]

            return comm, True
        except IOError:
            return str(pid), False

    def clear_expired(self, now, interval):
        for key, (comm, past) in sorted(self.store.items()):
            if (now - past) > interval:
                del self.store[key]


pid_comm_helper = PidCommHelper()


# process event
def audit_tcp_ipv4_event(cpu, data, size):
    event = b["tcp_ipv4_events"].event(data)
    command, found = pid_comm_helper.get_comm(event.pid)
    if not found:
        command = event.comm.decode('utf-8', 'replace')
    tx_b = event.tx_b
    rx_b = event.rx_b

    # close event (type=3) brings the tx, rx bytes in the last period
    # remove the item in period perf_hash
    key = get_ipv4_session_key(event)
    if event.type == 3 or event.type == 4 or event.type == 5:
        tx_rx_bytes = g_tcp_ipv4_throughput.pop(key, None)
        if tx_rx_bytes:
            tx_b += tx_rx_bytes[0]
            rx_b += tx_rx_bytes[1]

    # transfer event (type=4, 5)
    if event.type == 4 or event.type == 5:
        g_tcp_ipv4_throughput[key] = [tx_b, rx_b, event.pid]

    audit_msg = (audit_format_string % (tcptype2str(event.type), command,
                                        event.pid, "TCP",
                                        inet_ntop(AF_INET, pack("I", event.saddr))+":"+str(event.sport),
                                        inet_ntop(AF_INET, pack("I", event.daddr))+":"+str(event.dport),
                                        rx_b, tx_b))

    logger.info(audit_msg)


def audit_udp_ipv4_event(cpu, data, size):
    event = b["udp_ipv4_events"].event(data)
    key = get_ipv4_session_key(event)
    command, found = pid_comm_helper.get_comm(event.pid)
    if not found:
        command = event.comm.decode('utf-8', 'replace')

    tx_b = event.tx_b
    rx_b = event.rx_b

    # type = 1: udp connect
    if event.type == 1:
        # save to global store
        g_udp_ipv4_throughput[key] = [tx_b, rx_b, event.pid]
        audit_msg = (audit_format_string % (udptype2str(event.type), command,
                                            event.pid, "UDP",
                                            key.laddr+":"+str(event.sport), key.daddr+":"+str(event.dport),
                                            rx_b, tx_b))

        logger.info(audit_msg)

    # udp sock destroy event may not be able to brings the last tx/rx bytes out in the bpf_hash
    # remove the item in period perf_hash
    if event.type == 3:
        # todo: should respond to udp sock destroy event, but the bpf code cannot bring out the last
        #       rx/tx in map becasuse of concurrency condition, python codes will get rx/tx from the map,
        #       remember them and then omitted them in next period audit.
        print("udp sock destroy event: %s:%-5d %s:%-5d" % (key.laddr, key.lport, key.daddr, key.dport))
        tx_rx_bytes = g_udp_ipv4_throughput.pop(key, None)
        if tx_rx_bytes:
            tx_b += tx_rx_bytes[0]
            rx_b += tx_rx_bytes[1]

        for k, v in udp_ipv4_send_bytes.items():
            key2 = get_ipv4_session_key(k)
            if key.lport == key2.lport:
                print("udp send items: %s:%-5d %s:%-5d" % (key.laddr, key.lport, key.daddr, key.dport))
            if operator.eq(key, key2):
                tx_b += v.value
                g_udp_key_processed.add(key)
            break

        for k, v in udp_ipv4_recv_bytes.items():
            key2 = get_ipv4_session_key(k)
            if key.lport == key2.lport:
                print("udp recv items: %s:%-5d %s:%-5d" % (key.laddr, key.lport, key.daddr, key.dport))
            if operator.eq(key, key2):
                rx_b += v.value
                g_udp_key_processed.add(key)
            break

        audit_msg = (audit_format_string % (udptype2str(event.type), command,
                                            event.pid, "UDP",
                                            key.laddr+":"+str(event.sport), key.daddr+":"+str(event.dport),
                                            rx_b, tx_b))
        logger.info(audit_msg)

    # type = 4, process exit event, should output all udp records in time
    if event.type == 4:
        # IPv4 UDP: build dict of all seen keys
        # print("event.type = 4, pid = %d, tid = %d" % (event.pid, event.saddr))
        udp_ipv4_throughput = defaultdict(lambda: [0, 0, 0])
        keys = set()
        for k, v in udp_ipv4_send_bytes.items():
            if k.pid == event.pid:
                key = get_ipv4_session_key(k)
                keys.add(key)
                g_udp_ipv4_throughput[key][0] += v.value
                g_udp_ipv4_throughput[key][2] = k.pid
                g_udp_key_processed.add(key)

        for k, v in udp_ipv4_recv_bytes.items():
            if k.pid == event.pid:
                key = get_ipv4_session_key(k)
                keys.add(key)
                g_udp_ipv4_throughput[key][1] += v.value
                g_udp_ipv4_throughput[key][2] = k.pid
                g_udp_key_processed.add(key)

        for key in keys:
            udp_ipv4_throughput[key] = g_udp_ipv4_throughput[key]

        # output
        for k, (tx_b, rx_b, pid) in sorted(udp_ipv4_throughput.items(),
                                           key=lambda kv: (kv[1][0] + kv[1][1]),
                                           reverse=True):
            audit_msg = (audit_format_string % (udptype2str(event.type), pid_comm_helper.get_comm(pid)[0],
                                                pid, "UDP", k.laddr + ":" + str(k.lport), k.daddr + ":" + str(k.dport),
                                                rx_b, tx_b))
            logger.info(audit_msg)


def audit_ipv6_event(cpu, data, size):
    pass


def get_ipv4_session_key(k):
    return SessionKey(       # pid=k.pid, omit pid, just orient to (tcp, ladd, lport, dadd, dport)
                         laddr=inet_ntop(AF_INET, pack("I", k.saddr)),
                         lport=k.sport,
                         daddr=inet_ntop(AF_INET, pack("I", k.daddr)),
                         dport=k.dport)


def get_ipv6_session_key(k):
    return SessionKey( #pid=k.pid,
                         laddr=inet_ntop(AF_INET6, k.saddr),
                         lport=k.sport,
                         daddr=inet_ntop(AF_INET6, k.daddr),
                         dport=k.dport)


def audit_func(type_prompt="周期审计 PERI"):
    global g_tcp_ipv4_throughput
    global g_udp_ipv4_throughput

    # IPv4 TCP: build dict of all seen keys
    tcp_ipv4_throughput = defaultdict(lambda: [0, 0, 0])
    keys = set()
    for k, v in tcp_ipv4_send_bytes.items():
        key = get_ipv4_session_key(k)
        keys.add(key)
        g_tcp_ipv4_throughput[key][0] += v.value
        g_tcp_ipv4_throughput[key][2] = k.pid

    tcp_ipv4_send_bytes.clear()

    for k, v in tcp_ipv4_recv_bytes.items():
        key = get_ipv4_session_key(k)
        g_tcp_ipv4_throughput[key][1] += v.value
        g_tcp_ipv4_throughput[key][2] = k.pid

    tcp_ipv4_recv_bytes.clear()

    for key in keys:
        tcp_ipv4_throughput[key] = g_tcp_ipv4_throughput[key]

    # output
    for k, (tx_b, rx_b, pid) in sorted(tcp_ipv4_throughput.items(),
                                              key=lambda kv: (kv[1][0]+kv[1][1]),
                                              reverse=True):
        audit_msg = (audit_format_string % (type_prompt, pid_comm_helper.get_comm(pid)[0],
                                            pid, "TCP", k.laddr+":"+str(k.lport), k.daddr+":"+str(k.dport),
                                            rx_b, tx_b))
        logger.info(audit_msg)

    # store the accumulated tx/rx to the global dict
    g_tcp_ipv4_throughput.update(tcp_ipv4_throughput)

    # IPv4 UDP: build dict of all seen keys
    udp_ipv4_throughput = defaultdict(lambda: [0, 0, 0])
    keys.clear()
    for k, v in udp_ipv4_send_bytes.items():
        key = get_ipv4_session_key(k)
        if key in g_udp_key_processed:
            continue

        keys.add(key)
        g_udp_ipv4_throughput[key][0] += v.value
        g_udp_ipv4_throughput[key][2] = k.pid

    udp_ipv4_send_bytes.clear()

    for k, v in udp_ipv4_recv_bytes.items():
        key = get_ipv4_session_key(k)
        if key in g_udp_key_processed:
            continue

        keys.add(key)
        g_udp_ipv4_throughput[key][1] += v.value
        g_udp_ipv4_throughput[key][2] = k.pid

    udp_ipv4_recv_bytes.clear()

    g_udp_key_processed.clear()

    for key in keys:
        udp_ipv4_throughput[key] = g_udp_ipv4_throughput[key]

    # output
    for k, (tx_b, rx_b, pid) in sorted(udp_ipv4_throughput.items(),
                                              key=lambda kv: (kv[1][0]+kv[1][1]),
                                              reverse=True):
        audit_msg = (audit_format_string % (type_prompt, pid_comm_helper.get_comm(pid)[0],
                                            pid, "UDP",k.laddr+":"+str(k.lport), k.daddr+":"+str(k.dport),
                                            rx_b, tx_b))
        logger.info(audit_msg)

    return


def sys_up_time():
    with open('/proc/uptime', 'r') as f:
        return int(float(f.readline().split()[0]))


# period: in seconds.
class PeriodicAudit(object):
    def __init__(self, period):
        self.next_audit = time.time()
        self.done = False
        self.period = period
        self.audit()

    def audit(self):
        # print("in Periodic Audit")
        self.next_audit += self.period
        if not self.done:
            audit_func()
            # clear the udp_history_0,1 bpf_maps interleaved
            up_time = sys_up_time()
            even_odd = int(up_time / (2 * self.period) )

            # clear udp history bpf_map, help to judge if a udp tx_rx is a new connection
            if even_odd % 2 == 0:
                udp_history1.clear()
            else:
                udp_history0.clear()
                pid_comm_helper.clear_expired(up_time, 2 * self.period)

            threading.Timer(self.next_audit - time.time(), self.audit).start()

    def stop(self):
        self.done = True


# read events
b["tcp_ipv4_events"].open_perf_buffer(audit_tcp_ipv4_event, page_cnt=64)
b["udp_ipv4_events"].open_perf_buffer(audit_udp_ipv4_event, page_cnt=64)
# b["ipv6_events"].open_perf_buffer(audit_ipv6_event, page_cnt=64)

if args.console:
    print('Auditing... Output every %d secs. Hit Ctrl-C to end' % audit_period)

# period events
auditor = PeriodicAudit(audit_period)

# signal events
audit_sessions = True


def signal_handle(signum, frame):
    global audit_sessions
    audit_sessions = False


# SIGINT for ctrl+c, SIGTERM for systemctl stop service
signal.signal(signal.SIGINT, signal_handle)
signal.signal(signal.SIGTERM, signal_handle)

# audit sessions, until catch a SIGINT or SIGTERM
while audit_sessions:
    b.perf_buffer_poll()

auditor.stop()
audit_func(type_prompt="审计退出 EXIT")

if args.console:
    print("The auditor stopped.")

os._exit(0)
