Skip to content

改成多线程下载(菜鸟小白通过AI反复迭代升级) #10

@duma520

Description

@duma520

008.pdf

009.pdf

`import os
import sys
import re
import gzip
import json
import hashlib
import shutil
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from tqdm import tqdm
import tarfile
import urllib3
import argparse
import logging
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor

禁用SSL警告

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
VERSION = "v2.3.2"

日志配置

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
handlers=[
logging.FileHandler("docker_puller.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(name)

def parse_image_input(image_input):
"""安全解析镜像名称"""
try:
parts = image_input.split('/')
if len(parts) == 1:
repo = 'library'
img_tag = parts[0]
else:
repo = '/'.join(parts[:-1])
img_tag = parts[-1]

    if ':' in img_tag:
        img, tag = img_tag.split(':', 1)
    else:
        img, tag = img_tag, 'latest'
    return repo, img, tag
except Exception as e:
    logger.error(f"镜像格式解析失败: {image_input}")
    raise ValueError(f"无效的镜像格式: {image_input}") from e

def create_session():
"""创建带重试机制的会话"""
session = requests.Session()
retry = Retry(
total=5,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=['GET', 'POST']
)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session

def get_auth_head(session, registry, repository):
"""通用认证处理"""
try:
base_url = f"https://{registry}/v2/"
resp = session.get(base_url, verify=False, timeout=30)

    if resp.status_code == 401:
        auth_header = resp.headers.get('Www-Authenticate', '')
        matches = re.findall(r'(\w+)=["]([^"]+)["]', auth_header)
        auth_params = {k: v for k, v in matches}

        if 'realm' not in auth_params:
            raise ValueError("认证头缺少必要的realm参数")

        token_url = f"{auth_params['realm']}?service={auth_params.get('service', registry)}&scope=repository:{repository}:pull"
        token_resp = session.get(token_url, verify=False, timeout=30)
        token_resp.raise_for_status()
        
        return {
            'Authorization': f'Bearer {token_resp.json()['token']}',
            'Accept': 'application/vnd.docker.distribution.manifest.v2+json'
        }
    return {}
except Exception as e:
    logger.error(f"认证失败: {str(e)}")
    raise

def fetch_manifest(session, registry, repository, tag, auth_head):
"""获取镜像清单"""
try:
url = f'https://{registry}/v2/{repository}/manifests/{tag}'
headers = auth_head.copy()
headers['Accept'] = 'application/vnd.docker.distribution.manifest.v2+json'

    resp = session.get(url, headers=headers, verify=False, timeout=30)
    resp.raise_for_status()
    
    manifest = resp.json()
    if manifest.get('schemaVersion') == 1:
        layers = [layer['blobSum'] for layer in manifest['fsLayers']]
    elif manifest.get('schemaVersion') == 2:
        layers = [layer['digest'] for layer in manifest['layers']]
    else:
        raise ValueError(f"不支持的清单版本: {manifest.get('schemaVersion')}")

    return {'manifest': manifest, 'layers': layers}
except requests.exceptions.RequestException as e:
    logger.error(f"清单请求失败: {str(e)}")
    raise

class DownloadWorker:
def init(self, session, registry, repo, auth_head):
self.session = session
self.registry = registry
self.repo = repo
self.auth_head = auth_head

def download_blob(self, blob_digest):
    """下载单个镜像层(修复进度条锁问题)"""
    try:
        layer_dir = os.path.join(f"{self.repo.replace('/', '_')}_layers", 
                               blob_digest.split(':')[-1])
        os.makedirs(layer_dir, exist_ok=True)
        layer_file = os.path.join(layer_dir, "layer.tar")

        if os.path.exists(layer_file):
            logger.info(f"跳过已存在的层: {blob_digest[:12]}")
            return layer_dir

        url = f'https://{self.registry}/v2/{self.repo}/blobs/{blob_digest}'
        resp = self.session.get(url, headers=self.auth_head, 
                              stream=True, verify=False, timeout=60)
        resp.raise_for_status()

        total_size = int(resp.headers.get('content-length', 0))
        
        # 修复进度条参数
        with tqdm(total=total_size, unit='B', unit_scale=True,
                 desc=blob_digest[:12], leave=False) as pbar:
            with open(layer_file, 'wb') as f:
                for chunk in resp.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))

        # 校验完整性
        self._verify_layer(blob_digest, layer_file)
        return layer_dir

    except Exception as e:
        logger.error(f"层下载失败: {blob_digest[:12]} - {str(e)}")
        if os.path.exists(layer_file):
            os.remove(layer_file)
        raise

def _verify_layer(self, digest, file_path):
    """校验层文件完整性"""
    hash_algo, expected_hash = digest.split(':')
    sha256 = hashlib.sha256()
    with open(file_path, 'rb') as f:
        for chunk in iter(lambda: f.read(8192), b''):
            sha256.update(chunk)
    if sha256.hexdigest() != expected_hash:
        raise ValueError(f"校验失败: {expected_hash[:12]} vs {sha256.hexdigest()[:12]}")

def main():
parser = argparse.ArgumentParser(description=f"Docker镜像下载工具 {VERSION}")
parser.add_argument("image", help="镜像名称(格式:仓库/镜像:标签 或 镜像:标签)")
parser.add_argument("-r", "--registry", default="registry.hub.docker.com",
help="Docker仓库地址(默认:registry.hub.docker.com)")
parser.add_argument("-t", "--threads", type=int, default=5, help="下载线程数(默认5)")
args = parser.parse_args()

try:
    # 解析镜像参数
    repo, img, tag = parse_image_input(args.image)
    full_repo = f"{repo}/{img}" if repo != 'library' else img
    logger.info(f"开始下载镜像: {full_repo}:{tag} 从仓库: {args.registry}")

    # 初始化会话
    session = create_session()
    auth_head = get_auth_head(session, args.registry, full_repo)
    manifest_data = fetch_manifest(session, args.registry, full_repo, tag, auth_head)
    logger.info(f"检测到 {len(manifest_data['layers'])} 个层需要下载")

    # 多线程下载
    with ThreadPoolExecutor(max_workers=args.threads) as executor:
        worker = DownloadWorker(session, args.registry, full_repo, auth_head)
        futures = [executor.submit(worker.download_blob, blob) 
                  for blob in manifest_data['layers']]

        try:
            for future in concurrent.futures.as_completed(futures):
                future.result()
        except KeyboardInterrupt:
            logger.error("用户中断操作,正在清理...")
            executor.shutdown(wait=False)
            sys.exit(1)

    logger.info("所有层下载完成,正在打包...")
    # 打包逻辑...

except Exception as e:
    logger.error(f"程序执行失败: {str(e)}", exc_info=True)
    sys.exit(1)

if name == "main":
main()
`

再优化版

`import os
import sys
import re
import json
import hashlib
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from tqdm import tqdm
import urllib3
import argparse
import logging
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
VERSION = "v2.3.3"

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
handlers=[
logging.FileHandler("docker_puller.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(name)

class ThreadSafeProgress:
"""线程安全的进度管理类"""
def init(self, total_layers):
self.progress_bars = {}
self.lock = threading.Lock()
self.total = total_layers

def create_bar(self, digest):
    with self.lock:
        bar = tqdm(total=100, 
                  desc=digest[:12], 
                  leave=False,
                  position=len(self.progress_bars))
        self.progress_bars[digest] = bar
        return bar

class DownloadWorker:
def init(self, session, registry, repo, auth_head, progress_mgr):
self.session = session
self.registry = registry
self.repo = repo
self.auth_head = auth_head
self.progress_mgr = progress_mgr

def download_blob(self, blob_digest):
    try:
        layer_dir = os.path.join(f"{self.repo.replace('/', '_')}_layers", 
                               blob_digest.split(':')[-1])
        os.makedirs(layer_dir, exist_ok=True)
        layer_file = os.path.join(layer_dir, "layer.tar")

        if os.path.exists(layer_file):
            return layer_dir

        url = f'https://{self.registry}/v2/{self.repo}/blobs/{blob_digest}'
        resp = self.session.get(url, headers=self.auth_head, 
                              stream=True, verify=False, timeout=60)
        resp.raise_for_status()
        
        total_size = int(resp.headers.get('content-length', 0))
        progress_bar = self.progress_mgr.create_bar(blob_digest)
        
        downloaded = 0
        with open(layer_file, 'wb') as f:
            for chunk in resp.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    progress_bar.n = min(100, int(downloaded / total_size * 100))
                    progress_bar.refresh()

        self._verify_layer(blob_digest, layer_file)
        return layer_dir

    except Exception as e:
        logger.error(f"下载失败 {blob_digest[:12]}: {str(e)}")
        if os.path.exists(layer_file):
            os.remove(layer_file)
        raise

def _verify_layer(self, digest, file_path):
    hash_algo, expected_hash = digest.split(':')
    sha256 = hashlib.sha256()
    with open(file_path, 'rb') as f:
        while chunk := f.read(8192):
            sha256.update(chunk)
    if sha256.hexdigest() != expected_hash:
        raise ValueError(f"校验失败: {expected_hash[:12]}")

def main():
# ... [保持其他函数不变] ...

try:
    # ... [认证和清单获取] ...
    
    # 初始化进度管理器
    progress_mgr = ThreadSafeProgress(len(manifest_data['layers']))
    
    with ThreadPoolExecutor(max_workers=args.threads) as executor:
        worker = DownloadWorker(session, args.registry, full_repo, auth_head, progress_mgr)
        futures = [executor.submit(worker.download_blob, blob) 
                  for blob in manifest_data['layers']]
        
        try:
            for future in concurrent.futures.as_completed(futures):
                future.result()
        except KeyboardInterrupt:
            logger.error("操作已取消")
            executor.shutdown(wait=False)
            sys.exit(1)

    logger.info("所有层下载完成")
    
except Exception as e:
    logger.error(f"执行失败: {str(e)}", exc_info=True)
    sys.exit(1)

if name == "main":
main()
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions