代码拉取完成,页面将自动刷新
import numpy as np
from datetime import datetime
def ASVRG_nonsmooth(w, X, Y, compute_loss, compute_gradient, **kwargs):
print('--- ASVRG with l1 regularizer ---\n')
alpha = kwargs.get("alpha", 1e-3)
tau = kwargs.get("tau", 1e-10)
max_iterations = kwargs.get('max_iterations', 100)
m = kwargs.get('inner_loop_m', 1)
batch_size = kwargs.get("batch_size", 1)
tol_grad = kwargs.get("tol_grad", 1e-5)
eta = kwargs.get("eta", 1)
l1_weight = kwargs.get("l1_weight", 0.1)
L = 20/3 #
momentum_w = 1 - L*eta/(1-L*eta)
# momentum_w = 0.3
rho = 1
n, d = X.shape
pre_w = y = w
sum_w = w
# output data
loss_collector = [compute_loss(w, X, Y, alpha, tau, l1_weight)]
grad_norm_collector = [np.linalg.norm(compute_gradient(w, X, Y, alpha, tau))]
total_samples_collector = [0]
times_collector = [0]
start_time = datetime.now()
max_outerloop = int(max_iterations/m)
for s in range(max_outerloop):
gradient = compute_gradient(w, X, Y, alpha, tau)
outer_gradient = gradient
outer_w = w
w = (1-momentum_w)*outer_w + momentum_w*y
for t in range(m):
sample_size = batch_size * 2
if t == 0:
sample_size +=n
sum_w = np.zeros(w.shape)
idx_sample = np.random.randint(0, n, size=batch_size)
_X = [X[i] for i in idx_sample]
_X = np.array(_X)
_Y = [Y[i] for i in idx_sample]
_Y = np.array(_Y)
gradient = compute_gradient(w, _X, _Y, alpha, tau) - compute_gradient(outer_w, _X, _Y, alpha,
tau) + outer_gradient
y = np.sign(y-eta/momentum_w*gradient) * np.amax(
np.array([abs(y-eta/momentum_w*gradient) - l1_weight * eta/momentum_w, np.zeros(w.shape)]),
axis=0)
w = outer_w + momentum_w*(y-outer_w)
sum_w += w
# save data
loss_collector.append(compute_loss(w, X, Y, alpha, tau, l1_weight))
# true gradient
True_gradient = compute_gradient(w, X, Y, alpha, tau)
Next_w = np.sign(w - eta * True_gradient) * np.amax(
np.array([abs(w - eta * True_gradient) - l1_weight * eta, np.zeros(w.shape)]),
axis=0)
grad_norm_collector.append(np.linalg.norm((w - Next_w) / eta))
total_samples_collector.append(sample_size + total_samples_collector[-1])
times_collector.append((datetime.now() - start_time).total_seconds())
# print
print("Iteration:", t+s*m, "Loss:", loss_collector[-1], "Grad Norm:", grad_norm_collector[-1], "Batch Size:",
sample_size,
"Total Sample:", total_samples_collector[-1], "Time:", times_collector[-1])
if grad_norm_collector[-1] <= tol_grad:
break
w = sum_w/m
momentum_w = (np.sqrt(momentum_w**4+4*momentum_w**2)-momentum_w**2)/2
m = min(int(rho*m),m)
# save data
loss_collector.append(compute_loss(w, X, Y, alpha, tau, l1_weight))
# true gradient
True_gradient = compute_gradient(w, X, Y, alpha, tau)
Next_w = np.sign(w - eta * True_gradient) * np.amax(
np.array([abs(w - eta * True_gradient) - l1_weight * eta, np.zeros(w.shape)]),
axis=0)
grad_norm_collector.append(np.linalg.norm((w - Next_w) / eta))
total_samples_collector.append(0 + total_samples_collector[-1])
times_collector.append((datetime.now() - start_time).total_seconds())
# print
#print("Iteration:", s, "Loss:", loss_collector[-1], "Grad Norm:", grad_norm_collector[-1], "Batch Size:",
# sample_size,
# "Total Sample:", total_samples_collector[-1], "Time:", times_collector[-1])
if grad_norm_collector[-1] <= tol_grad:
break
return loss_collector, grad_norm_collector, total_samples_collector, times_collector
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。