代码拉取完成,页面将自动刷新
import numpy as np
from datetime import datetime
def RSAG_nonsmooth(w,X,Y,compute_loss,compute_gradient,**kwargs):
print('--- RSAG with L1 Regualrizer---\n')
alpha = kwargs.get("alpha",1e-3)
tau = kwargs.get("tau",1e-10)
max_iterations = kwargs.get('max_iterations',100)
m = kwargs.get('inner_loop_m',1)
batch_size = kwargs.get("batch_size",1)
tol_grad = kwargs.get("tol_grad",1e-5)
eta = kwargs.get("eta", 1)
momentum = kwargs.get("momentum", 2)
l1_weight = kwargs.get("l1_weight", 0.1)
n,d = X.shape
# output data
loss_collector = [compute_loss(w,X,Y,alpha,tau, l1_weight)]
grad_norm_collector = [np.linalg.norm(compute_gradient(w,X,Y,alpha,tau))]
total_samples_collector = [0]
times_collector = [0]
start_time = datetime.now()
x = y = w
_lambda = 1
_beta = eta
_alpha = 2
for k in range(max_iterations):
# we let z = w in the paper
_lambda = (1+_alpha)*_beta
_alpha = momentum/(k + momentum)
# _alpha = 0
w = (1-_alpha)*y + _alpha*x
# sample and compute gradient
sample_size = batch_size*2
idx_sample = np.random.randint(0,n,size=batch_size)
_X = [X[i] for i in idx_sample]
_X = np.array(_X)
_Y = [Y[i] for i in idx_sample]
_Y = np.array(_Y)
gradient = compute_gradient(w,_X,_Y,alpha,tau)
# algorithm
x = np.sign(x - _lambda*gradient) * np.amax(np.array([abs(x - _lambda*gradient) - l1_weight * _lambda, np.zeros(w.shape)]),
axis=0)
y = np.sign(w - _beta * gradient) * np.amax(np.array([abs(w - _beta * gradient) - l1_weight * _beta, np.zeros(w.shape)]),
axis=0)
# save data
loss_collector.append(compute_loss(w,X,Y,alpha,tau, l1_weight))
True_gradient = compute_gradient(w, X, Y, alpha, tau)
Next_w = np.sign(w - eta * True_gradient) * np.amax(
np.array([abs(w - eta * True_gradient) - l1_weight * eta, np.zeros(w.shape)]),
axis=0)
grad_norm_collector.append(np.linalg.norm((w - Next_w) / eta))
total_samples_collector.append(sample_size+total_samples_collector[-1])
times_collector.append((datetime.now() - start_time).total_seconds())
# print
print("Iteration:", k, "Loss:", loss_collector[-1], "Grad Norm:", grad_norm_collector[-1], "Batch Size:", sample_size,
"Total Sample:", total_samples_collector[-1], "Time:", times_collector[-1])
if grad_norm_collector[-1] <= tol_grad:
break
return loss_collector,grad_norm_collector,total_samples_collector,times_collector
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。