298 lines
11 KiB
Python
298 lines
11 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import uuid
|
|
import socket
|
|
import pwd
|
|
import grp
|
|
import requests
|
|
import logging
|
|
from datetime import datetime
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
# Configuration
|
|
CONFIG_FILE = '/opt/dnsblock/config.json'
|
|
STATE_FILE = '/opt/dnsblock/rpz_state.json'
|
|
LOG_DIR = '/opt/dnsblock/logs'
|
|
LOG_FILE = os.path.join(LOG_DIR, 'agent.log')
|
|
RPZ_FILE_DEFAULT = '/opt/dnsblock/rpz.dnsblock.zone'
|
|
|
|
# Ensure Log Directory Exists
|
|
if not os.path.exists(LOG_DIR):
|
|
try:
|
|
os.makedirs(LOG_DIR)
|
|
except Exception as e:
|
|
print(f"Failed to create log directory: {e}")
|
|
sys.exit(1)
|
|
|
|
# Setup Logging with Rotation
|
|
handler = RotatingFileHandler(LOG_FILE, maxBytes=5*1024*1024, backupCount=5)
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
logger.addHandler(handler)
|
|
|
|
# Also log to stdout for systemd journal
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
console_handler.setFormatter(formatter)
|
|
logger.addHandler(console_handler)
|
|
|
|
def get_machine_id():
|
|
"""Retrieves a unique machine ID."""
|
|
try:
|
|
# Try reading /etc/machine-id (Linux standard)
|
|
if os.path.exists('/etc/machine-id'):
|
|
with open('/etc/machine-id', 'r') as f:
|
|
return f.read().strip()
|
|
# Fallback to hardware UUID
|
|
return str(uuid.getnode())
|
|
except Exception as e:
|
|
logging.error(f"Error getting machine ID: {e}")
|
|
sys.exit(1)
|
|
|
|
def load_config():
|
|
"""Loads configuration from JSON file."""
|
|
if not os.path.exists(CONFIG_FILE):
|
|
logging.error(f"Config file not found at {CONFIG_FILE}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
with open(CONFIG_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logging.error(f"Error loading config: {e}")
|
|
sys.exit(1)
|
|
|
|
def get_system_domain():
|
|
"""Detects the system domain from FQDN or hostname."""
|
|
try:
|
|
fqdn = socket.getfqdn()
|
|
if fqdn and '.' in fqdn:
|
|
# Extract domain portion (e.g., server.gtecnet.com.br -> gtecnet.com.br)
|
|
parts = fqdn.split('.')
|
|
if len(parts) >= 2:
|
|
domain = '.'.join(parts[1:])
|
|
logging.info(f"Detected system domain: {domain}")
|
|
return domain
|
|
|
|
# Try hostname as fallback
|
|
hostname = socket.gethostname()
|
|
if hostname and '.' in hostname:
|
|
parts = hostname.split('.')
|
|
if len(parts) >= 2:
|
|
domain = '.'.join(parts[1:])
|
|
logging.info(f"Detected system domain from hostname: {domain}")
|
|
return domain
|
|
|
|
logging.info("Could not detect system domain, using localhost")
|
|
return "localhost"
|
|
except Exception as e:
|
|
logging.warning(f"Error detecting system domain: {e}, using localhost")
|
|
return "localhost"
|
|
|
|
def load_rpz_state():
|
|
"""Loads RPZ state from JSON file."""
|
|
if not os.path.exists(STATE_FILE):
|
|
return {'last_serial': None, 'last_date': None}
|
|
|
|
try:
|
|
with open(STATE_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logging.warning(f"Error loading RPZ state: {e}")
|
|
return {'last_serial': None, 'last_date': None}
|
|
|
|
def save_rpz_state(state):
|
|
"""Saves RPZ state to JSON file."""
|
|
try:
|
|
with open(STATE_FILE, 'w') as f:
|
|
json.dump(state, f)
|
|
except Exception as e:
|
|
logging.error(f"Error saving RPZ state: {e}")
|
|
|
|
def generate_serial():
|
|
"""Generates serial in YYYYMMDDNN format."""
|
|
today = datetime.now().strftime('%Y%m%d')
|
|
state = load_rpz_state()
|
|
|
|
last_serial = state.get('last_serial')
|
|
last_date = state.get('last_date')
|
|
|
|
if last_date == today and last_serial:
|
|
# Same day, increment sequence
|
|
sequence = int(str(last_serial)[-2:]) + 1
|
|
if sequence > 99:
|
|
sequence = 99 # Cap at 99
|
|
else:
|
|
# New day, reset sequence
|
|
sequence = 0
|
|
|
|
new_serial = int(f"{today}{sequence:02d}")
|
|
|
|
# Save state
|
|
save_rpz_state({'last_serial': new_serial, 'last_date': today})
|
|
|
|
return new_serial
|
|
|
|
def apply_file_ownership(path, file_owner):
|
|
"""Apply user:group ownership to a file or directory."""
|
|
if not file_owner:
|
|
return True
|
|
|
|
try:
|
|
parts = file_owner.split(':')
|
|
user = parts[0]
|
|
group = parts[1] if len(parts) > 1 else user
|
|
|
|
uid = pwd.getpwnam(user).pw_uid
|
|
gid = grp.getgrnam(group).gr_gid
|
|
|
|
os.chown(path, uid, gid)
|
|
logging.debug(f"Applied ownership {file_owner} to {path}")
|
|
return True
|
|
except KeyError as e:
|
|
logging.warning(f"User or group not found for file_owner '{file_owner}': {e}")
|
|
return False
|
|
except PermissionError as e:
|
|
logging.warning(f"Permission denied setting ownership on {path}: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logging.error(f"Error applying ownership to {path}: {e}")
|
|
return False
|
|
|
|
def generate_rpz(domains, output_file, file_owner=None):
|
|
"""Generates the RPZ zone file with proper SOA header."""
|
|
try:
|
|
# Ensure directory exists
|
|
output_dir = os.path.dirname(output_file)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
# Apply ownership to directory
|
|
if file_owner:
|
|
apply_file_ownership(output_dir, file_owner)
|
|
|
|
# Get system domain and generate serial
|
|
system_domain = get_system_domain()
|
|
serial = generate_serial()
|
|
|
|
with open(output_file, 'w') as f:
|
|
f.write("$TTL 1H\n")
|
|
f.write(f"@ IN SOA localhost. {system_domain}. (\n")
|
|
f.write(f" {serial} ; Serial\n")
|
|
f.write(" 1h ; Refresh\n")
|
|
f.write(" 15m ; Retry\n")
|
|
f.write(" 30d ; Expire\n")
|
|
f.write(" 2h ; Negative Cache TTL\n")
|
|
f.write(" )\n")
|
|
f.write(f" NS {system_domain}.\n")
|
|
f.write("\n")
|
|
f.write("; RPZ block hosts\n")
|
|
f.write("\n")
|
|
|
|
for domain in domains:
|
|
# CNAME to . (NXDOMAIN) or specific sinkhole
|
|
f.write(f"{domain} CNAME .\n")
|
|
f.write(f"*.{domain} CNAME .\n")
|
|
|
|
# Apply ownership to file
|
|
if file_owner:
|
|
apply_file_ownership(output_file, file_owner)
|
|
|
|
logging.info(f"RPZ file generated at {output_file} with {len(domains)} domains (serial: {serial}).")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Error generating RPZ: {e}")
|
|
return False
|
|
|
|
def main():
|
|
try:
|
|
logging.info("DNSBlock Agent Starting...")
|
|
|
|
config = load_config()
|
|
serial_key = config.get('serial_key')
|
|
api_url = config.get('api_url')
|
|
rpz_file = config.get('rpz_file', RPZ_FILE_DEFAULT)
|
|
reload_command = config.get('reload_command')
|
|
file_owner = config.get('file_owner') # Ex: 'unbound:unbound'
|
|
|
|
if not serial_key or not api_url:
|
|
logging.error("Missing serial_key or api_url in config.")
|
|
sys.exit(1)
|
|
|
|
machine_id = get_machine_id()
|
|
logging.info(f"Machine ID: {machine_id}")
|
|
|
|
last_checksum = None
|
|
was_error = False # Track if we had an error in the previous cycle
|
|
|
|
while True:
|
|
try:
|
|
logging.debug("Starting sync cycle...")
|
|
headers = {
|
|
'Authorization': f'Bearer {serial_key}',
|
|
'X-Machine-ID': machine_id
|
|
}
|
|
|
|
response = requests.get(f"{api_url}/api/v1/domains", headers=headers, timeout=10)
|
|
|
|
if response.status_code == 200:
|
|
# Log reconnection if we had an error before
|
|
if was_error:
|
|
logging.info("Connection restored. Sync resumed successfully.")
|
|
was_error = False
|
|
|
|
try:
|
|
data = response.json()
|
|
current_checksum = data.get('checksum')
|
|
|
|
if current_checksum != last_checksum:
|
|
logging.info("Change detected in domain list. Updating RPZ...")
|
|
domains = data.get('domains', [])
|
|
if generate_rpz(domains, rpz_file, file_owner):
|
|
last_checksum = current_checksum
|
|
|
|
if reload_command:
|
|
logging.info(f"Executing reload command: {reload_command}")
|
|
exit_code = os.system(reload_command)
|
|
if exit_code == 0:
|
|
logging.info("Service reloaded successfully.")
|
|
else:
|
|
logging.error(f"Reload command failed with exit code {exit_code}")
|
|
else:
|
|
logging.debug("No changes in domain list.")
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to decode JSON response. Content preview: {response.text[:200]}")
|
|
logging.debug(f"Full response: {response.text}")
|
|
|
|
elif response.status_code == 401:
|
|
logging.error("Unauthorized: Invalid Serial Key.")
|
|
was_error = True
|
|
elif response.status_code == 403:
|
|
try:
|
|
error_data = response.json()
|
|
error_msg = error_data.get('error', 'Access denied')
|
|
except:
|
|
error_msg = response.text[:100]
|
|
logging.warning(f"Access denied (403): {error_msg}. Will retry in 60 seconds...")
|
|
was_error = True
|
|
else:
|
|
logging.warning(f"API Error: {response.status_code} - {response.text}")
|
|
was_error = True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Connection Error: {e}")
|
|
was_error = True
|
|
|
|
# Wait for 60 seconds before next check
|
|
time.sleep(60)
|
|
|
|
except Exception as e:
|
|
logging.critical(f"Fatal Error: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|