Commit 1cf3dca3 authored by Pietro Morerio's avatar Pietro Morerio
Browse files

train classes __init__

parent 71aaa719
#My ignore
*.UBUNTU-PC
*.pyc
MNIST_data/*
.ipynb_checkpoints/*
model*/
train.log*/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
......@@ -13,7 +13,8 @@ FLAGS = flags.FLAGS
class MyBatchGenerator():
def __init__(self, images, labels):
def __init__(self, images, labels, num_classes=10):
self.num_classes=num_classes
np.random.seed(0)
random.seed(0)
self.labels = labels
......@@ -33,12 +34,12 @@ class MyBatchGenerator():
left = []
right = []
sim = []
num_classes = 5
# genuine ~1/4+1/4 of the batch
for i in range(num_classes):
for i in range(self.num_classes):
l = choice(self.num_idx[i], batch_size, replace=False).tolist()
for j in range(int(batch_size/4./num_classes)):
for j in range(int(batch_size/4./self.num_classes)):
left.append(self.to_img(l.pop()))
right.append(self.to_img(l.pop()))
sim.append([1])
......@@ -48,9 +49,9 @@ class MyBatchGenerator():
sim.append([1])
#impostor ~1/2 of the batch
for i in range(num_classes):
for i in range(self.num_classes):
l = choice(self.num_idx[i], batch_size, replace=False).tolist()
for j in range(int(batch_size/2./num_classes)):
for j in range(int(batch_size/2./self.num_classes)):
left.append(self.to_img(l.pop()))
right.append(add_defect(self.to_img(l.pop())))
sim.append([0])
......
......@@ -7,12 +7,12 @@ from dataset import BatchGenerator, MyBatchGenerator, get_mnist
from model import *
flags.DEFINE_integer('batch_size', 128, 'Batch size.')
flags.DEFINE_integer('train_iter', 25000, 'Total training iter')
flags.DEFINE_integer('train_iter', 50000, 'Total training iter')
flags.DEFINE_integer('step', 500, 'Save after ... iteration')
mnist = get_mnist()
#~ gen = BatchGenerator(mnist.train.images, mnist.train.labels)
gen = MyBatchGenerator(mnist.train.images, mnist.train.labels)
gen = MyBatchGenerator(mnist.train.images, mnist.train.labels, num_classes=5)
test_im = np.array([im.reshape((28,28,1)) for im in mnist.test.images])
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', '#ff00ff', '#990000', '#999900', '#009900', '#009999']
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment