diff options
Diffstat (limited to 'asn.py')
-rwxr-xr-x | asn.py | 227 |
1 files changed, 144 insertions, 83 deletions
@@ -8,17 +8,109 @@ import socket import sys import sqlite3 import threading +from collections import OrderedDict from glob import glob +from threading import Thread import git +from flask import Flask, json, request, Response +from waitress import serve logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S') log = logging.getLogger('asn') log.setLevel(logging.DEBUG) -class Listener: +class Common: + def get_netblock(self, ip): + net = None + try: + net = ipaddress.ip_network(ip) + except: + return net + + if net.version == 4: + net = net.supernet(new_prefix=24) + elif net.version == 6: + net = net.supernet(new_prefix=64) + + return net + + def is_invalid(self, ip): + net = ipaddress.ip_network(ip) + + return net.is_loopback or net.is_private or net.is_multicast + + def get_announcements(self, hosts): + announcements = [] + for host in hosts: + if self.is_invalid(host): + continue + + n = self.get_netblock(host) + if not n: + continue + + res = self.db.query(n) + if res: + res = list(res[0]) + res.insert(0, str(host)) + announcements.extend([res]) + + return announcements + +class HTTPListener(Common): def __init__(self, db, host, port): + Common.__init__(self) + self.db = db + + self._app = FlaskWrapper("asn") + self._app.add_endpoint(endpoint="/", name="index", handler=self._handler) + self._app.run(host=host, port=port) + + def _handler(self): + if request.headers.getlist("X-Forwarded-For"): + ip = request.headers.getlist("X-Forwarded-For")[0] + else: + ip = request.remote_addr + + log.info(f'{ip} {request.path}') + data = Common.get_announcements(self, [ipaddress.ip_address(ip)]) + if data: + data = data[0] + else: + return "no announcement found", 404 + + res = OrderedDict() + res["host"] = data[0] + res["hostname"] = socket.gethostbyaddr(data[0])[0] + res["org"] = f"AS{data[1]} {data[3]}" + res["announcement"] = data[4] + + return Response(json.dumps(res, indent=2), mimetype="application/json") + +class FlaskWrapper: + def __init__(self, name): + self.app = Flask(name) + self.app.config['JSON_SORT_KEYS'] = False + + def run(self, **kwargs): + serve(self.app, **kwargs) + + def add_endpoint(self, endpoint=None, name=None, handler=None): + self.app.add_url_rule(endpoint, name, HTTPHandler(handler)) + +class HTTPHandler: + def __init__(self, action): + self.action = action + + def __call__(self): + res = self.action() + return res + +class WHOISListener(Common): + def __init__(self, db, host, port): + Common.__init__(self) self.db = db self._listen(host, port) @@ -35,52 +127,38 @@ class Listener: def _handler(self, conn, addr): resp = '' try: - recv_data = conn.recv(1024) - recv_data = str(recv_data, 'utf-8').strip() - except ConnectionResetError: - log.info(f'{addr[0]} connection reset') - except UnicodeDecodeError: - log.info(f'{addr[0]} could not decode to utf-8') - except Exception as err: - log.info(f'{addr[0]} {err}') + recv = conn.recv(1024) + recv = str(recv, 'utf-8').strip() + except Exception: + pass else: - log.info(f'{addr[0]} {recv_data}') + log.info(f'{addr[0]} {recv}') - announcements = self._get_announcements(recv_data) - if announcements: - resp = self._pretty(announcements) + hosts = set() + try: + ip = ipaddress.ip_address(recv) + hosts.add(ip) + except ValueError: + try: + hosts = self._resolve(recv) + except: + pass + finally: + announcements = Common.get_announcements(self, hosts) + if announcements: + resp = self._pretty(announcements) finally: conn.sendall(bytes(resp, 'utf-8')) conn.shutdown(socket.SHUT_RDWR) conn.close() - def _get_announcements(self, recv): + def _resolve(self, hostname): + info = socket.getaddrinfo(hostname, 80, proto=socket.IPPROTO_TCP) hosts = set() - try: - ip = ipaddress.ip_address(recv) - hosts.add(ip) - except ValueError: - try: - hosts = self._resolve(recv) - except: - return [] - finally: - announcements = [] - for host in hosts: - if self._is_invalid(host): - continue - - n = self._get_netblock(host) - if not n: - continue - - res = self.db.query(n) - if res: - res = list(res[0]) - res.insert(0, str(host)) - announcements.extend([res]) + for i in info: + hosts.add(i[4][0]) - return announcements + return hosts def _pretty(self, announces): announces = sorted(announces, key=lambda x: ipaddress.ip_network(x[4]).version) @@ -102,33 +180,6 @@ class Listener: return out - def _resolve(self, hostname): - info = socket.getaddrinfo(hostname, 80, proto=socket.IPPROTO_TCP) - hosts = set() - for i in info: - hosts.add(i[4][0]) - - return hosts - - def _is_invalid(self, ip): - net = ipaddress.ip_network(ip) - - return net.is_loopback or net.is_private or net.is_multicast - - def _get_netblock(self, ip): - net = None - try: - net = ipaddress.ip_network(ip) - except: - return net - - if net.version == 4: - net = net.supernet(new_prefix=24) - elif net.version == 6: - net = net.supernet(new_prefix=64) - - return net - class DB: def __init__(self): self.repo_path = os.path.dirname(os.path.abspath(__file__)) @@ -179,8 +230,8 @@ class DB: def update(self): if not self._submodule_pull(): return False - else: - return True + + return True def _submodule_pull(self): repo = git.Repo(self.repo_path) @@ -200,7 +251,7 @@ class DB: def _get_entries(self, txt): with open(txt, 'r') as f: - kv = dict() + kv = {} while True: try: line = next(f) @@ -213,7 +264,7 @@ class DB: if kv.get('descr'): kv['name'] = kv.pop('descr') self._add(kv) - kv = dict() + kv = {} continue @@ -252,12 +303,19 @@ class DB: if __name__ == '__main__': desc = 'asn: map hosts to their corresponding ASN via WHOIS' - parser = argparse.ArgumentParser(description=desc) - parser.add_argument('--host', dest='host', type=str, action='store', - help='IP address to listen on', - required=False) - parser.add_argument('--port', dest='port', type=int, action='store', - help='Port to listen on', + parser = argparse.ArgumentParser(description=desc, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--whois-host', dest='whois_host', type=str, action='store', + help='IP to listen on for WHOIS service', + default="127.0.0.1", required=False) + parser.add_argument('--whois-port', dest='whois_port', type=int, action='store', + help='Port to listen on for WHOIS service', + default=4343, required=False) + parser.add_argument('--http-host', dest='http_host', type=str, action='store', + help='IP to listen on for HTTP service', + default="127.0.0.1", required=False) + parser.add_argument('--http-port', dest='http_port', type=int, action='store', + help='Port to listen on for HTTP service', default=8080, required=False) parser.add_argument('--update', dest='update', action='store_true', help='Update dataset submodule and create/populate cache', @@ -266,11 +324,7 @@ if __name__ == '__main__': help='Create and populate cache from current dataset', required=False) args = parser.parse_args() - db = DB() - if not len(sys.argv) > 1: - parser.print_help(sys.stderr) - sys.exit() if args.populate and args.update: log.error('--populate and --update used; redundant, use one') @@ -288,7 +342,14 @@ if __name__ == '__main__': else: log.info('no changes since last update') - if args.host and args.port: - log.info(f'listening on {args.host}:{args.port}') - listen = Listener(db, args.host, args.port) - + if args.whois_host and args.whois_port: + log.info(f'WHOIS: listening on {args.whois_host}:{args.whois_port}') + whois = Thread(target=WHOISListener, args=(db, args.whois_host, + args.whois_port)) + whois.start() + + if args.http_host and args.http_port: + log.info(f'HTTP: listening on {args.http_host}:{args.http_port}') + http = Thread(target=HTTPListener, args=(db, args.http_host, + args.http_port)) + http.start() |