diff options
Diffstat (limited to 'asn.py')
-rwxr-xr-x | asn.py | 259 |
1 files changed, 259 insertions, 0 deletions
@@ -0,0 +1,259 @@ +#!/usr/bin/env python3 + +import argparse +import ipaddress +import logging +import os +import socket +import sys +import sqlite3 +import threading + +import git + +logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s', + datefmt='%m/%d/%Y %H:%M:%S') +log = logging.getLogger('asn') +log.setLevel(logging.DEBUG) + +class Listener: + def __init__(self, host, port, ipv6_enabled=False): + self.ipv6_enabled = ipv6_enabled + self._listen(host, port) + + def _listen(self, host, port): + with socket.socket() as _socket: + _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + _socket.bind((host, port)) + _socket.listen() + while True: + conn, addr = _socket.accept() + threading.Thread(target=self._handler, + args=(conn,addr,), daemon=True).start() + + def _handler(self, conn, addr): + recv_data = conn.recv(1024) + if not recv_data: + conn.close() + + recv_data = str(recv_data, 'utf-8').strip() + log.info(f'{addr[0]} {recv_data}') + + announcements = self._get_announcements(recv_data) + announcements = self._pretty(announcements) + + if not announcements.strip(): + announcements = 'no valid hostname or IP discovered' + + conn.sendall(bytes(announcements, 'utf-8')) + conn.shutdown(socket.SHUT_RDWR) + conn.close() + + def _get_announcements(self, recv): + db = DB() + hosts = set() + try: + ip = ipaddress.ip_address(recv) + hosts.add(ip) + except ValueError: + try: + hosts = self._resolve(recv) + except: + return [] + finally: + announcements = [] + for host in hosts: + if self._is_invalid(host): + return [] + + n = self._get_netblock(host) + if n: + announcements.extend(db.query(n)) + + return announcements + + def _pretty(self, announcements): + out = [] + for x in announcements: + out.append(' '.join(('AS'+str(x[0]), x[1], x[2], x[3]))) + + return '\n'.join(out) + '\n' + + def _resolve(self, hostname): + info = socket.getaddrinfo(hostname, 80, proto=socket.IPPROTO_TCP) + hosts = set() + for i in info: + hosts.add(i[4][0]) + + return hosts + + def _is_invalid(self, ip): + net = ipaddress.ip_network(ip) + + return net.is_loopback or net.is_private or net.is_multicast + + def _get_netblock(self, ip): + net = None + try: + net = ipaddress.ip_network(ip) + except: + return net + + if net.version == 4: + net = net.supernet(new_prefix=24) + elif net.version == 6: + if self.ipv6_enabled: + net = net.supernet(new_prefix=64) + else: + return None + + return net + +class DB: + def __init__(self): + self.repo_path = os.path.dirname(os.path.abspath(__file__)) + self.db_path = os.path.join(self.repo_path, 'cache.db') + self.con = sqlite3.connect(self.db_path) + + loc = os.path.join(self.repo_path, 'location-database') + self.db_txt = os.path.join(loc, 'database.txt') + + def populate_db(self): + with self.con: + self.con.execute('PRAGMA foreign_keys=OFF') + self.con.execute('DROP TABLE IF EXISTS net') + self.con.execute('DROP TABLE IF EXISTS asn') + self.con.execute(''' + CREATE TABLE IF NOT EXISTS asn ( + aut_num INTEGER NOT NULL PRIMARY KEY, + name TEXT + ) + ''') + self.con.execute(''' + CREATE UNIQUE INDEX idx_aut_num ON asn(aut_num) + ''') + self.con.execute(''' + CREATE TABLE IF NOT EXISTS net ( + id integer NOT NULL PRIMARY KEY, + aut_num INTEGER, + net TEXT, + country TEXT, + FOREIGN KEY(aut_num) REFERENCES asn(aut_num) + ) + ''') + self.con.execute('PRAGMA foreign_keys=ON') + + self._get_entries() + + def update(self): + if not self._submodule_pull(): + return False + else: + return True + + def _submodule_pull(self): + repo = git.Repo(self.repo_path) + + updated = False + for module in repo.submodules: + module.module().git.checkout('master') + + current = module.module().head.commit + log.info(f'current location-db commit: {current}') + + module.module().remotes.origin.pull() + if current != module.module().head.commit: + updated = True + + return updated + + def _get_entries(self): + with open(self.db_txt, 'r') as f: + kv = dict() + while True: + try: + line = next(f) + except StopIteration: + break + + if not line.strip() or line.strip().startswith('#'): + if kv: + self._add(kv) + kv = dict() + + continue + + (k, v) = (x.strip() for x in line.split(':', 1)) + kv[k] = v + + + def _add(self, kv): + # ASN information block + if kv.get('aut-num') and kv['aut-num'].startswith('AS'): + self.con.execute(''' + INSERT OR REPLACE INTO asn(aut_num, name) VALUES(?,?) + ''', (kv['aut-num'][2:], kv.get('name'))) + + if kv.get('net'): + self.con.execute(''' + INSERT OR REPLACE INTO net(aut_num, net, country) + VALUES((SELECT aut_num FROM asn WHERE aut_num = ?),?,?) + ''', (kv.get('aut-num'), kv.get('net'), kv.get('country'))) + + def query(self, net): + announcements = [] + while True: + rows = self.con.execute(''' + SELECT net.aut_num, net.country, asn.name, net.net + FROM net + INNER JOIN asn on asn.aut_num = ( + SELECT aut_num FROM net WHERE net = ? + ) + WHERE net.net = ? + ''', (str(net), str(net))).fetchall() + if len(rows) != 0: + announcements.extend(rows) + break + if net.prefixlen > 0: + net = net.supernet() + else: + break + + return announcements + +if __name__ == '__main__': + desc = 'asn: map hosts to their corresponding ASN via WHOIS' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument('--host', dest='host', type=str, action='store', + help='IP address to listen on', + required=False) + parser.add_argument('--port', dest='port', type=int, action='store', + help='Port to listen on', + required=False) + parser.add_argument('--ipv6', dest='ipv6_enabled', action='store_true', + help='Support queries for IPv6 hosts', + required=False) + parser.add_argument('--update', dest='update', action='store_true', + help='Update dataset submodule and create/populate cache', + required=False) + parser.add_argument('--populate', dest='populate', action='store_true', + help='Create and populate cache from current dataset', + required=False) + args = parser.parse_args() + + if args.host and args.port: + listen = Listener(args.host, args.port, args.ipv6_enabled) + elif args.update: + db = DB() + log.info('checking remote repository for new dataset...') + if db.update(): + log.info('dataset updated, creating/populating cache...') + db.populate_db() + else: + log.info('no changes since last update') + elif args.populate: + db = DB() + log.info('creating/populating cache...') + db.populate_db() + else: + parser.print_help(sys.stderr) |