aboutsummaryrefslogtreecommitdiff
path: root/asn.py
diff options
context:
space:
mode:
Diffstat (limited to 'asn.py')
-rwxr-xr-xasn.py259
1 files changed, 259 insertions, 0 deletions
diff --git a/asn.py b/asn.py
new file mode 100755
index 0000000..1b0d3ec
--- /dev/null
+++ b/asn.py
@@ -0,0 +1,259 @@
+#!/usr/bin/env python3
+
+import argparse
+import ipaddress
+import logging
+import os
+import socket
+import sys
+import sqlite3
+import threading
+
+import git
+
+logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s',
+ datefmt='%m/%d/%Y %H:%M:%S')
+log = logging.getLogger('asn')
+log.setLevel(logging.DEBUG)
+
+class Listener:
+ def __init__(self, host, port, ipv6_enabled=False):
+ self.ipv6_enabled = ipv6_enabled
+ self._listen(host, port)
+
+ def _listen(self, host, port):
+ with socket.socket() as _socket:
+ _socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ _socket.bind((host, port))
+ _socket.listen()
+ while True:
+ conn, addr = _socket.accept()
+ threading.Thread(target=self._handler,
+ args=(conn,addr,), daemon=True).start()
+
+ def _handler(self, conn, addr):
+ recv_data = conn.recv(1024)
+ if not recv_data:
+ conn.close()
+
+ recv_data = str(recv_data, 'utf-8').strip()
+ log.info(f'{addr[0]} {recv_data}')
+
+ announcements = self._get_announcements(recv_data)
+ announcements = self._pretty(announcements)
+
+ if not announcements.strip():
+ announcements = 'no valid hostname or IP discovered'
+
+ conn.sendall(bytes(announcements, 'utf-8'))
+ conn.shutdown(socket.SHUT_RDWR)
+ conn.close()
+
+ def _get_announcements(self, recv):
+ db = DB()
+ hosts = set()
+ try:
+ ip = ipaddress.ip_address(recv)
+ hosts.add(ip)
+ except ValueError:
+ try:
+ hosts = self._resolve(recv)
+ except:
+ return []
+ finally:
+ announcements = []
+ for host in hosts:
+ if self._is_invalid(host):
+ return []
+
+ n = self._get_netblock(host)
+ if n:
+ announcements.extend(db.query(n))
+
+ return announcements
+
+ def _pretty(self, announcements):
+ out = []
+ for x in announcements:
+ out.append(' '.join(('AS'+str(x[0]), x[1], x[2], x[3])))
+
+ return '\n'.join(out) + '\n'
+
+ def _resolve(self, hostname):
+ info = socket.getaddrinfo(hostname, 80, proto=socket.IPPROTO_TCP)
+ hosts = set()
+ for i in info:
+ hosts.add(i[4][0])
+
+ return hosts
+
+ def _is_invalid(self, ip):
+ net = ipaddress.ip_network(ip)
+
+ return net.is_loopback or net.is_private or net.is_multicast
+
+ def _get_netblock(self, ip):
+ net = None
+ try:
+ net = ipaddress.ip_network(ip)
+ except:
+ return net
+
+ if net.version == 4:
+ net = net.supernet(new_prefix=24)
+ elif net.version == 6:
+ if self.ipv6_enabled:
+ net = net.supernet(new_prefix=64)
+ else:
+ return None
+
+ return net
+
+class DB:
+ def __init__(self):
+ self.repo_path = os.path.dirname(os.path.abspath(__file__))
+ self.db_path = os.path.join(self.repo_path, 'cache.db')
+ self.con = sqlite3.connect(self.db_path)
+
+ loc = os.path.join(self.repo_path, 'location-database')
+ self.db_txt = os.path.join(loc, 'database.txt')
+
+ def populate_db(self):
+ with self.con:
+ self.con.execute('PRAGMA foreign_keys=OFF')
+ self.con.execute('DROP TABLE IF EXISTS net')
+ self.con.execute('DROP TABLE IF EXISTS asn')
+ self.con.execute('''
+ CREATE TABLE IF NOT EXISTS asn (
+ aut_num INTEGER NOT NULL PRIMARY KEY,
+ name TEXT
+ )
+ ''')
+ self.con.execute('''
+ CREATE UNIQUE INDEX idx_aut_num ON asn(aut_num)
+ ''')
+ self.con.execute('''
+ CREATE TABLE IF NOT EXISTS net (
+ id integer NOT NULL PRIMARY KEY,
+ aut_num INTEGER,
+ net TEXT,
+ country TEXT,
+ FOREIGN KEY(aut_num) REFERENCES asn(aut_num)
+ )
+ ''')
+ self.con.execute('PRAGMA foreign_keys=ON')
+
+ self._get_entries()
+
+ def update(self):
+ if not self._submodule_pull():
+ return False
+ else:
+ return True
+
+ def _submodule_pull(self):
+ repo = git.Repo(self.repo_path)
+
+ updated = False
+ for module in repo.submodules:
+ module.module().git.checkout('master')
+
+ current = module.module().head.commit
+ log.info(f'current location-db commit: {current}')
+
+ module.module().remotes.origin.pull()
+ if current != module.module().head.commit:
+ updated = True
+
+ return updated
+
+ def _get_entries(self):
+ with open(self.db_txt, 'r') as f:
+ kv = dict()
+ while True:
+ try:
+ line = next(f)
+ except StopIteration:
+ break
+
+ if not line.strip() or line.strip().startswith('#'):
+ if kv:
+ self._add(kv)
+ kv = dict()
+
+ continue
+
+ (k, v) = (x.strip() for x in line.split(':', 1))
+ kv[k] = v
+
+
+ def _add(self, kv):
+ # ASN information block
+ if kv.get('aut-num') and kv['aut-num'].startswith('AS'):
+ self.con.execute('''
+ INSERT OR REPLACE INTO asn(aut_num, name) VALUES(?,?)
+ ''', (kv['aut-num'][2:], kv.get('name')))
+
+ if kv.get('net'):
+ self.con.execute('''
+ INSERT OR REPLACE INTO net(aut_num, net, country)
+ VALUES((SELECT aut_num FROM asn WHERE aut_num = ?),?,?)
+ ''', (kv.get('aut-num'), kv.get('net'), kv.get('country')))
+
+ def query(self, net):
+ announcements = []
+ while True:
+ rows = self.con.execute('''
+ SELECT net.aut_num, net.country, asn.name, net.net
+ FROM net
+ INNER JOIN asn on asn.aut_num = (
+ SELECT aut_num FROM net WHERE net = ?
+ )
+ WHERE net.net = ?
+ ''', (str(net), str(net))).fetchall()
+ if len(rows) != 0:
+ announcements.extend(rows)
+ break
+ if net.prefixlen > 0:
+ net = net.supernet()
+ else:
+ break
+
+ return announcements
+
+if __name__ == '__main__':
+ desc = 'asn: map hosts to their corresponding ASN via WHOIS'
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument('--host', dest='host', type=str, action='store',
+ help='IP address to listen on',
+ required=False)
+ parser.add_argument('--port', dest='port', type=int, action='store',
+ help='Port to listen on',
+ required=False)
+ parser.add_argument('--ipv6', dest='ipv6_enabled', action='store_true',
+ help='Support queries for IPv6 hosts',
+ required=False)
+ parser.add_argument('--update', dest='update', action='store_true',
+ help='Update dataset submodule and create/populate cache',
+ required=False)
+ parser.add_argument('--populate', dest='populate', action='store_true',
+ help='Create and populate cache from current dataset',
+ required=False)
+ args = parser.parse_args()
+
+ if args.host and args.port:
+ listen = Listener(args.host, args.port, args.ipv6_enabled)
+ elif args.update:
+ db = DB()
+ log.info('checking remote repository for new dataset...')
+ if db.update():
+ log.info('dataset updated, creating/populating cache...')
+ db.populate_db()
+ else:
+ log.info('no changes since last update')
+ elif args.populate:
+ db = DB()
+ log.info('creating/populating cache...')
+ db.populate_db()
+ else:
+ parser.print_help(sys.stderr)