#!/usr/bin/env python3 import argparse import ipaddress import logging import os import socket import sys import sqlite3 import threading from collections import OrderedDict from glob import glob from threading import Thread import git from flask import Flask, json, request, Response from waitress import serve logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S') log = logging.getLogger('asn') log.setLevel(logging.DEBUG) class Common: def get_netblock(self, ip): net = None try: net = ipaddress.ip_network(ip) except: return net if net.version == 4: net = net.supernet(new_prefix=24) elif net.version == 6: net = net.supernet(new_prefix=64) return net def is_invalid(self, ip): net = ipaddress.ip_network(ip) return net.is_loopback or net.is_private or net.is_multicast def get_announcements(self, hosts): announcements = [] for host in hosts: if self.is_invalid(host): continue n = self.get_netblock(host) if not n: continue res = self.db.query(n) if res: res = list(res[0]) res.insert(0, str(host)) announcements.extend([res]) return announcements class HTTPListener(Common): def __init__(self, db, host, port): Common.__init__(self) self.db = db self._app = FlaskWrapper("asn") self._app.add_endpoint(endpoint="/", name="index", handler=self._handler) self._app.run(host=host, port=port) def _handler(self): if request.headers.getlist("X-Forwarded-For"): ip = request.headers.getlist("X-Forwarded-For")[0] else: ip = request.remote_addr log.info(f'{ip} {request.path}') data = Common.get_announcements(self, [ipaddress.ip_address(ip)]) if data: data = data[0] else: return "no announcement found", 404 res = OrderedDict() res["host"] = data[0] res["hostname"] = socket.gethostbyaddr(data[0])[0] res["org"] = f"AS{data[1]} {data[3]}" res["announcement"] = data[4] return Response(json.dumps(res, indent=2), mimetype="application/json") class FlaskWrapper: def __init__(self, name): self.app = Flask(name) self.app.config['JSON_SORT_KEYS'] = False def run(self, **kwargs): serve(self.app, **kwargs) def add_endpoint(self, endpoint=None, name=None, handler=None): self.app.add_url_rule(endpoint, name, HTTPHandler(handler)) class HTTPHandler: def __init__(self, action): self.action = action def __call__(self): res = self.action() return res class WHOISListener(Common): def __init__(self, db, host, port): Common.__init__(self) self.db = db self._listen(host, port) def _listen(self, host, port): with socket.socket() as _socket: _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _socket.bind((host, port)) _socket.listen() while True: conn, addr = _socket.accept() threading.Thread(target=self._handler, args=(conn,addr,), daemon=True).start() def _handler(self, conn, addr): resp = '' try: recv = conn.recv(1024) recv = str(recv, 'utf-8').strip() except Exception: pass else: log.info(f'{addr[0]} {recv}') hosts = set() try: ip = ipaddress.ip_address(recv) hosts.add(ip) except ValueError: try: hosts = self._resolve(recv) except: pass finally: announcements = Common.get_announcements(self, hosts) if announcements: resp = self._pretty(announcements) finally: conn.sendall(bytes(resp, 'utf-8')) conn.shutdown(socket.SHUT_RDWR) conn.close() def _resolve(self, hostname): info = socket.getaddrinfo(hostname, 80, proto=socket.IPPROTO_TCP) hosts = set() for i in info: hosts.add(i[4][0]) return hosts def _pretty(self, announces): announces = sorted(announces, key=lambda x: ipaddress.ip_network(x[4]).version) head = ('IP Address', 'AS Number', 'Country', 'AS Name', 'Announcement') announces.insert(0, head) w = [len(max(i, key=lambda x: len(str(x)))) for i in zip(*announces)] out = '' header, data = announces[0], announces[1:] out += ' | '.join(format(title, "%ds" % width) for width, title in zip(w, header)) out += '\n' + '-+-'.join( '-' * width for width in w ) + '\n' for row in data: out += " | ".join(format(str(cdata), "%ds" % width) for width, cdata in zip(w, row)) out += '\n' return out class DB: def __init__(self): self.repo_path = os.path.dirname(os.path.abspath(__file__)) self.db_path = os.path.join(self.repo_path, 'cache.db') self.con = sqlite3.connect(self.db_path, check_same_thread=False) loc = os.path.join(self.repo_path, 'location-database') self.dataset = os.path.join(loc, 'database.txt') self.overrides = [] for p in os.walk(os.path.join(loc, 'overrides')): for f in glob(os.path.join(p[0], '*.txt')): self.overrides.append(f) def populate_db(self): with self.con: self.con.execute('PRAGMA foreign_keys=OFF') self.con.execute('DROP TABLE IF EXISTS net') self.con.execute('DROP TABLE IF EXISTS asn') self.con.execute(''' CREATE TABLE IF NOT EXISTS asn ( aut_num INTEGER NOT NULL PRIMARY KEY, name TEXT ) ''') self.con.execute(''' CREATE UNIQUE INDEX idx_aut_num ON asn(aut_num) ''') self.con.execute(''' CREATE TABLE IF NOT EXISTS net ( id integer NOT NULL PRIMARY KEY, aut_num INTEGER, net TEXT, country TEXT, FOREIGN KEY(aut_num) REFERENCES asn(aut_num) ) ''') self.con.execute(''' CREATE UNIQUE INDEX idx_net ON net(net) ''') self.con.execute('PRAGMA foreign_keys=ON') for txt in self.overrides: self._get_entries(txt) self._get_entries(self.dataset) def update(self): if not self._submodule_pull(): return False return True def _submodule_pull(self): repo = git.Repo(self.repo_path) updated = False for module in repo.submodules: module.module().git.checkout('master') current = module.module().head.commit log.info(f'current location-db commit: {current}') module.module().remotes.origin.pull() if current != module.module().head.commit: updated = True return updated def _get_entries(self, txt): with open(txt, 'r') as f: kv = {} while True: try: line = next(f) except StopIteration: break if not line.strip() or line.strip().startswith('#'): if kv: # key correction for overrides; uses descr if kv.get('descr'): kv['name'] = kv.pop('descr') self._add(kv) kv = {} continue (k, v) = (x.strip() for x in line.split(':', 1)) kv[k] = v def _add(self, kv): # ASN information block if kv.get('aut-num') and kv['aut-num'].startswith('AS'): self.con.execute(''' INSERT OR REPLACE INTO asn(aut_num, name) VALUES(?,?) ''', (kv['aut-num'][2:], kv.get('name'))) if kv.get('net'): self.con.execute(''' INSERT OR REPLACE INTO net(aut_num, net, country) VALUES((SELECT aut_num FROM asn WHERE aut_num = ?),?,?) ''', (kv.get('aut-num'), kv.get('net'), kv.get('country'))) def query(self, net): while True: rows = self.con.execute(''' SELECT net.aut_num, net.country, asn.name, net.net FROM net INNER JOIN asn on asn.aut_num = ( SELECT aut_num FROM net WHERE net = ? ) WHERE net.net = ? ''', (str(net), str(net))).fetchall() if len(rows) != 0: return rows if net.prefixlen > 0: net = net.supernet() else: break if __name__ == '__main__': desc = 'asn: map hosts to their corresponding ASN via WHOIS' parser = argparse.ArgumentParser(description=desc, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--whois-host', dest='whois_host', type=str, action='store', help='IP to listen on for WHOIS service', default="127.0.0.1", required=False) parser.add_argument('--whois-port', dest='whois_port', type=int, action='store', help='Port to listen on for WHOIS service', default=4343, required=False) parser.add_argument('--http-host', dest='http_host', type=str, action='store', help='IP to listen on for HTTP service', default="127.0.0.1", required=False) parser.add_argument('--http-port', dest='http_port', type=int, action='store', help='Port to listen on for HTTP service', default=8080, required=False) parser.add_argument('--update', dest='update', action='store_true', help='Update dataset submodule and create/populate cache', required=False) parser.add_argument('--populate', dest='populate', action='store_true', help='Create and populate cache from current dataset', required=False) args = parser.parse_args() db = DB() if args.populate and args.update: log.error('--populate and --update used; redundant, use one') sys.exit() if args.populate: log.info('creating and populating db cache...') db.populate_db() if args.update: log.info('checking remote repository for new dataset...') if db.update(): log.info('dataset fetched and updated') db.populate_db() else: log.info('no changes since last update') if args.whois_host and args.whois_port: log.info(f'WHOIS: listening on {args.whois_host}:{args.whois_port}') whois = Thread(target=WHOISListener, args=(db, args.whois_host, args.whois_port)) whois.start() if args.http_host and args.http_port: log.info(f'HTTP: listening on {args.http_host}:{args.http_port}') http = Thread(target=HTTPListener, args=(db, args.http_host, args.http_port)) http.start()