-
Notifications
You must be signed in to change notification settings - Fork 61
Description
`import os
import sys
import re
import gzip
import json
import hashlib
import shutil
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from tqdm import tqdm
import tarfile
import urllib3
import argparse
import logging
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
VERSION = "v2.3.2"
日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
handlers=[
logging.FileHandler("docker_puller.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(name)
def parse_image_input(image_input):
"""安全解析镜像名称"""
try:
parts = image_input.split('/')
if len(parts) == 1:
repo = 'library'
img_tag = parts[0]
else:
repo = '/'.join(parts[:-1])
img_tag = parts[-1]
if ':' in img_tag:
img, tag = img_tag.split(':', 1)
else:
img, tag = img_tag, 'latest'
return repo, img, tag
except Exception as e:
logger.error(f"镜像格式解析失败: {image_input}")
raise ValueError(f"无效的镜像格式: {image_input}") from e
def create_session():
"""创建带重试机制的会话"""
session = requests.Session()
retry = Retry(
total=5,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=['GET', 'POST']
)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session
def get_auth_head(session, registry, repository):
"""通用认证处理"""
try:
base_url = f"https://{registry}/v2/"
resp = session.get(base_url, verify=False, timeout=30)
if resp.status_code == 401:
auth_header = resp.headers.get('Www-Authenticate', '')
matches = re.findall(r'(\w+)=["]([^"]+)["]', auth_header)
auth_params = {k: v for k, v in matches}
if 'realm' not in auth_params:
raise ValueError("认证头缺少必要的realm参数")
token_url = f"{auth_params['realm']}?service={auth_params.get('service', registry)}&scope=repository:{repository}:pull"
token_resp = session.get(token_url, verify=False, timeout=30)
token_resp.raise_for_status()
return {
'Authorization': f'Bearer {token_resp.json()['token']}',
'Accept': 'application/vnd.docker.distribution.manifest.v2+json'
}
return {}
except Exception as e:
logger.error(f"认证失败: {str(e)}")
raise
def fetch_manifest(session, registry, repository, tag, auth_head):
"""获取镜像清单"""
try:
url = f'https://{registry}/v2/{repository}/manifests/{tag}'
headers = auth_head.copy()
headers['Accept'] = 'application/vnd.docker.distribution.manifest.v2+json'
resp = session.get(url, headers=headers, verify=False, timeout=30)
resp.raise_for_status()
manifest = resp.json()
if manifest.get('schemaVersion') == 1:
layers = [layer['blobSum'] for layer in manifest['fsLayers']]
elif manifest.get('schemaVersion') == 2:
layers = [layer['digest'] for layer in manifest['layers']]
else:
raise ValueError(f"不支持的清单版本: {manifest.get('schemaVersion')}")
return {'manifest': manifest, 'layers': layers}
except requests.exceptions.RequestException as e:
logger.error(f"清单请求失败: {str(e)}")
raise
class DownloadWorker:
def init(self, session, registry, repo, auth_head):
self.session = session
self.registry = registry
self.repo = repo
self.auth_head = auth_head
def download_blob(self, blob_digest):
"""下载单个镜像层(修复进度条锁问题)"""
try:
layer_dir = os.path.join(f"{self.repo.replace('/', '_')}_layers",
blob_digest.split(':')[-1])
os.makedirs(layer_dir, exist_ok=True)
layer_file = os.path.join(layer_dir, "layer.tar")
if os.path.exists(layer_file):
logger.info(f"跳过已存在的层: {blob_digest[:12]}")
return layer_dir
url = f'https://{self.registry}/v2/{self.repo}/blobs/{blob_digest}'
resp = self.session.get(url, headers=self.auth_head,
stream=True, verify=False, timeout=60)
resp.raise_for_status()
total_size = int(resp.headers.get('content-length', 0))
# 修复进度条参数
with tqdm(total=total_size, unit='B', unit_scale=True,
desc=blob_digest[:12], leave=False) as pbar:
with open(layer_file, 'wb') as f:
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
# 校验完整性
self._verify_layer(blob_digest, layer_file)
return layer_dir
except Exception as e:
logger.error(f"层下载失败: {blob_digest[:12]} - {str(e)}")
if os.path.exists(layer_file):
os.remove(layer_file)
raise
def _verify_layer(self, digest, file_path):
"""校验层文件完整性"""
hash_algo, expected_hash = digest.split(':')
sha256 = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha256.update(chunk)
if sha256.hexdigest() != expected_hash:
raise ValueError(f"校验失败: {expected_hash[:12]} vs {sha256.hexdigest()[:12]}")
def main():
parser = argparse.ArgumentParser(description=f"Docker镜像下载工具 {VERSION}")
parser.add_argument("image", help="镜像名称(格式:仓库/镜像:标签 或 镜像:标签)")
parser.add_argument("-r", "--registry", default="registry.hub.docker.com",
help="Docker仓库地址(默认:registry.hub.docker.com)")
parser.add_argument("-t", "--threads", type=int, default=5, help="下载线程数(默认5)")
args = parser.parse_args()
try:
# 解析镜像参数
repo, img, tag = parse_image_input(args.image)
full_repo = f"{repo}/{img}" if repo != 'library' else img
logger.info(f"开始下载镜像: {full_repo}:{tag} 从仓库: {args.registry}")
# 初始化会话
session = create_session()
auth_head = get_auth_head(session, args.registry, full_repo)
manifest_data = fetch_manifest(session, args.registry, full_repo, tag, auth_head)
logger.info(f"检测到 {len(manifest_data['layers'])} 个层需要下载")
# 多线程下载
with ThreadPoolExecutor(max_workers=args.threads) as executor:
worker = DownloadWorker(session, args.registry, full_repo, auth_head)
futures = [executor.submit(worker.download_blob, blob)
for blob in manifest_data['layers']]
try:
for future in concurrent.futures.as_completed(futures):
future.result()
except KeyboardInterrupt:
logger.error("用户中断操作,正在清理...")
executor.shutdown(wait=False)
sys.exit(1)
logger.info("所有层下载完成,正在打包...")
# 打包逻辑...
except Exception as e:
logger.error(f"程序执行失败: {str(e)}", exc_info=True)
sys.exit(1)
if name == "main":
main()
`
再优化版
`import os
import sys
import re
import json
import hashlib
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from tqdm import tqdm
import urllib3
import argparse
import logging
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
VERSION = "v2.3.3"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
handlers=[
logging.FileHandler("docker_puller.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(name)
class ThreadSafeProgress:
"""线程安全的进度管理类"""
def init(self, total_layers):
self.progress_bars = {}
self.lock = threading.Lock()
self.total = total_layers
def create_bar(self, digest):
with self.lock:
bar = tqdm(total=100,
desc=digest[:12],
leave=False,
position=len(self.progress_bars))
self.progress_bars[digest] = bar
return bar
class DownloadWorker:
def init(self, session, registry, repo, auth_head, progress_mgr):
self.session = session
self.registry = registry
self.repo = repo
self.auth_head = auth_head
self.progress_mgr = progress_mgr
def download_blob(self, blob_digest):
try:
layer_dir = os.path.join(f"{self.repo.replace('/', '_')}_layers",
blob_digest.split(':')[-1])
os.makedirs(layer_dir, exist_ok=True)
layer_file = os.path.join(layer_dir, "layer.tar")
if os.path.exists(layer_file):
return layer_dir
url = f'https://{self.registry}/v2/{self.repo}/blobs/{blob_digest}'
resp = self.session.get(url, headers=self.auth_head,
stream=True, verify=False, timeout=60)
resp.raise_for_status()
total_size = int(resp.headers.get('content-length', 0))
progress_bar = self.progress_mgr.create_bar(blob_digest)
downloaded = 0
with open(layer_file, 'wb') as f:
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
progress_bar.n = min(100, int(downloaded / total_size * 100))
progress_bar.refresh()
self._verify_layer(blob_digest, layer_file)
return layer_dir
except Exception as e:
logger.error(f"下载失败 {blob_digest[:12]}: {str(e)}")
if os.path.exists(layer_file):
os.remove(layer_file)
raise
def _verify_layer(self, digest, file_path):
hash_algo, expected_hash = digest.split(':')
sha256 = hashlib.sha256()
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
sha256.update(chunk)
if sha256.hexdigest() != expected_hash:
raise ValueError(f"校验失败: {expected_hash[:12]}")
def main():
# ... [保持其他函数不变] ...
try:
# ... [认证和清单获取] ...
# 初始化进度管理器
progress_mgr = ThreadSafeProgress(len(manifest_data['layers']))
with ThreadPoolExecutor(max_workers=args.threads) as executor:
worker = DownloadWorker(session, args.registry, full_repo, auth_head, progress_mgr)
futures = [executor.submit(worker.download_blob, blob)
for blob in manifest_data['layers']]
try:
for future in concurrent.futures.as_completed(futures):
future.result()
except KeyboardInterrupt:
logger.error("操作已取消")
executor.shutdown(wait=False)
sys.exit(1)
logger.info("所有层下载完成")
except Exception as e:
logger.error(f"执行失败: {str(e)}", exc_info=True)
sys.exit(1)
if name == "main":
main()
`