187 lines
6.9 KiB
Python
187 lines
6.9 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import uuid
|
|
import requests
|
|
import logging
|
|
from datetime import datetime
|
|
|
|
# Configuration
|
|
CONFIG_FILE = '/opt/dnsblock/config.json'
|
|
import logging
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
# Configuration
|
|
CONFIG_FILE = '/opt/dnsblock/config.json'
|
|
LOG_DIR = '/opt/dnsblock/logs'
|
|
LOG_FILE = os.path.join(LOG_DIR, 'agent.log')
|
|
RPZ_FILE_DEFAULT = '/opt/dnsblock/rpz.dnsblock.zone'
|
|
|
|
# Ensure Log Directory Exists
|
|
if not os.path.exists(LOG_DIR):
|
|
try:
|
|
os.makedirs(LOG_DIR)
|
|
except Exception as e:
|
|
print(f"Failed to create log directory: {e}")
|
|
sys.exit(1)
|
|
|
|
# Setup Logging with Rotation
|
|
handler = RotatingFileHandler(LOG_FILE, maxBytes=5*1024*1024, backupCount=5)
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
logger.addHandler(handler)
|
|
|
|
# Also log to stdout for systemd journal
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
console_handler.setFormatter(formatter)
|
|
logger.addHandler(console_handler)
|
|
|
|
def get_machine_id():
|
|
"""Retrieves a unique machine ID."""
|
|
try:
|
|
# Try reading /etc/machine-id (Linux standard)
|
|
if os.path.exists('/etc/machine-id'):
|
|
with open('/etc/machine-id', 'r') as f:
|
|
return f.read().strip()
|
|
# Fallback to hardware UUID
|
|
return str(uuid.getnode())
|
|
except Exception as e:
|
|
logging.error(f"Error getting machine ID: {e}")
|
|
sys.exit(1)
|
|
|
|
def load_config():
|
|
"""Loads configuration from JSON file."""
|
|
if not os.path.exists(CONFIG_FILE):
|
|
logging.error(f"Config file not found at {CONFIG_FILE}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
with open(CONFIG_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logging.error(f"Error loading config: {e}")
|
|
sys.exit(1)
|
|
|
|
def generate_rpz(domains, output_file):
|
|
"""Generates the RPZ zone file."""
|
|
try:
|
|
# Ensure directory exists
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
|
|
with open(output_file, 'w') as f:
|
|
f.write("$TTL 30\n")
|
|
f.write("@ IN SOA rpz.dnsblock.zone. root.rpz.dnsblock.zone. (\n")
|
|
f.write(f" {int(time.time())} ; Serial\n")
|
|
f.write(" 30 ; Refresh\n")
|
|
f.write(" 15 ; Retry\n")
|
|
f.write(" 604800 ; Expire\n")
|
|
f.write(" 30 ) ; Negative Cache TTL\n")
|
|
f.write(";\n")
|
|
f.write("@ IN NS localhost.\n")
|
|
f.write("@ IN A 127.0.0.1\n")
|
|
f.write(";\n")
|
|
f.write("; Blocked Domains\n")
|
|
|
|
for domain in domains:
|
|
# CNAME to . (NXDOMAIN) or specific sinkhole
|
|
f.write(f"{domain} CNAME .\n")
|
|
f.write(f"*.{domain} CNAME .\n")
|
|
|
|
logging.info(f"RPZ file generated at {output_file} with {len(domains)} domains.")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Error generating RPZ: {e}")
|
|
return False
|
|
|
|
def main():
|
|
try:
|
|
logging.info("DNSBlock Agent Starting...")
|
|
|
|
config = load_config()
|
|
serial_key = config.get('serial_key')
|
|
api_url = config.get('api_url')
|
|
rpz_file = config.get('rpz_file', RPZ_FILE_DEFAULT)
|
|
reload_command = config.get('reload_command')
|
|
|
|
if not serial_key or not api_url:
|
|
logging.error("Missing serial_key or api_url in config.")
|
|
sys.exit(1)
|
|
|
|
machine_id = get_machine_id()
|
|
logging.info(f"Machine ID: {machine_id}")
|
|
|
|
last_checksum = None
|
|
was_error = False # Track if we had an error in the previous cycle
|
|
|
|
while True:
|
|
try:
|
|
logging.debug("Starting sync cycle...")
|
|
headers = {
|
|
'Authorization': f'Bearer {serial_key}',
|
|
'X-Machine-ID': machine_id
|
|
}
|
|
|
|
response = requests.get(f"{api_url}/api/v1/domains", headers=headers, timeout=10)
|
|
|
|
if response.status_code == 200:
|
|
# Log reconnection if we had an error before
|
|
if was_error:
|
|
logging.info("Connection restored. Sync resumed successfully.")
|
|
was_error = False
|
|
|
|
try:
|
|
data = response.json()
|
|
current_checksum = data.get('checksum')
|
|
|
|
if current_checksum != last_checksum:
|
|
logging.info("Change detected in domain list. Updating RPZ...")
|
|
domains = data.get('domains', [])
|
|
if generate_rpz(domains, rpz_file):
|
|
last_checksum = current_checksum
|
|
|
|
if reload_command:
|
|
logging.info(f"Executing reload command: {reload_command}")
|
|
exit_code = os.system(reload_command)
|
|
if exit_code == 0:
|
|
logging.info("Service reloaded successfully.")
|
|
else:
|
|
logging.error(f"Reload command failed with exit code {exit_code}")
|
|
else:
|
|
logging.debug("No changes in domain list.")
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to decode JSON response. Content preview: {response.text[:200]}")
|
|
logging.debug(f"Full response: {response.text}")
|
|
|
|
elif response.status_code == 401:
|
|
logging.error("Unauthorized: Invalid Serial Key.")
|
|
was_error = True
|
|
elif response.status_code == 403:
|
|
try:
|
|
error_data = response.json()
|
|
error_msg = error_data.get('error', 'Access denied')
|
|
except:
|
|
error_msg = response.text[:100]
|
|
logging.warning(f"Access denied (403): {error_msg}. Will retry in 60 seconds...")
|
|
was_error = True
|
|
else:
|
|
logging.warning(f"API Error: {response.status_code} - {response.text}")
|
|
was_error = True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Connection Error: {e}")
|
|
was_error = True
|
|
|
|
# Wait for 60 seconds before next check
|
|
time.sleep(60)
|
|
|
|
except Exception as e:
|
|
logging.critical(f"Fatal Error: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|